feat: Implement initial schema and add various API, service, and management command enhancements across the application.

This commit is contained in:
pacnpal
2026-01-01 15:13:01 -05:00
parent c95f99ca10
commit b243b17af7
413 changed files with 11164 additions and 17433 deletions

37
backend/.flake8 Normal file
View File

@@ -0,0 +1,37 @@
[flake8]
# Match Black and Ruff line length
max-line-length = 120
# Ignore rules that conflict with Black formatting or are handled by other tools
ignore =
# E203: whitespace before ':' - Black intentionally does this
E203,
# E501: line too long - handled by Black/Ruff
E501,
# W503: line break before binary operator - conflicts with Black
W503,
# E226: missing whitespace around arithmetic operator - Black style
E226,
# W391: blank line at end of file - not critical
W391,
# C901: function is too complex - these are intentional for complex business logic
C901,
# F401: imported but unused - star imports for choice registration are intentional
F401
# Exclude common directories
exclude =
.git,
__pycache__,
migrations,
.venv,
venv,
build,
dist,
*.egg-info,
node_modules,
htmlcov,
.pytest_cache
# Complexity threshold - set high since we have intentional complex functions
max-complexity = 50

View File

@@ -1,2 +1,2 @@
# Import choices to trigger registration # Import choices to trigger registration
from .choices import * from .choices import * # noqa: F403

View File

@@ -77,8 +77,6 @@ class UserProfileInline(admin.StackedInline):
) )
@admin.register(User) @admin.register(User)
class CustomUserAdmin(QueryOptimizationMixin, ExportActionMixin, UserAdmin): class CustomUserAdmin(QueryOptimizationMixin, ExportActionMixin, UserAdmin):
""" """
@@ -332,8 +330,9 @@ class CustomUserAdmin(QueryOptimizationMixin, ExportActionMixin, UserAdmin):
try: try:
profile = user.profile profile = user.profile
# Credits would be recalculated from ride history here # Credits would be recalculated from ride history here
profile.save(update_fields=["coaster_credits", "dark_ride_credits", profile.save(
"flat_ride_credits", "water_ride_credits"]) update_fields=["coaster_credits", "dark_ride_credits", "flat_ride_credits", "water_ride_credits"]
)
count += 1 count += 1
except UserProfile.DoesNotExist: except UserProfile.DoesNotExist:
pass pass
@@ -442,12 +441,14 @@ class UserProfileAdmin(QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin
@admin.display(description="Completeness") @admin.display(description="Completeness")
def profile_completeness(self, obj): def profile_completeness(self, obj):
"""Display profile completeness indicator.""" """Display profile completeness indicator."""
fields_filled = sum([ fields_filled = sum(
bool(obj.display_name), [
bool(obj.avatar), bool(obj.display_name),
bool(obj.bio), bool(obj.avatar),
bool(obj.twitter or obj.instagram or obj.youtube or obj.discord), bool(obj.bio),
]) bool(obj.twitter or obj.instagram or obj.youtube or obj.discord),
]
)
percentage = (fields_filled / 4) * 100 percentage = (fields_filled / 4) * 100
color = "green" if percentage >= 75 else "orange" if percentage >= 50 else "red" color = "green" if percentage >= 75 else "orange" if percentage >= 50 else "red"
return format_html( return format_html(
@@ -529,12 +530,8 @@ class EmailVerificationAdmin(QueryOptimizationMixin, BaseModelAdmin):
def expiration_status(self, obj): def expiration_status(self, obj):
"""Display expiration status with color coding.""" """Display expiration status with color coding."""
if timezone.now() - obj.last_sent > timedelta(days=1): if timezone.now() - obj.last_sent > timedelta(days=1):
return format_html( return format_html('<span style="color: red; font-weight: bold;">Expired</span>')
'<span style="color: red; font-weight: bold;">Expired</span>' return format_html('<span style="color: green; font-weight: bold;">Valid</span>')
)
return format_html(
'<span style="color: green; font-weight: bold;">Valid</span>'
)
@admin.display(description="Can Resend", boolean=True) @admin.display(description="Can Resend", boolean=True)
def can_resend(self, obj): def can_resend(self, obj):
@@ -665,6 +662,3 @@ class PasswordResetAdmin(ReadOnlyAdminMixin, BaseModelAdmin):
"Cleanup old tokens", "Cleanup old tokens",
) )
return actions return actions

View File

@@ -26,7 +26,7 @@ user_roles = ChoiceGroup(
"css_class": "text-blue-600 bg-blue-50", "css_class": "text-blue-600 bg-blue-50",
"permissions": ["create_content", "create_reviews", "create_lists"], "permissions": ["create_content", "create_reviews", "create_lists"],
"sort_order": 1, "sort_order": 1,
} },
), ),
RichChoice( RichChoice(
value="MODERATOR", value="MODERATOR",
@@ -38,7 +38,7 @@ user_roles = ChoiceGroup(
"css_class": "text-green-600 bg-green-50", "css_class": "text-green-600 bg-green-50",
"permissions": ["moderate_content", "review_submissions", "manage_reports"], "permissions": ["moderate_content", "review_submissions", "manage_reports"],
"sort_order": 2, "sort_order": 2,
} },
), ),
RichChoice( RichChoice(
value="ADMIN", value="ADMIN",
@@ -50,7 +50,7 @@ user_roles = ChoiceGroup(
"css_class": "text-purple-600 bg-purple-50", "css_class": "text-purple-600 bg-purple-50",
"permissions": ["manage_users", "site_configuration", "advanced_moderation"], "permissions": ["manage_users", "site_configuration", "advanced_moderation"],
"sort_order": 3, "sort_order": 3,
} },
), ),
RichChoice( RichChoice(
value="SUPERUSER", value="SUPERUSER",
@@ -62,9 +62,9 @@ user_roles = ChoiceGroup(
"css_class": "text-red-600 bg-red-50", "css_class": "text-red-600 bg-red-50",
"permissions": ["full_access", "system_administration", "database_access"], "permissions": ["full_access", "system_administration", "database_access"],
"sort_order": 4, "sort_order": 4,
} },
), ),
] ],
) )
@@ -83,13 +83,9 @@ theme_preferences = ChoiceGroup(
"color": "yellow", "color": "yellow",
"icon": "sun", "icon": "sun",
"css_class": "text-yellow-600 bg-yellow-50", "css_class": "text-yellow-600 bg-yellow-50",
"preview_colors": { "preview_colors": {"background": "#ffffff", "text": "#1f2937", "accent": "#3b82f6"},
"background": "#ffffff",
"text": "#1f2937",
"accent": "#3b82f6"
},
"sort_order": 1, "sort_order": 1,
} },
), ),
RichChoice( RichChoice(
value="dark", value="dark",
@@ -99,15 +95,11 @@ theme_preferences = ChoiceGroup(
"color": "gray", "color": "gray",
"icon": "moon", "icon": "moon",
"css_class": "text-gray-600 bg-gray-50", "css_class": "text-gray-600 bg-gray-50",
"preview_colors": { "preview_colors": {"background": "#1f2937", "text": "#f9fafb", "accent": "#60a5fa"},
"background": "#1f2937",
"text": "#f9fafb",
"accent": "#60a5fa"
},
"sort_order": 2, "sort_order": 2,
} },
), ),
] ],
) )
@@ -133,7 +125,7 @@ unit_systems = ChoiceGroup(
"large_distance": "km", "large_distance": "km",
}, },
"sort_order": 1, "sort_order": 1,
} },
), ),
RichChoice( RichChoice(
value="imperial", value="imperial",
@@ -150,9 +142,9 @@ unit_systems = ChoiceGroup(
"large_distance": "mi", "large_distance": "mi",
}, },
"sort_order": 2, "sort_order": 2,
} },
), ),
] ],
) )
@@ -177,10 +169,10 @@ privacy_levels = ChoiceGroup(
"Profile visible to all users", "Profile visible to all users",
"Activity appears in public feeds", "Activity appears in public feeds",
"Searchable by search engines", "Searchable by search engines",
"Can be found by username search" "Can be found by username search",
], ],
"sort_order": 1, "sort_order": 1,
} },
), ),
RichChoice( RichChoice(
value="friends", value="friends",
@@ -196,10 +188,10 @@ privacy_levels = ChoiceGroup(
"Profile visible only to friends", "Profile visible only to friends",
"Activity hidden from public feeds", "Activity hidden from public feeds",
"Not searchable by search engines", "Not searchable by search engines",
"Requires friend request approval" "Requires friend request approval",
], ],
"sort_order": 2, "sort_order": 2,
} },
), ),
RichChoice( RichChoice(
value="private", value="private",
@@ -215,12 +207,12 @@ privacy_levels = ChoiceGroup(
"Profile completely hidden", "Profile completely hidden",
"No activity in any feeds", "No activity in any feeds",
"Not discoverable by other users", "Not discoverable by other users",
"Maximum privacy protection" "Maximum privacy protection",
], ],
"sort_order": 3, "sort_order": 3,
} },
), ),
] ],
) )
@@ -242,7 +234,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "roller_coaster", "ride_category": "roller_coaster",
"typical_list_size": 10, "typical_list_size": 10,
"sort_order": 1, "sort_order": 1,
} },
), ),
RichChoice( RichChoice(
value="DR", value="DR",
@@ -255,7 +247,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "dark_ride", "ride_category": "dark_ride",
"typical_list_size": 10, "typical_list_size": 10,
"sort_order": 2, "sort_order": 2,
} },
), ),
RichChoice( RichChoice(
value="FR", value="FR",
@@ -268,7 +260,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "flat_ride", "ride_category": "flat_ride",
"typical_list_size": 10, "typical_list_size": 10,
"sort_order": 3, "sort_order": 3,
} },
), ),
RichChoice( RichChoice(
value="WR", value="WR",
@@ -281,7 +273,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "water_ride", "ride_category": "water_ride",
"typical_list_size": 10, "typical_list_size": 10,
"sort_order": 4, "sort_order": 4,
} },
), ),
RichChoice( RichChoice(
value="PK", value="PK",
@@ -294,9 +286,9 @@ top_list_categories = ChoiceGroup(
"entity_type": "park", "entity_type": "park",
"typical_list_size": 10, "typical_list_size": 10,
"sort_order": 5, "sort_order": 5,
} },
), ),
] ],
) )
@@ -320,7 +312,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"], "default_channels": ["email", "push", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 1, "sort_order": 1,
} },
), ),
RichChoice( RichChoice(
value="submission_rejected", value="submission_rejected",
@@ -334,7 +326,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"], "default_channels": ["email", "push", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 2, "sort_order": 2,
} },
), ),
RichChoice( RichChoice(
value="submission_pending", value="submission_pending",
@@ -348,7 +340,7 @@ notification_types = ChoiceGroup(
"default_channels": ["inapp"], "default_channels": ["inapp"],
"priority": "low", "priority": "low",
"sort_order": 3, "sort_order": 3,
} },
), ),
# Review related # Review related
RichChoice( RichChoice(
@@ -363,7 +355,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"], "default_channels": ["email", "push", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 4, "sort_order": 4,
} },
), ),
RichChoice( RichChoice(
value="review_helpful", value="review_helpful",
@@ -377,7 +369,7 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"], "default_channels": ["push", "inapp"],
"priority": "low", "priority": "low",
"sort_order": 5, "sort_order": 5,
} },
), ),
# Social related # Social related
RichChoice( RichChoice(
@@ -392,7 +384,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"], "default_channels": ["email", "push", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 6, "sort_order": 6,
} },
), ),
RichChoice( RichChoice(
value="friend_accepted", value="friend_accepted",
@@ -406,7 +398,7 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"], "default_channels": ["push", "inapp"],
"priority": "low", "priority": "low",
"sort_order": 7, "sort_order": 7,
} },
), ),
RichChoice( RichChoice(
value="message_received", value="message_received",
@@ -420,7 +412,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"], "default_channels": ["email", "push", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 8, "sort_order": 8,
} },
), ),
RichChoice( RichChoice(
value="profile_comment", value="profile_comment",
@@ -434,7 +426,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"], "default_channels": ["email", "push", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 9, "sort_order": 9,
} },
), ),
# System related # System related
RichChoice( RichChoice(
@@ -449,7 +441,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "inapp"], "default_channels": ["email", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 10, "sort_order": 10,
} },
), ),
RichChoice( RichChoice(
value="account_security", value="account_security",
@@ -463,7 +455,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"], "default_channels": ["email", "push", "inapp"],
"priority": "high", "priority": "high",
"sort_order": 11, "sort_order": 11,
} },
), ),
RichChoice( RichChoice(
value="feature_update", value="feature_update",
@@ -477,7 +469,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "inapp"], "default_channels": ["email", "inapp"],
"priority": "low", "priority": "low",
"sort_order": 12, "sort_order": 12,
} },
), ),
RichChoice( RichChoice(
value="maintenance", value="maintenance",
@@ -491,7 +483,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "inapp"], "default_channels": ["email", "inapp"],
"priority": "normal", "priority": "normal",
"sort_order": 13, "sort_order": 13,
} },
), ),
# Achievement related # Achievement related
RichChoice( RichChoice(
@@ -506,7 +498,7 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"], "default_channels": ["push", "inapp"],
"priority": "low", "priority": "low",
"sort_order": 14, "sort_order": 14,
} },
), ),
RichChoice( RichChoice(
value="milestone_reached", value="milestone_reached",
@@ -520,9 +512,9 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"], "default_channels": ["push", "inapp"],
"priority": "low", "priority": "low",
"sort_order": 15, "sort_order": 15,
} },
), ),
] ],
) )
@@ -545,7 +537,7 @@ notification_priorities = ChoiceGroup(
"batch_eligible": True, "batch_eligible": True,
"delay_minutes": 60, "delay_minutes": 60,
"sort_order": 1, "sort_order": 1,
} },
), ),
RichChoice( RichChoice(
value="normal", value="normal",
@@ -559,7 +551,7 @@ notification_priorities = ChoiceGroup(
"batch_eligible": True, "batch_eligible": True,
"delay_minutes": 15, "delay_minutes": 15,
"sort_order": 2, "sort_order": 2,
} },
), ),
RichChoice( RichChoice(
value="high", value="high",
@@ -573,7 +565,7 @@ notification_priorities = ChoiceGroup(
"batch_eligible": False, "batch_eligible": False,
"delay_minutes": 0, "delay_minutes": 0,
"sort_order": 3, "sort_order": 3,
} },
), ),
RichChoice( RichChoice(
value="urgent", value="urgent",
@@ -588,9 +580,9 @@ notification_priorities = ChoiceGroup(
"delay_minutes": 0, "delay_minutes": 0,
"bypass_preferences": True, "bypass_preferences": True,
"sort_order": 4, "sort_order": 4,
} },
), ),
] ],
) )

View File

@@ -53,28 +53,34 @@ class UserExportService:
"dark_ride": profile.dark_ride_credits, "dark_ride": profile.dark_ride_credits,
"flat_ride": profile.flat_ride_credits, "flat_ride": profile.flat_ride_credits,
"water_ride": profile.water_ride_credits, "water_ride": profile.water_ride_credits,
} },
} }
# Reviews # Reviews
park_reviews = list(ParkReview.objects.filter(user=user).values( park_reviews = list(
"park__name", "rating", "review", "created_at", "updated_at", "is_published" ParkReview.objects.filter(user=user).values(
)) "park__name", "rating", "review", "created_at", "updated_at", "is_published"
)
)
ride_reviews = list(RideReview.objects.filter(user=user).values( ride_reviews = list(
"ride__name", "rating", "review", "created_at", "updated_at", "is_published" RideReview.objects.filter(user=user).values(
)) "ride__name", "rating", "review", "created_at", "updated_at", "is_published"
)
)
# Lists # Lists
user_lists = [] user_lists = []
for user_list in UserList.objects.filter(user=user): for user_list in UserList.objects.filter(user=user):
items = list(user_list.items.values("order", "content_type__model", "object_id", "comment")) items = list(user_list.items.values("order", "content_type__model", "object_id", "comment"))
user_lists.append({ user_lists.append(
"title": user_list.title, {
"description": user_list.description, "title": user_list.title,
"created_at": user_list.created_at, "description": user_list.description,
"items": items "created_at": user_list.created_at,
}) "items": items,
}
)
export_data = { export_data = {
"account": user_data, "account": user_data,
@@ -85,10 +91,7 @@ class UserExportService:
"ride_reviews": ride_reviews, "ride_reviews": ride_reviews,
"lists": user_lists, "lists": user_lists,
}, },
"export_info": { "export_info": {"generated_at": timezone.now(), "version": "1.0"},
"generated_at": timezone.now(),
"version": "1.0"
}
} }
return export_data return export_data

View File

@@ -99,8 +99,6 @@ class LoginHistory(models.Model):
# Default cleanup for entries older than the specified days # Default cleanup for entries older than the specified days
cutoff = timezone.now() - timedelta(days=days) cutoff = timezone.now() - timedelta(days=days)
deleted_count, _ = cls.objects.filter( deleted_count, _ = cls.objects.filter(login_timestamp__lt=cutoff).delete()
login_timestamp__lt=cutoff
).delete()
return deleted_count return deleted_count

View File

@@ -22,20 +22,14 @@ class Command(BaseCommand):
# Check SocialAccount # Check SocialAccount
self.stdout.write("\nChecking SocialAccount table:") self.stdout.write("\nChecking SocialAccount table:")
for account in SocialAccount.objects.all(): for account in SocialAccount.objects.all():
self.stdout.write( self.stdout.write(f"ID: {account.pk}, Provider: {account.provider}, UID: {account.uid}")
f"ID: {account.pk}, Provider: {account.provider}, UID: {account.uid}"
)
# Check SocialToken # Check SocialToken
self.stdout.write("\nChecking SocialToken table:") self.stdout.write("\nChecking SocialToken table:")
for token in SocialToken.objects.all(): for token in SocialToken.objects.all():
self.stdout.write( self.stdout.write(f"ID: {token.pk}, Account: {token.account}, App: {token.app}")
f"ID: {token.pk}, Account: {token.account}, App: {token.app}"
)
# Check Site # Check Site
self.stdout.write("\nChecking Site table:") self.stdout.write("\nChecking Site table:")
for site in Site.objects.all(): for site in Site.objects.all():
self.stdout.write( self.stdout.write(f"ID: {site.pk}, Domain: {site.domain}, Name: {site.name}")
f"ID: {site.pk}, Domain: {site.domain}, Name: {site.name}"
)

View File

@@ -17,6 +17,4 @@ class Command(BaseCommand):
self.stdout.write(f"Name: {app.name}") self.stdout.write(f"Name: {app.name}")
self.stdout.write(f"Client ID: {app.client_id}") self.stdout.write(f"Client ID: {app.client_id}")
self.stdout.write(f"Secret: {app.secret}") self.stdout.write(f"Secret: {app.secret}")
self.stdout.write( self.stdout.write(f"Sites: {', '.join(str(site.domain) for site in app.sites.all())}")
f"Sites: {', '.join(str(site.domain) for site in app.sites.all())}"
)

View File

@@ -15,14 +15,9 @@ class Command(BaseCommand):
# Remove migration records # Remove migration records
cursor.execute("DELETE FROM django_migrations WHERE app='socialaccount'") cursor.execute("DELETE FROM django_migrations WHERE app='socialaccount'")
cursor.execute( cursor.execute("DELETE FROM django_migrations WHERE app='accounts' " "AND name LIKE '%social%'")
"DELETE FROM django_migrations WHERE app='accounts' "
"AND name LIKE '%social%'"
)
# Reset sequences # Reset sequences
cursor.execute("DELETE FROM sqlite_sequence WHERE name LIKE '%social%'") cursor.execute("DELETE FROM sqlite_sequence WHERE name LIKE '%social%'")
self.stdout.write( self.stdout.write(self.style.SUCCESS("Successfully cleaned up social auth configuration"))
self.style.SUCCESS("Successfully cleaned up social auth configuration")
)

View File

@@ -18,24 +18,18 @@ class Command(BaseCommand):
self.stdout.write(self.style.SUCCESS(f"Deleted {count} test users")) self.stdout.write(self.style.SUCCESS(f"Deleted {count} test users"))
# Delete test reviews # Delete test reviews
reviews = ParkReview.objects.filter( reviews = ParkReview.objects.filter(user__username__in=["testuser", "moderator"])
user__username__in=["testuser", "moderator"]
)
count = reviews.count() count = reviews.count()
reviews.delete() reviews.delete()
self.stdout.write(self.style.SUCCESS(f"Deleted {count} test reviews")) self.stdout.write(self.style.SUCCESS(f"Deleted {count} test reviews"))
# Delete test photos - both park and ride photos # Delete test photos - both park and ride photos
park_photos = ParkPhoto.objects.filter( park_photos = ParkPhoto.objects.filter(uploader__username__in=["testuser", "moderator"])
uploader__username__in=["testuser", "moderator"]
)
park_count = park_photos.count() park_count = park_photos.count()
park_photos.delete() park_photos.delete()
self.stdout.write(self.style.SUCCESS(f"Deleted {park_count} test park photos")) self.stdout.write(self.style.SUCCESS(f"Deleted {park_count} test park photos"))
ride_photos = RidePhoto.objects.filter( ride_photos = RidePhoto.objects.filter(uploader__username__in=["testuser", "moderator"])
uploader__username__in=["testuser", "moderator"]
)
ride_count = ride_photos.count() ride_count = ride_photos.count()
ride_photos.delete() ride_photos.delete()
self.stdout.write(self.style.SUCCESS(f"Deleted {ride_count} test ride photos")) self.stdout.write(self.style.SUCCESS(f"Deleted {ride_count} test ride photos"))

View File

@@ -37,18 +37,12 @@ class Command(BaseCommand):
provider="google", provider="google",
defaults={ defaults={
"name": "Google", "name": "Google",
"client_id": ( "client_id": ("135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." "apps.googleusercontent.com"),
"135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2."
"apps.googleusercontent.com"
),
"secret": "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue", "secret": "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue",
}, },
) )
if not created: if not created:
google_app.client_id = ( google_app.client_id = "135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." "apps.googleusercontent.com"
"135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2."
"apps.googleusercontent.com"
)
google_app.secret = "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue" google_app.secret = "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue"
google_app.save() google_app.save()
google_app.sites.add(site) google_app.sites.add(site)

View File

@@ -14,9 +14,7 @@ class Command(BaseCommand):
) )
user.set_password("testpass123") user.set_password("testpass123")
user.save() user.save()
self.stdout.write( self.stdout.write(self.style.SUCCESS(f"Created test user: {user.get_username()}"))
self.style.SUCCESS(f"Created test user: {user.get_username()}")
)
else: else:
self.stdout.write(self.style.WARNING("Test user already exists")) self.stdout.write(self.style.WARNING("Test user already exists"))
@@ -47,11 +45,7 @@ class Command(BaseCommand):
# Add user to moderator group # Add user to moderator group
moderator.groups.add(moderator_group) moderator.groups.add(moderator_group)
self.stdout.write( self.stdout.write(self.style.SUCCESS(f"Created moderator user: {moderator.get_username()}"))
self.style.SUCCESS(
f"Created moderator user: {moderator.get_username()}"
)
)
else: else:
self.stdout.write(self.style.WARNING("Moderator user already exists")) self.stdout.write(self.style.WARNING("Moderator user already exists"))

View File

@@ -17,9 +17,7 @@ class Command(BaseCommand):
help = "Delete a user while preserving all their submissions" help = "Delete a user while preserving all their submissions"
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument("username", nargs="?", type=str, help="Username of the user to delete")
"username", nargs="?", type=str, help="Username of the user to delete"
)
parser.add_argument( parser.add_argument(
"--user-id", "--user-id",
type=str, type=str,
@@ -30,9 +28,7 @@ class Command(BaseCommand):
action="store_true", action="store_true",
help="Show what would be deleted without actually deleting", help="Show what would be deleted without actually deleting",
) )
parser.add_argument( parser.add_argument("--force", action="store_true", help="Skip confirmation prompt")
"--force", action="store_true", help="Skip confirmation prompt"
)
def handle(self, *args, **options): def handle(self, *args, **options):
username = options.get("username") username = options.get("username")
@@ -52,7 +48,7 @@ class Command(BaseCommand):
user = User.objects.get(username=username) if username else User.objects.get(user_id=user_id) user = User.objects.get(username=username) if username else User.objects.get(user_id=user_id)
except User.DoesNotExist: except User.DoesNotExist:
identifier = username or user_id identifier = username or user_id
raise CommandError(f'User "{identifier}" does not exist') raise CommandError(f'User "{identifier}" does not exist') from None
# Check if user can be deleted # Check if user can be deleted
can_delete, reason = UserDeletionService.can_delete_user(user) can_delete, reason = UserDeletionService.can_delete_user(user)
@@ -61,27 +57,13 @@ class Command(BaseCommand):
# Count submissions # Count submissions
submission_counts = { submission_counts = {
"park_reviews": getattr( "park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(),
user, "park_reviews", user.__class__.objects.none() "ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(),
).count(), "uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(),
"ride_reviews": getattr( "uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(),
user, "ride_reviews", user.__class__.objects.none() "top_lists": getattr(user, "top_lists", user.__class__.objects.none()).count(),
).count(), "edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(),
"uploaded_park_photos": getattr( "photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(),
user, "uploaded_park_photos", user.__class__.objects.none()
).count(),
"uploaded_ride_photos": getattr(
user, "uploaded_ride_photos", user.__class__.objects.none()
).count(),
"top_lists": getattr(
user, "top_lists", user.__class__.objects.none()
).count(),
"edit_submissions": getattr(
user, "edit_submissions", user.__class__.objects.none()
).count(),
"photo_submissions": getattr(
user, "photo_submissions", user.__class__.objects.none()
).count(),
} }
total_submissions = sum(submission_counts.values()) total_submissions = sum(submission_counts.values())
@@ -98,9 +80,7 @@ class Command(BaseCommand):
self.stdout.write(self.style.WARNING("\nSubmissions to preserve:")) self.stdout.write(self.style.WARNING("\nSubmissions to preserve:"))
for submission_type, count in submission_counts.items(): for submission_type, count in submission_counts.items():
if count > 0: if count > 0:
self.stdout.write( self.stdout.write(f' {submission_type.replace("_", " ").title()}: {count}')
f' {submission_type.replace("_", " ").title()}: {count}'
)
self.stdout.write(f"\nTotal submissions: {total_submissions}") self.stdout.write(f"\nTotal submissions: {total_submissions}")
@@ -111,9 +91,7 @@ class Command(BaseCommand):
) )
) )
else: else:
self.stdout.write( self.stdout.write(self.style.WARNING("\nNo submissions found for this user."))
self.style.WARNING("\nNo submissions found for this user.")
)
if dry_run: if dry_run:
self.stdout.write(self.style.SUCCESS("\n[DRY RUN] No changes were made.")) self.stdout.write(self.style.SUCCESS("\n[DRY RUN] No changes were made."))
@@ -136,11 +114,7 @@ class Command(BaseCommand):
try: try:
result = UserDeletionService.delete_user_preserve_submissions(user) result = UserDeletionService.delete_user_preserve_submissions(user)
self.stdout.write( self.stdout.write(self.style.SUCCESS(f'\nSuccessfully deleted user "{result["deleted_user"]["username"]}"'))
self.style.SUCCESS(
f'\nSuccessfully deleted user "{result["deleted_user"]["username"]}"'
)
)
preserved_count = sum(result["preserved_submissions"].values()) preserved_count = sum(result["preserved_submissions"].values())
if preserved_count > 0: if preserved_count > 0:
@@ -154,9 +128,7 @@ class Command(BaseCommand):
self.stdout.write(self.style.WARNING("\nPreservation Summary:")) self.stdout.write(self.style.WARNING("\nPreservation Summary:"))
for submission_type, count in result["preserved_submissions"].items(): for submission_type, count in result["preserved_submissions"].items():
if count > 0: if count > 0:
self.stdout.write( self.stdout.write(f' {submission_type.replace("_", " ").title()}: {count}')
f' {submission_type.replace("_", " ").title()}: {count}'
)
except Exception as e: except Exception as e:
raise CommandError(f"Error deleting user: {str(e)}") raise CommandError(f"Error deleting user: {str(e)}") from None

View File

@@ -7,12 +7,5 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs): def handle(self, *args, **kwargs):
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute( cursor.execute("DELETE FROM django_migrations WHERE app='rides' " "AND name='0001_initial';")
"DELETE FROM django_migrations WHERE app='rides' " self.stdout.write(self.style.SUCCESS("Successfully removed rides.0001_initial from migration history"))
"AND name='0001_initial';"
)
self.stdout.write(
self.style.SUCCESS(
"Successfully removed rides.0001_initial from migration history"
)
)

View File

@@ -34,6 +34,4 @@ class Command(BaseCommand):
secret=os.getenv("DISCORD_CLIENT_SECRET"), secret=os.getenv("DISCORD_CLIENT_SECRET"),
) )
discord_app.sites.add(site) discord_app.sites.add(site)
self.stdout.write( self.stdout.write(f"Created Discord app with client_id: {discord_app.client_id}")
f"Created Discord app with client_id: {discord_app.client_id}"
)

View File

@@ -47,9 +47,7 @@ class Command(BaseCommand):
help = "Generate avatars for letters A-Z and numbers 0-9" help = "Generate avatars for letters A-Z and numbers 0-9"
def handle(self, *args, **kwargs): def handle(self, *args, **kwargs):
characters = [chr(i) for i in range(65, 91)] + [ characters = [chr(i) for i in range(65, 91)] + [str(i) for i in range(10)] # A-Z and 0-9
str(i) for i in range(10)
] # A-Z and 0-9
for char in characters: for char in characters:
generate_avatar(char) generate_avatar(char)
self.stdout.write(self.style.SUCCESS(f"Generated avatar for {char}")) self.stdout.write(self.style.SUCCESS(f"Generated avatar for {char}"))

View File

@@ -11,6 +11,4 @@ class Command(BaseCommand):
for profile in profiles: for profile in profiles:
# This will trigger the avatar generation logic in the save method # This will trigger the avatar generation logic in the save method
profile.save() profile.save()
self.stdout.write( self.stdout.write(self.style.SUCCESS(f"Regenerated avatar for {profile.user.username}"))
self.style.SUCCESS(f"Regenerated avatar for {profile.user.username}")
)

View File

@@ -69,18 +69,18 @@ class Command(BaseCommand):
# Security: Using Django ORM instead of raw SQL for user creation # Security: Using Django ORM instead of raw SQL for user creation
user = User.objects.create_superuser( user = User.objects.create_superuser(
username='admin', username="admin",
email='admin@thrillwiki.com', email="admin@thrillwiki.com",
password='admin', password="admin",
role='SUPERUSER', role="SUPERUSER",
) )
# Create profile using ORM # Create profile using ORM
UserProfile.objects.create( UserProfile.objects.create(
user=user, user=user,
display_name='Admin', display_name="Admin",
pronouns='they/them', pronouns="they/them",
bio='ThrillWiki Administrator', bio="ThrillWiki Administrator",
) )
self.stdout.write("Superuser created.") self.stdout.write("Superuser created.")

View File

@@ -30,9 +30,7 @@ class Command(BaseCommand):
google_app = SocialApp.objects.create( google_app = SocialApp.objects.create(
provider="google", provider="google",
name="Google", name="Google",
client_id=( client_id=("135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2.apps.googleusercontent.com"),
"135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2.apps.googleusercontent.com"
),
secret="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", secret="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",
) )
google_app.sites.add(site) google_app.sites.add(site)

View File

@@ -12,13 +12,7 @@ class Command(BaseCommand):
cursor.execute("DELETE FROM socialaccount_socialapp_sites") cursor.execute("DELETE FROM socialaccount_socialapp_sites")
# Reset sequences # Reset sequences
cursor.execute( cursor.execute("DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp'")
"DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp'" cursor.execute("DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp_sites'")
)
cursor.execute(
"DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp_sites'"
)
self.stdout.write( self.stdout.write(self.style.SUCCESS("Successfully reset social auth configuration"))
self.style.SUCCESS("Successfully reset social auth configuration")
)

View File

@@ -30,9 +30,7 @@ class Command(BaseCommand):
user.is_staff = True user.is_staff = True
user.save() user.save()
self.stdout.write( self.stdout.write(self.style.SUCCESS("Successfully set up groups and permissions"))
self.style.SUCCESS("Successfully set up groups and permissions")
)
# Print summary # Print summary
for group in Group.objects.all(): for group in Group.objects.all():

View File

@@ -10,7 +10,5 @@ class Command(BaseCommand):
Site.objects.all().delete() Site.objects.all().delete()
# Create default site # Create default site
site = Site.objects.create( site = Site.objects.create(id=1, domain="localhost:8000", name="ThrillWiki Development")
id=1, domain="localhost:8000", name="ThrillWiki Development"
)
self.stdout.write(self.style.SUCCESS(f"Created site: {site.domain}")) self.stdout.write(self.style.SUCCESS(f"Created site: {site.domain}"))

View File

@@ -49,27 +49,15 @@ class Command(BaseCommand):
discord_client_secret, discord_client_secret,
] ]
): ):
self.stdout.write( self.stdout.write(self.style.ERROR("Missing required environment variables"))
self.style.ERROR("Missing required environment variables") self.stdout.write(f"DEBUG: google_client_id is None: {google_client_id is None}")
) self.stdout.write(f"DEBUG: google_client_secret is None: {google_client_secret is None}")
self.stdout.write( self.stdout.write(f"DEBUG: discord_client_id is None: {discord_client_id is None}")
f"DEBUG: google_client_id is None: {google_client_id is None}" self.stdout.write(f"DEBUG: discord_client_secret is None: {discord_client_secret is None}")
)
self.stdout.write(
f"DEBUG: google_client_secret is None: {google_client_secret is None}"
)
self.stdout.write(
f"DEBUG: discord_client_id is None: {discord_client_id is None}"
)
self.stdout.write(
f"DEBUG: discord_client_secret is None: {discord_client_secret is None}"
)
return return
# Get or create the default site # Get or create the default site
site, _ = Site.objects.get_or_create( site, _ = Site.objects.get_or_create(id=1, defaults={"domain": "localhost:8000", "name": "localhost"})
id=1, defaults={"domain": "localhost:8000", "name": "localhost"}
)
# Set up Google # Set up Google
google_app, created = SocialApp.objects.get_or_create( google_app, created = SocialApp.objects.get_or_create(
@@ -92,11 +80,7 @@ class Command(BaseCommand):
google_app.save() google_app.save()
self.stdout.write("DEBUG: Successfully updated Google app") self.stdout.write("DEBUG: Successfully updated Google app")
else: else:
self.stdout.write( self.stdout.write(self.style.ERROR("Google client_id or secret is None, skipping update."))
self.style.ERROR(
"Google client_id or secret is None, skipping update."
)
)
google_app.sites.add(site) google_app.sites.add(site)
# Set up Discord # Set up Discord
@@ -120,11 +104,7 @@ class Command(BaseCommand):
discord_app.save() discord_app.save()
self.stdout.write("DEBUG: Successfully updated Discord app") self.stdout.write("DEBUG: Successfully updated Discord app")
else: else:
self.stdout.write( self.stdout.write(self.style.ERROR("Discord client_id or secret is None, skipping update."))
self.style.ERROR(
"Discord client_id or secret is None, skipping update."
)
)
discord_app.sites.add(site) discord_app.sites.add(site)
self.stdout.write(self.style.SUCCESS("Successfully set up social auth apps")) self.stdout.write(self.style.SUCCESS("Successfully set up social auth apps"))

View File

@@ -42,6 +42,4 @@ class Command(BaseCommand):
for app in SocialApp.objects.all(): for app in SocialApp.objects.all():
self.stdout.write(f"- {app.name} ({app.provider}): {app.client_id}") self.stdout.write(f"- {app.name} ({app.provider}): {app.client_id}")
self.stdout.write( self.stdout.write(self.style.SUCCESS(f"\nTotal social apps: {SocialApp.objects.count()}"))
self.style.SUCCESS(f"\nTotal social apps: {SocialApp.objects.count()}")
)

View File

@@ -40,9 +40,7 @@ class Command(BaseCommand):
# Show callback URL # Show callback URL
callback_url = "http://localhost:8000/accounts/discord/login/callback/" callback_url = "http://localhost:8000/accounts/discord/login/callback/"
self.stdout.write( self.stdout.write("\nCallback URL to configure in Discord Developer Portal:")
"\nCallback URL to configure in Discord Developer Portal:"
)
self.stdout.write(callback_url) self.stdout.write(callback_url)
# Show frontend login URL # Show frontend login URL

View File

@@ -18,6 +18,4 @@ class Command(BaseCommand):
# Add all sites # Add all sites
for site in sites: for site in sites:
app.sites.add(site) app.sites.add(site)
self.stdout.write( self.stdout.write(f"Added sites: {', '.join(site.domain for site in sites)}")
f"Added sites: {', '.join(site.domain for site in sites)}"
)

View File

@@ -22,17 +22,13 @@ class Command(BaseCommand):
# Show callback URL # Show callback URL
callback_url = "http://localhost:8000/accounts/discord/login/callback/" callback_url = "http://localhost:8000/accounts/discord/login/callback/"
self.stdout.write( self.stdout.write("\nCallback URL to configure in Discord Developer Portal:")
"\nCallback URL to configure in Discord Developer Portal:"
)
self.stdout.write(callback_url) self.stdout.write(callback_url)
# Show OAuth2 settings # Show OAuth2 settings
self.stdout.write("\nOAuth2 settings in settings.py:") self.stdout.write("\nOAuth2 settings in settings.py:")
discord_settings = settings.SOCIALACCOUNT_PROVIDERS.get("discord", {}) discord_settings = settings.SOCIALACCOUNT_PROVIDERS.get("discord", {})
self.stdout.write( self.stdout.write(f"PKCE Enabled: {discord_settings.get('OAUTH_PKCE_ENABLED', False)}")
f"PKCE Enabled: {discord_settings.get('OAUTH_PKCE_ENABLED', False)}"
)
self.stdout.write(f"Scopes: {discord_settings.get('SCOPE', [])}") self.stdout.write(f"Scopes: {discord_settings.get('SCOPE', [])}")
except SocialApp.DoesNotExist: except SocialApp.DoesNotExist:

View File

@@ -38,9 +38,7 @@ class Migration(migrations.Migration):
), ),
( (
"last_login", "last_login",
models.DateTimeField( models.DateTimeField(blank=True, null=True, verbose_name="last login"),
blank=True, null=True, verbose_name="last login"
),
), ),
( (
"is_superuser", "is_superuser",
@@ -53,29 +51,21 @@ class Migration(migrations.Migration):
( (
"username", "username",
models.CharField( models.CharField(
error_messages={ error_messages={"unique": "A user with that username already exists."},
"unique": "A user with that username already exists."
},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150, max_length=150,
unique=True, unique=True,
validators=[ validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
django.contrib.auth.validators.UnicodeUsernameValidator()
],
verbose_name="username", verbose_name="username",
), ),
), ),
( (
"first_name", "first_name",
models.CharField( models.CharField(blank=True, max_length=150, verbose_name="first name"),
blank=True, max_length=150, verbose_name="first name"
),
), ),
( (
"last_name", "last_name",
models.CharField( models.CharField(blank=True, max_length=150, verbose_name="last name"),
blank=True, max_length=150, verbose_name="last name"
),
), ),
( (
"email", "email",

View File

@@ -57,9 +57,7 @@ class Migration(migrations.Migration):
("password", models.CharField(max_length=128, verbose_name="password")), ("password", models.CharField(max_length=128, verbose_name="password")),
( (
"last_login", "last_login",
models.DateTimeField( models.DateTimeField(blank=True, null=True, verbose_name="last login"),
blank=True, null=True, verbose_name="last login"
),
), ),
( (
"is_superuser", "is_superuser",
@@ -72,34 +70,24 @@ class Migration(migrations.Migration):
( (
"username", "username",
models.CharField( models.CharField(
error_messages={ error_messages={"unique": "A user with that username already exists."},
"unique": "A user with that username already exists."
},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150, max_length=150,
validators=[ validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
django.contrib.auth.validators.UnicodeUsernameValidator()
],
verbose_name="username", verbose_name="username",
), ),
), ),
( (
"first_name", "first_name",
models.CharField( models.CharField(blank=True, max_length=150, verbose_name="first name"),
blank=True, max_length=150, verbose_name="first name"
),
), ),
( (
"last_name", "last_name",
models.CharField( models.CharField(blank=True, max_length=150, verbose_name="last name"),
blank=True, max_length=150, verbose_name="last name"
),
), ),
( (
"email", "email",
models.EmailField( models.EmailField(blank=True, max_length=254, verbose_name="email address"),
blank=True, max_length=254, verbose_name="email address"
),
), ),
( (
"is_staff", "is_staff",
@@ -119,9 +107,7 @@ class Migration(migrations.Migration):
), ),
( (
"date_joined", "date_joined",
models.DateTimeField( models.DateTimeField(default=django.utils.timezone.now, verbose_name="date joined"),
default=django.utils.timezone.now, verbose_name="date joined"
),
), ),
( (
"user_id", "user_id",

View File

@@ -41,9 +41,7 @@ class Migration(migrations.Migration):
("created_at", models.DateTimeField(auto_now_add=True)), ("created_at", models.DateTimeField(auto_now_add=True)),
( (
"expires_at", "expires_at",
models.DateTimeField( models.DateTimeField(help_text="When this deletion request expires"),
help_text="When this deletion request expires"
),
), ),
( (
"email_sent_at", "email_sent_at",
@@ -55,9 +53,7 @@ class Migration(migrations.Migration):
), ),
( (
"attempts", "attempts",
models.PositiveIntegerField( models.PositiveIntegerField(default=0, help_text="Number of verification attempts made"),
default=0, help_text="Number of verification attempts made"
),
), ),
( (
"max_attempts", "max_attempts",
@@ -103,9 +99,7 @@ class Migration(migrations.Migration):
("created_at", models.DateTimeField(auto_now_add=True)), ("created_at", models.DateTimeField(auto_now_add=True)),
( (
"expires_at", "expires_at",
models.DateTimeField( models.DateTimeField(help_text="When this deletion request expires"),
help_text="When this deletion request expires"
),
), ),
( (
"email_sent_at", "email_sent_at",
@@ -117,9 +111,7 @@ class Migration(migrations.Migration):
), ),
( (
"attempts", "attempts",
models.PositiveIntegerField( models.PositiveIntegerField(default=0, help_text="Number of verification attempts made"),
default=0, help_text="Number of verification attempts made"
),
), ),
( (
"max_attempts", "max_attempts",
@@ -171,21 +163,15 @@ class Migration(migrations.Migration):
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="userdeletionrequest", model_name="userdeletionrequest",
index=models.Index( index=models.Index(fields=["verification_code"], name="accounts_us_verific_94460d_idx"),
fields=["verification_code"], name="accounts_us_verific_94460d_idx"
),
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="userdeletionrequest", model_name="userdeletionrequest",
index=models.Index( index=models.Index(fields=["expires_at"], name="accounts_us_expires_1d1dca_idx"),
fields=["expires_at"], name="accounts_us_expires_1d1dca_idx"
),
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="userdeletionrequest", model_name="userdeletionrequest",
index=models.Index( index=models.Index(fields=["user", "is_used"], name="accounts_us_user_id_1ce18a_idx"),
fields=["user", "is_used"], name="accounts_us_user_id_1ce18a_idx"
),
), ),
pgtrigger.migrations.AddTrigger( pgtrigger.migrations.AddTrigger(
model_name="userdeletionrequest", model_name="userdeletionrequest",

View File

@@ -57,9 +57,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="user", model_name="user",
name="last_password_change", name="last_password_change",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now),
auto_now_add=True, default=django.utils.timezone.now
),
preserve_default=False, preserve_default=False,
), ),
migrations.AddField( migrations.AddField(
@@ -185,9 +183,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="userevent", model_name="userevent",
name="last_password_change", name="last_password_change",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now),
auto_now_add=True, default=django.utils.timezone.now
),
preserve_default=False, preserve_default=False,
), ),
migrations.AddField( migrations.AddField(

View File

@@ -454,9 +454,7 @@ class Migration(migrations.Migration):
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="usernotification", model_name="usernotification",
index=models.Index( index=models.Index(fields=["user", "is_read"], name="accounts_us_user_id_785929_idx"),
fields=["user", "is_read"], name="accounts_us_user_id_785929_idx"
),
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="usernotification", model_name="usernotification",
@@ -467,15 +465,11 @@ class Migration(migrations.Migration):
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="usernotification", model_name="usernotification",
index=models.Index( index=models.Index(fields=["created_at"], name="accounts_us_created_a62f54_idx"),
fields=["created_at"], name="accounts_us_created_a62f54_idx"
),
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="usernotification", model_name="usernotification",
index=models.Index( index=models.Index(fields=["expires_at"], name="accounts_us_expires_f267b1_idx"),
fields=["expires_at"], name="accounts_us_expires_f267b1_idx"
),
), ),
pgtrigger.migrations.AddTrigger( pgtrigger.migrations.AddTrigger(
model_name="usernotification", model_name="usernotification",

View File

@@ -26,25 +26,24 @@ def safe_add_avatar_field(apps, schema_editor):
""" """
# Check if the column already exists # Check if the column already exists
with schema_editor.connection.cursor() as cursor: with schema_editor.connection.cursor() as cursor:
cursor.execute(""" cursor.execute(
"""
SELECT column_name SELECT column_name
FROM information_schema.columns FROM information_schema.columns
WHERE table_name='accounts_userprofile' WHERE table_name='accounts_userprofile'
AND column_name='avatar_id' AND column_name='avatar_id'
""") """
)
column_exists = cursor.fetchone() is not None column_exists = cursor.fetchone() is not None
if not column_exists: if not column_exists:
# Column doesn't exist, add it # Column doesn't exist, add it
UserProfile = apps.get_model('accounts', 'UserProfile') UserProfile = apps.get_model("accounts", "UserProfile")
field = models.ForeignKey( field = models.ForeignKey(
'django_cloudflareimages_toolkit.CloudflareImage', "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.SET_NULL, null=True, blank=True
on_delete=models.SET_NULL,
null=True,
blank=True
) )
field.set_attributes_from_name('avatar') field.set_attributes_from_name("avatar")
schema_editor.add_field(UserProfile, field) schema_editor.add_field(UserProfile, field)
@@ -54,24 +53,23 @@ def reverse_safe_add_avatar_field(apps, schema_editor):
""" """
# Check if the column exists and remove it # Check if the column exists and remove it
with schema_editor.connection.cursor() as cursor: with schema_editor.connection.cursor() as cursor:
cursor.execute(""" cursor.execute(
"""
SELECT column_name SELECT column_name
FROM information_schema.columns FROM information_schema.columns
WHERE table_name='accounts_userprofile' WHERE table_name='accounts_userprofile'
AND column_name='avatar_id' AND column_name='avatar_id'
""") """
)
column_exists = cursor.fetchone() is not None column_exists = cursor.fetchone() is not None
if column_exists: if column_exists:
UserProfile = apps.get_model('accounts', 'UserProfile') UserProfile = apps.get_model("accounts", "UserProfile")
field = models.ForeignKey( field = models.ForeignKey(
'django_cloudflareimages_toolkit.CloudflareImage', "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.SET_NULL, null=True, blank=True
on_delete=models.SET_NULL,
null=True,
blank=True
) )
field.set_attributes_from_name('avatar') field.set_attributes_from_name("avatar")
schema_editor.remove_field(UserProfile, field) schema_editor.remove_field(UserProfile, field)
@@ -89,15 +87,13 @@ class Migration(migrations.Migration):
# First, remove the old avatar column (CloudflareImageField) # First, remove the old avatar column (CloudflareImageField)
migrations.RunSQL( migrations.RunSQL(
"ALTER TABLE accounts_userprofile DROP COLUMN IF EXISTS avatar;", "ALTER TABLE accounts_userprofile DROP COLUMN IF EXISTS avatar;",
reverse_sql="-- Cannot reverse this operation" reverse_sql="-- Cannot reverse this operation",
), ),
# Safely add the new avatar_id column for ForeignKey # Safely add the new avatar_id column for ForeignKey
migrations.RunPython( migrations.RunPython(
safe_add_avatar_field, safe_add_avatar_field,
reverse_safe_add_avatar_field, reverse_safe_add_avatar_field,
), ),
# Run the data migration # Run the data migration
migrations.RunPython( migrations.RunPython(
migrate_avatar_data, migrate_avatar_data,

View File

@@ -6,17 +6,16 @@ from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('accounts', '0010_auto_20250830_1657'), ("accounts", "0010_auto_20250830_1657"),
('django_cloudflareimages_toolkit', '0001_initial'), ("django_cloudflareimages_toolkit", "0001_initial"),
] ]
operations = [ operations = [
# Remove the old avatar field from the event table # Remove the old avatar field from the event table
migrations.RunSQL( migrations.RunSQL(
"ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar;", "ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar;",
reverse_sql="-- Cannot reverse this operation" reverse_sql="-- Cannot reverse this operation",
), ),
# Add the new avatar_id field to match the main table (only if it doesn't exist) # Add the new avatar_id field to match the main table (only if it doesn't exist)
migrations.RunSQL( migrations.RunSQL(
""" """
@@ -32,6 +31,6 @@ class Migration(migrations.Migration):
END IF; END IF;
END $$; END $$;
""", """,
reverse_sql="ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar_id;" reverse_sql="ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar_id;",
), ),
] ]

View File

@@ -13,28 +13,28 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('accounts', '0012_alter_toplist_category_and_more'), ("accounts", "0012_alter_toplist_category_and_more"),
] ]
operations = [ operations = [
# Add db_index to is_banned field # Add db_index to is_banned field
migrations.AlterField( migrations.AlterField(
model_name='user', model_name="user",
name='is_banned', name="is_banned",
field=models.BooleanField(default=False, db_index=True), field=models.BooleanField(default=False, db_index=True),
), ),
# Add composite index for common query patterns # Add composite index for common query patterns
migrations.AddIndex( migrations.AddIndex(
model_name='user', model_name="user",
index=models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'), index=models.Index(fields=["is_banned", "role"], name="accounts_user_banned_role_idx"),
), ),
# Add CheckConstraint for ban consistency # Add CheckConstraint for ban consistency
migrations.AddConstraint( migrations.AddConstraint(
model_name='user', model_name="user",
constraint=models.CheckConstraint( constraint=models.CheckConstraint(
name='user_ban_consistency', name="user_ban_consistency",
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False), check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set' violation_error_message="Banned users must have a ban_date set",
), ),
), ),
] ]

View File

@@ -18,7 +18,6 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.AlterModelOptions( migrations.AlterModelOptions(
name="user", name="user",
options={"verbose_name": "User", "verbose_name_plural": "Users"}, options={"verbose_name": "User", "verbose_name_plural": "Users"},
@@ -58,9 +57,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="userprofile", model_name="userprofile",
name="location", name="location",
field=models.CharField( field=models.CharField(blank=True, help_text="User's location (City, Country)", max_length=100),
blank=True, help_text="User's location (City, Country)", max_length=100
),
), ),
migrations.AddField( migrations.AddField(
model_name="userprofile", model_name="userprofile",
@@ -78,9 +75,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="userprofileevent", model_name="userprofileevent",
name="location", name="location",
field=models.CharField( field=models.CharField(blank=True, help_text="User's location (City, Country)", max_length=100),
blank=True, help_text="User's location (City, Country)", max_length=100
),
), ),
migrations.AddField( migrations.AddField(
model_name="userprofileevent", model_name="userprofileevent",
@@ -98,23 +93,17 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="emailverification", model_name="emailverification",
name="created_at", name="created_at",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When this verification was created"),
auto_now_add=True, help_text="When this verification was created"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="emailverification", model_name="emailverification",
name="last_sent", name="last_sent",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent"),
auto_now_add=True, help_text="When the verification email was last sent"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="emailverification", model_name="emailverification",
name="token", name="token",
field=models.CharField( field=models.CharField(help_text="Verification token", max_length=64, unique=True),
help_text="Verification token", max_length=64, unique=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="emailverification", model_name="emailverification",
@@ -128,16 +117,12 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="emailverificationevent", model_name="emailverificationevent",
name="created_at", name="created_at",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When this verification was created"),
auto_now_add=True, help_text="When this verification was created"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="emailverificationevent", model_name="emailverificationevent",
name="last_sent", name="last_sent",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent"),
auto_now_add=True, help_text="When the verification email was last sent"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="emailverificationevent", model_name="emailverificationevent",
@@ -181,9 +166,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="passwordreset", model_name="passwordreset",
name="created_at", name="created_at",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When this reset was requested"),
auto_now_add=True, help_text="When this reset was requested"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="passwordreset", model_name="passwordreset",
@@ -198,9 +181,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="passwordreset", model_name="passwordreset",
name="used", name="used",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether this token has been used"),
default=False, help_text="Whether this token has been used"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="passwordreset", model_name="passwordreset",
@@ -214,9 +195,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="passwordresetevent", model_name="passwordresetevent",
name="created_at", name="created_at",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When this reset was requested"),
auto_now_add=True, help_text="When this reset was requested"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="passwordresetevent", model_name="passwordresetevent",
@@ -231,9 +210,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="passwordresetevent", model_name="passwordresetevent",
name="used", name="used",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether this token has been used"),
default=False, help_text="Whether this token has been used"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="passwordresetevent", model_name="passwordresetevent",
@@ -267,30 +244,22 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="allow_friend_requests", name="allow_friend_requests",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to allow friend requests"),
default=True, help_text="Whether to allow friend requests"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="allow_messages", name="allow_messages",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to allow direct messages"),
default=True, help_text="Whether to allow direct messages"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="allow_profile_comments", name="allow_profile_comments",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether to allow profile comments"),
default=False, help_text="Whether to allow profile comments"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="ban_date", name="ban_date",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="Date the user was banned", null=True),
blank=True, help_text="Date the user was banned", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
@@ -300,37 +269,27 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="email_notifications", name="email_notifications",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to send email notifications"),
default=True, help_text="Whether to send email notifications"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="is_banned", name="is_banned",
field=models.BooleanField( field=models.BooleanField(db_index=True, default=False, help_text="Whether this user is banned"),
db_index=True, default=False, help_text="Whether this user is banned"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="last_password_change", name="last_password_change",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When the password was last changed"),
auto_now_add=True, help_text="When the password was last changed"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="login_history_retention", name="login_history_retention",
field=models.IntegerField( field=models.IntegerField(default=90, help_text="How long to retain login history (days)"),
default=90, help_text="How long to retain login history (days)"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="login_notifications", name="login_notifications",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to send login notifications"),
default=True, help_text="Whether to send login notifications"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
@@ -352,9 +311,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="push_notifications", name="push_notifications",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether to send push notifications"),
default=False, help_text="Whether to send push notifications"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
@@ -378,9 +335,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="search_visibility", name="search_visibility",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether profile appears in search results"),
default=True, help_text="Whether profile appears in search results"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
@@ -390,51 +345,37 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="show_email", name="show_email",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether to show email on profile"),
default=False, help_text="Whether to show email on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="show_join_date", name="show_join_date",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show join date on profile"),
default=True, help_text="Whether to show join date on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="show_photos", name="show_photos",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show photos on profile"),
default=True, help_text="Whether to show photos on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="show_real_name", name="show_real_name",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show real name on profile"),
default=True, help_text="Whether to show real name on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="show_reviews", name="show_reviews",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show reviews on profile"),
default=True, help_text="Whether to show reviews on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="show_statistics", name="show_statistics",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show statistics on profile"),
default=True, help_text="Whether to show statistics on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="show_top_lists", name="show_top_lists",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show top lists on profile"),
default=True, help_text="Whether to show top lists on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
@@ -452,9 +393,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="two_factor_enabled", name="two_factor_enabled",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled"),
default=False, help_text="Whether two-factor authentication is enabled"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
@@ -476,30 +415,22 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="allow_friend_requests", name="allow_friend_requests",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to allow friend requests"),
default=True, help_text="Whether to allow friend requests"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="allow_messages", name="allow_messages",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to allow direct messages"),
default=True, help_text="Whether to allow direct messages"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="allow_profile_comments", name="allow_profile_comments",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether to allow profile comments"),
default=False, help_text="Whether to allow profile comments"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="ban_date", name="ban_date",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="Date the user was banned", null=True),
blank=True, help_text="Date the user was banned", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
@@ -509,37 +440,27 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="email_notifications", name="email_notifications",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to send email notifications"),
default=True, help_text="Whether to send email notifications"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="is_banned", name="is_banned",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether this user is banned"),
default=False, help_text="Whether this user is banned"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="last_password_change", name="last_password_change",
field=models.DateTimeField( field=models.DateTimeField(auto_now_add=True, help_text="When the password was last changed"),
auto_now_add=True, help_text="When the password was last changed"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="login_history_retention", name="login_history_retention",
field=models.IntegerField( field=models.IntegerField(default=90, help_text="How long to retain login history (days)"),
default=90, help_text="How long to retain login history (days)"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="login_notifications", name="login_notifications",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to send login notifications"),
default=True, help_text="Whether to send login notifications"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
@@ -561,9 +482,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="push_notifications", name="push_notifications",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether to send push notifications"),
default=False, help_text="Whether to send push notifications"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
@@ -586,9 +505,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="search_visibility", name="search_visibility",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether profile appears in search results"),
default=True, help_text="Whether profile appears in search results"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
@@ -598,51 +515,37 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="show_email", name="show_email",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether to show email on profile"),
default=False, help_text="Whether to show email on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="show_join_date", name="show_join_date",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show join date on profile"),
default=True, help_text="Whether to show join date on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="show_photos", name="show_photos",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show photos on profile"),
default=True, help_text="Whether to show photos on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="show_real_name", name="show_real_name",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show real name on profile"),
default=True, help_text="Whether to show real name on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="show_reviews", name="show_reviews",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show reviews on profile"),
default=True, help_text="Whether to show reviews on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="show_statistics", name="show_statistics",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show statistics on profile"),
default=True, help_text="Whether to show statistics on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="show_top_lists", name="show_top_lists",
field=models.BooleanField( field=models.BooleanField(default=True, help_text="Whether to show top lists on profile"),
default=True, help_text="Whether to show top lists on profile"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
@@ -660,9 +563,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userevent", model_name="userevent",
name="two_factor_enabled", name="two_factor_enabled",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled"),
default=False, help_text="Whether two-factor authentication is enabled"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
@@ -678,23 +579,17 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
name="email_sent", name="email_sent",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether email was sent"),
default=False, help_text="Whether email was sent"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
name="email_sent_at", name="email_sent_at",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="When email was sent", null=True),
blank=True, help_text="When email was sent", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
name="is_read", name="is_read",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether this notification has been read"),
default=False, help_text="Whether this notification has been read"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
@@ -704,30 +599,22 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
name="object_id", name="object_id",
field=models.PositiveIntegerField( field=models.PositiveIntegerField(blank=True, help_text="ID of related object", null=True),
blank=True, help_text="ID of related object", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
name="push_sent", name="push_sent",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether push notification was sent"),
default=False, help_text="Whether push notification was sent"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
name="push_sent_at", name="push_sent_at",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="When push notification was sent", null=True),
blank=True, help_text="When push notification was sent", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
name="read_at", name="read_at",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="When this notification was read", null=True),
blank=True, help_text="When this notification was read", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotification", model_name="usernotification",
@@ -761,23 +648,17 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
name="email_sent", name="email_sent",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether email was sent"),
default=False, help_text="Whether email was sent"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
name="email_sent_at", name="email_sent_at",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="When email was sent", null=True),
blank=True, help_text="When email was sent", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
name="is_read", name="is_read",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether this notification has been read"),
default=False, help_text="Whether this notification has been read"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
@@ -787,30 +668,22 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
name="object_id", name="object_id",
field=models.PositiveIntegerField( field=models.PositiveIntegerField(blank=True, help_text="ID of related object", null=True),
blank=True, help_text="ID of related object", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
name="push_sent", name="push_sent",
field=models.BooleanField( field=models.BooleanField(default=False, help_text="Whether push notification was sent"),
default=False, help_text="Whether push notification was sent"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
name="push_sent_at", name="push_sent_at",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="When push notification was sent", null=True),
blank=True, help_text="When push notification was sent", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
name="read_at", name="read_at",
field=models.DateTimeField( field=models.DateTimeField(blank=True, help_text="When this notification was read", null=True),
blank=True, help_text="When this notification was read", null=True
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="usernotificationevent", model_name="usernotificationevent",
@@ -844,37 +717,27 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
name="bio", name="bio",
field=models.TextField( field=models.TextField(blank=True, help_text="User biography", max_length=500),
blank=True, help_text="User biography", max_length=500
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
name="coaster_credits", name="coaster_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of roller coasters ridden"),
default=0, help_text="Number of roller coasters ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
name="dark_ride_credits", name="dark_ride_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of dark rides ridden"),
default=0, help_text="Number of dark rides ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
name="discord", name="discord",
field=models.CharField( field=models.CharField(blank=True, help_text="Discord username", max_length=100),
blank=True, help_text="Discord username", max_length=100
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
name="flat_ride_credits", name="flat_ride_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of flat rides ridden"),
default=0, help_text="Number of flat rides ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
@@ -884,9 +747,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
name="pronouns", name="pronouns",
field=models.CharField( field=models.CharField(blank=True, help_text="User's preferred pronouns", max_length=50),
blank=True, help_text="User's preferred pronouns", max_length=50
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
@@ -906,9 +767,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
name="water_ride_credits", name="water_ride_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of water rides ridden"),
default=0, help_text="Number of water rides ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofile", model_name="userprofile",
@@ -932,37 +791,27 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
name="bio", name="bio",
field=models.TextField( field=models.TextField(blank=True, help_text="User biography", max_length=500),
blank=True, help_text="User biography", max_length=500
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
name="coaster_credits", name="coaster_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of roller coasters ridden"),
default=0, help_text="Number of roller coasters ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
name="dark_ride_credits", name="dark_ride_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of dark rides ridden"),
default=0, help_text="Number of dark rides ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
name="discord", name="discord",
field=models.CharField( field=models.CharField(blank=True, help_text="Discord username", max_length=100),
blank=True, help_text="Discord username", max_length=100
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
name="flat_ride_credits", name="flat_ride_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of flat rides ridden"),
default=0, help_text="Number of flat rides ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
@@ -972,9 +821,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
name="pronouns", name="pronouns",
field=models.CharField( field=models.CharField(blank=True, help_text="User's preferred pronouns", max_length=50),
blank=True, help_text="User's preferred pronouns", max_length=50
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
@@ -996,9 +843,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",
name="water_ride_credits", name="water_ride_credits",
field=models.IntegerField( field=models.IntegerField(default=0, help_text="Number of water rides ridden"),
default=0, help_text="Number of water rides ridden"
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="userprofileevent", model_name="userprofileevent",

View File

@@ -1,6 +1,7 @@
""" """
Mixins for authentication views. Mixins for authentication views.
""" """
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from apps.core.utils.turnstile import get_client_ip, validate_turnstile_token from apps.core.utils.turnstile import get_client_ip, validate_turnstile_token
@@ -24,14 +25,14 @@ class TurnstileMixin:
token = None token = None
# Check POST data (form submissions) # Check POST data (form submissions)
if hasattr(request, 'POST'): if hasattr(request, "POST"):
token = request.POST.get("cf-turnstile-response") token = request.POST.get("cf-turnstile-response")
# Check JSON body (API requests) # Check JSON body (API requests)
if not token and hasattr(request, 'data'): if not token and hasattr(request, "data"):
data = getattr(request, 'data', {}) data = getattr(request, "data", {})
if hasattr(data, 'get'): if hasattr(data, "get"):
token = data.get('turnstile_token') or data.get('cf-turnstile-response') token = data.get("turnstile_token") or data.get("cf-turnstile-response")
# Get client IP # Get client IP
ip = get_client_ip(request) ip = get_client_ip(request)
@@ -39,6 +40,6 @@ class TurnstileMixin:
# Validate the token # Validate the token
result = validate_turnstile_token(token, ip) result = validate_turnstile_token(token, ip)
if not result.get('success'): if not result.get("success"):
error_msg = result.get('error', 'Captcha verification failed. Please try again.') error_msg = result.get("error", "Captcha verification failed. Please try again.")
raise ValidationError(error_msg) raise ValidationError(error_msg)

View File

@@ -41,10 +41,7 @@ class User(AbstractUser):
max_length=10, max_length=10,
unique=True, unique=True,
editable=False, editable=False,
help_text=( help_text=("Unique identifier for this user that remains constant even if the " "username changes"),
"Unique identifier for this user that remains constant even if the "
"username changes"
),
) )
role = RichChoiceField( role = RichChoiceField(
@@ -55,13 +52,9 @@ class User(AbstractUser):
db_index=True, db_index=True,
help_text="User role (user, moderator, admin)", help_text="User role (user, moderator, admin)",
) )
is_banned = models.BooleanField( is_banned = models.BooleanField(default=False, db_index=True, help_text="Whether this user is banned")
default=False, db_index=True, help_text="Whether this user is banned"
)
ban_reason = models.TextField(blank=True, help_text="Reason for ban") ban_reason = models.TextField(blank=True, help_text="Reason for ban")
ban_date = models.DateTimeField( ban_date = models.DateTimeField(null=True, blank=True, help_text="Date the user was banned")
null=True, blank=True, help_text="Date the user was banned"
)
pending_email = models.EmailField(blank=True, null=True) pending_email = models.EmailField(blank=True, null=True)
theme_preference = RichChoiceField( theme_preference = RichChoiceField(
choice_group="theme_preferences", choice_group="theme_preferences",
@@ -72,12 +65,8 @@ class User(AbstractUser):
) )
# Notification preferences # Notification preferences
email_notifications = models.BooleanField( email_notifications = models.BooleanField(default=True, help_text="Whether to send email notifications")
default=True, help_text="Whether to send email notifications" push_notifications = models.BooleanField(default=False, help_text="Whether to send push notifications")
)
push_notifications = models.BooleanField(
default=False, help_text="Whether to send push notifications"
)
# Privacy settings # Privacy settings
privacy_level = RichChoiceField( privacy_level = RichChoiceField(
@@ -87,39 +76,17 @@ class User(AbstractUser):
default="public", default="public",
help_text="Overall privacy level", help_text="Overall privacy level",
) )
show_email = models.BooleanField( show_email = models.BooleanField(default=False, help_text="Whether to show email on profile")
default=False, help_text="Whether to show email on profile" show_real_name = models.BooleanField(default=True, help_text="Whether to show real name on profile")
) show_join_date = models.BooleanField(default=True, help_text="Whether to show join date on profile")
show_real_name = models.BooleanField( show_statistics = models.BooleanField(default=True, help_text="Whether to show statistics on profile")
default=True, help_text="Whether to show real name on profile" show_reviews = models.BooleanField(default=True, help_text="Whether to show reviews on profile")
) show_photos = models.BooleanField(default=True, help_text="Whether to show photos on profile")
show_join_date = models.BooleanField( show_top_lists = models.BooleanField(default=True, help_text="Whether to show top lists on profile")
default=True, help_text="Whether to show join date on profile" allow_friend_requests = models.BooleanField(default=True, help_text="Whether to allow friend requests")
) allow_messages = models.BooleanField(default=True, help_text="Whether to allow direct messages")
show_statistics = models.BooleanField( allow_profile_comments = models.BooleanField(default=False, help_text="Whether to allow profile comments")
default=True, help_text="Whether to show statistics on profile" search_visibility = models.BooleanField(default=True, help_text="Whether profile appears in search results")
)
show_reviews = models.BooleanField(
default=True, help_text="Whether to show reviews on profile"
)
show_photos = models.BooleanField(
default=True, help_text="Whether to show photos on profile"
)
show_top_lists = models.BooleanField(
default=True, help_text="Whether to show top lists on profile"
)
allow_friend_requests = models.BooleanField(
default=True, help_text="Whether to allow friend requests"
)
allow_messages = models.BooleanField(
default=True, help_text="Whether to allow direct messages"
)
allow_profile_comments = models.BooleanField(
default=False, help_text="Whether to allow profile comments"
)
search_visibility = models.BooleanField(
default=True, help_text="Whether profile appears in search results"
)
activity_visibility = RichChoiceField( activity_visibility = RichChoiceField(
choice_group="privacy_levels", choice_group="privacy_levels",
domain="accounts", domain="accounts",
@@ -129,21 +96,11 @@ class User(AbstractUser):
) )
# Security settings # Security settings
two_factor_enabled = models.BooleanField( two_factor_enabled = models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled")
default=False, help_text="Whether two-factor authentication is enabled" login_notifications = models.BooleanField(default=True, help_text="Whether to send login notifications")
) session_timeout = models.IntegerField(default=30, help_text="Session timeout in days")
login_notifications = models.BooleanField( login_history_retention = models.IntegerField(default=90, help_text="How long to retain login history (days)")
default=True, help_text="Whether to send login notifications" last_password_change = models.DateTimeField(auto_now_add=True, help_text="When the password was last changed")
)
session_timeout = models.IntegerField(
default=30, help_text="Session timeout in days"
)
login_history_retention = models.IntegerField(
default=90, help_text="How long to retain login history (days)"
)
last_password_change = models.DateTimeField(
auto_now_add=True, help_text="When the password was last changed"
)
# Display name - core user data for better performance # Display name - core user data for better performance
display_name = models.CharField( display_name = models.CharField(
@@ -179,13 +136,13 @@ class User(AbstractUser):
verbose_name = "User" verbose_name = "User"
verbose_name_plural = "Users" verbose_name_plural = "Users"
indexes = [ indexes = [
models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'), models.Index(fields=["is_banned", "role"], name="accounts_user_banned_role_idx"),
] ]
constraints = [ constraints = [
models.CheckConstraint( models.CheckConstraint(
name='user_ban_consistency', name="user_ban_consistency",
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False), check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set' violation_error_message="Banned users must have a ban_date set",
), ),
] ]
@@ -224,14 +181,10 @@ class UserProfile(models.Model):
related_name="user_profiles", related_name="user_profiles",
help_text="User's avatar image", help_text="User's avatar image",
) )
pronouns = models.CharField( pronouns = models.CharField(max_length=50, blank=True, help_text="User's preferred pronouns")
max_length=50, blank=True, help_text="User's preferred pronouns"
)
bio = models.TextField(max_length=500, blank=True, help_text="User biography") bio = models.TextField(max_length=500, blank=True, help_text="User biography")
location = models.CharField( location = models.CharField(max_length=100, blank=True, help_text="User's location (City, Country)")
max_length=100, blank=True, help_text="User's location (City, Country)"
)
unit_system = RichChoiceField( unit_system = RichChoiceField(
choice_group="unit_systems", choice_group="unit_systems",
domain="accounts", domain="accounts",
@@ -247,18 +200,10 @@ class UserProfile(models.Model):
discord = models.CharField(max_length=100, blank=True, help_text="Discord username") discord = models.CharField(max_length=100, blank=True, help_text="Discord username")
# Ride statistics # Ride statistics
coaster_credits = models.IntegerField( coaster_credits = models.IntegerField(default=0, help_text="Number of roller coasters ridden")
default=0, help_text="Number of roller coasters ridden" dark_ride_credits = models.IntegerField(default=0, help_text="Number of dark rides ridden")
) flat_ride_credits = models.IntegerField(default=0, help_text="Number of flat rides ridden")
dark_ride_credits = models.IntegerField( water_ride_credits = models.IntegerField(default=0, help_text="Number of water rides ridden")
default=0, help_text="Number of dark rides ridden"
)
flat_ride_credits = models.IntegerField(
default=0, help_text="Number of flat rides ridden"
)
water_ride_credits = models.IntegerField(
default=0, help_text="Number of water rides ridden"
)
def get_avatar_url(self): def get_avatar_url(self):
""" """
@@ -266,12 +211,12 @@ class UserProfile(models.Model):
""" """
if self.avatar and self.avatar.is_uploaded: if self.avatar and self.avatar.is_uploaded:
# Try to get avatar variant first, fallback to public # Try to get avatar variant first, fallback to public
avatar_url = self.avatar.get_url('avatar') avatar_url = self.avatar.get_url("avatar")
if avatar_url: if avatar_url:
return avatar_url return avatar_url
# Fallback to public variant # Fallback to public variant
public_url = self.avatar.get_url('public') public_url = self.avatar.get_url("public")
if public_url: if public_url:
return public_url return public_url
@@ -298,10 +243,10 @@ class UserProfile(models.Model):
variants = {} variants = {}
# Try to get specific variants # Try to get specific variants
thumbnail_url = self.avatar.get_url('thumbnail') thumbnail_url = self.avatar.get_url("thumbnail")
avatar_url = self.avatar.get_url('avatar') avatar_url = self.avatar.get_url("avatar")
large_url = self.avatar.get_url('large') large_url = self.avatar.get_url("large")
public_url = self.avatar.get_url('public') public_url = self.avatar.get_url("public")
# Use specific variants if available, otherwise fallback to public or first available # Use specific variants if available, otherwise fallback to public or first available
fallback_url = public_url fallback_url = public_url
@@ -354,18 +299,10 @@ class EmailVerification(models.Model):
on_delete=models.CASCADE, on_delete=models.CASCADE,
help_text="User this verification belongs to", help_text="User this verification belongs to",
) )
token = models.CharField( token = models.CharField(max_length=64, unique=True, help_text="Verification token")
max_length=64, unique=True, help_text="Verification token" created_at = models.DateTimeField(auto_now_add=True, help_text="When this verification was created")
) updated_at = models.DateTimeField(auto_now=True, help_text="When this verification was last updated")
created_at = models.DateTimeField( last_sent = models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent")
auto_now_add=True, help_text="When this verification was created"
)
updated_at = models.DateTimeField(
auto_now=True, help_text="When this verification was last updated"
)
last_sent = models.DateTimeField(
auto_now_add=True, help_text="When the verification email was last sent"
)
def __str__(self): def __str__(self):
return f"Email verification for {self.user.username}" return f"Email verification for {self.user.username}"
@@ -383,9 +320,7 @@ class PasswordReset(models.Model):
help_text="User requesting password reset", help_text="User requesting password reset",
) )
token = models.CharField(max_length=64, help_text="Reset token") token = models.CharField(max_length=64, help_text="Reset token")
created_at = models.DateTimeField( created_at = models.DateTimeField(auto_now_add=True, help_text="When this reset was requested")
auto_now_add=True, help_text="When this reset was requested"
)
expires_at = models.DateTimeField(help_text="When this reset token expires") expires_at = models.DateTimeField(help_text="When this reset token expires")
used = models.BooleanField(default=False, help_text="Whether this token has been used") used = models.BooleanField(default=False, help_text="Whether this token has been used")
@@ -397,8 +332,6 @@ class PasswordReset(models.Model):
verbose_name_plural = "Password Resets" verbose_name_plural = "Password Resets"
@pghistory.track() @pghistory.track()
class UserDeletionRequest(models.Model): class UserDeletionRequest(models.Model):
""" """
@@ -409,9 +342,7 @@ class UserDeletionRequest(models.Model):
provide the correct code. provide the correct code.
""" """
user = models.OneToOneField( user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="deletion_request")
User, on_delete=models.CASCADE, related_name="deletion_request"
)
verification_code = models.CharField( verification_code = models.CharField(
max_length=32, max_length=32,
@@ -422,21 +353,13 @@ class UserDeletionRequest(models.Model):
created_at = models.DateTimeField(auto_now_add=True) created_at = models.DateTimeField(auto_now_add=True)
expires_at = models.DateTimeField(help_text="When this deletion request expires") expires_at = models.DateTimeField(help_text="When this deletion request expires")
email_sent_at = models.DateTimeField( email_sent_at = models.DateTimeField(null=True, blank=True, help_text="When the verification email was sent")
null=True, blank=True, help_text="When the verification email was sent"
)
attempts = models.PositiveIntegerField( attempts = models.PositiveIntegerField(default=0, help_text="Number of verification attempts made")
default=0, help_text="Number of verification attempts made"
)
max_attempts = models.PositiveIntegerField( max_attempts = models.PositiveIntegerField(default=5, help_text="Maximum number of verification attempts allowed")
default=5, help_text="Maximum number of verification attempts allowed"
)
is_used = models.BooleanField( is_used = models.BooleanField(default=False, help_text="Whether this deletion request has been used")
default=False, help_text="Whether this deletion request has been used"
)
class Meta: class Meta:
verbose_name = "User Deletion Request" verbose_name = "User Deletion Request"
@@ -466,9 +389,7 @@ class UserDeletionRequest(models.Model):
"""Generate a unique 8-character verification code.""" """Generate a unique 8-character verification code."""
while True: while True:
# Generate a random 8-character alphanumeric code # Generate a random 8-character alphanumeric code
code = "".join( code = "".join(secrets.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") for _ in range(8))
secrets.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") for _ in range(8)
)
# Ensure it's unique # Ensure it's unique
if not UserDeletionRequest.objects.filter(verification_code=code).exists(): if not UserDeletionRequest.objects.filter(verification_code=code).exists():
@@ -480,11 +401,7 @@ class UserDeletionRequest(models.Model):
def is_valid(self): def is_valid(self):
"""Check if this deletion request is still valid.""" """Check if this deletion request is still valid."""
return ( return not self.is_used and not self.is_expired() and self.attempts < self.max_attempts
not self.is_used
and not self.is_expired()
and self.attempts < self.max_attempts
)
def increment_attempts(self): def increment_attempts(self):
"""Increment the number of verification attempts.""" """Increment the number of verification attempts."""
@@ -499,9 +416,7 @@ class UserDeletionRequest(models.Model):
@classmethod @classmethod
def cleanup_expired(cls): def cleanup_expired(cls):
"""Remove expired deletion requests.""" """Remove expired deletion requests."""
expired_requests = cls.objects.filter( expired_requests = cls.objects.filter(expires_at__lt=timezone.now(), is_used=False)
expires_at__lt=timezone.now(), is_used=False
)
count = expired_requests.count() count = expired_requests.count()
expired_requests.delete() expired_requests.delete()
return count return count
@@ -541,9 +456,7 @@ class UserNotification(TrackedModel):
blank=True, blank=True,
help_text="Type of related object", help_text="Type of related object",
) )
object_id = models.PositiveIntegerField( object_id = models.PositiveIntegerField(null=True, blank=True, help_text="ID of related object")
null=True, blank=True, help_text="ID of related object"
)
related_object = GenericForeignKey("content_type", "object_id") related_object = GenericForeignKey("content_type", "object_id")
# Metadata # Metadata
@@ -555,24 +468,14 @@ class UserNotification(TrackedModel):
) )
# Status tracking # Status tracking
is_read = models.BooleanField( is_read = models.BooleanField(default=False, help_text="Whether this notification has been read")
default=False, help_text="Whether this notification has been read" read_at = models.DateTimeField(null=True, blank=True, help_text="When this notification was read")
)
read_at = models.DateTimeField(
null=True, blank=True, help_text="When this notification was read"
)
# Delivery tracking # Delivery tracking
email_sent = models.BooleanField(default=False, help_text="Whether email was sent") email_sent = models.BooleanField(default=False, help_text="Whether email was sent")
email_sent_at = models.DateTimeField( email_sent_at = models.DateTimeField(null=True, blank=True, help_text="When email was sent")
null=True, blank=True, help_text="When email was sent" push_sent = models.BooleanField(default=False, help_text="Whether push notification was sent")
) push_sent_at = models.DateTimeField(null=True, blank=True, help_text="When push notification was sent")
push_sent = models.BooleanField(
default=False, help_text="Whether push notification was sent"
)
push_sent_at = models.DateTimeField(
null=True, blank=True, help_text="When push notification was sent"
)
# Additional data (JSON field for flexibility) # Additional data (JSON field for flexibility)
extra_data = models.JSONField(default=dict, blank=True) extra_data = models.JSONField(default=dict, blank=True)
@@ -619,9 +522,7 @@ class UserNotification(TrackedModel):
@classmethod @classmethod
def mark_all_read_for_user(cls, user): def mark_all_read_for_user(cls, user):
"""Mark all notifications as read for a specific user.""" """Mark all notifications as read for a specific user."""
return cls.objects.filter(user=user, is_read=False).update( return cls.objects.filter(user=user, is_read=False).update(is_read=True, read_at=timezone.now())
is_read=True, read_at=timezone.now()
)
@pghistory.track() @pghistory.track()

View File

@@ -27,16 +27,10 @@ def user_profile_optimized(*, user_id: int) -> Any:
User.DoesNotExist: If user doesn't exist User.DoesNotExist: If user doesn't exist
""" """
return ( return (
User.objects.prefetch_related( User.objects.prefetch_related("park_reviews", "ride_reviews", "socialaccount_set")
"park_reviews", "ride_reviews", "socialaccount_set"
)
.annotate( .annotate(
park_review_count=Count( park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
"park_reviews", filter=Q(park_reviews__is_published=True) ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
total_review_count=F("park_review_count") + F("ride_review_count"), total_review_count=F("park_review_count") + F("ride_review_count"),
) )
.get(id=user_id) .get(id=user_id)
@@ -53,12 +47,8 @@ def active_users_with_stats() -> QuerySet:
return ( return (
User.objects.filter(is_active=True) User.objects.filter(is_active=True)
.annotate( .annotate(
park_review_count=Count( park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
"park_reviews", filter=Q(park_reviews__is_published=True) ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
total_review_count=F("park_review_count") + F("ride_review_count"), total_review_count=F("park_review_count") + F("ride_review_count"),
) )
.order_by("-total_review_count") .order_by("-total_review_count")
@@ -112,12 +102,8 @@ def top_reviewers(*, limit: int = 10) -> QuerySet:
return ( return (
User.objects.filter(is_active=True) User.objects.filter(is_active=True)
.annotate( .annotate(
park_review_count=Count( park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
"park_reviews", filter=Q(park_reviews__is_published=True) ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
total_review_count=F("park_review_count") + F("ride_review_count"), total_review_count=F("park_review_count") + F("ride_review_count"),
) )
.filter(total_review_count__gt=0) .filter(total_review_count__gt=0)
@@ -159,9 +145,9 @@ def users_by_registration_date(*, start_date, end_date) -> QuerySet:
Returns: Returns:
QuerySet of users registered in the date range QuerySet of users registered in the date range
""" """
return User.objects.filter( return User.objects.filter(date_joined__date__gte=start_date, date_joined__date__lte=end_date).order_by(
date_joined__date__gte=start_date, date_joined__date__lte=end_date "-date_joined"
).order_by("-date_joined") )
def user_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet: def user_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet:
@@ -176,8 +162,7 @@ def user_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet:
QuerySet of matching users for autocomplete QuerySet of matching users for autocomplete
""" """
return User.objects.filter( return User.objects.filter(
Q(username__icontains=query) Q(username__icontains=query) | Q(display_name__icontains=query),
| Q(display_name__icontains=query),
is_active=True, is_active=True,
).order_by("username")[:limit] ).order_by("username")[:limit]
@@ -210,11 +195,7 @@ def user_statistics_summary() -> dict[str, Any]:
# Users with reviews # Users with reviews
users_with_reviews = ( users_with_reviews = (
User.objects.filter( User.objects.filter(Q(park_reviews__isnull=False) | Q(ride_reviews__isnull=False)).distinct().count()
Q(park_reviews__isnull=False) | Q(ride_reviews__isnull=False)
)
.distinct()
.count()
) )
# Recent registrations (last 30 days) # Recent registrations (last 30 days)
@@ -228,9 +209,7 @@ def user_statistics_summary() -> dict[str, Any]:
"staff_users": staff_users, "staff_users": staff_users,
"users_with_reviews": users_with_reviews, "users_with_reviews": users_with_reviews,
"recent_registrations": recent_registrations, "recent_registrations": recent_registrations,
"review_participation_rate": ( "review_participation_rate": ((users_with_reviews / total_users * 100) if total_users > 0 else 0),
(users_with_reviews / total_users * 100) if total_users > 0 else 0
),
} }
@@ -241,11 +220,7 @@ def users_needing_email_verification() -> QuerySet:
Returns: Returns:
QuerySet of users with unverified emails QuerySet of users with unverified emails
""" """
return ( return User.objects.filter(is_active=True, emailaddress__verified=False).distinct().order_by("date_joined")
User.objects.filter(is_active=True, emailaddress__verified=False)
.distinct()
.order_by("date_joined")
)
def users_by_review_activity(*, min_reviews: int = 1) -> QuerySet: def users_by_review_activity(*, min_reviews: int = 1) -> QuerySet:
@@ -260,12 +235,8 @@ def users_by_review_activity(*, min_reviews: int = 1) -> QuerySet:
""" """
return ( return (
User.objects.annotate( User.objects.annotate(
park_review_count=Count( park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
"park_reviews", filter=Q(park_reviews__is_published=True) ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
total_review_count=F("park_review_count") + F("ride_review_count"), total_review_count=F("park_review_count") + F("ride_review_count"),
) )
.filter(total_review_count__gte=min_reviews) .filter(total_review_count__gte=min_reviews)

View File

@@ -62,12 +62,8 @@ class LoginSerializer(serializers.Serializer):
Serializer for user login Serializer for user login
""" """
username = serializers.CharField( username = serializers.CharField(max_length=254, help_text="Username or email address")
max_length=254, help_text="Username or email address" password = serializers.CharField(max_length=128, style={"input_type": "password"}, trim_whitespace=False)
)
password = serializers.CharField(
max_length=128, style={"input_type": "password"}, trim_whitespace=False
)
def validate(self, attrs): def validate(self, attrs):
username = attrs.get("username") username = attrs.get("username")
@@ -89,9 +85,7 @@ class SignupSerializer(serializers.ModelSerializer):
validators=[validate_password], validators=[validate_password],
style={"input_type": "password"}, style={"input_type": "password"},
) )
password_confirm = serializers.CharField( password_confirm = serializers.CharField(write_only=True, style={"input_type": "password"})
write_only=True, style={"input_type": "password"}
)
class Meta: class Meta:
model = User model = User
@@ -118,9 +112,7 @@ class SignupSerializer(serializers.ModelSerializer):
def validate_username(self, value): def validate_username(self, value):
"""Validate username is unique""" """Validate username is unique"""
if UserModel.objects.filter(username=value).exists(): if UserModel.objects.filter(username=value).exists():
raise serializers.ValidationError( raise serializers.ValidationError("A user with this username already exists.")
"A user with this username already exists."
)
return value return value
def validate(self, attrs): def validate(self, attrs):
@@ -129,9 +121,7 @@ class SignupSerializer(serializers.ModelSerializer):
password_confirm = attrs.get("password_confirm") password_confirm = attrs.get("password_confirm")
if password != password_confirm: if password != password_confirm:
raise serializers.ValidationError( raise serializers.ValidationError({"password_confirm": "Passwords do not match."})
{"password_confirm": "Passwords do not match."}
)
return attrs return attrs
@@ -194,9 +184,7 @@ class PasswordResetSerializer(serializers.Serializer):
"site_name": site.name, "site_name": site.name,
} }
email_html = render_to_string( email_html = render_to_string("accounts/email/password_reset.html", context)
"accounts/email/password_reset.html", context
)
# Narrow and validate email type for the static checker # Narrow and validate email type for the static checker
email = getattr(self.user, "email", None) email = getattr(self.user, "email", None)
@@ -218,15 +206,11 @@ class PasswordChangeSerializer(serializers.Serializer):
Serializer for password change Serializer for password change
""" """
old_password = serializers.CharField( old_password = serializers.CharField(max_length=128, style={"input_type": "password"})
max_length=128, style={"input_type": "password"}
)
new_password = serializers.CharField( new_password = serializers.CharField(
max_length=128, validators=[validate_password], style={"input_type": "password"} max_length=128, validators=[validate_password], style={"input_type": "password"}
) )
new_password_confirm = serializers.CharField( new_password_confirm = serializers.CharField(max_length=128, style={"input_type": "password"})
max_length=128, style={"input_type": "password"}
)
def validate_old_password(self, value): def validate_old_password(self, value):
"""Validate old password is correct""" """Validate old password is correct"""
@@ -241,9 +225,7 @@ class PasswordChangeSerializer(serializers.Serializer):
new_password_confirm = attrs.get("new_password_confirm") new_password_confirm = attrs.get("new_password_confirm")
if new_password != new_password_confirm: if new_password != new_password_confirm:
raise serializers.ValidationError( raise serializers.ValidationError({"new_password_confirm": "New passwords do not match."})
{"new_password_confirm": "New passwords do not match."}
)
return attrs return attrs

View File

@@ -81,21 +81,15 @@ class AccountService:
""" """
# Verify old password # Verify old password
if not user.check_password(old_password): if not user.check_password(old_password):
logger.warning( logger.warning(f"Password change failed: incorrect current password for user {user.id}")
f"Password change failed: incorrect current password for user {user.id}" return {"success": False, "message": "Current password is incorrect", "redirect_url": None}
)
return {
'success': False,
'message': "Current password is incorrect",
'redirect_url': None
}
# Validate new password # Validate new password
if not AccountService.validate_password(new_password): if not AccountService.validate_password(new_password):
return { return {
'success': False, "success": False,
'message': "Password must be at least 8 characters and contain uppercase, lowercase, and numbers", "message": "Password must be at least 8 characters and contain uppercase, lowercase, and numbers",
'redirect_url': None "redirect_url": None,
} }
# Update password # Update password
@@ -111,9 +105,9 @@ class AccountService:
logger.info(f"Password changed successfully for user {user.id}") logger.info(f"Password changed successfully for user {user.id}")
return { return {
'success': True, "success": True,
'message': "Password changed successfully. Please check your email for confirmation.", "message": "Password changed successfully. Please check your email for confirmation.",
'redirect_url': None "redirect_url": None,
} }
@staticmethod @staticmethod
@@ -125,9 +119,7 @@ class AccountService:
"site_name": site.name, "site_name": site.name,
} }
email_html = render_to_string( email_html = render_to_string("accounts/email/password_change_confirmation.html", context)
"accounts/email/password_change_confirmation.html", context
)
try: try:
EmailService.send_email( EmailService.send_email(
@@ -166,26 +158,17 @@ class AccountService:
} }
""" """
if not new_email: if not new_email:
return { return {"success": False, "message": "New email is required"}
'success': False,
'message': "New email is required"
}
# Check if email is already in use # Check if email is already in use
if User.objects.filter(email=new_email).exclude(id=user.id).exists(): if User.objects.filter(email=new_email).exclude(id=user.id).exists():
return { return {"success": False, "message": "This email address is already in use"}
'success': False,
'message': "This email address is already in use"
}
# Generate verification token # Generate verification token
token = get_random_string(64) token = get_random_string(64)
# Create or update email verification record # Create or update email verification record
EmailVerification.objects.update_or_create( EmailVerification.objects.update_or_create(user=user, defaults={"token": token})
user=user,
defaults={"token": token}
)
# Store pending email # Store pending email
user.pending_email = new_email user.pending_email = new_email
@@ -196,18 +179,10 @@ class AccountService:
logger.info(f"Email change initiated for user {user.id} to {new_email}") logger.info(f"Email change initiated for user {user.id} to {new_email}")
return { return {"success": True, "message": "Verification email sent to your new email address"}
'success': True,
'message': "Verification email sent to your new email address"
}
@staticmethod @staticmethod
def _send_email_verification( def _send_email_verification(request: HttpRequest, user: User, new_email: str, token: str) -> None:
request: HttpRequest,
user: User,
new_email: str,
token: str
) -> None:
"""Send email verification for email change.""" """Send email verification for email change."""
from django.urls import reverse from django.urls import reverse
@@ -245,22 +220,14 @@ class AccountService:
Dictionary with success status and message Dictionary with success status and message
""" """
try: try:
verification = EmailVerification.objects.select_related("user").get( verification = EmailVerification.objects.select_related("user").get(token=token)
token=token
)
except EmailVerification.DoesNotExist: except EmailVerification.DoesNotExist:
return { return {"success": False, "message": "Invalid or expired verification token"}
'success': False,
'message': "Invalid or expired verification token"
}
user = verification.user user = verification.user
if not user.pending_email: if not user.pending_email:
return { return {"success": False, "message": "No pending email change found"}
'success': False,
'message': "No pending email change found"
}
# Update email # Update email
old_email = user.email old_email = user.email
@@ -273,10 +240,7 @@ class AccountService:
logger.info(f"Email changed for user {user.id} from {old_email} to {user.email}") logger.info(f"Email changed for user {user.id} from {old_email} to {user.email}")
return { return {"success": True, "message": "Email address updated successfully"}
'success': True,
'message': "Email address updated successfully"
}
class UserDeletionService: class UserDeletionService:
@@ -337,39 +301,17 @@ class UserDeletionService:
# Count submissions before transfer # Count submissions before transfer
submission_counts = { submission_counts = {
"park_reviews": getattr( "park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(),
user, "park_reviews", user.__class__.objects.none() "ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(),
).count(), "uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(),
"ride_reviews": getattr( "uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(),
user, "ride_reviews", user.__class__.objects.none() "top_lists": getattr(user, "top_lists", user.__class__.objects.none()).count(),
).count(), "edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(),
"uploaded_park_photos": getattr( "photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(),
user, "uploaded_park_photos", user.__class__.objects.none() "moderated_park_reviews": getattr(user, "moderated_park_reviews", user.__class__.objects.none()).count(),
).count(), "moderated_ride_reviews": getattr(user, "moderated_ride_reviews", user.__class__.objects.none()).count(),
"uploaded_ride_photos": getattr( "handled_submissions": getattr(user, "handled_submissions", user.__class__.objects.none()).count(),
user, "uploaded_ride_photos", user.__class__.objects.none() "handled_photos": getattr(user, "handled_photos", user.__class__.objects.none()).count(),
).count(),
"top_lists": getattr(
user, "top_lists", user.__class__.objects.none()
).count(),
"edit_submissions": getattr(
user, "edit_submissions", user.__class__.objects.none()
).count(),
"photo_submissions": getattr(
user, "photo_submissions", user.__class__.objects.none()
).count(),
"moderated_park_reviews": getattr(
user, "moderated_park_reviews", user.__class__.objects.none()
).count(),
"moderated_ride_reviews": getattr(
user, "moderated_ride_reviews", user.__class__.objects.none()
).count(),
"handled_submissions": getattr(
user, "handled_submissions", user.__class__.objects.none()
).count(),
"handled_photos": getattr(
user, "handled_photos", user.__class__.objects.none()
).count(),
} }
# Transfer all submissions to deleted user # Transfer all submissions to deleted user
@@ -440,11 +382,17 @@ class UserDeletionService:
return False, "Cannot delete the system deleted user placeholder" return False, "Cannot delete the system deleted user placeholder"
if user.is_superuser: if user.is_superuser:
return False, "Superuser accounts cannot be deleted for security reasons. Please contact system administrator or remove superuser privileges first." return (
False,
"Superuser accounts cannot be deleted for security reasons. Please contact system administrator or remove superuser privileges first.",
)
# Check if user has critical admin role # Check if user has critical admin role
if user.role == User.Roles.ADMIN and user.is_staff: if user.role == User.Roles.ADMIN and user.is_staff:
return False, "Admin accounts with staff privileges cannot be deleted. Please remove admin privileges first or contact system administrator." return (
False,
"Admin accounts with staff privileges cannot be deleted. Please remove admin privileges first or contact system administrator.",
)
# Add any other business rules here # Add any other business rules here
@@ -492,9 +440,7 @@ class UserDeletionService:
site = Site.objects.get_current() site = Site.objects.get_current()
except Site.DoesNotExist: except Site.DoesNotExist:
# Fallback to default site # Fallback to default site
site = Site.objects.get_or_create( site = Site.objects.get_or_create(id=1, defaults={"domain": "localhost:8000", "name": "localhost:8000"})[0]
id=1, defaults={"domain": "localhost:8000", "name": "localhost:8000"}
)[0]
# Prepare email context # Prepare email context
context = { context = {
@@ -502,9 +448,7 @@ class UserDeletionService:
"verification_code": deletion_request.verification_code, "verification_code": deletion_request.verification_code,
"expires_at": deletion_request.expires_at, "expires_at": deletion_request.expires_at,
"site_name": getattr(settings, "SITE_NAME", "ThrillWiki"), "site_name": getattr(settings, "SITE_NAME", "ThrillWiki"),
"frontend_domain": getattr( "frontend_domain": getattr(settings, "FRONTEND_DOMAIN", "http://localhost:3000"),
settings, "FRONTEND_DOMAIN", "http://localhost:3000"
),
} }
# Render email content # Render email content
@@ -564,11 +508,9 @@ The ThrillWiki Team
ValueError: If verification fails ValueError: If verification fails
""" """
try: try:
deletion_request = UserDeletionRequest.objects.get( deletion_request = UserDeletionRequest.objects.get(verification_code=verification_code)
verification_code=verification_code
)
except UserDeletionRequest.DoesNotExist: except UserDeletionRequest.DoesNotExist:
raise ValueError("Invalid verification code") raise ValueError("Invalid verification code") from None
# Check if request is still valid # Check if request is still valid
if not deletion_request.is_valid(): if not deletion_request.is_valid():

View File

@@ -8,4 +8,4 @@ including social provider management, user authentication, and profile services.
from .social_provider_service import SocialProviderService from .social_provider_service import SocialProviderService
from .user_deletion_service import UserDeletionService from .user_deletion_service import UserDeletionService
__all__ = ['SocialProviderService', 'UserDeletionService'] __all__ = ["SocialProviderService", "UserDeletionService"]

View File

@@ -139,7 +139,9 @@ class NotificationService:
UserNotification: The created notification UserNotification: The created notification
""" """
title = f"Your {submission_type} needs attention" title = f"Your {submission_type} needs attention"
message = f"Your {submission_type} submission has been reviewed and needs some changes before it can be approved." message = (
f"Your {submission_type} submission has been reviewed and needs some changes before it can be approved."
)
message += f"\n\nReason: {rejection_reason}" message += f"\n\nReason: {rejection_reason}"
if additional_message: if additional_message:
@@ -216,9 +218,7 @@ class NotificationService:
preferences = NotificationPreference.objects.create(user=user) preferences = NotificationPreference.objects.create(user=user)
# Send email notification if enabled # Send email notification if enabled
if preferences.should_send_notification( if preferences.should_send_notification(notification.notification_type, "email"):
notification.notification_type, "email"
):
NotificationService._send_email_notification(notification) NotificationService._send_email_notification(notification)
# Toast notifications are always created (the notification object itself) # Toast notifications are always created (the notification object itself)
@@ -261,14 +261,10 @@ class NotificationService:
notification.email_sent_at = timezone.now() notification.email_sent_at = timezone.now()
notification.save(update_fields=["email_sent", "email_sent_at"]) notification.save(update_fields=["email_sent", "email_sent_at"])
logger.info( logger.info(f"Email notification sent to {user.email} for notification {notification.id}")
f"Email notification sent to {user.email} for notification {notification.id}"
)
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to send email notification {notification.id}: {str(e)}")
f"Failed to send email notification {notification.id}: {str(e)}"
)
@staticmethod @staticmethod
def get_user_notifications( def get_user_notifications(
@@ -298,9 +294,7 @@ class NotificationService:
queryset = queryset.filter(notification_type__in=notification_types) queryset = queryset.filter(notification_type__in=notification_types)
# Exclude expired notifications # Exclude expired notifications
queryset = queryset.filter( queryset = queryset.filter(models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=timezone.now()))
models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=timezone.now())
)
if limit: if limit:
queryset = queryset[:limit] queryset = queryset[:limit]
@@ -308,9 +302,7 @@ class NotificationService:
return list(queryset) return list(queryset)
@staticmethod @staticmethod
def mark_notifications_read( def mark_notifications_read(user: User, notification_ids: list[int] | None = None) -> int:
user: User, notification_ids: list[int] | None = None
) -> int:
""" """
Mark notifications as read for a user. Mark notifications as read for a user.
@@ -341,9 +333,7 @@ class NotificationService:
""" """
cutoff_date = timezone.now() - timedelta(days=days) cutoff_date = timezone.now() - timedelta(days=days)
old_notifications = UserNotification.objects.filter( old_notifications = UserNotification.objects.filter(is_read=True, read_at__lt=cutoff_date)
is_read=True, read_at__lt=cutoff_date
)
count = old_notifications.count() count = old_notifications.count()
old_notifications.delete() old_notifications.delete()

View File

@@ -40,23 +40,20 @@ class SocialProviderService:
""" """
try: try:
# Count remaining social accounts after disconnection # Count remaining social accounts after disconnection
remaining_social_accounts = user.socialaccount_set.exclude( remaining_social_accounts = user.socialaccount_set.exclude(provider=provider).count()
provider=provider
).count()
# Check if user has email/password auth # Check if user has email/password auth
has_password_auth = ( has_password_auth = user.email and user.has_usable_password() and bool(user.password) # Not empty/unusable
user.email and
user.has_usable_password() and
bool(user.password) # Not empty/unusable
)
# Allow disconnection only if alternative auth exists # Allow disconnection only if alternative auth exists
can_disconnect = remaining_social_accounts > 0 or has_password_auth can_disconnect = remaining_social_accounts > 0 or has_password_auth
if not can_disconnect: if not can_disconnect:
if remaining_social_accounts == 0 and not has_password_auth: if remaining_social_accounts == 0 and not has_password_auth:
return False, "Cannot disconnect your only authentication method. Please set up a password or connect another social provider first." return (
False,
"Cannot disconnect your only authentication method. Please set up a password or connect another social provider first.",
)
elif not has_password_auth: elif not has_password_auth:
return False, "Please set up email/password authentication before disconnecting this provider." return False, "Please set up email/password authentication before disconnecting this provider."
else: else:
@@ -65,8 +62,7 @@ class SocialProviderService:
return True, "Provider can be safely disconnected." return True, "Provider can be safely disconnected."
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error checking disconnect permission for user {user.id}, provider {provider}: {e}")
f"Error checking disconnect permission for user {user.id}, provider {provider}: {e}")
return False, "Unable to verify disconnection safety. Please try again." return False, "Unable to verify disconnection safety. Please try again."
@staticmethod @staticmethod
@@ -84,18 +80,16 @@ class SocialProviderService:
connected_providers = [] connected_providers = []
for social_account in user.socialaccount_set.all(): for social_account in user.socialaccount_set.all():
can_disconnect, reason = SocialProviderService.can_disconnect_provider( can_disconnect, reason = SocialProviderService.can_disconnect_provider(user, social_account.provider)
user, social_account.provider
)
provider_info = { provider_info = {
'provider': social_account.provider, "provider": social_account.provider,
'provider_name': social_account.get_provider().name, "provider_name": social_account.get_provider().name,
'uid': social_account.uid, "uid": social_account.uid,
'date_joined': social_account.date_joined, "date_joined": social_account.date_joined,
'can_disconnect': can_disconnect, "can_disconnect": can_disconnect,
'disconnect_reason': reason if not can_disconnect else None, "disconnect_reason": reason if not can_disconnect else None,
'extra_data': social_account.extra_data "extra_data": social_account.extra_data,
} }
connected_providers.append(provider_info) connected_providers.append(provider_info)
@@ -122,28 +116,25 @@ class SocialProviderService:
available_providers = [] available_providers = []
# Get all social apps configured for this site # Get all social apps configured for this site
social_apps = SocialApp.objects.filter(sites=site).order_by('provider') social_apps = SocialApp.objects.filter(sites=site).order_by("provider")
for social_app in social_apps: for social_app in social_apps:
try: try:
provider = registry.by_id(social_app.provider) provider = registry.by_id(social_app.provider)
provider_info = { provider_info = {
'id': social_app.provider, "id": social_app.provider,
'name': provider.name, "name": provider.name,
'auth_url': request.build_absolute_uri( "auth_url": request.build_absolute_uri(f"/accounts/{social_app.provider}/login/"),
f'/accounts/{social_app.provider}/login/' "connect_url": request.build_absolute_uri(
f"/api/v1/auth/social/connect/{social_app.provider}/"
), ),
'connect_url': request.build_absolute_uri(
f'/api/v1/auth/social/connect/{social_app.provider}/'
)
} }
available_providers.append(provider_info) available_providers.append(provider_info)
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"Error processing provider {social_app.provider}: {e}")
f"Error processing provider {social_app.provider}: {e}")
continue continue
return available_providers return available_providers
@@ -166,8 +157,7 @@ class SocialProviderService:
""" """
try: try:
# First check if disconnection is allowed # First check if disconnection is allowed
can_disconnect, reason = SocialProviderService.can_disconnect_provider( can_disconnect, reason = SocialProviderService.can_disconnect_provider(user, provider)
user, provider)
if not can_disconnect: if not can_disconnect:
return False, reason return False, reason
@@ -182,8 +172,7 @@ class SocialProviderService:
deleted_count = social_accounts.count() deleted_count = social_accounts.count()
social_accounts.delete() social_accounts.delete()
logger.info( logger.info(f"User {user.id} disconnected {deleted_count} {provider} account(s)")
f"User {user.id} disconnected {deleted_count} {provider} account(s)")
return True, f"{provider.title()} account disconnected successfully." return True, f"{provider.title()} account disconnected successfully."
@@ -205,31 +194,24 @@ class SocialProviderService:
try: try:
connected_providers = SocialProviderService.get_connected_providers(user) connected_providers = SocialProviderService.get_connected_providers(user)
has_password_auth = ( has_password_auth = user.email and user.has_usable_password() and bool(user.password)
user.email and
user.has_usable_password() and
bool(user.password)
)
auth_methods_count = len(connected_providers) + \ auth_methods_count = len(connected_providers) + (1 if has_password_auth else 0)
(1 if has_password_auth else 0)
return { return {
'user_id': user.id, "user_id": user.id,
'username': user.username, "username": user.username,
'email': user.email, "email": user.email,
'has_password_auth': has_password_auth, "has_password_auth": has_password_auth,
'connected_providers': connected_providers, "connected_providers": connected_providers,
'total_auth_methods': auth_methods_count, "total_auth_methods": auth_methods_count,
'can_disconnect_any': auth_methods_count > 1, "can_disconnect_any": auth_methods_count > 1,
'requires_password_setup': not has_password_auth and len(connected_providers) == 1 "requires_password_setup": not has_password_auth and len(connected_providers) == 1,
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting auth status for user {user.id}: {e}") logger.error(f"Error getting auth status for user {user.id}: {e}")
return { return {"error": "Unable to retrieve authentication status"}
'error': 'Unable to retrieve authentication status'
}
@staticmethod @staticmethod
def validate_provider_exists(provider: str) -> tuple[bool, str]: def validate_provider_exists(provider: str) -> tuple[bool, str]:

View File

@@ -59,7 +59,7 @@ class UserDeletionService:
return False, "Cannot delete staff accounts" return False, "Cannot delete staff accounts"
# Check for system users (if you have any special system accounts) # Check for system users (if you have any special system accounts)
if hasattr(user, 'role') and user.role in ['ADMIN', 'MODERATOR']: if hasattr(user, "role") and user.role in ["ADMIN", "MODERATOR"]:
return False, "Cannot delete admin or moderator accounts" return False, "Cannot delete admin or moderator accounts"
return True, None return True, None
@@ -84,8 +84,7 @@ class UserDeletionService:
raise ValueError(reason) raise ValueError(reason)
# Generate verification code # Generate verification code
verification_code = ''.join(secrets.choice( verification_code = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(8))
string.ascii_uppercase + string.digits) for _ in range(8))
# Set expiration (24 hours from now) # Set expiration (24 hours from now)
expires_at = timezone.now() + timezone.timedelta(hours=24) expires_at = timezone.now() + timezone.timedelta(hours=24)
@@ -97,8 +96,7 @@ class UserDeletionService:
UserDeletionService._deletion_requests[verification_code] = deletion_request UserDeletionService._deletion_requests[verification_code] = deletion_request
# Send verification email # Send verification email
UserDeletionService._send_deletion_verification_email( UserDeletionService._send_deletion_verification_email(user, verification_code, expires_at)
user, verification_code, expires_at)
return deletion_request return deletion_request
@@ -136,10 +134,10 @@ class UserDeletionService:
del UserDeletionService._deletion_requests[verification_code] del UserDeletionService._deletion_requests[verification_code]
# Add verification info to result # Add verification info to result
result['deletion_request'] = { result["deletion_request"] = {
'verification_code': verification_code, "verification_code": verification_code,
'created_at': deletion_request.created_at, "created_at": deletion_request.created_at,
'verified_at': timezone.now(), "verified_at": timezone.now(),
} }
return result return result
@@ -180,13 +178,13 @@ class UserDeletionService:
""" """
# Get or create the "deleted_user" placeholder # Get or create the "deleted_user" placeholder
deleted_user_placeholder, created = User.objects.get_or_create( deleted_user_placeholder, created = User.objects.get_or_create(
username='deleted_user', username="deleted_user",
defaults={ defaults={
'email': 'deleted@thrillwiki.com', "email": "deleted@thrillwiki.com",
'first_name': 'Deleted', "first_name": "Deleted",
'last_name': 'User', "last_name": "User",
'is_active': False, "is_active": False,
} },
) )
# Count submissions before transfer # Count submissions before transfer
@@ -197,22 +195,22 @@ class UserDeletionService:
# Store user info before deletion # Store user info before deletion
deleted_user_info = { deleted_user_info = {
'username': user.username, "username": user.username,
'user_id': getattr(user, 'user_id', user.id), "user_id": getattr(user, "user_id", user.id),
'email': user.email, "email": user.email,
'date_joined': user.date_joined, "date_joined": user.date_joined,
} }
# Delete the user account # Delete the user account
user.delete() user.delete()
return { return {
'deleted_user': deleted_user_info, "deleted_user": deleted_user_info,
'preserved_submissions': submission_counts, "preserved_submissions": submission_counts,
'transferred_to': { "transferred_to": {
'username': deleted_user_placeholder.username, "username": deleted_user_placeholder.username,
'user_id': getattr(deleted_user_placeholder, 'user_id', deleted_user_placeholder.id), "user_id": getattr(deleted_user_placeholder, "user_id", deleted_user_placeholder.id),
} },
} }
@staticmethod @staticmethod
@@ -222,20 +220,13 @@ class UserDeletionService:
# Count different types of submissions # Count different types of submissions
# Note: These are placeholder counts - adjust based on your actual models # Note: These are placeholder counts - adjust based on your actual models
counts['park_reviews'] = getattr( counts["park_reviews"] = getattr(user, "park_reviews", user.__class__.objects.none()).count()
user, 'park_reviews', user.__class__.objects.none()).count() counts["ride_reviews"] = getattr(user, "ride_reviews", user.__class__.objects.none()).count()
counts['ride_reviews'] = getattr( counts["uploaded_park_photos"] = getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count()
user, 'ride_reviews', user.__class__.objects.none()).count() counts["uploaded_ride_photos"] = getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count()
counts['uploaded_park_photos'] = getattr( counts["top_lists"] = getattr(user, "top_lists", user.__class__.objects.none()).count()
user, 'uploaded_park_photos', user.__class__.objects.none()).count() counts["edit_submissions"] = getattr(user, "edit_submissions", user.__class__.objects.none()).count()
counts['uploaded_ride_photos'] = getattr( counts["photo_submissions"] = getattr(user, "photo_submissions", user.__class__.objects.none()).count()
user, 'uploaded_ride_photos', user.__class__.objects.none()).count()
counts['top_lists'] = getattr(
user, 'top_lists', user.__class__.objects.none()).count()
counts['edit_submissions'] = getattr(
user, 'edit_submissions', user.__class__.objects.none()).count()
counts['photo_submissions'] = getattr(
user, 'photo_submissions', user.__class__.objects.none()).count()
return counts return counts
@@ -247,30 +238,30 @@ class UserDeletionService:
# Note: Adjust these based on your actual model relationships # Note: Adjust these based on your actual model relationships
# Park reviews # Park reviews
if hasattr(user, 'park_reviews'): if hasattr(user, "park_reviews"):
user.park_reviews.all().update(user=placeholder_user) user.park_reviews.all().update(user=placeholder_user)
# Ride reviews # Ride reviews
if hasattr(user, 'ride_reviews'): if hasattr(user, "ride_reviews"):
user.ride_reviews.all().update(user=placeholder_user) user.ride_reviews.all().update(user=placeholder_user)
# Uploaded photos # Uploaded photos
if hasattr(user, 'uploaded_park_photos'): if hasattr(user, "uploaded_park_photos"):
user.uploaded_park_photos.all().update(user=placeholder_user) user.uploaded_park_photos.all().update(user=placeholder_user)
if hasattr(user, 'uploaded_ride_photos'): if hasattr(user, "uploaded_ride_photos"):
user.uploaded_ride_photos.all().update(user=placeholder_user) user.uploaded_ride_photos.all().update(user=placeholder_user)
# Top lists # Top lists
if hasattr(user, 'top_lists'): if hasattr(user, "top_lists"):
user.top_lists.all().update(user=placeholder_user) user.top_lists.all().update(user=placeholder_user)
# Edit submissions # Edit submissions
if hasattr(user, 'edit_submissions'): if hasattr(user, "edit_submissions"):
user.edit_submissions.all().update(user=placeholder_user) user.edit_submissions.all().update(user=placeholder_user)
# Photo submissions # Photo submissions
if hasattr(user, 'photo_submissions'): if hasattr(user, "photo_submissions"):
user.photo_submissions.all().update(user=placeholder_user) user.photo_submissions.all().update(user=placeholder_user)
@staticmethod @staticmethod
@@ -278,18 +269,16 @@ class UserDeletionService:
"""Send verification email for account deletion.""" """Send verification email for account deletion."""
try: try:
context = { context = {
'user': user, "user": user,
'verification_code': verification_code, "verification_code": verification_code,
'expires_at': expires_at, "expires_at": expires_at,
'site_name': 'ThrillWiki', "site_name": "ThrillWiki",
'site_url': getattr(settings, 'SITE_URL', 'https://thrillwiki.com'), "site_url": getattr(settings, "SITE_URL", "https://thrillwiki.com"),
} }
subject = 'ThrillWiki: Confirm Account Deletion' subject = "ThrillWiki: Confirm Account Deletion"
html_message = render_to_string( html_message = render_to_string("emails/account_deletion_verification.html", context)
'emails/account_deletion_verification.html', context) plain_message = render_to_string("emails/account_deletion_verification.txt", context)
plain_message = render_to_string(
'emails/account_deletion_verification.txt', context)
send_mail( send_mail(
subject=subject, subject=subject,
@@ -303,6 +292,5 @@ class UserDeletionService:
logger.info(f"Deletion verification email sent to {user.email}") logger.info(f"Deletion verification email sent to {user.email}")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to send deletion verification email to {user.email}: {str(e)}")
f"Failed to send deletion verification email to {user.email}: {str(e)}")
raise raise

View File

@@ -108,7 +108,7 @@ def sync_user_role_with_groups(sender, instance, **kwargs):
User.Roles.MODERATOR, User.Roles.MODERATOR,
]: ]:
instance.is_staff = True instance.is_staff = True
elif old_instance.role in [ elif old_instance.role in [ # noqa: SIM102
User.Roles.ADMIN, User.Roles.ADMIN,
User.Roles.MODERATOR, User.Roles.MODERATOR,
]: ]:
@@ -119,9 +119,7 @@ def sync_user_role_with_groups(sender, instance, **kwargs):
except User.DoesNotExist: except User.DoesNotExist:
pass pass
except Exception as e: except Exception as e:
print( print(f"Error syncing role with groups for user {instance.username}: {str(e)}")
f"Error syncing role with groups for user {instance.username}: {str(e)}"
)
def create_default_groups(): def create_default_groups():
@@ -200,19 +198,19 @@ def log_successful_login(sender, user, request, **kwargs):
""" """
try: try:
# Get IP address # Get IP address
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
ip_address = x_forwarded_for.split(',')[0].strip() if x_forwarded_for else request.META.get('REMOTE_ADDR') ip_address = x_forwarded_for.split(",")[0].strip() if x_forwarded_for else request.META.get("REMOTE_ADDR")
# Get user agent # Get user agent
user_agent = request.META.get('HTTP_USER_AGENT', '')[:500] user_agent = request.META.get("HTTP_USER_AGENT", "")[:500]
# Determine login method from session or request # Determine login method from session or request
login_method = 'PASSWORD' login_method = "PASSWORD"
if hasattr(request, 'session'): if hasattr(request, "session"):
sociallogin = getattr(request, '_sociallogin', None) sociallogin = getattr(request, "_sociallogin", None)
if sociallogin: if sociallogin:
provider = sociallogin.account.provider.upper() provider = sociallogin.account.provider.upper()
if provider in ['GOOGLE', 'DISCORD']: if provider in ["GOOGLE", "DISCORD"]:
login_method = provider login_method = provider
# Create login history entry # Create login history entry

View File

@@ -113,16 +113,10 @@ class SignalsTestCase(TestCase):
moderator_group = Group.objects.get(name=User.Roles.MODERATOR) moderator_group = Group.objects.get(name=User.Roles.MODERATOR)
self.assertIsNotNone(moderator_group) self.assertIsNotNone(moderator_group)
self.assertTrue( self.assertTrue(moderator_group.permissions.filter(codename="change_review").exists())
moderator_group.permissions.filter(codename="change_review").exists() self.assertFalse(moderator_group.permissions.filter(codename="change_user").exists())
)
self.assertFalse(
moderator_group.permissions.filter(codename="change_user").exists()
)
admin_group = Group.objects.get(name=User.Roles.ADMIN) admin_group = Group.objects.get(name=User.Roles.ADMIN)
self.assertIsNotNone(admin_group) self.assertIsNotNone(admin_group)
self.assertTrue( self.assertTrue(admin_group.permissions.filter(codename="change_review").exists())
admin_group.permissions.filter(codename="change_review").exists()
)
self.assertTrue(admin_group.permissions.filter(codename="change_user").exists()) self.assertTrue(admin_group.permissions.filter(codename="change_user").exists())

View File

@@ -150,6 +150,3 @@ class TestPasswordResetAdmin(TestCase):
request.user = UserModel(is_superuser=True) request.user = UserModel(is_superuser=True)
actions = self.admin.get_actions(request) actions = self.admin.get_actions(request)
assert "cleanup_old_tokens" in actions assert "cleanup_old_tokens" in actions

View File

@@ -85,16 +85,16 @@ class UserIndexTests(TestCase):
def test_is_banned_field_is_indexed(self): def test_is_banned_field_is_indexed(self):
"""Verify is_banned field has db_index=True.""" """Verify is_banned field has db_index=True."""
field = User._meta.get_field('is_banned') field = User._meta.get_field("is_banned")
self.assertTrue(field.db_index) self.assertTrue(field.db_index)
def test_role_field_is_indexed(self): def test_role_field_is_indexed(self):
"""Verify role field has db_index=True.""" """Verify role field has db_index=True."""
field = User._meta.get_field('role') field = User._meta.get_field("role")
self.assertTrue(field.db_index) self.assertTrue(field.db_index)
def test_composite_index_exists(self): def test_composite_index_exists(self):
"""Verify composite index on (is_banned, role) exists.""" """Verify composite index on (is_banned, role) exists."""
indexes = User._meta.indexes indexes = User._meta.indexes
index_names = [idx.name for idx in indexes] index_names = [idx.name for idx in indexes]
self.assertIn('accounts_user_banned_role_idx', index_names) self.assertIn("accounts_user_banned_role_idx", index_names)

View File

@@ -15,9 +15,7 @@ class UserDeletionServiceTest(TestCase):
def setUp(self): def setUp(self):
"""Set up test data.""" """Set up test data."""
# Create test users # Create test users
self.user = User.objects.create_user( self.user = User.objects.create_user(username="testuser", email="test@example.com", password="testpass123")
username="testuser", email="test@example.com", password="testpass123"
)
self.admin_user = User.objects.create_user( self.admin_user = User.objects.create_user(
username="admin", username="admin",
@@ -27,13 +25,9 @@ class UserDeletionServiceTest(TestCase):
) )
# Create user profiles # Create user profiles
UserProfile.objects.create( UserProfile.objects.create(user=self.user, display_name="Test User", bio="Test bio")
user=self.user, display_name="Test User", bio="Test bio"
)
UserProfile.objects.create( UserProfile.objects.create(user=self.admin_user, display_name="Admin User", bio="Admin bio")
user=self.admin_user, display_name="Admin User", bio="Admin bio"
)
def test_get_or_create_deleted_user(self): def test_get_or_create_deleted_user(self):
"""Test that deleted user placeholder is created correctly.""" """Test that deleted user placeholder is created correctly."""
@@ -108,9 +102,7 @@ class UserDeletionServiceTest(TestCase):
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
UserDeletionService.delete_user_preserve_submissions(deleted_user) UserDeletionService.delete_user_preserve_submissions(deleted_user)
self.assertIn( self.assertIn("Cannot delete the system deleted user placeholder", str(context.exception))
"Cannot delete the system deleted user placeholder", str(context.exception)
)
def test_delete_user_with_submissions_transfers_correctly(self): def test_delete_user_with_submissions_transfers_correctly(self):
"""Test that user submissions are transferred to deleted user placeholder.""" """Test that user submissions are transferred to deleted user placeholder."""
@@ -141,7 +133,7 @@ class UserDeletionServiceTest(TestCase):
original_user_count = User.objects.count() original_user_count = User.objects.count()
# Mock a failure during the deletion process # Mock a failure during the deletion process
with self.assertRaises(Exception), transaction.atomic(): with self.assertRaises(Exception), transaction.atomic(): # noqa: B017
# Start the deletion process # Start the deletion process
UserDeletionService.get_or_create_deleted_user() UserDeletionService.get_or_create_deleted_user()

View File

@@ -61,11 +61,7 @@ class CustomLoginView(TurnstileMixin, LoginView):
context={"user_id": user.id, "username": user.username}, context={"user_id": user.id, "username": user.username},
request=self.request, request=self.request,
) )
return ( return HttpResponseClientRefresh() if getattr(self.request, "htmx", False) else response
HttpResponseClientRefresh()
if getattr(self.request, "htmx", False)
else response
)
def form_invalid(self, form): def form_invalid(self, form):
log_security_event( log_security_event(
@@ -116,11 +112,7 @@ class CustomSignupView(TurnstileMixin, SignupView):
}, },
request=self.request, request=self.request,
) )
return ( return HttpResponseClientRefresh() if getattr(self.request, "htmx", False) else response
HttpResponseClientRefresh()
if getattr(self.request, "htmx", False)
else response
)
def form_invalid(self, form): def form_invalid(self, form):
if getattr(self.request, "htmx", False): if getattr(self.request, "htmx", False):
@@ -260,9 +252,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
and bool(re.search(r"[0-9]", password)) and bool(re.search(r"[0-9]", password))
) )
def _send_password_change_confirmation( def _send_password_change_confirmation(self, request: HttpRequest, user: User) -> None:
self, request: HttpRequest, user: User
) -> None:
"""Send password change confirmation email.""" """Send password change confirmation email."""
site = get_current_site(request) site = get_current_site(request)
context = { context = {
@@ -270,9 +260,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
"site_name": site.name, "site_name": site.name,
} }
email_html = render_to_string( email_html = render_to_string("accounts/email/password_change_confirmation.html", context)
"accounts/email/password_change_confirmation.html", context
)
EmailService.send_email( EmailService.send_email(
to=user.email, to=user.email,
@@ -282,9 +270,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
html=email_html, html=email_html,
) )
def _handle_password_change( def _handle_password_change(self, request: HttpRequest) -> HttpResponseRedirect | None:
self, request: HttpRequest
) -> HttpResponseRedirect | None:
user = cast(User, request.user) user = cast(User, request.user)
old_password = request.POST.get("old_password", "") old_password = request.POST.get("old_password", "")
new_password = request.POST.get("new_password", "") new_password = request.POST.get("new_password", "")
@@ -327,9 +313,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
def _handle_email_change(self, request: HttpRequest) -> None: def _handle_email_change(self, request: HttpRequest) -> None:
if new_email := request.POST.get("new_email"): if new_email := request.POST.get("new_email"):
self._send_email_verification(request, new_email) self._send_email_verification(request, new_email)
messages.success( messages.success(request, "Verification email sent to your new email address")
request, "Verification email sent to your new email address"
)
else: else:
messages.error(request, "New email is required") messages.error(request, "New email is required")
@@ -385,9 +369,7 @@ def create_password_reset_token(user: User) -> str:
return token return token
def send_password_reset_email( def send_password_reset_email(user: User, site: Site | RequestSite, token: str) -> None:
user: User, site: Site | RequestSite, token: str
) -> None:
reset_url = reverse("password_reset_confirm", kwargs={"token": token}) reset_url = reverse("password_reset_confirm", kwargs={"token": token})
context = { context = {
"user": user, "user": user,
@@ -457,16 +439,12 @@ def handle_password_reset(
messages.success(request, "Password reset successfully") messages.success(request, "Password reset successfully")
def send_password_reset_confirmation( def send_password_reset_confirmation(user: User, site: Site | RequestSite) -> None:
user: User, site: Site | RequestSite
) -> None:
context = { context = {
"user": user, "user": user,
"site_name": site.name, "site_name": site.name,
} }
email_html = render_to_string( email_html = render_to_string("accounts/email/password_reset_complete.html", context)
"accounts/email/password_reset_complete.html", context
)
EmailService.send_email( EmailService.send_email(
to=user.email, to=user.email,
@@ -479,9 +457,7 @@ def send_password_reset_confirmation(
def reset_password(request: HttpRequest, token: str) -> HttpResponse: def reset_password(request: HttpRequest, token: str) -> HttpResponse:
try: try:
reset = PasswordReset.objects.select_related("user").get( reset = PasswordReset.objects.select_related("user").get(token=token, expires_at__gt=timezone.now(), used=False)
token=token, expires_at__gt=timezone.now(), used=False
)
if request.method == "POST": if request.method == "POST":
if new_password := request.POST.get("new_password"): if new_password := request.POST.get("new_password"):

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,7 @@ class UserProfileUpdateInputSerializer(serializers.ModelSerializer):
cloudflare_id = validated_data.pop("cloudflare_image_id", None) cloudflare_id = validated_data.pop("cloudflare_image_id", None)
if cloudflare_id: if cloudflare_id:
from django_cloudflareimages_toolkit.models import CloudflareImage from django_cloudflareimages_toolkit.models import CloudflareImage
image, _ = CloudflareImage.objects.get_or_create(cloudflare_id=cloudflare_id) image, _ = CloudflareImage.objects.get_or_create(cloudflare_id=cloudflare_id)
instance.avatar = image instance.avatar = image

View File

@@ -76,9 +76,7 @@ urlpatterns = [
name="update_privacy_settings", name="update_privacy_settings",
), ),
# Security settings endpoints # Security settings endpoints
path( path("settings/security/", views.get_security_settings, name="get_security_settings"),
"settings/security/", views.get_security_settings, name="get_security_settings"
),
path( path(
"settings/security/update/", "settings/security/update/",
views.update_security_settings, views.update_security_settings,
@@ -90,9 +88,7 @@ urlpatterns = [
path("top-lists/", views.get_user_top_lists, name="get_user_top_lists"), path("top-lists/", views.get_user_top_lists, name="get_user_top_lists"),
path("top-lists/create/", views.create_top_list, name="create_top_list"), path("top-lists/create/", views.create_top_list, name="create_top_list"),
path("top-lists/<int:list_id>/", views.update_top_list, name="update_top_list"), path("top-lists/<int:list_id>/", views.update_top_list, name="update_top_list"),
path( path("top-lists/<int:list_id>/delete/", views.delete_top_list, name="delete_top_list"),
"top-lists/<int:list_id>/delete/", views.delete_top_list, name="delete_top_list"
),
# Notification endpoints # Notification endpoints
path("notifications/", views.get_user_notifications, name="get_user_notifications"), path("notifications/", views.get_user_notifications, name="get_user_notifications"),
path( path(
@@ -114,18 +110,13 @@ urlpatterns = [
path("profile/avatar/upload/", views.upload_avatar, name="upload_avatar"), path("profile/avatar/upload/", views.upload_avatar, name="upload_avatar"),
path("profile/avatar/save/", views.save_avatar_image, name="save_avatar_image"), path("profile/avatar/save/", views.save_avatar_image, name="save_avatar_image"),
path("profile/avatar/delete/", views.delete_avatar, name="delete_avatar"), path("profile/avatar/delete/", views.delete_avatar, name="delete_avatar"),
# Login history endpoint # Login history endpoint
path("login-history/", views.get_login_history, name="get_login_history"), path("login-history/", views.get_login_history, name="get_login_history"),
# Magic Link (Login by Code) endpoints # Magic Link (Login by Code) endpoints
path("magic-link/request/", views_magic_link.request_magic_link, name="request_magic_link"), path("magic-link/request/", views_magic_link.request_magic_link, name="request_magic_link"),
path("magic-link/verify/", views_magic_link.verify_magic_link, name="verify_magic_link"), path("magic-link/verify/", views_magic_link.verify_magic_link, name="verify_magic_link"),
# Public Profile # Public Profile
path("profiles/<str:username>/", views.get_public_user_profile, name="get_public_user_profile"), path("profiles/<str:username>/", views.get_public_user_profile, name="get_public_user_profile"),
# ViewSet routes # ViewSet routes
path("", include(router.urls)), path("", include(router.urls)),
] ]

View File

@@ -69,8 +69,7 @@ logger = logging.getLogger(__name__)
200: { 200: {
"description": "User successfully deleted with submissions preserved", "description": "User successfully deleted with submissions preserved",
"example": { "example": {
"success": True, "detail": "User successfully deleted with submissions preserved",
"message": "User successfully deleted with submissions preserved",
"deleted_user": { "deleted_user": {
"username": "john_doe", "username": "john_doe",
"user_id": "1234", "user_id": "1234",
@@ -92,17 +91,16 @@ logger = logging.getLogger(__name__)
400: { 400: {
"description": "Bad request - user cannot be deleted", "description": "Bad request - user cannot be deleted",
"example": { "example": {
"success": False, "detail": "Cannot delete user: Cannot delete superuser accounts",
"error": "Cannot delete user: Cannot delete superuser accounts",
}, },
}, },
404: { 404: {
"description": "User not found", "description": "User not found",
"example": {"success": False, "error": "User not found"}, "example": {"detail": "User not found"},
}, },
403: { 403: {
"description": "Permission denied - admin access required", "description": "Permission denied - admin access required",
"example": {"success": False, "error": "Admin access required"}, "example": {"detail": "Admin access required"},
}, },
}, },
tags=["User Management"], tags=["User Management"],
@@ -137,7 +135,7 @@ def delete_user_preserve_submissions(request, user_id):
"is_superuser": user.is_superuser, "is_superuser": user.is_superuser,
"user_role": user.role, "user_role": user.role,
"rejection_reason": reason, "rejection_reason": reason,
} },
) )
# Determine error code based on reason # Determine error code based on reason
@@ -151,8 +149,7 @@ def delete_user_preserve_submissions(request, user_id):
return Response( return Response(
{ {
"success": False, "detail": f"Cannot delete user: {reason}",
"error": f"Cannot delete user: {reason}",
"error_code": error_code, "error_code": error_code,
"user_info": { "user_info": {
"username": user.username, "username": user.username,
@@ -174,7 +171,7 @@ def delete_user_preserve_submissions(request, user_id):
"target_user": user.username, "target_user": user.username,
"target_user_id": user_id, "target_user_id": user_id,
"action": "user_deletion", "action": "user_deletion",
} },
) )
# Perform the deletion # Perform the deletion
@@ -185,17 +182,16 @@ def delete_user_preserve_submissions(request, user_id):
f"Successfully deleted user {result['deleted_user']['username']} (ID: {user_id}) by admin {request.user.username}", f"Successfully deleted user {result['deleted_user']['username']} (ID: {user_id}) by admin {request.user.username}",
extra={ extra={
"admin_user": request.user.username, "admin_user": request.user.username,
"deleted_user": result['deleted_user']['username'], "deleted_user": result["deleted_user"]["username"],
"deleted_user_id": user_id, "deleted_user_id": user_id,
"preserved_submissions": result['preserved_submissions'], "preserved_submissions": result["preserved_submissions"],
"action": "user_deletion_completed", "action": "user_deletion_completed",
} },
) )
return Response( return Response(
{ {
"success": True, "detail": "User successfully deleted with submissions preserved",
"message": "User successfully deleted with submissions preserved",
**result, **result,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -208,16 +204,15 @@ def delete_user_preserve_submissions(request, user_id):
extra={ extra={
"admin_user": request.user.username, "admin_user": request.user.username,
"target_user_id": user_id, "target_user_id": user_id,
"error": str(e), "detail": str(e),
"action": "user_deletion_error", "action": "user_deletion_error",
}, },
exc_info=True exc_info=True,
) )
return Response( return Response(
{ {
"success": False, "detail": f"Error deleting user: {str(e)}",
"error": f"Error deleting user: {str(e)}",
"error_code": "DELETION_ERROR", "error_code": "DELETION_ERROR",
"help_text": "Please try again or contact system administrator if the problem persists.", "help_text": "Please try again or contact system administrator if the problem persists.",
}, },
@@ -259,8 +254,7 @@ def delete_user_preserve_submissions(request, user_id):
}, },
}, },
"example": { "example": {
"success": True, "detail": "Avatar saved successfully",
"message": "Avatar saved successfully",
"avatar_url": "https://imagedelivery.net/account-hash/image-id/avatar", "avatar_url": "https://imagedelivery.net/account-hash/image-id/avatar",
"avatar_variants": { "avatar_variants": {
"thumbnail": "https://imagedelivery.net/account-hash/image-id/thumbnail", "thumbnail": "https://imagedelivery.net/account-hash/image-id/thumbnail",
@@ -285,7 +279,7 @@ def save_avatar_image(request):
if not cloudflare_image_id: if not cloudflare_image_id:
return Response( return Response(
{"success": False, "error": "cloudflare_image_id is required"}, {"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -299,26 +293,25 @@ def save_avatar_image(request):
if not image_data: if not image_data:
return Response( return Response(
{"success": False, "error": "Image not found in Cloudflare"}, {"detail": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
# Try to find existing CloudflareImage record by cloudflare_id # Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None cloudflare_image = None
try: try:
cloudflare_image = CloudflareImage.objects.get( cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare # Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded' cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now() cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {}) cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure # Extract variants from nested result structure
cloudflare_image.variants = image_data.get('result', {}).get('variants', []) cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
cloudflare_image.cloudflare_metadata = image_data cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width') cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get('height') cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get('format', '') cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save() cloudflare_image.save()
except CloudflareImage.DoesNotExist: except CloudflareImage.DoesNotExist:
@@ -326,25 +319,23 @@ def save_avatar_image(request):
cloudflare_image = CloudflareImage.objects.create( cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id, cloudflare_id=cloudflare_image_id,
user=user, user=user,
status='uploaded', status="uploaded",
upload_url='', # Not needed for uploaded images upload_url="", # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(), uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}), metadata=image_data.get("meta", {}),
# Extract variants from nested result structure # Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []), variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data, cloudflare_metadata=image_data,
width=image_data.get('width'), width=image_data.get("width"),
height=image_data.get('height'), height=image_data.get("height"),
format=image_data.get('format', ''), format=image_data.get("format", ""),
) )
except Exception as api_error: except Exception as api_error:
logger.error( logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response( return Response(
{"success": False, {"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -391,8 +382,7 @@ def save_avatar_image(request):
return Response( return Response(
{ {
"success": True, "detail": "Avatar saved successfully",
"message": "Avatar saved successfully",
"avatar_url": avatar_url, "avatar_url": avatar_url,
"avatar_variants": avatar_variants, "avatar_variants": avatar_variants,
}, },
@@ -402,7 +392,7 @@ def save_avatar_image(request):
except Exception as e: except Exception as e:
logger.error(f"Error saving avatar image: {str(e)}", exc_info=True) logger.error(f"Error saving avatar image: {str(e)}", exc_info=True)
return Response( return Response(
{"success": False, "error": f"Failed to save avatar: {str(e)}"}, {"detail": f"Failed to save avatar: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -420,8 +410,7 @@ def save_avatar_image(request):
"avatar_url": {"type": "string"}, "avatar_url": {"type": "string"},
}, },
"example": { "example": {
"success": True, "detail": "Avatar deleted successfully",
"message": "Avatar deleted successfully",
"avatar_url": "https://ui-avatars.com/api/?name=J&size=200&background=random&color=fff&bold=true", "avatar_url": "https://ui-avatars.com/api/?name=J&size=200&background=random&color=fff&bold=true",
}, },
}, },
@@ -447,6 +436,7 @@ def delete_avatar(request):
# Delete from Cloudflare first, then from database # Delete from Cloudflare first, then from database
try: try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService() service = CloudflareImagesService()
service.delete_image(avatar_to_delete) service.delete_image(avatar_to_delete)
logger.info(f"Successfully deleted avatar from Cloudflare: {avatar_to_delete.cloudflare_id}") logger.info(f"Successfully deleted avatar from Cloudflare: {avatar_to_delete.cloudflare_id}")
@@ -461,8 +451,7 @@ def delete_avatar(request):
return Response( return Response(
{ {
"success": True, "detail": "Avatar deleted successfully",
"message": "Avatar deleted successfully",
"avatar_url": avatar_url, "avatar_url": avatar_url,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -471,8 +460,7 @@ def delete_avatar(request):
except UserProfile.DoesNotExist: except UserProfile.DoesNotExist:
return Response( return Response(
{ {
"success": True, "detail": "No avatar to delete",
"message": "No avatar to delete",
"avatar_url": f"https://ui-avatars.com/api/?name={user.username[0].upper()}&size=200&background=random&color=fff&bold=true", "avatar_url": f"https://ui-avatars.com/api/?name={user.username[0].upper()}&size=200&background=random&color=fff&bold=true",
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -480,7 +468,7 @@ def delete_avatar(request):
except Exception as e: except Exception as e:
return Response( return Response(
{"success": False, "error": f"Failed to delete avatar: {str(e)}"}, {"detail": f"Failed to delete avatar: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -506,7 +494,7 @@ def request_account_deletion(request):
can_delete, reason = UserDeletionService.can_delete_user(user) can_delete, reason = UserDeletionService.can_delete_user(user)
if not can_delete: if not can_delete:
return Response( return Response(
{"success": False, "error": reason}, {"detail": reason},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -515,8 +503,7 @@ def request_account_deletion(request):
return Response( return Response(
{ {
"success": True, "detail": "Verification code sent to your email",
"message": "Verification code sent to your email",
"expires_at": deletion_request.expires_at, "expires_at": deletion_request.expires_at,
"email": user.email, "email": user.email,
}, },
@@ -534,7 +521,7 @@ def request_account_deletion(request):
"user_role": request.user.role, "user_role": request.user.role,
"rejection_reason": str(e), "rejection_reason": str(e),
"action": "self_deletion_rejected", "action": "self_deletion_rejected",
} },
) )
# Determine error code based on reason # Determine error code based on reason
@@ -549,8 +536,7 @@ def request_account_deletion(request):
return Response( return Response(
{ {
"success": False, "detail": error_message,
"error": error_message,
"error_code": error_code, "error_code": error_code,
"user_info": { "user_info": {
"username": request.user.username, "username": request.user.username,
@@ -570,16 +556,15 @@ def request_account_deletion(request):
extra={ extra={
"user": request.user.username, "user": request.user.username,
"user_id": request.user.user_id, "user_id": request.user.user_id,
"error": str(e), "detail": str(e),
"action": "self_deletion_error", "action": "self_deletion_error",
}, },
exc_info=True exc_info=True,
) )
return Response( return Response(
{ {
"success": False, "detail": f"Error creating deletion request: {str(e)}",
"error": f"Error creating deletion request: {str(e)}",
"error_code": "DELETION_REQUEST_ERROR", "error_code": "DELETION_REQUEST_ERROR",
"help_text": "Please try again or contact support if the problem persists.", "help_text": "Please try again or contact support if the problem persists.",
}, },
@@ -611,8 +596,7 @@ def request_account_deletion(request):
200: { 200: {
"description": "Account successfully deleted", "description": "Account successfully deleted",
"example": { "example": {
"success": True, "detail": "Account successfully deleted with submissions preserved",
"message": "Account successfully deleted with submissions preserved",
"deleted_user": { "deleted_user": {
"username": "john_doe", "username": "john_doe",
"user_id": "1234", "user_id": "1234",
@@ -637,7 +621,7 @@ def request_account_deletion(request):
}, },
400: { 400: {
"description": "Invalid or expired verification code", "description": "Invalid or expired verification code",
"example": {"success": False, "error": "Verification code has expired"}, "example": {"detail": "Verification code has expired"},
}, },
}, },
tags=["Self-Service Account Management"], tags=["Self-Service Account Management"],
@@ -663,7 +647,7 @@ def verify_account_deletion(request):
if not verification_code: if not verification_code:
return Response( return Response(
{"success": False, "error": "Verification code is required"}, {"detail": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -672,20 +656,17 @@ def verify_account_deletion(request):
return Response( return Response(
{ {
"success": True, "detail": "Account successfully deleted with submissions preserved",
"message": "Account successfully deleted with submissions preserved",
**result, **result,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
) )
except ValueError as e: except ValueError as e:
return Response( return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
{"success": False, "error": str(e)}, status=status.HTTP_400_BAD_REQUEST
)
except Exception as e: except Exception as e:
return Response( return Response(
{"success": False, "error": f"Error verifying deletion: {str(e)}"}, {"detail": f"Error verifying deletion: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -701,14 +682,13 @@ def verify_account_deletion(request):
200: { 200: {
"description": "Deletion request cancelled or no request found", "description": "Deletion request cancelled or no request found",
"example": { "example": {
"success": True, "detail": "Deletion request cancelled",
"message": "Deletion request cancelled",
"had_pending_request": True, "had_pending_request": True,
}, },
}, },
401: { 401: {
"description": "Authentication required", "description": "Authentication required",
"example": {"success": False, "error": "Authentication required"}, "example": {"detail": "Authentication required"},
}, },
}, },
tags=["Self-Service Account Management"], tags=["Self-Service Account Management"],
@@ -732,12 +712,7 @@ def cancel_account_deletion(request):
return Response( return Response(
{ {
"success": True, "detail": ("Deletion request cancelled" if had_request else "No pending deletion request found"),
"message": (
"Deletion request cancelled"
if had_request
else "No pending deletion request found"
),
"had_pending_request": had_request, "had_pending_request": had_request,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -745,7 +720,7 @@ def cancel_account_deletion(request):
except Exception as e: except Exception as e:
return Response( return Response(
{"success": False, "error": f"Error cancelling deletion request: {str(e)}"}, {"detail": f"Error cancelling deletion request: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -753,10 +728,7 @@ def cancel_account_deletion(request):
@extend_schema( @extend_schema(
operation_id="check_user_deletion_eligibility", operation_id="check_user_deletion_eligibility",
summary="Check if user can be deleted", summary="Check if user can be deleted",
description=( description=("Check if a user can be safely deleted and get a preview of " "what submissions would be preserved."),
"Check if a user can be safely deleted and get a preview of "
"what submissions would be preserved."
),
parameters=[ parameters=[
OpenApiParameter( OpenApiParameter(
name="user_id", name="user_id",
@@ -792,11 +764,11 @@ def cancel_account_deletion(request):
}, },
404: { 404: {
"description": "User not found", "description": "User not found",
"example": {"success": False, "error": "User not found"}, "example": {"detail": "User not found"},
}, },
403: { 403: {
"description": "Permission denied - admin access required", "description": "Permission denied - admin access required",
"example": {"success": False, "error": "Admin access required"}, "example": {"detail": "Admin access required"},
}, },
}, },
tags=["User Management"], tags=["User Management"],
@@ -821,27 +793,13 @@ def check_user_deletion_eligibility(request, user_id):
# Count submissions # Count submissions
submission_counts = { submission_counts = {
"park_reviews": getattr( "park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(),
user, "park_reviews", user.__class__.objects.none() "ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(),
).count(), "uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(),
"ride_reviews": getattr( "uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(),
user, "ride_reviews", user.__class__.objects.none() "top_lists": getattr(user, "user_lists", user.__class__.objects.none()).count(),
).count(), "edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(),
"uploaded_park_photos": getattr( "photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(),
user, "uploaded_park_photos", user.__class__.objects.none()
).count(),
"uploaded_ride_photos": getattr(
user, "uploaded_ride_photos", user.__class__.objects.none()
).count(),
"top_lists": getattr(
user, "user_lists", user.__class__.objects.none()
).count(),
"edit_submissions": getattr(
user, "edit_submissions", user.__class__.objects.none()
).count(),
"photo_submissions": getattr(
user, "photo_submissions", user.__class__.objects.none()
).count(),
} }
total_submissions = sum(submission_counts.values()) total_submissions = sum(submission_counts.values())
@@ -865,7 +823,7 @@ def check_user_deletion_eligibility(request, user_id):
except Exception as e: except Exception as e:
return Response( return Response(
{"success": False, "error": f"Error checking user: {str(e)}"}, {"detail": f"Error checking user: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -912,9 +870,7 @@ def get_user_profile(request):
@permission_classes([IsAuthenticated]) @permission_classes([IsAuthenticated])
def update_user_account(request): def update_user_account(request):
"""Update basic account information.""" """Update basic account information."""
serializer = AccountUpdateSerializer( serializer = AccountUpdateSerializer(request.user, data=request.data, partial=True, context={"request": request})
request.user, data=request.data, partial=True, context={"request": request}
)
if serializer.is_valid(): if serializer.is_valid():
serializer.save() serializer.save()
@@ -944,9 +900,7 @@ def update_user_profile(request):
"""Update user profile information.""" """Update user profile information."""
profile, created = UserProfile.objects.get_or_create(user=request.user) profile, created = UserProfile.objects.get_or_create(user=request.user)
serializer = ProfileUpdateSerializer( serializer = ProfileUpdateSerializer(profile, data=request.data, partial=True, context={"request": request})
profile, data=request.data, partial=True, context={"request": request}
)
if serializer.is_valid(): if serializer.is_valid():
serializer.save() serializer.save()
@@ -1046,9 +1000,7 @@ def update_user_preferences(request):
@permission_classes([IsAuthenticated]) @permission_classes([IsAuthenticated])
def update_theme_preference(request): def update_theme_preference(request):
"""Update theme preference.""" """Update theme preference."""
serializer = ThemePreferenceSerializer( serializer = ThemePreferenceSerializer(request.user, data=request.data, partial=True)
request.user, data=request.data, partial=True
)
if serializer.is_valid(): if serializer.is_valid():
serializer.save() serializer.save()
@@ -1395,14 +1347,9 @@ def update_top_list(request, list_id):
try: try:
top_list = UserList.objects.get(id=list_id, user=request.user) top_list = UserList.objects.get(id=list_id, user=request.user)
except UserList.DoesNotExist: except UserList.DoesNotExist:
return Response( return Response({"detail": "Top list not found"}, status=status.HTTP_404_NOT_FOUND)
{"error": "Top list not found"},
status=status.HTTP_404_NOT_FOUND
)
serializer = UserListSerializer( serializer = UserListSerializer(top_list, data=request.data, partial=True, context={"request": request})
top_list, data=request.data, partial=True, context={"request": request}
)
if serializer.is_valid(): if serializer.is_valid():
serializer.save() serializer.save()
@@ -1430,10 +1377,7 @@ def delete_top_list(request, list_id):
top_list.delete() top_list.delete()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)
except UserList.DoesNotExist: except UserList.DoesNotExist:
return Response( return Response({"detail": "Top list not found"}, status=status.HTTP_404_NOT_FOUND)
{"error": "Top list not found"},
status=status.HTTP_404_NOT_FOUND
)
# === NOTIFICATION ENDPOINTS === # === NOTIFICATION ENDPOINTS ===
@@ -1453,9 +1397,9 @@ def delete_top_list(request, list_id):
@permission_classes([IsAuthenticated]) @permission_classes([IsAuthenticated])
def get_user_notifications(request): def get_user_notifications(request):
"""Get user notifications.""" """Get user notifications."""
notifications = UserNotification.objects.filter( notifications = UserNotification.objects.filter(user=request.user).order_by("-created_at")[
user=request.user :50
).order_by("-created_at")[:50] # Limit to 50 most recent ] # Limit to 50 most recent
serializer = UserNotificationSerializer(notifications, many=True) serializer = UserNotificationSerializer(notifications, many=True)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
@@ -1483,19 +1427,16 @@ def mark_notifications_read(request):
mark_all = serializer.validated_data.get("mark_all", False) mark_all = serializer.validated_data.get("mark_all", False)
if mark_all: if mark_all:
UserNotification.objects.filter( UserNotification.objects.filter(user=request.user, is_read=False).update(
user=request.user, is_read=False is_read=True, read_at=timezone.now()
).update(is_read=True, read_at=timezone.now()) )
count = UserNotification.objects.filter(user=request.user).count() count = UserNotification.objects.filter(user=request.user).count()
else: else:
count = UserNotification.objects.filter( count = UserNotification.objects.filter(id__in=notification_ids, user=request.user, is_read=False).update(
id__in=notification_ids, user=request.user, is_read=False is_read=True, read_at=timezone.now()
).update(is_read=True, read_at=timezone.now()) )
return Response( return Response({"detail": f"Marked {count} notifications as read"}, status=status.HTTP_200_OK)
{"message": f"Marked {count} notifications as read"},
status=status.HTTP_200_OK
)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -1544,9 +1485,7 @@ def update_notification_preferences(request):
except NotificationPreference.DoesNotExist: except NotificationPreference.DoesNotExist:
preferences = NotificationPreference.objects.create(user=request.user) preferences = NotificationPreference.objects.create(user=request.user)
serializer = NotificationPreferenceSerializer( serializer = NotificationPreferenceSerializer(preferences, data=request.data, partial=True)
preferences, data=request.data, partial=True
)
if serializer.is_valid(): if serializer.is_valid():
serializer.save() serializer.save()
@@ -1578,10 +1517,7 @@ def upload_avatar(request):
if serializer.is_valid(): if serializer.is_valid():
# Handle avatar upload logic here # Handle avatar upload logic here
# This would typically involve saving the file and updating the user profile # This would typically involve saving the file and updating the user profile
return Response( return Response({"detail": "Avatar uploaded successfully"}, status=status.HTTP_200_OK)
{"message": "Avatar uploaded successfully"},
status=status.HTTP_200_OK
)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -1596,8 +1532,8 @@ def upload_avatar(request):
"example": { "example": {
"account": {"username": "user", "email": "user@example.com"}, "account": {"username": "user", "email": "user@example.com"},
"profile": {"display_name": "User"}, "profile": {"display_name": "User"},
"content": {"park_reviews": [], "lists": []} "content": {"park_reviews": [], "lists": []},
} },
}, },
401: {"description": "Authentication required"}, 401: {"description": "Authentication required"},
}, },
@@ -1612,10 +1548,7 @@ def export_user_data(request):
return Response(export_data, status=status.HTTP_200_OK) return Response(export_data, status=status.HTTP_200_OK)
except Exception as e: except Exception as e:
logger.error(f"Error exporting data for user {request.user.id}: {e}", exc_info=True) logger.error(f"Error exporting data for user {request.user.id}: {e}", exc_info=True)
return Response( return Response({"detail": "Failed to generate data export"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
{"error": "Failed to generate data export"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@extend_schema( @extend_schema(
@@ -1690,20 +1623,25 @@ def get_login_history(request):
# Serialize # Serialize
results = [] results = []
for entry in entries: for entry in entries:
results.append({ results.append(
"id": entry.id, {
"ip_address": entry.ip_address, "id": entry.id,
"user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents "ip_address": entry.ip_address,
"login_method": entry.login_method, "user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents
"login_method_display": dict(LoginHistory._meta.get_field('login_method').choices).get(entry.login_method, entry.login_method), "login_method": entry.login_method,
"login_timestamp": entry.login_timestamp.isoformat(), "login_method_display": dict(LoginHistory._meta.get_field("login_method").choices).get(
"country": entry.country, entry.login_method, entry.login_method
"city": entry.city, ),
"success": entry.success, "login_timestamp": entry.login_timestamp.isoformat(),
}) "country": entry.country,
"city": entry.city,
return Response({ "success": entry.success,
"results": results, }
"count": len(results), )
})
return Response(
{
"results": results,
"count": len(results),
}
)

View File

@@ -15,22 +15,23 @@ class RideCreditViewSet(viewsets.ModelViewSet):
ViewSet for managing Ride Credits. ViewSet for managing Ride Credits.
Allows users to track rides they have ridden. Allows users to track rides they have ridden.
""" """
serializer_class = RideCreditSerializer serializer_class = RideCreditSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly] permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filter_backends = [DjangoFilterBackend, filters.OrderingFilter] filter_backends = [DjangoFilterBackend, filters.OrderingFilter]
filterset_fields = ['user__username', 'ride__park__slug', 'ride__manufacturer__slug'] filterset_fields = ["user__username", "ride__park__slug", "ride__manufacturer__slug"]
ordering_fields = ['first_ridden_at', 'last_ridden_at', 'created_at', 'count', 'rating', 'display_order'] ordering_fields = ["first_ridden_at", "last_ridden_at", "created_at", "count", "rating", "display_order"]
ordering = ['display_order', '-last_ridden_at'] ordering = ["display_order", "-last_ridden_at"]
def get_queryset(self): def get_queryset(self):
""" """
Return ride credits. Return ride credits.
Optionally filter by user via query param ?user=username Optionally filter by user via query param ?user=username
""" """
queryset = RideCredit.objects.all().select_related('ride', 'ride__park', 'user') queryset = RideCredit.objects.all().select_related("ride", "ride__park", "user")
# Filter by user if provided # Filter by user if provided
username = self.request.query_params.get('user') username = self.request.query_params.get("user")
if username: if username:
queryset = queryset.filter(user__username=username) queryset = queryset.filter(user__username=username)
@@ -40,64 +41,49 @@ class RideCreditViewSet(viewsets.ModelViewSet):
"""Associate the current user with the ride credit.""" """Associate the current user with the ride credit."""
serializer.save(user=self.request.user) serializer.save(user=self.request.user)
@action(detail=False, methods=['post'], permission_classes=[permissions.IsAuthenticated]) @action(detail=False, methods=["post"], permission_classes=[permissions.IsAuthenticated])
@extend_schema( @extend_schema(
summary="Reorder ride credits", summary="Reorder ride credits",
description="Bulk update the display order of ride credits. Send a list of {id, order} objects.", description="Bulk update the display order of ride credits. Send a list of {id, order} objects.",
request={ request={
'application/json': { "application/json": {
'type': 'object', "type": "object",
'properties': { "properties": {
'order': { "order": {
'type': 'array', "type": "array",
'items': { "items": {
'type': 'object', "type": "object",
'properties': { "properties": {"id": {"type": "integer"}, "order": {"type": "integer"}},
'id': {'type': 'integer'}, "required": ["id", "order"],
'order': {'type': 'integer'} },
},
'required': ['id', 'order']
}
} }
} },
} }
} },
) )
def reorder(self, request): def reorder(self, request):
""" """
Bulk update display_order for multiple credits. Bulk update display_order for multiple credits.
Expects: {"order": [{"id": 1, "order": 0}, {"id": 2, "order": 1}, ...]} Expects: {"order": [{"id": 1, "order": 0}, {"id": 2, "order": 1}, ...]}
""" """
order_data = request.data.get('order', []) order_data = request.data.get("order", [])
if not order_data: if not order_data:
return Response( return Response({"detail": "No order data provided"}, status=status.HTTP_400_BAD_REQUEST)
{'error': 'No order data provided'},
status=status.HTTP_400_BAD_REQUEST
)
# Validate that all credits belong to the current user # Validate that all credits belong to the current user
credit_ids = [item['id'] for item in order_data] credit_ids = [item["id"] for item in order_data]
user_credits = RideCredit.objects.filter( user_credits = RideCredit.objects.filter(id__in=credit_ids, user=request.user).values_list("id", flat=True)
id__in=credit_ids,
user=request.user
).values_list('id', flat=True)
if set(credit_ids) != set(user_credits): if set(credit_ids) != set(user_credits):
return Response( return Response({"detail": "You can only reorder your own credits"}, status=status.HTTP_403_FORBIDDEN)
{'error': 'You can only reorder your own credits'},
status=status.HTTP_403_FORBIDDEN
)
# Bulk update in a transaction # Bulk update in a transaction
with transaction.atomic(): with transaction.atomic():
for item in order_data: for item in order_data:
RideCredit.objects.filter( RideCredit.objects.filter(id=item["id"], user=request.user).update(display_order=item["order"])
id=item['id'],
user=request.user
).update(display_order=item['order'])
return Response({'status': 'reordered', 'count': len(order_data)}) return Response({"status": "reordered", "count": len(order_data)})
@extend_schema( @extend_schema(
summary="List ride credits", summary="List ride credits",
@@ -109,8 +95,7 @@ class RideCreditViewSet(viewsets.ModelViewSet):
type=OpenApiTypes.STR, type=OpenApiTypes.STR,
description="Filter by username", description="Filter by username",
), ),
] ],
) )
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs) return super().list(request, *args, **kwargs)

View File

@@ -4,6 +4,7 @@ Magic Link (Login by Code) API views.
Provides API endpoints for passwordless login via email code. Provides API endpoints for passwordless login via email code.
Uses django-allauth's built-in login-by-code functionality. Uses django-allauth's built-in login-by-code functionality.
""" """
from django.conf import settings from django.conf import settings
from drf_spectacular.utils import OpenApiExample, extend_schema from drf_spectacular.utils import OpenApiExample, extend_schema
from rest_framework import status from rest_framework import status
@@ -15,6 +16,7 @@ try:
from allauth.account.internal.flows.login_by_code import perform_login_by_code, request_login_code from allauth.account.internal.flows.login_by_code import perform_login_by_code, request_login_code
from allauth.account.models import EmailAddress from allauth.account.models import EmailAddress
from allauth.account.utils import user_email # noqa: F401 - imported to verify availability from allauth.account.utils import user_email # noqa: F401 - imported to verify availability
HAS_LOGIN_BY_CODE = True HAS_LOGIN_BY_CODE = True
except ImportError: except ImportError:
HAS_LOGIN_BY_CODE = False HAS_LOGIN_BY_CODE = False
@@ -24,27 +26,19 @@ except ImportError:
summary="Request magic link login code", summary="Request magic link login code",
description="Send a one-time login code to the user's email address.", description="Send a one-time login code to the user's email address.",
request={ request={
'application/json': { "application/json": {
'type': 'object', "type": "object",
'properties': { "properties": {"email": {"type": "string", "format": "email"}},
'email': {'type': 'string', 'format': 'email'} "required": ["email"],
},
'required': ['email']
} }
}, },
responses={ responses={
200: {'description': 'Login code sent successfully'}, 200: {"description": "Login code sent successfully"},
400: {'description': 'Invalid email or feature disabled'}, 400: {"description": "Invalid email or feature disabled"},
}, },
examples=[ examples=[OpenApiExample("Request login code", value={"email": "user@example.com"}, request_only=True)],
OpenApiExample(
'Request login code',
value={'email': 'user@example.com'},
request_only=True
)
]
) )
@api_view(['POST']) @api_view(["POST"])
@permission_classes([AllowAny]) @permission_classes([AllowAny])
def request_magic_link(request): def request_magic_link(request):
""" """
@@ -55,25 +49,18 @@ def request_magic_link(request):
2. If the email exists, a code is sent 2. If the email exists, a code is sent
3. User enters the code to complete login 3. User enters the code to complete login
""" """
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False): if not getattr(settings, "ACCOUNT_LOGIN_BY_CODE_ENABLED", False):
return Response( return Response({"detail": "Magic link login is not enabled"}, status=status.HTTP_400_BAD_REQUEST)
{'error': 'Magic link login is not enabled'},
status=status.HTTP_400_BAD_REQUEST
)
if not HAS_LOGIN_BY_CODE: if not HAS_LOGIN_BY_CODE:
return Response( return Response(
{'error': 'Login by code is not available in this version of allauth'}, {"detail": "Login by code is not available in this version of allauth"}, status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST
) )
email = request.data.get('email', '').lower().strip() email = request.data.get("email", "").lower().strip()
if not email: if not email:
return Response( return Response({"detail": "Email is required"}, status=status.HTTP_400_BAD_REQUEST)
{'error': 'Email is required'},
status=status.HTTP_400_BAD_REQUEST
)
# Check if email exists (don't reveal if it doesn't for security) # Check if email exists (don't reveal if it doesn't for security)
try: try:
@@ -83,40 +70,39 @@ def request_magic_link(request):
# Request the login code # Request the login code
request_login_code(request._request, user) request_login_code(request._request, user)
return Response({ return Response(
'success': True, {
'message': 'If an account exists with this email, a login code has been sent.', "detail": "If an account exists with this email, a login code has been sent.",
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300) "timeout": getattr(settings, "ACCOUNT_LOGIN_BY_CODE_TIMEOUT", 300),
}) }
)
except EmailAddress.DoesNotExist: except EmailAddress.DoesNotExist:
# Don't reveal that the email doesn't exist # Don't reveal that the email doesn't exist
return Response({ return Response(
'success': True, {
'message': 'If an account exists with this email, a login code has been sent.', "detail": "If an account exists with this email, a login code has been sent.",
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300) "timeout": getattr(settings, "ACCOUNT_LOGIN_BY_CODE_TIMEOUT", 300),
}) }
)
@extend_schema( @extend_schema(
summary="Verify magic link code", summary="Verify magic link code",
description="Verify the login code and complete the login process.", description="Verify the login code and complete the login process.",
request={ request={
'application/json': { "application/json": {
'type': 'object', "type": "object",
'properties': { "properties": {"email": {"type": "string", "format": "email"}, "code": {"type": "string"}},
'email': {'type': 'string', 'format': 'email'}, "required": ["email", "code"],
'code': {'type': 'string'}
},
'required': ['email', 'code']
} }
}, },
responses={ responses={
200: {'description': 'Login successful'}, 200: {"description": "Login successful"},
400: {'description': 'Invalid or expired code'}, 400: {"description": "Invalid or expired code"},
} },
) )
@api_view(['POST']) @api_view(["POST"])
@permission_classes([AllowAny]) @permission_classes([AllowAny])
def verify_magic_link(request): def verify_magic_link(request):
""" """
@@ -124,26 +110,17 @@ def verify_magic_link(request):
This is the second step of the magic link flow. This is the second step of the magic link flow.
""" """
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False): if not getattr(settings, "ACCOUNT_LOGIN_BY_CODE_ENABLED", False):
return Response( return Response({"detail": "Magic link login is not enabled"}, status=status.HTTP_400_BAD_REQUEST)
{'error': 'Magic link login is not enabled'},
status=status.HTTP_400_BAD_REQUEST
)
if not HAS_LOGIN_BY_CODE: if not HAS_LOGIN_BY_CODE:
return Response( return Response({"detail": "Login by code is not available"}, status=status.HTTP_400_BAD_REQUEST)
{'error': 'Login by code is not available'},
status=status.HTTP_400_BAD_REQUEST
)
email = request.data.get('email', '').lower().strip() email = request.data.get("email", "").lower().strip()
code = request.data.get('code', '').strip() code = request.data.get("code", "").strip()
if not email or not code: if not email or not code:
return Response( return Response({"detail": "Email and code are required"}, status=status.HTTP_400_BAD_REQUEST)
{'error': 'Email and code are required'},
status=status.HTTP_400_BAD_REQUEST
)
try: try:
email_address = EmailAddress.objects.get(email__iexact=email, verified=True) email_address = EmailAddress.objects.get(email__iexact=email, verified=True)
@@ -153,28 +130,20 @@ def verify_magic_link(request):
success = perform_login_by_code(request._request, user, code) success = perform_login_by_code(request._request, user, code)
if success: if success:
return Response({ return Response(
'success': True, {
'message': 'Login successful', "detail": "Login successful",
'user': { "user": {"id": user.id, "username": user.username, "email": user.email},
'id': user.id,
'username': user.username,
'email': user.email
} }
}) )
else: else:
return Response( return Response(
{'error': 'Invalid or expired code. Please request a new one.'}, {"detail": "Invalid or expired code. Please request a new one."}, status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST
) )
except EmailAddress.DoesNotExist: except EmailAddress.DoesNotExist:
return Response( return Response({"detail": "Invalid email or code"}, status=status.HTTP_400_BAD_REQUEST)
{'error': 'Invalid email or code'},
status=status.HTTP_400_BAD_REQUEST
)
except Exception: except Exception:
return Response( return Response(
{'error': 'Invalid or expired code. Please request a new one.'}, {"detail": "Invalid or expired code. Please request a new one."}, status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST
) )

View File

@@ -17,6 +17,7 @@ from rest_framework.response import Response
try: try:
import qrcode import qrcode
HAS_QRCODE = True HAS_QRCODE = True
except ImportError: except ImportError:
HAS_QRCODE = False HAS_QRCODE = False
@@ -59,12 +60,14 @@ def get_mfa_status(request):
except Authenticator.DoesNotExist: except Authenticator.DoesNotExist:
pass pass
return Response({ return Response(
"mfa_enabled": totp_enabled, {
"totp_enabled": totp_enabled, "mfa_enabled": totp_enabled,
"recovery_codes_enabled": recovery_enabled, "totp_enabled": totp_enabled,
"recovery_codes_count": recovery_count, "recovery_codes_enabled": recovery_enabled,
}) "recovery_codes_count": recovery_count,
}
)
@extend_schema( @extend_schema(
@@ -110,11 +113,13 @@ def setup_totp(request):
# Store secret in session for later verification # Store secret in session for later verification
request.session["pending_totp_secret"] = secret request.session["pending_totp_secret"] = secret
return Response({ return Response(
"secret": secret, {
"provisioning_uri": uri, "secret": secret,
"qr_code_base64": qr_code_base64, "provisioning_uri": uri,
}) "qr_code_base64": qr_code_base64,
}
)
@extend_schema( @extend_schema(
@@ -138,8 +143,7 @@ def setup_totp(request):
200: { 200: {
"description": "TOTP activated successfully", "description": "TOTP activated successfully",
"example": { "example": {
"success": True, "detail": "Two-factor authentication enabled",
"message": "Two-factor authentication enabled",
"recovery_codes": ["ABCD1234", "EFGH5678"], "recovery_codes": ["ABCD1234", "EFGH5678"],
}, },
}, },
@@ -160,7 +164,7 @@ def activate_totp(request):
if not code: if not code:
return Response( return Response(
{"success": False, "error": "Verification code is required"}, {"detail": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -168,21 +172,21 @@ def activate_totp(request):
secret = request.session.get("pending_totp_secret") secret = request.session.get("pending_totp_secret")
if not secret: if not secret:
return Response( return Response(
{"success": False, "error": "No pending TOTP setup. Please start setup again."}, {"detail": "No pending TOTP setup. Please start setup again."},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
# Verify the code # Verify the code
if not totp_auth.validate_totp_code(secret, code): if not totp_auth.validate_totp_code(secret, code):
return Response( return Response(
{"success": False, "error": "Invalid verification code"}, {"detail": "Invalid verification code"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
# Check if already has TOTP # Check if already has TOTP
if Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists(): if Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
return Response( return Response(
{"success": False, "error": "TOTP is already enabled"}, {"detail": "TOTP is already enabled"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -204,11 +208,12 @@ def activate_totp(request):
# Clear session # Clear session
del request.session["pending_totp_secret"] del request.session["pending_totp_secret"]
return Response({ return Response(
"success": True, {
"message": "Two-factor authentication enabled", "detail": "Two-factor authentication enabled",
"recovery_codes": codes, "recovery_codes": codes,
}) }
)
@extend_schema( @extend_schema(
@@ -230,7 +235,7 @@ def activate_totp(request):
responses={ responses={
200: { 200: {
"description": "TOTP disabled", "description": "TOTP disabled",
"example": {"success": True, "message": "Two-factor authentication disabled"}, "example": {"detail": "Two-factor authentication disabled"},
}, },
400: {"description": "Invalid password or MFA not enabled"}, 400: {"description": "Invalid password or MFA not enabled"},
}, },
@@ -248,26 +253,26 @@ def deactivate_totp(request):
# Verify password # Verify password
if not user.check_password(password): if not user.check_password(password):
return Response( return Response(
{"success": False, "error": "Invalid password"}, {"detail": "Invalid password"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
# Remove TOTP and recovery codes # Remove TOTP and recovery codes
deleted_count, _ = Authenticator.objects.filter( deleted_count, _ = Authenticator.objects.filter(
user=user, user=user, type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES]
type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES]
).delete() ).delete()
if deleted_count == 0: if deleted_count == 0:
return Response( return Response(
{"success": False, "error": "Two-factor authentication is not enabled"}, {"detail": "Two-factor authentication is not enabled"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
return Response({ return Response(
"success": True, {
"message": "Two-factor authentication disabled", "detail": "Two-factor authentication disabled",
}) }
)
@extend_schema( @extend_schema(
@@ -277,9 +282,7 @@ def deactivate_totp(request):
request={ request={
"application/json": { "application/json": {
"type": "object", "type": "object",
"properties": { "properties": {"code": {"type": "string", "description": "6-digit TOTP code"}},
"code": {"type": "string", "description": "6-digit TOTP code"}
},
"required": ["code"], "required": ["code"],
} }
}, },
@@ -301,7 +304,7 @@ def verify_totp(request):
if not code: if not code:
return Response( return Response(
{"success": False, "error": "Verification code is required"}, {"detail": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -313,12 +316,12 @@ def verify_totp(request):
return Response({"success": True}) return Response({"success": True})
else: else:
return Response( return Response(
{"success": False, "error": "Invalid verification code"}, {"detail": "Invalid verification code"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
except Authenticator.DoesNotExist: except Authenticator.DoesNotExist:
return Response( return Response(
{"success": False, "error": "TOTP is not enabled"}, {"detail": "TOTP is not enabled"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -330,9 +333,7 @@ def verify_totp(request):
request={ request={
"application/json": { "application/json": {
"type": "object", "type": "object",
"properties": { "properties": {"password": {"type": "string", "description": "Current password"}},
"password": {"type": "string", "description": "Current password"}
},
"required": ["password"], "required": ["password"],
} }
}, },
@@ -358,14 +359,14 @@ def regenerate_recovery_codes(request):
# Verify password # Verify password
if not user.check_password(password): if not user.check_password(password):
return Response( return Response(
{"success": False, "error": "Invalid password"}, {"detail": "Invalid password"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
# Check if TOTP is enabled # Check if TOTP is enabled
if not Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists(): if not Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
return Response( return Response(
{"success": False, "error": "Two-factor authentication is not enabled"}, {"detail": "Two-factor authentication is not enabled"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -379,7 +380,9 @@ def regenerate_recovery_codes(request):
defaults={"data": {"codes": codes}}, defaults={"data": {"codes": codes}},
) )
return Response({ return Response(
"success": True, {
"recovery_codes": codes, "success": True,
}) "recovery_codes": codes,
}
)

View File

@@ -38,8 +38,6 @@ class ModelChoices:
"""Model choices utility class.""" """Model choices utility class."""
# === AUTHENTICATION SERIALIZERS === # === AUTHENTICATION SERIALIZERS ===
@@ -95,12 +93,8 @@ class UserOutputSerializer(serializers.ModelSerializer):
class LoginInputSerializer(serializers.Serializer): class LoginInputSerializer(serializers.Serializer):
"""Input serializer for user login.""" """Input serializer for user login."""
username = serializers.CharField( username = serializers.CharField(max_length=254, help_text="Username or email address")
max_length=254, help_text="Username or email address" password = serializers.CharField(max_length=128, style={"input_type": "password"}, trim_whitespace=False)
)
password = serializers.CharField(
max_length=128, style={"input_type": "password"}, trim_whitespace=False
)
def validate(self, attrs): def validate(self, attrs):
username = attrs.get("username") username = attrs.get("username")
@@ -129,9 +123,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
validators=[validate_password], validators=[validate_password],
style={"input_type": "password"}, style={"input_type": "password"},
) )
password_confirm = serializers.CharField( password_confirm = serializers.CharField(write_only=True, style={"input_type": "password"})
write_only=True, style={"input_type": "password"}
)
class Meta: class Meta:
model = UserModel model = UserModel
@@ -158,9 +150,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
def validate_username(self, value): def validate_username(self, value):
"""Validate username is unique.""" """Validate username is unique."""
if UserModel.objects.filter(username=value).exists(): if UserModel.objects.filter(username=value).exists():
raise serializers.ValidationError( raise serializers.ValidationError("A user with this username already exists.")
"A user with this username already exists."
)
return value return value
def validate(self, attrs): def validate(self, attrs):
@@ -169,9 +159,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
password_confirm = attrs.get("password_confirm") password_confirm = attrs.get("password_confirm")
if password != password_confirm: if password != password_confirm:
raise serializers.ValidationError( raise serializers.ValidationError({"password_confirm": "Passwords do not match."})
{"password_confirm": "Passwords do not match."}
)
return attrs return attrs
@@ -204,8 +192,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
# Create or update email verification record # Create or update email verification record
verification, created = EmailVerification.objects.get_or_create( verification, created = EmailVerification.objects.get_or_create(
user=user, user=user, defaults={"token": get_random_string(64)}
defaults={'token': get_random_string(64)}
) )
if not created: if not created:
@@ -214,14 +201,12 @@ class SignupInputSerializer(serializers.ModelSerializer):
verification.save() verification.save()
# Get current site from request context # Get current site from request context
request = self.context.get('request') request = self.context.get("request")
if request: if request:
site = get_current_site(request._request) site = get_current_site(request._request)
# Build verification URL # Build verification URL
verification_url = request.build_absolute_uri( verification_url = request.build_absolute_uri(f"/api/v1/auth/verify-email/{verification.token}/")
f"/api/v1/auth/verify-email/{verification.token}/"
)
# Send verification email # Send verification email
try: try:
@@ -243,13 +228,11 @@ The ThrillWiki Team
) )
# Log the ForwardEmail email ID from the response # Log the ForwardEmail email ID from the response
email_id = response.get('id') if response else None email_id = response.get("id") if response else None
if email_id: if email_id:
logger.info( logger.info(f"Verification email sent successfully to {user.email}. ForwardEmail ID: {email_id}")
f"Verification email sent successfully to {user.email}. ForwardEmail ID: {email_id}")
else: else:
logger.info( logger.info(f"Verification email sent successfully to {user.email}. No email ID in response.")
f"Verification email sent successfully to {user.email}. No email ID in response.")
except Exception as e: except Exception as e:
# Log the error but don't fail registration # Log the error but don't fail registration
@@ -312,17 +295,13 @@ class PasswordResetOutputSerializer(serializers.Serializer):
class PasswordChangeInputSerializer(serializers.Serializer): class PasswordChangeInputSerializer(serializers.Serializer):
"""Input serializer for password change.""" """Input serializer for password change."""
old_password = serializers.CharField( old_password = serializers.CharField(max_length=128, style={"input_type": "password"})
max_length=128, style={"input_type": "password"}
)
new_password = serializers.CharField( new_password = serializers.CharField(
max_length=128, max_length=128,
validators=[validate_password], validators=[validate_password],
style={"input_type": "password"}, style={"input_type": "password"},
) )
new_password_confirm = serializers.CharField( new_password_confirm = serializers.CharField(max_length=128, style={"input_type": "password"})
max_length=128, style={"input_type": "password"}
)
def validate_old_password(self, value): def validate_old_password(self, value):
"""Validate old password is correct.""" """Validate old password is correct."""
@@ -337,9 +316,7 @@ class PasswordChangeInputSerializer(serializers.Serializer):
new_password_confirm = attrs.get("new_password_confirm") new_password_confirm = attrs.get("new_password_confirm")
if new_password != new_password_confirm: if new_password != new_password_confirm:
raise serializers.ValidationError( raise serializers.ValidationError({"new_password_confirm": "New passwords do not match."})
{"new_password_confirm": "New passwords do not match."}
)
return attrs return attrs
@@ -471,6 +448,3 @@ class UserProfileUpdateInputSerializer(serializers.Serializer):
dark_ride_credits = serializers.IntegerField(required=False) dark_ride_credits = serializers.IntegerField(required=False)
flat_ride_credits = serializers.IntegerField(required=False) flat_ride_credits = serializers.IntegerField(required=False)
water_ride_credits = serializers.IntegerField(required=False) water_ride_credits = serializers.IntegerField(required=False)

View File

@@ -19,13 +19,13 @@ from .social import (
__all__ = [ __all__ = [
# Social authentication serializers # Social authentication serializers
'ConnectedProviderSerializer', "ConnectedProviderSerializer",
'AvailableProviderSerializer', "AvailableProviderSerializer",
'SocialAuthStatusSerializer', "SocialAuthStatusSerializer",
'ConnectProviderInputSerializer', "ConnectProviderInputSerializer",
'ConnectProviderOutputSerializer', "ConnectProviderOutputSerializer",
'DisconnectProviderOutputSerializer', "DisconnectProviderOutputSerializer",
'SocialProviderListOutputSerializer', "SocialProviderListOutputSerializer",
'ConnectedProvidersListOutputSerializer', "ConnectedProvidersListOutputSerializer",
'SocialProviderErrorSerializer', "SocialProviderErrorSerializer",
] ]

View File

@@ -14,74 +14,36 @@ User = get_user_model()
class ConnectedProviderSerializer(serializers.Serializer): class ConnectedProviderSerializer(serializers.Serializer):
"""Serializer for connected social provider information.""" """Serializer for connected social provider information."""
provider = serializers.CharField( provider = serializers.CharField(help_text="Provider ID (e.g., 'google', 'discord')")
help_text="Provider ID (e.g., 'google', 'discord')" provider_name = serializers.CharField(help_text="Human-readable provider name")
) uid = serializers.CharField(help_text="User ID on the social provider")
provider_name = serializers.CharField( date_joined = serializers.DateTimeField(help_text="When this provider was connected")
help_text="Human-readable provider name" can_disconnect = serializers.BooleanField(help_text="Whether this provider can be safely disconnected")
)
uid = serializers.CharField(
help_text="User ID on the social provider"
)
date_joined = serializers.DateTimeField(
help_text="When this provider was connected"
)
can_disconnect = serializers.BooleanField(
help_text="Whether this provider can be safely disconnected"
)
disconnect_reason = serializers.CharField( disconnect_reason = serializers.CharField(
allow_null=True, allow_null=True, required=False, help_text="Reason why provider cannot be disconnected (if applicable)"
required=False,
help_text="Reason why provider cannot be disconnected (if applicable)"
)
extra_data = serializers.JSONField(
required=False,
help_text="Additional data from the social provider"
) )
extra_data = serializers.JSONField(required=False, help_text="Additional data from the social provider")
class AvailableProviderSerializer(serializers.Serializer): class AvailableProviderSerializer(serializers.Serializer):
"""Serializer for available social provider information.""" """Serializer for available social provider information."""
id = serializers.CharField( id = serializers.CharField(help_text="Provider ID (e.g., 'google', 'discord')")
help_text="Provider ID (e.g., 'google', 'discord')" name = serializers.CharField(help_text="Human-readable provider name")
) auth_url = serializers.URLField(help_text="URL to initiate authentication with this provider")
name = serializers.CharField( connect_url = serializers.URLField(help_text="API URL to connect this provider")
help_text="Human-readable provider name"
)
auth_url = serializers.URLField(
help_text="URL to initiate authentication with this provider"
)
connect_url = serializers.URLField(
help_text="API URL to connect this provider"
)
class SocialAuthStatusSerializer(serializers.Serializer): class SocialAuthStatusSerializer(serializers.Serializer):
"""Serializer for comprehensive social authentication status.""" """Serializer for comprehensive social authentication status."""
user_id = serializers.IntegerField( user_id = serializers.IntegerField(help_text="User's ID")
help_text="User's ID" username = serializers.CharField(help_text="User's username")
) email = serializers.EmailField(help_text="User's email address")
username = serializers.CharField( has_password_auth = serializers.BooleanField(help_text="Whether user has email/password authentication set up")
help_text="User's username" connected_providers = ConnectedProviderSerializer(many=True, help_text="List of connected social providers")
) total_auth_methods = serializers.IntegerField(help_text="Total number of authentication methods available")
email = serializers.EmailField( can_disconnect_any = serializers.BooleanField(help_text="Whether user can safely disconnect any provider")
help_text="User's email address"
)
has_password_auth = serializers.BooleanField(
help_text="Whether user has email/password authentication set up"
)
connected_providers = ConnectedProviderSerializer(
many=True,
help_text="List of connected social providers"
)
total_auth_methods = serializers.IntegerField(
help_text="Total number of authentication methods available"
)
can_disconnect_any = serializers.BooleanField(
help_text="Whether user can safely disconnect any provider"
)
requires_password_setup = serializers.BooleanField( requires_password_setup = serializers.BooleanField(
help_text="Whether user needs to set up password before disconnecting" help_text="Whether user needs to set up password before disconnecting"
) )
@@ -90,9 +52,7 @@ class SocialAuthStatusSerializer(serializers.Serializer):
class ConnectProviderInputSerializer(serializers.Serializer): class ConnectProviderInputSerializer(serializers.Serializer):
"""Serializer for social provider connection requests.""" """Serializer for social provider connection requests."""
provider = serializers.CharField( provider = serializers.CharField(help_text="Provider ID to connect (e.g., 'google', 'discord')")
help_text="Provider ID to connect (e.g., 'google', 'discord')"
)
def validate_provider(self, value): def validate_provider(self, value):
"""Validate that the provider is supported and configured.""" """Validate that the provider is supported and configured."""
@@ -108,93 +68,51 @@ class ConnectProviderInputSerializer(serializers.Serializer):
class ConnectProviderOutputSerializer(serializers.Serializer): class ConnectProviderOutputSerializer(serializers.Serializer):
"""Serializer for social provider connection responses.""" """Serializer for social provider connection responses."""
success = serializers.BooleanField( success = serializers.BooleanField(help_text="Whether the connection was successful")
help_text="Whether the connection was successful" message = serializers.CharField(help_text="Success or error message")
) provider = serializers.CharField(help_text="Provider that was connected")
message = serializers.CharField( auth_url = serializers.URLField(required=False, help_text="URL to complete the connection process")
help_text="Success or error message"
)
provider = serializers.CharField(
help_text="Provider that was connected"
)
auth_url = serializers.URLField(
required=False,
help_text="URL to complete the connection process"
)
class DisconnectProviderOutputSerializer(serializers.Serializer): class DisconnectProviderOutputSerializer(serializers.Serializer):
"""Serializer for social provider disconnection responses.""" """Serializer for social provider disconnection responses."""
success = serializers.BooleanField( success = serializers.BooleanField(help_text="Whether the disconnection was successful")
help_text="Whether the disconnection was successful" message = serializers.CharField(help_text="Success or error message")
) provider = serializers.CharField(help_text="Provider that was disconnected")
message = serializers.CharField(
help_text="Success or error message"
)
provider = serializers.CharField(
help_text="Provider that was disconnected"
)
remaining_providers = serializers.ListField( remaining_providers = serializers.ListField(
child=serializers.CharField(), child=serializers.CharField(), help_text="List of remaining connected providers"
help_text="List of remaining connected providers"
)
has_password_auth = serializers.BooleanField(
help_text="Whether user still has password authentication"
) )
has_password_auth = serializers.BooleanField(help_text="Whether user still has password authentication")
suggestions = serializers.ListField( suggestions = serializers.ListField(
child=serializers.CharField(), child=serializers.CharField(),
required=False, required=False,
help_text="Suggestions for maintaining account access (if applicable)" help_text="Suggestions for maintaining account access (if applicable)",
) )
class SocialProviderListOutputSerializer(serializers.Serializer): class SocialProviderListOutputSerializer(serializers.Serializer):
"""Serializer for listing available social providers.""" """Serializer for listing available social providers."""
available_providers = AvailableProviderSerializer( available_providers = AvailableProviderSerializer(many=True, help_text="List of available social providers")
many=True, count = serializers.IntegerField(help_text="Number of available providers")
help_text="List of available social providers"
)
count = serializers.IntegerField(
help_text="Number of available providers"
)
class ConnectedProvidersListOutputSerializer(serializers.Serializer): class ConnectedProvidersListOutputSerializer(serializers.Serializer):
"""Serializer for listing connected social providers.""" """Serializer for listing connected social providers."""
connected_providers = ConnectedProviderSerializer( connected_providers = ConnectedProviderSerializer(many=True, help_text="List of connected social providers")
many=True, count = serializers.IntegerField(help_text="Number of connected providers")
help_text="List of connected social providers" has_password_auth = serializers.BooleanField(help_text="Whether user has password authentication")
) can_disconnect_any = serializers.BooleanField(help_text="Whether user can safely disconnect any provider")
count = serializers.IntegerField(
help_text="Number of connected providers"
)
has_password_auth = serializers.BooleanField(
help_text="Whether user has password authentication"
)
can_disconnect_any = serializers.BooleanField(
help_text="Whether user can safely disconnect any provider"
)
class SocialProviderErrorSerializer(serializers.Serializer): class SocialProviderErrorSerializer(serializers.Serializer):
"""Serializer for social provider error responses.""" """Serializer for social provider error responses."""
error = serializers.CharField( error = serializers.CharField(help_text="Error message")
help_text="Error message" code = serializers.CharField(required=False, help_text="Error code for programmatic handling")
)
code = serializers.CharField(
required=False,
help_text="Error code for programmatic handling"
)
suggestions = serializers.ListField( suggestions = serializers.ListField(
child=serializers.CharField(), child=serializers.CharField(), required=False, help_text="Suggestions for resolving the error"
required=False,
help_text="Suggestions for resolving the error"
)
provider = serializers.CharField(
required=False,
help_text="Provider related to the error (if applicable)"
) )
provider = serializers.CharField(required=False, help_text="Provider related to the error (if applicable)")

View File

@@ -36,13 +36,10 @@ urlpatterns = [
path("signup/", SignupAPIView.as_view(), name="auth-signup"), path("signup/", SignupAPIView.as_view(), name="auth-signup"),
path("logout/", LogoutAPIView.as_view(), name="auth-logout"), path("logout/", LogoutAPIView.as_view(), name="auth-logout"),
path("user/", CurrentUserAPIView.as_view(), name="auth-current-user"), path("user/", CurrentUserAPIView.as_view(), name="auth-current-user"),
# JWT token management # JWT token management
path("token/refresh/", TokenRefreshView.as_view(), name="auth-token-refresh"), path("token/refresh/", TokenRefreshView.as_view(), name="auth-token-refresh"),
# Social authentication endpoints (dj-rest-auth) # Social authentication endpoints (dj-rest-auth)
path("social/", include("dj_rest_auth.registration.urls")), path("social/", include("dj_rest_auth.registration.urls")),
path( path(
"password/reset/", "password/reset/",
PasswordResetAPIView.as_view(), PasswordResetAPIView.as_view(),
@@ -58,7 +55,6 @@ urlpatterns = [
SocialProvidersAPIView.as_view(), SocialProvidersAPIView.as_view(),
name="auth-social-providers", name="auth-social-providers",
), ),
# Social provider management endpoints # Social provider management endpoints
path( path(
"social/providers/available/", "social/providers/available/",
@@ -85,9 +81,7 @@ urlpatterns = [
SocialAuthStatusAPIView.as_view(), SocialAuthStatusAPIView.as_view(),
name="auth-social-status", name="auth-social-status",
), ),
path("status/", AuthStatusAPIView.as_view(), name="auth-status"), path("status/", AuthStatusAPIView.as_view(), name="auth-status"),
# Email verification endpoints # Email verification endpoints
path( path(
"verify-email/<str:token>/", "verify-email/<str:token>/",
@@ -99,7 +93,6 @@ urlpatterns = [
ResendVerificationAPIView.as_view(), ResendVerificationAPIView.as_view(),
name="auth-resend-verification", name="auth-resend-verification",
), ),
# MFA (Multi-Factor Authentication) endpoints # MFA (Multi-Factor Authentication) endpoints
path("mfa/status/", mfa_views.get_mfa_status, name="auth-mfa-status"), path("mfa/status/", mfa_views.get_mfa_status, name="auth-mfa-status"),
path("mfa/totp/setup/", mfa_views.setup_totp, name="auth-mfa-totp-setup"), path("mfa/totp/setup/", mfa_views.setup_totp, name="auth-mfa-totp-setup"),

View File

@@ -85,9 +85,7 @@ def _get_underlying_request(request: Request) -> HttpRequest:
# Helper: encapsulate user lookup + authenticate to reduce complexity in view # Helper: encapsulate user lookup + authenticate to reduce complexity in view
def _authenticate_user_by_lookup( def _authenticate_user_by_lookup(email_or_username: str, password: str, request: Request) -> UserModel | None:
email_or_username: str, password: str, request: Request
) -> UserModel | None:
""" """
Try a single optimized query to find a user by email OR username then authenticate. Try a single optimized query to find a user by email OR username then authenticate.
Returns authenticated user or None. Returns authenticated user or None.
@@ -154,7 +152,7 @@ class LoginAPIView(APIView):
# instantiate mixin before calling to avoid type-mismatch in static analysis # instantiate mixin before calling to avoid type-mismatch in static analysis
TurnstileMixin().validate_turnstile(request) TurnstileMixin().validate_turnstile(request)
except ValidationError as e: except ValidationError as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
except Exception: except Exception:
# If mixin doesn't do anything, continue # If mixin doesn't do anything, continue
pass pass
@@ -168,7 +166,7 @@ class LoginAPIView(APIView):
if not email_or_username or not password: if not email_or_username or not password:
return Response( return Response(
{"error": "username and password are required"}, {"detail": "username and password are required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -177,8 +175,7 @@ class LoginAPIView(APIView):
if user: if user:
if getattr(user, "is_active", False): if getattr(user, "is_active", False):
# pass a real HttpRequest to Django login with backend specified # pass a real HttpRequest to Django login with backend specified
login(_get_underlying_request(request), user, login(_get_underlying_request(request), user, backend="django.contrib.auth.backends.ModelBackend")
backend='django.contrib.auth.backends.ModelBackend')
# Generate JWT tokens # Generate JWT tokens
from rest_framework_simplejwt.tokens import RefreshToken from rest_framework_simplejwt.tokens import RefreshToken
@@ -191,22 +188,22 @@ class LoginAPIView(APIView):
"access": str(access_token), "access": str(access_token),
"refresh": str(refresh), "refresh": str(refresh),
"user": user, "user": user,
"message": "Login successful", "detail": "Login successful",
} }
) )
return Response(response_serializer.data) return Response(response_serializer.data)
else: else:
return Response( return Response(
{ {
"error": "Email verification required", "detail": "Please verify your email address before logging in. Check your email for a verification link.",
"message": "Please verify your email address before logging in. Check your email for a verification link.", "code": "EMAIL_VERIFICATION_REQUIRED",
"email_verification_required": True "email_verification_required": True,
}, },
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
else: else:
return Response( return Response(
{"error": "Invalid credentials"}, {"detail": "Invalid credentials"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -237,7 +234,7 @@ class SignupAPIView(APIView):
# instantiate mixin before calling to avoid type-mismatch in static analysis # instantiate mixin before calling to avoid type-mismatch in static analysis
TurnstileMixin().validate_turnstile(request) TurnstileMixin().validate_turnstile(request)
except ValidationError as e: except ValidationError as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
except Exception: except Exception:
# If mixin doesn't do anything, continue # If mixin doesn't do anything, continue
pass pass
@@ -252,7 +249,7 @@ class SignupAPIView(APIView):
"access": None, "access": None,
"refresh": None, "refresh": None,
"user": user, "user": user,
"message": "Registration successful. Please check your email to verify your account.", "detail": "Registration successful. Please check your email to verify your account.",
"email_verification_required": True, "email_verification_required": True,
} }
) )
@@ -282,18 +279,18 @@ class LogoutAPIView(APIView):
try: try:
# Get refresh token from request data with proper type handling # Get refresh token from request data with proper type handling
refresh_token = None refresh_token = None
if hasattr(request, 'data') and request.data is not None: if hasattr(request, "data") and request.data is not None:
data = getattr(request, 'data', {}) data = getattr(request, "data", {})
if hasattr(data, 'get'): if hasattr(data, "get"):
refresh_token = data.get("refresh") refresh_token = data.get("refresh")
if refresh_token and isinstance(refresh_token, str): if refresh_token and isinstance(refresh_token, str):
# Blacklist the refresh token # Blacklist the refresh token
from rest_framework_simplejwt.tokens import RefreshToken from rest_framework_simplejwt.tokens import RefreshToken
try: try:
# Create RefreshToken from string and blacklist it # Create RefreshToken from string and blacklist it
refresh_token_obj = RefreshToken( refresh_token_obj = RefreshToken(refresh_token) # type: ignore[arg-type]
refresh_token) # type: ignore[arg-type]
refresh_token_obj.blacklist() refresh_token_obj.blacklist()
except Exception: except Exception:
# Token might be invalid or already blacklisted # Token might be invalid or already blacklisted
@@ -306,14 +303,10 @@ class LogoutAPIView(APIView):
# Logout from session using the underlying HttpRequest # Logout from session using the underlying HttpRequest
logout(_get_underlying_request(request)) logout(_get_underlying_request(request))
response_serializer = LogoutOutputSerializer( response_serializer = LogoutOutputSerializer({"detail": "Logout successful"})
{"message": "Logout successful"}
)
return Response(response_serializer.data) return Response(response_serializer.data)
except Exception: except Exception:
return Response( return Response({"detail": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
{"error": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@extend_schema_view( @extend_schema_view(
@@ -357,15 +350,11 @@ class PasswordResetAPIView(APIView):
serializer_class = PasswordResetInputSerializer serializer_class = PasswordResetInputSerializer
def post(self, request: Request) -> Response: def post(self, request: Request) -> Response:
serializer = PasswordResetInputSerializer( serializer = PasswordResetInputSerializer(data=request.data, context={"request": request})
data=request.data, context={"request": request}
)
if serializer.is_valid(): if serializer.is_valid():
serializer.save() serializer.save()
response_serializer = PasswordResetOutputSerializer( response_serializer = PasswordResetOutputSerializer({"detail": "Password reset email sent"})
{"detail": "Password reset email sent"}
)
return Response(response_serializer.data) return Response(response_serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -391,15 +380,11 @@ class PasswordChangeAPIView(APIView):
serializer_class = PasswordChangeInputSerializer serializer_class = PasswordChangeInputSerializer
def post(self, request: Request) -> Response: def post(self, request: Request) -> Response:
serializer = PasswordChangeInputSerializer( serializer = PasswordChangeInputSerializer(data=request.data, context={"request": request})
data=request.data, context={"request": request}
)
if serializer.is_valid(): if serializer.is_valid():
serializer.save() serializer.save()
response_serializer = PasswordChangeOutputSerializer( response_serializer = PasswordChangeOutputSerializer({"detail": "Password changed successfully"})
{"detail": "Password changed successfully"}
)
return Response(response_serializer.data) return Response(response_serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -443,13 +428,9 @@ class SocialProvidersAPIView(APIView):
for social_app in social_apps: for social_app in social_apps:
try: try:
provider_name = ( provider_name = social_app.name or getattr(social_app, "provider", "").title()
social_app.name or getattr(social_app, "provider", "").title()
)
auth_url = request.build_absolute_uri( auth_url = request.build_absolute_uri(f"/accounts/{social_app.provider}/login/")
f"/accounts/{social_app.provider}/login/"
)
providers_list.append( providers_list.append(
{ {
@@ -532,7 +513,7 @@ class AvailableProvidersAPIView(APIView):
"name": "Discord", "name": "Discord",
"login_url": "/auth/social/discord/", "login_url": "/auth/social/discord/",
"connect_url": "/auth/social/connect/discord/", "connect_url": "/auth/social/connect/discord/",
} },
] ]
serializer = AvailableProviderSerializer(providers, many=True) serializer = AvailableProviderSerializer(providers, many=True)
@@ -585,31 +566,29 @@ class ConnectProviderAPIView(APIView):
def post(self, request: Request, provider: str) -> Response: def post(self, request: Request, provider: str) -> Response:
# Validate provider # Validate provider
if provider not in ['google', 'discord']: if provider not in ["google", "discord"]:
return Response( return Response(
{ {
"success": False, "detail": f"Provider '{provider}' is not supported",
"error": "INVALID_PROVIDER", "code": "INVALID_PROVIDER",
"message": f"Provider '{provider}' is not supported", "suggestions": ["Use 'google' or 'discord'"],
"suggestions": ["Use 'google' or 'discord'"]
}, },
status=status.HTTP_400_BAD_REQUEST status=status.HTTP_400_BAD_REQUEST,
) )
serializer = ConnectProviderInputSerializer(data=request.data) serializer = ConnectProviderInputSerializer(data=request.data)
if not serializer.is_valid(): if not serializer.is_valid():
return Response( return Response(
{ {
"success": False, "detail": "Invalid request data",
"error": "VALIDATION_ERROR", "code": "VALIDATION_ERROR",
"message": "Invalid request data",
"details": serializer.errors, "details": serializer.errors,
"suggestions": ["Provide a valid access_token"] "suggestions": ["Provide a valid access_token"],
}, },
status=status.HTTP_400_BAD_REQUEST status=status.HTTP_400_BAD_REQUEST,
) )
access_token = serializer.validated_data['access_token'] access_token = serializer.validated_data["access_token"]
try: try:
service = SocialProviderService() service = SocialProviderService()
@@ -622,14 +601,14 @@ class ConnectProviderAPIView(APIView):
return Response( return Response(
{ {
"success": False, "success": False,
"error": "CONNECTION_FAILED", "detail": "CONNECTION_FAILED",
"message": str(e), "message": str(e),
"suggestions": [ "suggestions": [
"Verify the access token is valid", "Verify the access token is valid",
"Ensure the provider account is not already connected to another user" "Ensure the provider account is not already connected to another user",
] ],
}, },
status=status.HTTP_400_BAD_REQUEST status=status.HTTP_400_BAD_REQUEST,
) )
@@ -653,35 +632,33 @@ class DisconnectProviderAPIView(APIView):
def post(self, request: Request, provider: str) -> Response: def post(self, request: Request, provider: str) -> Response:
# Validate provider # Validate provider
if provider not in ['google', 'discord']: if provider not in ["google", "discord"]:
return Response( return Response(
{ {
"success": False, "detail": f"Provider '{provider}' is not supported",
"error": "INVALID_PROVIDER", "code": "INVALID_PROVIDER",
"message": f"Provider '{provider}' is not supported", "suggestions": ["Use 'google' or 'discord'"],
"suggestions": ["Use 'google' or 'discord'"]
}, },
status=status.HTTP_400_BAD_REQUEST status=status.HTTP_400_BAD_REQUEST,
) )
try: try:
service = SocialProviderService() service = SocialProviderService()
# Check if disconnection is safe # Check if disconnection is safe
can_disconnect, reason = service.can_disconnect_provider( can_disconnect, reason = service.can_disconnect_provider(request.user, provider)
request.user, provider)
if not can_disconnect: if not can_disconnect:
return Response( return Response(
{ {
"success": False, "success": False,
"error": "UNSAFE_DISCONNECTION", "detail": "UNSAFE_DISCONNECTION",
"message": reason, "message": reason,
"suggestions": [ "suggestions": [
"Set up email/password authentication before disconnecting", "Set up email/password authentication before disconnecting",
"Connect another social provider before disconnecting this one" "Connect another social provider before disconnecting this one",
] ],
}, },
status=status.HTTP_400_BAD_REQUEST status=status.HTTP_400_BAD_REQUEST,
) )
# Perform disconnection # Perform disconnection
@@ -694,14 +671,14 @@ class DisconnectProviderAPIView(APIView):
return Response( return Response(
{ {
"success": False, "success": False,
"error": "DISCONNECTION_FAILED", "detail": "DISCONNECTION_FAILED",
"message": str(e), "message": str(e),
"suggestions": [ "suggestions": [
"Verify the provider is currently connected", "Verify the provider is currently connected",
"Ensure you have alternative authentication methods" "Ensure you have alternative authentication methods",
] ],
}, },
status=status.HTTP_400_BAD_REQUEST status=status.HTTP_400_BAD_REQUEST,
) )
@@ -755,7 +732,7 @@ class EmailVerificationAPIView(APIView):
from apps.accounts.models import EmailVerification from apps.accounts.models import EmailVerification
try: try:
verification = EmailVerification.objects.select_related('user').get(token=token) verification = EmailVerification.objects.select_related("user").get(token=token)
user = verification.user user = verification.user
# Activate the user # Activate the user
@@ -765,16 +742,10 @@ class EmailVerificationAPIView(APIView):
# Delete the verification record # Delete the verification record
verification.delete() verification.delete()
return Response({ return Response({"detail": "Email verified successfully. You can now log in.", "success": True})
"message": "Email verified successfully. You can now log in.",
"success": True
})
except EmailVerification.DoesNotExist: except EmailVerification.DoesNotExist:
return Response( return Response({"detail": "Invalid or expired verification token"}, status=status.HTTP_404_NOT_FOUND)
{"error": "Invalid or expired verification token"},
status=status.HTTP_404_NOT_FOUND
)
@extend_schema_view( @extend_schema_view(
@@ -803,27 +774,20 @@ class ResendVerificationAPIView(APIView):
from apps.accounts.models import EmailVerification from apps.accounts.models import EmailVerification
email = request.data.get('email') email = request.data.get("email")
if not email: if not email:
return Response( return Response({"detail": "Email address is required"}, status=status.HTTP_400_BAD_REQUEST)
{"error": "Email address is required"},
status=status.HTTP_400_BAD_REQUEST
)
try: try:
user = UserModel.objects.get(email__iexact=email.strip().lower()) user = UserModel.objects.get(email__iexact=email.strip().lower())
# Don't resend if user is already active # Don't resend if user is already active
if user.is_active: if user.is_active:
return Response( return Response({"detail": "Email is already verified"}, status=status.HTTP_400_BAD_REQUEST)
{"error": "Email is already verified"},
status=status.HTTP_400_BAD_REQUEST
)
# Create or update verification record # Create or update verification record
verification, created = EmailVerification.objects.get_or_create( verification, created = EmailVerification.objects.get_or_create(
user=user, user=user, defaults={"token": get_random_string(64)}
defaults={'token': get_random_string(64)}
) )
if not created: if not created:
@@ -833,9 +797,7 @@ class ResendVerificationAPIView(APIView):
# Send verification email # Send verification email
site = get_current_site(_get_underlying_request(request)) site = get_current_site(_get_underlying_request(request))
verification_url = request.build_absolute_uri( verification_url = request.build_absolute_uri(f"/api/v1/auth/verify-email/{verification.token}/")
f"/api/v1/auth/verify-email/{verification.token}/"
)
try: try:
EmailService.send_email( EmailService.send_email(
@@ -855,27 +817,21 @@ The ThrillWiki Team
site=site, site=site,
) )
return Response({ return Response({"detail": "Verification email sent successfully", "success": True})
"message": "Verification email sent successfully",
"success": True
})
except Exception as e: except Exception as e:
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.error(f"Failed to send verification email to {user.email}: {e}") logger.error(f"Failed to send verification email to {user.email}: {e}")
return Response( return Response(
{"error": "Failed to send verification email"}, {"detail": "Failed to send verification email"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
status=status.HTTP_500_INTERNAL_SERVER_ERROR
) )
except UserModel.DoesNotExist: except UserModel.DoesNotExist:
# Don't reveal whether email exists # Don't reveal whether email exists
return Response({ return Response({"detail": "If the email exists, a verification email has been sent", "success": True})
"message": "If the email exists, a verification email has been sent",
"success": True
})
# Note: User Profile, Top List, and Top List Item ViewSets are now handled # Note: User Profile, Top List, and Top List Item ViewSets are now handled

View File

@@ -8,7 +8,6 @@ Caching Strategy:
- EntityNotFoundView: No caching - POST requests with context-specific data - EntityNotFoundView: No caching - POST requests with context-specific data
""" """
import contextlib import contextlib
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema
@@ -82,9 +81,7 @@ class EntityFuzzySearchView(APIView):
try: try:
# Parse request data # Parse request data
query = request.data.get("query", "").strip() query = request.data.get("query", "").strip()
entity_types_raw = request.data.get( entity_types_raw = request.data.get("entity_types", ["park", "ride", "company"])
"entity_types", ["park", "ride", "company"]
)
include_suggestions = request.data.get("include_suggestions", True) include_suggestions = request.data.get("include_suggestions", True)
# Validate query # Validate query
@@ -92,7 +89,7 @@ class EntityFuzzySearchView(APIView):
return Response( return Response(
{ {
"success": False, "success": False,
"error": "Query must be at least 2 characters long", "detail": "Query must be at least 2 characters long",
"code": "INVALID_QUERY", "code": "INVALID_QUERY",
}, },
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
@@ -120,9 +117,7 @@ class EntityFuzzySearchView(APIView):
"query": query, "query": query,
"matches": [match.to_dict() for match in matches], "matches": [match.to_dict() for match in matches],
"user_authenticated": ( "user_authenticated": (
request.user.is_authenticated request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
if hasattr(request.user, "is_authenticated")
else False
), ),
} }
@@ -143,7 +138,7 @@ class EntityFuzzySearchView(APIView):
return Response( return Response(
{ {
"success": False, "success": False,
"error": f"Internal server error: {str(e)}", "detail": f"Internal server error: {str(e)}",
"code": "INTERNAL_ERROR", "code": "INTERNAL_ERROR",
}, },
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -192,7 +187,7 @@ class EntityNotFoundView(APIView):
return Response( return Response(
{ {
"success": False, "success": False,
"error": "original_query is required", "detail": "original_query is required",
"code": "MISSING_QUERY", "code": "MISSING_QUERY",
}, },
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
@@ -233,9 +228,7 @@ class EntityNotFoundView(APIView):
"context": context, "context": context,
"matches": [match.to_dict() for match in matches], "matches": [match.to_dict() for match in matches],
"user_authenticated": ( "user_authenticated": (
request.user.is_authenticated request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
if hasattr(request.user, "is_authenticated")
else False
), ),
"has_matches": len(matches) > 0, "has_matches": len(matches) > 0,
} }
@@ -257,7 +250,7 @@ class EntityNotFoundView(APIView):
return Response( return Response(
{ {
"success": False, "success": False,
"error": f"Internal server error: {str(e)}", "detail": f"Internal server error: {str(e)}",
"code": "INTERNAL_ERROR", "code": "INTERNAL_ERROR",
}, },
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -297,9 +290,7 @@ class QuickEntitySuggestionView(APIView):
limit = min(int(request.GET.get("limit", 5)), 10) # Cap at 10 limit = min(int(request.GET.get("limit", 5)), 10) # Cap at 10
if not query or len(query) < 2: if not query or len(query) < 2:
return Response( return Response({"suggestions": [], "query": query}, status=status.HTTP_200_OK)
{"suggestions": [], "query": query}, status=status.HTTP_200_OK
)
# Parse entity types # Parse entity types
entity_types = [] entity_types = []
@@ -312,9 +303,7 @@ class QuickEntitySuggestionView(APIView):
entity_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY] entity_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
# Get fuzzy matches # Get fuzzy matches
matches, _ = entity_fuzzy_matcher.find_entity( matches, _ = entity_fuzzy_matcher.find_entity(query=query, entity_types=entity_types, user=request.user)
query=query, entity_types=entity_types, user=request.user
)
# Format as simple suggestions # Format as simple suggestions
suggestions = [] suggestions = []
@@ -337,15 +326,13 @@ class QuickEntitySuggestionView(APIView):
except Exception as e: except Exception as e:
return Response( return Response(
{"suggestions": [], "query": request.GET.get("q", ""), "error": str(e)}, {"suggestions": [], "query": request.GET.get("q", ""), "detail": str(e)},
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
) # Return 200 even on errors for autocomplete ) # Return 200 even on errors for autocomplete
# Utility function for other views to use # Utility function for other views to use
def get_entity_suggestions( def get_entity_suggestions(query: str, entity_types: list[str] | None = None, user=None):
query: str, entity_types: list[str] | None = None, user=None
):
""" """
Utility function for other Django views to get entity suggestions. Utility function for other Django views to get entity suggestions.
@@ -370,8 +357,6 @@ def get_entity_suggestions(
if not parsed_types: if not parsed_types:
parsed_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY] parsed_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
return entity_fuzzy_matcher.find_entity( return entity_fuzzy_matcher.find_entity(query=query, entity_types=parsed_types, user=user)
query=query, entity_types=parsed_types, user=user
)
except Exception: except Exception:
return [], None return [], None

View File

@@ -76,7 +76,7 @@ class SendEmailView(APIView):
if not all([to, subject, text]): if not all([to, subject, text]):
return Response( return Response(
{ {
"error": "Missing required fields", "detail": "Missing required fields",
"required_fields": ["to", "subject", "text"], "required_fields": ["to", "subject", "text"],
}, },
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
@@ -96,11 +96,9 @@ class SendEmailView(APIView):
) )
return Response( return Response(
{"message": "Email sent successfully", "response": response}, {"detail": "Email sent successfully", "response": response},
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
) )
except Exception as e: except Exception as e:
return Response( return Response({"detail": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

View File

@@ -37,21 +37,11 @@ class _FallbackSerializer(drf_serializers.Serializer):
return {} return {}
ParkHistoryEventSerializer = getattr( ParkHistoryEventSerializer = getattr(history_serializers, "ParkHistoryEventSerializer", _FallbackSerializer)
history_serializers, "ParkHistoryEventSerializer", _FallbackSerializer RideHistoryEventSerializer = getattr(history_serializers, "RideHistoryEventSerializer", _FallbackSerializer)
) ParkHistoryOutputSerializer = getattr(history_serializers, "ParkHistoryOutputSerializer", _FallbackSerializer)
RideHistoryEventSerializer = getattr( RideHistoryOutputSerializer = getattr(history_serializers, "RideHistoryOutputSerializer", _FallbackSerializer)
history_serializers, "RideHistoryEventSerializer", _FallbackSerializer UnifiedHistoryTimelineSerializer = getattr(history_serializers, "UnifiedHistoryTimelineSerializer", _FallbackSerializer)
)
ParkHistoryOutputSerializer = getattr(
history_serializers, "ParkHistoryOutputSerializer", _FallbackSerializer
)
RideHistoryOutputSerializer = getattr(
history_serializers, "RideHistoryOutputSerializer", _FallbackSerializer
)
UnifiedHistoryTimelineSerializer = getattr(
history_serializers, "UnifiedHistoryTimelineSerializer", _FallbackSerializer
)
# --- Constants for model strings to avoid duplication --- # --- Constants for model strings to avoid duplication ---
PARK_MODEL = "parks.park" PARK_MODEL = "parks.park"
@@ -201,18 +191,14 @@ class ParkHistoryViewSet(ReadOnlyModelViewSet):
# Base queryset for park events # Base queryset for park events
queryset = ( queryset = (
pghistory.models.Events.objects.filter( pghistory.models.Events.objects.filter(pgh_model__in=[PARK_MODEL], pgh_obj_id=getattr(park, "id", None))
pgh_model__in=[PARK_MODEL], pgh_obj_id=getattr(park, "id", None)
)
.select_related() .select_related()
.order_by("-pgh_created_at") .order_by("-pgh_created_at")
) )
# Apply list filters via helper to reduce complexity # Apply list filters via helper to reduce complexity
if self.action == "list": if self.action == "list":
queryset = _apply_list_filters( queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=50, max_limit=500)
queryset, cast(Request, self.request), default_limit=50, max_limit=500
)
return queryset return queryset
@@ -322,18 +308,14 @@ class RideHistoryViewSet(ReadOnlyModelViewSet):
# Base queryset for ride events # Base queryset for ride events
queryset = ( queryset = (
pghistory.models.Events.objects.filter( pghistory.models.Events.objects.filter(pgh_model__in=RIDE_MODELS, pgh_obj_id=getattr(ride, "id", None))
pgh_model__in=RIDE_MODELS, pgh_obj_id=getattr(ride, "id", None)
)
.select_related() .select_related()
.order_by("-pgh_created_at") .order_by("-pgh_created_at")
) )
# Apply list filters via helper # Apply list filters via helper
if self.action == "list": if self.action == "list":
queryset = _apply_list_filters( queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=50, max_limit=500)
queryset, cast(Request, self.request), default_limit=50, max_limit=500
)
return queryset return queryset
@@ -462,9 +444,7 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet):
# Apply shared list filters when serving the list action # Apply shared list filters when serving the list action
if self.action == "list": if self.action == "list":
queryset = _apply_list_filters( queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=100, max_limit=1000)
queryset, cast(Request, self.request), default_limit=100, max_limit=1000
)
return queryset return queryset
@@ -477,9 +457,7 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet):
events = list(self.get_queryset()) # evaluate for counts / earliest/latest use events = list(self.get_queryset()) # evaluate for counts / earliest/latest use
# Summary statistics across all tracked models # Summary statistics across all tracked models
total_events = pghistory.models.Events.objects.filter( total_events = pghistory.models.Events.objects.filter(pgh_model__in=ALL_TRACKED_MODELS).count()
pgh_model__in=ALL_TRACKED_MODELS
).count()
event_type_counts = ( event_type_counts = (
pghistory.models.Events.objects.filter(pgh_model__in=ALL_TRACKED_MODELS) pghistory.models.Events.objects.filter(pgh_model__in=ALL_TRACKED_MODELS)
@@ -497,12 +475,8 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet):
"summary": { "summary": {
"total_events": total_events, "total_events": total_events,
"events_returned": len(events), "events_returned": len(events),
"event_type_breakdown": { "event_type_breakdown": {item["pgh_label"]: item["count"] for item in event_type_counts},
item["pgh_label"]: item["count"] for item in event_type_counts "model_type_breakdown": {item["pgh_model"]: item["count"] for item in model_type_counts},
},
"model_type_breakdown": {
item["pgh_model"]: item["count"] for item in model_type_counts
},
"time_range": { "time_range": {
"earliest": events[-1].pgh_created_at if events else None, "earliest": events[-1].pgh_created_at if events else None,
"latest": events[0].pgh_created_at if events else None, "latest": events[0].pgh_created_at if events else None,

View File

@@ -11,6 +11,7 @@ from apps.core.utils.cloudflare import get_direct_upload_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GenerateUploadURLView(APIView): class GenerateUploadURLView(APIView):
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
@@ -21,19 +22,10 @@ class GenerateUploadURLView(APIView):
return Response(result, status=status.HTTP_200_OK) return Response(result, status=status.HTTP_200_OK)
except ImproperlyConfigured as e: except ImproperlyConfigured as e:
logger.error(f"Configuration Error: {e}") logger.error(f"Configuration Error: {e}")
return Response( return Response({"detail": "Server configuration error."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
{"detail": "Server configuration error."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Cloudflare API Error: {e}") logger.error(f"Cloudflare API Error: {e}")
return Response( return Response({"detail": "Failed to generate upload URL."}, status=status.HTTP_502_BAD_GATEWAY)
{"detail": "Failed to generate upload URL."},
status=status.HTTP_502_BAD_GATEWAY
)
except Exception: except Exception:
logger.exception("Unexpected error generating upload URL") logger.exception("Unexpected error generating upload URL")
return Response( return Response({"detail": "An unexpected error occurred."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
{"detail": "An unexpected error occurred."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

View File

@@ -162,16 +162,13 @@ class MapLocationsAPIView(APIView):
if not all([north, south, east, west]): if not all([north, south, east, west]):
return None return None
try: try:
return Polygon.from_bbox( return Polygon.from_bbox((float(west), float(south), float(east), float(north)))
(float(west), float(south), float(east), float(north))
)
except (ValueError, TypeError): except (ValueError, TypeError):
return None return None
def _serialize_park_location(self, park) -> dict: def _serialize_park_location(self, park) -> dict:
"""Serialize park location data.""" """Serialize park location data."""
location = park.location if hasattr( location = park.location if hasattr(park, "location") and park.location else None
park, "location") and park.location else None
return { return {
"city": location.city if location else "", "city": location.city if location else "",
"state": location.state if location else "", "state": location.state if location else "",
@@ -181,8 +178,7 @@ class MapLocationsAPIView(APIView):
def _serialize_park_data(self, park) -> dict: def _serialize_park_data(self, park) -> dict:
"""Serialize park data for map response.""" """Serialize park data for map response."""
location = park.location if hasattr( location = park.location if hasattr(park, "location") and park.location else None
park, "location") and park.location else None
return { return {
"id": park.id, "id": park.id,
"type": "park", "type": "park",
@@ -195,9 +191,7 @@ class MapLocationsAPIView(APIView):
"stats": { "stats": {
"coaster_count": park.coaster_count or 0, "coaster_count": park.coaster_count or 0,
"ride_count": park.ride_count or 0, "ride_count": park.ride_count or 0,
"average_rating": ( "average_rating": (float(park.average_rating) if park.average_rating else None),
float(park.average_rating) if park.average_rating else None
),
}, },
} }
@@ -206,14 +200,10 @@ class MapLocationsAPIView(APIView):
if "park" not in params["types"]: if "park" not in params["types"]:
return [] return []
parks_query = Park.objects.select_related( parks_query = Park.objects.select_related("location", "operator").filter(location__point__isnull=False)
"location", "operator"
).filter(location__point__isnull=False)
# Apply bounds filtering # Apply bounds filtering
bounds_polygon = self._create_bounds_polygon( bounds_polygon = self._create_bounds_polygon(params["north"], params["south"], params["east"], params["west"])
params["north"], params["south"], params["east"], params["west"]
)
if bounds_polygon: if bounds_polygon:
parks_query = parks_query.filter(location__point__within=bounds_polygon) parks_query = parks_query.filter(location__point__within=bounds_polygon)
@@ -229,11 +219,7 @@ class MapLocationsAPIView(APIView):
def _serialize_ride_location(self, ride) -> dict: def _serialize_ride_location(self, ride) -> dict:
"""Serialize ride location data.""" """Serialize ride location data."""
location = ( location = ride.park.location if hasattr(ride.park, "location") and ride.park.location else None
ride.park.location
if hasattr(ride.park, "location") and ride.park.location
else None
)
return { return {
"city": location.city if location else "", "city": location.city if location else "",
"state": location.state if location else "", "state": location.state if location else "",
@@ -243,11 +229,7 @@ class MapLocationsAPIView(APIView):
def _serialize_ride_data(self, ride) -> dict: def _serialize_ride_data(self, ride) -> dict:
"""Serialize ride data for map response.""" """Serialize ride data for map response."""
location = ( location = ride.park.location if hasattr(ride.park, "location") and ride.park.location else None
ride.park.location
if hasattr(ride.park, "location") and ride.park.location
else None
)
return { return {
"id": ride.id, "id": ride.id,
"type": "ride", "type": "ride",
@@ -259,9 +241,7 @@ class MapLocationsAPIView(APIView):
"location": self._serialize_ride_location(ride), "location": self._serialize_ride_location(ride),
"stats": { "stats": {
"category": ride.get_category_display() if ride.category else None, "category": ride.get_category_display() if ride.category else None,
"average_rating": ( "average_rating": (float(ride.average_rating) if ride.average_rating else None),
float(ride.average_rating) if ride.average_rating else None
),
"park_name": ride.park.name, "park_name": ride.park.name,
}, },
} }
@@ -271,17 +251,14 @@ class MapLocationsAPIView(APIView):
if "ride" not in params["types"]: if "ride" not in params["types"]:
return [] return []
rides_query = Ride.objects.select_related( rides_query = Ride.objects.select_related("park__location", "manufacturer").filter(
"park__location", "manufacturer" park__location__point__isnull=False
).filter(park__location__point__isnull=False) )
# Apply bounds filtering # Apply bounds filtering
bounds_polygon = self._create_bounds_polygon( bounds_polygon = self._create_bounds_polygon(params["north"], params["south"], params["east"], params["west"])
params["north"], params["south"], params["east"], params["west"]
)
if bounds_polygon: if bounds_polygon:
rides_query = rides_query.filter( rides_query = rides_query.filter(park__location__point__within=bounds_polygon)
park__location__point__within=bounds_polygon)
# Apply text search # Apply text search
if params["query"]: if params["query"]:
@@ -335,7 +312,7 @@ class MapLocationsAPIView(APIView):
# Use EnhancedCacheService for improved caching with monitoring # Use EnhancedCacheService for improved caching with monitoring
cache_service = EnhancedCacheService() cache_service = EnhancedCacheService()
cached_result = cache_service.get_cached_api_response('map_locations', params) cached_result = cache_service.get_cached_api_response("map_locations", params)
if cached_result: if cached_result:
logger.debug(f"Cache hit for map_locations with key: {cache_key}") logger.debug(f"Cache hit for map_locations with key: {cache_key}")
return Response(cached_result) return Response(cached_result)
@@ -349,7 +326,7 @@ class MapLocationsAPIView(APIView):
result = self._build_response(locations, params) result = self._build_response(locations, params)
# Cache result for 5 minutes using EnhancedCacheService # Cache result for 5 minutes using EnhancedCacheService
cache_service.cache_api_response('map_locations', params, result, timeout=300) cache_service.cache_api_response("map_locations", params, result, timeout=300)
logger.debug(f"Cached map_locations result for key: {cache_key}") logger.debug(f"Cached map_locations result for key: {cache_key}")
return Response(result) return Response(result)
@@ -357,7 +334,7 @@ class MapLocationsAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in MapLocationsAPIView: {str(e)}", exc_info=True) logger.error(f"Error in MapLocationsAPIView: {str(e)}", exc_info=True)
return Response( return Response(
{"status": "error", "message": "Failed to retrieve map locations"}, {"status": "error", "detail": "Failed to retrieve map locations"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -401,34 +378,28 @@ class MapLocationDetailAPIView(APIView):
permission_classes = [AllowAny] permission_classes = [AllowAny]
@cache_api_response(timeout=1800, key_prefix="map_detail") @cache_api_response(timeout=1800, key_prefix="map_detail")
def get( def get(self, request: HttpRequest, location_type: str, location_id: int) -> Response:
self, request: HttpRequest, location_type: str, location_id: int
) -> Response:
"""Get detailed information for a specific location.""" """Get detailed information for a specific location."""
try: try:
if location_type == "park": if location_type == "park":
try: try:
obj = Park.objects.select_related("location", "operator").get( obj = Park.objects.select_related("location", "operator").get(id=location_id)
id=location_id
)
except Park.DoesNotExist: except Park.DoesNotExist:
return Response( return Response(
{"status": "error", "message": "Park not found"}, {"status": "error", "detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
elif location_type == "ride": elif location_type == "ride":
try: try:
obj = Ride.objects.select_related( obj = Ride.objects.select_related("park__location", "manufacturer").get(id=location_id)
"park__location", "manufacturer"
).get(id=location_id)
except Ride.DoesNotExist: except Ride.DoesNotExist:
return Response( return Response(
{"status": "error", "message": "Ride not found"}, {"status": "error", "detail": "Ride not found"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
else: else:
return Response( return Response(
{"status": "error", "message": "Invalid location type"}, {"status": "error", "detail": "Invalid location type"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -440,59 +411,27 @@ class MapLocationDetailAPIView(APIView):
"name": obj.name, "name": obj.name,
"slug": obj.slug, "slug": obj.slug,
"description": obj.description, "description": obj.description,
"latitude": ( "latitude": (obj.location.latitude if hasattr(obj, "location") and obj.location else None),
obj.location.latitude "longitude": (obj.location.longitude if hasattr(obj, "location") and obj.location else None),
if hasattr(obj, "location") and obj.location
else None
),
"longitude": (
obj.location.longitude
if hasattr(obj, "location") and obj.location
else None
),
"status": obj.status, "status": obj.status,
"location": { "location": {
"street_address": ( "street_address": (
obj.location.street_address obj.location.street_address if hasattr(obj, "location") and obj.location else ""
if hasattr(obj, "location") and obj.location
else ""
),
"city": (
obj.location.city
if hasattr(obj, "location") and obj.location
else ""
),
"state": (
obj.location.state
if hasattr(obj, "location") and obj.location
else ""
),
"country": (
obj.location.country
if hasattr(obj, "location") and obj.location
else ""
),
"postal_code": (
obj.location.postal_code
if hasattr(obj, "location") and obj.location
else ""
), ),
"city": (obj.location.city if hasattr(obj, "location") and obj.location else ""),
"state": (obj.location.state if hasattr(obj, "location") and obj.location else ""),
"country": (obj.location.country if hasattr(obj, "location") and obj.location else ""),
"postal_code": (obj.location.postal_code if hasattr(obj, "location") and obj.location else ""),
"formatted_address": ( "formatted_address": (
obj.location.formatted_address obj.location.formatted_address if hasattr(obj, "location") and obj.location else ""
if hasattr(obj, "location") and obj.location
else ""
), ),
}, },
"stats": { "stats": {
"coaster_count": obj.coaster_count or 0, "coaster_count": obj.coaster_count or 0,
"ride_count": obj.ride_count or 0, "ride_count": obj.ride_count or 0,
"average_rating": ( "average_rating": (float(obj.average_rating) if obj.average_rating else None),
float(obj.average_rating) if obj.average_rating else None
),
"size_acres": float(obj.size_acres) if obj.size_acres else None, "size_acres": float(obj.size_acres) if obj.size_acres else None,
"opening_date": ( "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
obj.opening_date.isoformat() if obj.opening_date else None
),
}, },
"nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107 "nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107
} }
@@ -504,14 +443,10 @@ class MapLocationDetailAPIView(APIView):
"slug": obj.slug, "slug": obj.slug,
"description": obj.description, "description": obj.description,
"latitude": ( "latitude": (
obj.park.location.latitude obj.park.location.latitude if hasattr(obj.park, "location") and obj.park.location else None
if hasattr(obj.park, "location") and obj.park.location
else None
), ),
"longitude": ( "longitude": (
obj.park.location.longitude obj.park.location.longitude if hasattr(obj.park, "location") and obj.park.location else None
if hasattr(obj.park, "location") and obj.park.location
else None
), ),
"status": obj.status, "status": obj.status,
"location": { "location": {
@@ -520,25 +455,15 @@ class MapLocationDetailAPIView(APIView):
if hasattr(obj.park, "location") and obj.park.location if hasattr(obj.park, "location") and obj.park.location
else "" else ""
), ),
"city": ( "city": (obj.park.location.city if hasattr(obj.park, "location") and obj.park.location else ""),
obj.park.location.city
if hasattr(obj.park, "location") and obj.park.location
else ""
),
"state": ( "state": (
obj.park.location.state obj.park.location.state if hasattr(obj.park, "location") and obj.park.location else ""
if hasattr(obj.park, "location") and obj.park.location
else ""
), ),
"country": ( "country": (
obj.park.location.country obj.park.location.country if hasattr(obj.park, "location") and obj.park.location else ""
if hasattr(obj.park, "location") and obj.park.location
else ""
), ),
"postal_code": ( "postal_code": (
obj.park.location.postal_code obj.park.location.postal_code if hasattr(obj.park, "location") and obj.park.location else ""
if hasattr(obj.park, "location") and obj.park.location
else ""
), ),
"formatted_address": ( "formatted_address": (
obj.park.location.formatted_address obj.park.location.formatted_address
@@ -547,19 +472,11 @@ class MapLocationDetailAPIView(APIView):
), ),
}, },
"stats": { "stats": {
"category": ( "category": (obj.get_category_display() if obj.category else None),
obj.get_category_display() if obj.category else None "average_rating": (float(obj.average_rating) if obj.average_rating else None),
),
"average_rating": (
float(obj.average_rating) if obj.average_rating else None
),
"park_name": obj.park.name, "park_name": obj.park.name,
"opening_date": ( "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
obj.opening_date.isoformat() if obj.opening_date else None "manufacturer": (obj.manufacturer.name if obj.manufacturer else None),
),
"manufacturer": (
obj.manufacturer.name if obj.manufacturer else None
),
}, },
"nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107 "nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107
} }
@@ -574,7 +491,7 @@ class MapLocationDetailAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in MapLocationDetailAPIView: {str(e)}", exc_info=True) logger.error(f"Error in MapLocationDetailAPIView: {str(e)}", exc_info=True)
return Response( return Response(
{"status": "error", "message": "Failed to retrieve location details"}, {"status": "error", "detail": "Failed to retrieve location details"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -640,7 +557,7 @@ class MapSearchAPIView(APIView):
return Response( return Response(
{ {
"status": "error", "status": "error",
"message": "Search query 'q' parameter is required", "detail": "Search query 'q' parameter is required",
}, },
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -672,30 +589,16 @@ class MapSearchAPIView(APIView):
"name": park.name, "name": park.name,
"slug": park.slug, "slug": park.slug,
"latitude": ( "latitude": (
park.location.latitude park.location.latitude if hasattr(park, "location") and park.location else None
if hasattr(park, "location") and park.location
else None
), ),
"longitude": ( "longitude": (
park.location.longitude park.location.longitude if hasattr(park, "location") and park.location else None
if hasattr(park, "location") and park.location
else None
), ),
"location": { "location": {
"city": ( "city": (park.location.city if hasattr(park, "location") and park.location else ""),
park.location.city "state": (park.location.state if hasattr(park, "location") and park.location else ""),
if hasattr(park, "location") and park.location
else ""
),
"state": (
park.location.state
if hasattr(park, "location") and park.location
else ""
),
"country": ( "country": (
park.location.country park.location.country if hasattr(park, "location") and park.location else ""
if hasattr(park, "location") and park.location
else ""
), ),
}, },
"relevance_score": 1.0, # See FUTURE_WORK.md - THRILLWIKI-108 "relevance_score": 1.0, # See FUTURE_WORK.md - THRILLWIKI-108
@@ -734,20 +637,17 @@ class MapSearchAPIView(APIView):
"location": { "location": {
"city": ( "city": (
ride.park.location.city ride.park.location.city
if hasattr(ride.park, "location") if hasattr(ride.park, "location") and ride.park.location
and ride.park.location
else "" else ""
), ),
"state": ( "state": (
ride.park.location.state ride.park.location.state
if hasattr(ride.park, "location") if hasattr(ride.park, "location") and ride.park.location
and ride.park.location
else "" else ""
), ),
"country": ( "country": (
ride.park.location.country ride.park.location.country
if hasattr(ride.park, "location") if hasattr(ride.park, "location") and ride.park.location
and ride.park.location
else "" else ""
), ),
}, },
@@ -776,7 +676,7 @@ class MapSearchAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in MapSearchAPIView: {str(e)}", exc_info=True) logger.error(f"Error in MapSearchAPIView: {str(e)}", exc_info=True)
return Response( return Response(
{"status": "error", "message": "Search failed due to internal error"}, {"status": "error", "detail": "Search failed due to internal error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -848,8 +748,7 @@ class MapBoundsAPIView(APIView):
if not all([north_str, south_str, east_str, west_str]): if not all([north_str, south_str, east_str, west_str]):
return Response( return Response(
{"status": "error", {"status": "error", "detail": "All bounds parameters (north, south, east, west) are required"},
"message": "All bounds parameters (north, south, east, west) are required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -860,7 +759,7 @@ class MapBoundsAPIView(APIView):
west = float(west_str) if west_str else 0.0 west = float(west_str) if west_str else 0.0
except (TypeError, ValueError): except (TypeError, ValueError):
return Response( return Response(
{"status": "error", "message": "Invalid bounds parameters"}, {"status": "error", "detail": "Invalid bounds parameters"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -869,7 +768,7 @@ class MapBoundsAPIView(APIView):
return Response( return Response(
{ {
"status": "error", "status": "error",
"message": "North bound must be greater than south bound", "detail": "North bound must be greater than south bound",
}, },
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -878,7 +777,7 @@ class MapBoundsAPIView(APIView):
return Response( return Response(
{ {
"status": "error", "status": "error",
"message": "West bound must be less than east bound", "detail": "West bound must be less than east bound",
}, },
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -891,9 +790,7 @@ class MapBoundsAPIView(APIView):
# Get parks within bounds # Get parks within bounds
if "park" in types: if "park" in types:
parks_query = Park.objects.select_related("location").filter( parks_query = Park.objects.select_related("location").filter(location__point__within=bounds_polygon)
location__point__within=bounds_polygon
)
for park in parks_query[:100]: # Limit results for park in parks_query[:100]: # Limit results
locations.append( locations.append(
@@ -903,14 +800,10 @@ class MapBoundsAPIView(APIView):
"name": park.name, "name": park.name,
"slug": park.slug, "slug": park.slug,
"latitude": ( "latitude": (
park.location.latitude park.location.latitude if hasattr(park, "location") and park.location else None
if hasattr(park, "location") and park.location
else None
), ),
"longitude": ( "longitude": (
park.location.longitude park.location.longitude if hasattr(park, "location") and park.location else None
if hasattr(park, "location") and park.location
else None
), ),
"status": park.status, "status": park.status,
} }
@@ -960,7 +853,7 @@ class MapBoundsAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in MapBoundsAPIView: {str(e)}", exc_info=True) logger.error(f"Error in MapBoundsAPIView: {str(e)}", exc_info=True)
return Response( return Response(
{"status": "error", "message": "Failed to retrieve locations within bounds"}, {"status": "error", "detail": "Failed to retrieve locations within bounds"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -987,18 +880,15 @@ class MapStatsAPIView(APIView):
"""Get map service statistics and performance metrics.""" """Get map service statistics and performance metrics."""
try: try:
# Count locations with coordinates # Count locations with coordinates
parks_with_location = Park.objects.filter( parks_with_location = Park.objects.filter(location__point__isnull=False).count()
location__point__isnull=False rides_with_location = Ride.objects.filter(park__location__point__isnull=False).count()
).count()
rides_with_location = Ride.objects.filter(
park__location__point__isnull=False
).count()
total_locations = parks_with_location + rides_with_location total_locations = parks_with_location + rides_with_location
# Get cache statistics # Get cache statistics
from apps.core.services.enhanced_cache_service import CacheMonitor from apps.core.services.enhanced_cache_service import CacheMonitor
cache_monitor = CacheMonitor() cache_monitor = CacheMonitor()
cache_stats = cache_monitor.get_cache_statistics('map_locations') cache_stats = cache_monitor.get_cache_statistics("map_locations")
return Response( return Response(
{ {
@@ -1006,17 +896,17 @@ class MapStatsAPIView(APIView):
"total_locations": total_locations, "total_locations": total_locations,
"parks_with_location": parks_with_location, "parks_with_location": parks_with_location,
"rides_with_location": rides_with_location, "rides_with_location": rides_with_location,
"cache_hits": cache_stats.get('hits', 0), "cache_hits": cache_stats.get("hits", 0),
"cache_misses": cache_stats.get('misses', 0), "cache_misses": cache_stats.get("misses", 0),
"cache_hit_rate": cache_stats.get('hit_rate', 0.0), "cache_hit_rate": cache_stats.get("hit_rate", 0.0),
"cache_size": cache_stats.get('size', 0), "cache_size": cache_stats.get("size", 0),
} }
) )
except Exception as e: except Exception as e:
logger.error(f"Error in MapStatsAPIView: {str(e)}", exc_info=True) logger.error(f"Error in MapStatsAPIView: {str(e)}", exc_info=True)
return Response( return Response(
{"status": "error", "message": "Failed to retrieve map statistics"}, {"status": "error", "detail": "Failed to retrieve map statistics"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -1060,7 +950,7 @@ class MapCacheAPIView(APIView):
return Response( return Response(
{ {
"status": "success", "status": "success",
"message": f"Map cache cleared successfully. Cleared {cleared_count} entries.", "detail": f"Map cache cleared successfully. Cleared {cleared_count} entries.",
"cleared_count": cleared_count, "cleared_count": cleared_count,
} }
) )
@@ -1068,7 +958,7 @@ class MapCacheAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in MapCacheAPIView.delete: {str(e)}", exc_info=True) logger.error(f"Error in MapCacheAPIView.delete: {str(e)}", exc_info=True)
return Response( return Response(
{"status": "error", "message": "Failed to clear map cache"}, {"status": "error", "detail": "Failed to clear map cache"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -1076,7 +966,7 @@ class MapCacheAPIView(APIView):
"""Invalidate specific cache entries.""" """Invalidate specific cache entries."""
try: try:
# Get cache keys to invalidate from request data # Get cache keys to invalidate from request data
request_data = getattr(request, 'data', {}) request_data = getattr(request, "data", {})
cache_keys = request_data.get("cache_keys", []) if request_data else [] cache_keys = request_data.get("cache_keys", []) if request_data else []
if cache_keys: if cache_keys:
@@ -1088,7 +978,7 @@ class MapCacheAPIView(APIView):
return Response( return Response(
{ {
"status": "success", "status": "success",
"message": f"Cache invalidated successfully. Invalidated {invalidated_count} entries.", "detail": f"Cache invalidated successfully. Invalidated {invalidated_count} entries.",
"invalidated_count": invalidated_count, "invalidated_count": invalidated_count,
} }
) )
@@ -1096,7 +986,7 @@ class MapCacheAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in MapCacheAPIView.post: {str(e)}", exc_info=True) logger.error(f"Error in MapCacheAPIView.post: {str(e)}", exc_info=True)
return Response( return Response(
{"status": "error", "message": "Failed to invalidate cache"}, {"status": "error", "detail": "Failed to invalidate cache"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )

View File

@@ -33,7 +33,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
def __init__(self, get_response): def __init__(self, get_response):
super().__init__(get_response) super().__init__(get_response)
self.get_response = get_response self.get_response = get_response
self.enabled = getattr(settings, 'DEBUG', False) self.enabled = getattr(settings, "DEBUG", False)
if self.enabled: if self.enabled:
logger.info("Contract validation middleware enabled (DEBUG mode)") logger.info("Contract validation middleware enabled (DEBUG mode)")
@@ -45,11 +45,11 @@ class ContractValidationMiddleware(MiddlewareMixin):
return response return response
# Only validate API endpoints # Only validate API endpoints
if not request.path.startswith('/api/'): if not request.path.startswith("/api/"):
return response return response
# Only validate JSON responses # Only validate JSON responses
if not isinstance(response, (JsonResponse, Response)): if not isinstance(response, JsonResponse | Response):
return response return response
# Only validate successful responses (2xx status codes) # Only validate successful responses (2xx status codes)
@@ -58,7 +58,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
try: try:
# Get response data # Get response data
data = response.data if isinstance(response, Response) else json.loads(response.content.decode('utf-8')) data = response.data if isinstance(response, Response) else json.loads(response.content.decode("utf-8"))
# Validate the response # Validate the response
self._validate_response_contract(request.path, data) self._validate_response_contract(request.path, data)
@@ -68,11 +68,11 @@ class ContractValidationMiddleware(MiddlewareMixin):
logger.warning( logger.warning(
f"Contract validation error for {request.path}: {str(e)}", f"Contract validation error for {request.path}: {str(e)}",
extra={ extra={
'path': request.path, "path": request.path,
'method': request.method, "method": request.method,
'status_code': response.status_code, "status_code": response.status_code,
'validation_error': str(e) "validation_error": str(e),
} },
) )
return response return response
@@ -81,15 +81,15 @@ class ContractValidationMiddleware(MiddlewareMixin):
"""Validate response data against expected contracts.""" """Validate response data against expected contracts."""
# Check for filter metadata endpoints # Check for filter metadata endpoints
if 'filter-options' in path or 'filter_options' in path: if "filter-options" in path or "filter_options" in path:
self._validate_filter_metadata(path, data) self._validate_filter_metadata(path, data)
# Check for hybrid filtering endpoints # Check for hybrid filtering endpoints
if 'hybrid' in path: if "hybrid" in path:
self._validate_hybrid_response(path, data) self._validate_hybrid_response(path, data)
# Check for pagination responses # Check for pagination responses
if isinstance(data, dict) and 'results' in data: if isinstance(data, dict) and "results" in data:
self._validate_pagination_response(path, data) self._validate_pagination_response(path, data)
# Check for common contract violations # Check for common contract violations
@@ -100,22 +100,20 @@ class ContractValidationMiddleware(MiddlewareMixin):
if not isinstance(data, dict): if not isinstance(data, dict):
self._log_contract_violation( self._log_contract_violation(
path, path, "FILTER_METADATA_NOT_DICT", f"Filter metadata should be a dictionary, got {type(data).__name__}"
"FILTER_METADATA_NOT_DICT",
f"Filter metadata should be a dictionary, got {type(data).__name__}"
) )
return return
# Check for categorical filters # Check for categorical filters
if 'categorical' in data: if "categorical" in data:
categorical = data['categorical'] categorical = data["categorical"]
if isinstance(categorical, dict): if isinstance(categorical, dict):
for filter_name, filter_options in categorical.items(): for filter_name, filter_options in categorical.items():
self._validate_categorical_filter(path, filter_name, filter_options) self._validate_categorical_filter(path, filter_name, filter_options)
# Check for ranges # Check for ranges
if 'ranges' in data: if "ranges" in data:
ranges = data['ranges'] ranges = data["ranges"]
if isinstance(ranges, dict): if isinstance(ranges, dict):
for range_name, range_data in ranges.items(): for range_name, range_data in ranges.items():
self._validate_range_filter(path, range_name, range_data) self._validate_range_filter(path, range_name, range_data)
@@ -127,7 +125,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
self._log_contract_violation( self._log_contract_violation(
path, path,
"CATEGORICAL_FILTER_NOT_ARRAY", "CATEGORICAL_FILTER_NOT_ARRAY",
f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}" f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}",
) )
return return
@@ -138,28 +136,28 @@ class ContractValidationMiddleware(MiddlewareMixin):
path, path,
"CATEGORICAL_OPTION_IS_STRING", "CATEGORICAL_OPTION_IS_STRING",
f"Categorical filter '{filter_name}' option {i} is a string '{option}' but should be an object with value/label/count properties", f"Categorical filter '{filter_name}' option {i} is a string '{option}' but should be an object with value/label/count properties",
severity="ERROR" severity="ERROR",
) )
elif isinstance(option, dict): elif isinstance(option, dict):
# Validate object structure # Validate object structure
if 'value' not in option: if "value" not in option:
self._log_contract_violation( self._log_contract_violation(
path, path,
"MISSING_VALUE_PROPERTY", "MISSING_VALUE_PROPERTY",
f"Categorical filter '{filter_name}' option {i} missing 'value' property" f"Categorical filter '{filter_name}' option {i} missing 'value' property",
) )
if 'label' not in option: if "label" not in option:
self._log_contract_violation( self._log_contract_violation(
path, path,
"MISSING_LABEL_PROPERTY", "MISSING_LABEL_PROPERTY",
f"Categorical filter '{filter_name}' option {i} missing 'label' property" f"Categorical filter '{filter_name}' option {i} missing 'label' property",
) )
# Count is optional but should be number if present # Count is optional but should be number if present
if 'count' in option and option['count'] is not None and not isinstance(option['count'], (int, float)): if "count" in option and option["count"] is not None and not isinstance(option["count"], int | float):
self._log_contract_violation( self._log_contract_violation(
path, path,
"INVALID_COUNT_TYPE", "INVALID_COUNT_TYPE",
f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}" f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}",
) )
def _validate_range_filter(self, path: str, range_name: str, range_data: Any) -> None: def _validate_range_filter(self, path: str, range_name: str, range_data: Any) -> None:
@@ -169,26 +167,24 @@ class ContractValidationMiddleware(MiddlewareMixin):
self._log_contract_violation( self._log_contract_violation(
path, path,
"RANGE_FILTER_NOT_OBJECT", "RANGE_FILTER_NOT_OBJECT",
f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}" f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}",
) )
return return
# Check required properties # Check required properties
required_props = ['min', 'max'] required_props = ["min", "max"]
for prop in required_props: for prop in required_props:
if prop not in range_data: if prop not in range_data:
self._log_contract_violation( self._log_contract_violation(
path, path, "MISSING_RANGE_PROPERTY", f"Range filter '{range_name}' missing required property '{prop}'"
"MISSING_RANGE_PROPERTY",
f"Range filter '{range_name}' missing required property '{prop}'"
) )
# Check step property # Check step property
if 'step' in range_data and not isinstance(range_data['step'], (int, float)): if "step" in range_data and not isinstance(range_data["step"], int | float):
self._log_contract_violation( self._log_contract_violation(
path, path,
"INVALID_STEP_TYPE", "INVALID_STEP_TYPE",
f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}" f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}",
) )
def _validate_hybrid_response(self, path: str, data: Any) -> None: def _validate_hybrid_response(self, path: str, data: Any) -> None:
@@ -198,38 +194,36 @@ class ContractValidationMiddleware(MiddlewareMixin):
return return
# Check for strategy field # Check for strategy field
if 'strategy' in data: if "strategy" in data:
strategy = data['strategy'] strategy = data["strategy"]
if strategy not in ['client_side', 'server_side']: if strategy not in ["client_side", "server_side"]:
self._log_contract_violation( self._log_contract_violation(
path, path,
"INVALID_STRATEGY_VALUE", "INVALID_STRATEGY_VALUE",
f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'" f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'",
) )
# Check filter_metadata structure # Check filter_metadata structure
if 'filter_metadata' in data: if "filter_metadata" in data:
self._validate_filter_metadata(path, data['filter_metadata']) self._validate_filter_metadata(path, data["filter_metadata"])
def _validate_pagination_response(self, path: str, data: dict[str, Any]) -> None: def _validate_pagination_response(self, path: str, data: dict[str, Any]) -> None:
"""Validate pagination response structure.""" """Validate pagination response structure."""
# Check for required pagination fields # Check for required pagination fields
required_fields = ['count', 'results'] required_fields = ["count", "results"]
for field in required_fields: for field in required_fields:
if field not in data: if field not in data:
self._log_contract_violation( self._log_contract_violation(
path, path, "MISSING_PAGINATION_FIELD", f"Pagination response missing required field '{field}'"
"MISSING_PAGINATION_FIELD",
f"Pagination response missing required field '{field}'"
) )
# Check results is array # Check results is array
if 'results' in data and not isinstance(data['results'], list): if "results" in data and not isinstance(data["results"], list):
self._log_contract_violation( self._log_contract_violation(
path, path,
"RESULTS_NOT_ARRAY", "RESULTS_NOT_ARRAY",
f"Pagination 'results' should be an array, got {type(data['results']).__name__}" f"Pagination 'results' should be an array, got {type(data['results']).__name__}",
) )
def _validate_common_patterns(self, path: str, data: Any) -> None: def _validate_common_patterns(self, path: str, data: Any) -> None:
@@ -238,38 +232,32 @@ class ContractValidationMiddleware(MiddlewareMixin):
if isinstance(data, dict): if isinstance(data, dict):
# Check for null vs undefined issues # Check for null vs undefined issues
for key, value in data.items(): for key, value in data.items():
if value is None and key.endswith('_id'): if value is None and key.endswith("_id"):
# ID fields should probably be null, not undefined # ID fields should probably be null, not undefined
continue continue
# Check for numeric fields that might be strings # Check for numeric fields that might be strings
if key.endswith('_count') and isinstance(value, str): if key.endswith("_count") and isinstance(value, str):
try: try:
int(value) int(value)
self._log_contract_violation( self._log_contract_violation(
path, path,
"NUMERIC_FIELD_AS_STRING", "NUMERIC_FIELD_AS_STRING",
f"Field '{key}' appears to be numeric but is a string: '{value}'" f"Field '{key}' appears to be numeric but is a string: '{value}'",
) )
except ValueError: except ValueError:
pass pass
def _log_contract_violation( def _log_contract_violation(self, path: str, violation_type: str, message: str, severity: str = "WARNING") -> None:
self,
path: str,
violation_type: str,
message: str,
severity: str = "WARNING"
) -> None:
"""Log a contract violation with structured data.""" """Log a contract violation with structured data."""
log_data = { log_data = {
'contract_violation': True, "contract_violation": True,
'violation_type': violation_type, "violation_type": violation_type,
'api_path': path, "api_path": path,
'severity': severity, "severity": severity,
'message': message, "message": message,
'suggestion': self._get_violation_suggestion(violation_type) "suggestion": self._get_violation_suggestion(violation_type),
} }
if severity == "ERROR": if severity == "ERROR":
@@ -302,9 +290,8 @@ class ContractValidationMiddleware(MiddlewareMixin):
"Check serializer field types and database field types." "Check serializer field types and database field types."
), ),
"RESULTS_NOT_ARRAY": ( "RESULTS_NOT_ARRAY": (
"Ensure pagination 'results' field is always an array. " "Ensure pagination 'results' field is always an array. " "Check serializer implementation."
"Check serializer implementation." ),
)
} }
return suggestions.get(violation_type, "Check the API response format against frontend TypeScript interfaces.") return suggestions.get(violation_type, "Check the API response format against frontend TypeScript interfaces.")
@@ -326,9 +313,9 @@ class ContractValidationSettings:
# Paths to exclude from validation # Paths to exclude from validation
EXCLUDED_PATHS = [ EXCLUDED_PATHS = [
'/api/docs/', "/api/docs/",
'/api/schema/', "/api/schema/",
'/api/v1/auth/', # Auth endpoints might have different structures "/api/v1/auth/", # Auth endpoints might have different structures
] ]
@classmethod @classmethod

View File

@@ -17,6 +17,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
""" """
ViewSet for retrieving park history. ViewSet for retrieving park history.
""" """
permission_classes = [AllowAny] permission_classes = [AllowAny]
lookup_field = "slug" lookup_field = "slug"
lookup_url_kwarg = "park_slug" lookup_url_kwarg = "park_slug"
@@ -40,12 +41,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
"last_modified": events.first().pgh_created_at if len(events) else None, "last_modified": events.first().pgh_created_at if len(events) else None,
} }
data = { data = {"park": park, "current_state": park, "summary": summary, "events": events}
"park": park,
"current_state": park,
"summary": summary,
"events": events
}
serializer = ParkHistoryOutputSerializer(data) serializer = ParkHistoryOutputSerializer(data)
return Response(serializer.data) return Response(serializer.data)
@@ -55,6 +51,7 @@ class RideHistoryViewSet(viewsets.GenericViewSet):
""" """
ViewSet for retrieving ride history. ViewSet for retrieving ride history.
""" """
permission_classes = [AllowAny] permission_classes = [AllowAny]
lookup_field = "slug" lookup_field = "slug"
lookup_url_kwarg = "ride_slug" lookup_url_kwarg = "ride_slug"
@@ -79,12 +76,7 @@ class RideHistoryViewSet(viewsets.GenericViewSet):
"last_modified": events.first().pgh_created_at if len(events) else None, "last_modified": events.first().pgh_created_at if len(events) else None,
} }
data = { data = {"ride": ride, "current_state": ride, "summary": summary, "events": events}
"ride": ride,
"current_state": ride,
"summary": summary,
"events": events
}
serializer = RideHistoryOutputSerializer(data) serializer = RideHistoryOutputSerializer(data)
return Response(serializer.data) return Response(serializer.data)

View File

@@ -65,14 +65,12 @@ class ParkReviewViewSet(ModelViewSet):
def get_permissions(self): def get_permissions(self):
"""Set permissions based on action.""" """Set permissions based on action."""
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated] permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
return [permission() for permission in permission_classes] return [permission() for permission in permission_classes]
def get_queryset(self): def get_queryset(self):
"""Get reviews for the current park.""" """Get reviews for the current park."""
queryset = ParkReview.objects.select_related( queryset = ParkReview.objects.select_related("park", "user", "user__profile")
"park", "user", "user__profile"
)
park_slug = self.kwargs.get("park_slug") park_slug = self.kwargs.get("park_slug")
if park_slug: if park_slug:
@@ -82,7 +80,7 @@ class ParkReviewViewSet(ModelViewSet):
except Park.DoesNotExist: except Park.DoesNotExist:
return queryset.none() return queryset.none()
if not (hasattr(self.request, 'user') and getattr(self.request.user, 'is_staff', False)): if not (hasattr(self.request, "user") and getattr(self.request.user, "is_staff", False)):
queryset = queryset.filter(is_published=True) queryset = queryset.filter(is_published=True)
return queryset.order_by("-created_at") return queryset.order_by("-created_at")
@@ -102,16 +100,12 @@ class ParkReviewViewSet(ModelViewSet):
try: try:
park, _ = Park.get_by_slug(park_slug) park, _ = Park.get_by_slug(park_slug)
except Park.DoesNotExist: except Park.DoesNotExist:
raise NotFound("Park not found") raise NotFound("Park not found") from None
if ParkReview.objects.filter(park=park, user=self.request.user).exists(): if ParkReview.objects.filter(park=park, user=self.request.user).exists():
raise ValidationError("You have already reviewed this park") raise ValidationError("You have already reviewed this park")
serializer.save( serializer.save(park=park, user=self.request.user, is_published=True)
park=park,
user=self.request.user,
is_published=True
)
def perform_update(self, serializer): def perform_update(self, serializer):
instance = self.get_object() instance = self.get_object()
@@ -134,17 +128,18 @@ class ParkReviewViewSet(ModelViewSet):
try: try:
park, _ = Park.get_by_slug(park_slug) park, _ = Park.get_by_slug(park_slug)
except Park.DoesNotExist: except Park.DoesNotExist:
return Response({"error": "Park not found"}, status=status.HTTP_404_NOT_FOUND) return Response({"detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND)
reviews = ParkReview.objects.filter(park=park, is_published=True) reviews = ParkReview.objects.filter(park=park, is_published=True)
total_reviews = reviews.count() total_reviews = reviews.count()
avg_rating = reviews.aggregate(avg=Avg('rating'))['avg'] avg_rating = reviews.aggregate(avg=Avg("rating"))["avg"]
rating_distribution = {} rating_distribution = {}
for i in range(1, 11): for i in range(1, 11):
rating_distribution[str(i)] = reviews.filter(rating=i).count() rating_distribution[str(i)] = reviews.filter(rating=i).count()
from datetime import timedelta from datetime import timedelta
recent_reviews = reviews.filter(created_at__gte=timezone.now() - timedelta(days=30)).count() recent_reviews = reviews.filter(created_at__gte=timezone.now() - timedelta(days=30)).count()
stats = { stats = {

View File

@@ -21,6 +21,7 @@ from rest_framework.views import APIView
try: try:
from apps.parks.models import Park from apps.parks.models import Park
from apps.rides.models import Ride from apps.rides.models import Ride
MODELS_AVAILABLE = True MODELS_AVAILABLE = True
except Exception: except Exception:
Park = None # type: ignore Park = None # type: ignore
@@ -31,6 +32,7 @@ except Exception:
try: try:
from apps.api.v1.serializers.parks import ParkDetailOutputSerializer from apps.api.v1.serializers.parks import ParkDetailOutputSerializer
from apps.api.v1.serializers.rides import RideDetailOutputSerializer, RideListOutputSerializer from apps.api.v1.serializers.rides import RideDetailOutputSerializer, RideListOutputSerializer
SERIALIZERS_AVAILABLE = True SERIALIZERS_AVAILABLE = True
except Exception: except Exception:
SERIALIZERS_AVAILABLE = False SERIALIZERS_AVAILABLE = False
@@ -52,22 +54,41 @@ class ParkRidesListAPIView(APIView):
description="Get paginated list of rides at a specific park with filtering options", description="Get paginated list of rides at a specific park with filtering options",
parameters=[ parameters=[
# Pagination # Pagination
OpenApiParameter(name="page", location=OpenApiParameter.QUERY, OpenApiParameter(
type=OpenApiTypes.INT, description="Page number"), name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, description="Page number"
OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, ),
type=OpenApiTypes.INT, description="Number of results per page (max 100)"), OpenApiParameter(
name="page_size",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT,
description="Number of results per page (max 100)",
),
# Filtering # Filtering
OpenApiParameter(name="category", location=OpenApiParameter.QUERY, OpenApiParameter(
type=OpenApiTypes.STR, description="Filter by ride category"), name="category",
OpenApiParameter(name="status", location=OpenApiParameter.QUERY, location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Filter by operational status"), type=OpenApiTypes.STR,
OpenApiParameter(name="search", location=OpenApiParameter.QUERY, description="Filter by ride category",
type=OpenApiTypes.STR, description="Search rides by name"), ),
OpenApiParameter(
name="status",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="Filter by operational status",
),
OpenApiParameter(
name="search",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="Search rides by name",
),
# Ordering # Ordering
OpenApiParameter(name="ordering", location=OpenApiParameter.QUERY, OpenApiParameter(
type=OpenApiTypes.STR, description="Order results by field"), name="ordering",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="Order results by field",
),
], ],
responses={ responses={
200: OpenApiTypes.OBJECT, 200: OpenApiTypes.OBJECT,
@@ -87,12 +108,14 @@ class ParkRidesListAPIView(APIView):
try: try:
park, is_historical = Park.get_by_slug(park_slug) park, is_historical = Park.get_by_slug(park_slug)
except Park.DoesNotExist: except Park.DoesNotExist:
raise NotFound("Park not found") raise NotFound("Park not found") from None
# Get rides for this park # Get rides for this park
qs = Ride.objects.filter(park=park).select_related( qs = (
"manufacturer", "designer", "ride_model", "park_area" Ride.objects.filter(park=park)
).prefetch_related("photos") .select_related("manufacturer", "designer", "ride_model", "park_area")
.prefetch_related("photos")
)
# Apply filtering # Apply filtering
qs = self._apply_filters(qs, request.query_params) qs = self._apply_filters(qs, request.query_params)
@@ -107,9 +130,7 @@ class ParkRidesListAPIView(APIView):
page = paginator.paginate_queryset(qs, request) page = paginator.paginate_queryset(qs, request)
if SERIALIZERS_AVAILABLE: if SERIALIZERS_AVAILABLE:
serializer = RideListOutputSerializer( serializer = RideListOutputSerializer(page, many=True, context={"request": request, "park": park})
page, many=True, context={"request": request, "park": park}
)
return paginator.get_paginated_response(serializer.data) return paginator.get_paginated_response(serializer.data)
else: else:
# Fallback serialization # Fallback serialization
@@ -145,9 +166,7 @@ class ParkRidesListAPIView(APIView):
search = params.get("search") search = params.get("search")
if search: if search:
qs = qs.filter( qs = qs.filter(
Q(name__icontains=search) | Q(name__icontains=search) | Q(description__icontains=search) | Q(manufacturer__name__icontains=search)
Q(description__icontains=search) |
Q(manufacturer__name__icontains=search)
) )
return qs return qs
@@ -179,42 +198,46 @@ class ParkRideDetailAPIView(APIView):
try: try:
park, is_historical = Park.get_by_slug(park_slug) park, is_historical = Park.get_by_slug(park_slug)
except Park.DoesNotExist: except Park.DoesNotExist:
raise NotFound("Park not found") raise NotFound("Park not found") from None
# Get the ride # Get the ride
try: try:
ride, is_historical = Ride.get_by_slug(ride_slug, park=park) ride, is_historical = Ride.get_by_slug(ride_slug, park=park)
except Ride.DoesNotExist: except Ride.DoesNotExist:
raise NotFound("Ride not found at this park") raise NotFound("Ride not found at this park") from None
# Ensure ride belongs to this park # Ensure ride belongs to this park
if ride.park_id != park.id: if ride.park_id != park.id:
raise NotFound("Ride not found at this park") raise NotFound("Ride not found at this park")
if SERIALIZERS_AVAILABLE: if SERIALIZERS_AVAILABLE:
serializer = RideDetailOutputSerializer( serializer = RideDetailOutputSerializer(ride, context={"request": request, "park": park})
ride, context={"request": request, "park": park}
)
return Response(serializer.data) return Response(serializer.data)
else: else:
# Fallback serialization # Fallback serialization
return Response({ return Response(
"id": ride.id, {
"name": ride.name, "id": ride.id,
"slug": ride.slug, "name": ride.name,
"description": getattr(ride, "description", ""), "slug": ride.slug,
"category": getattr(ride, "category", ""), "description": getattr(ride, "description", ""),
"status": getattr(ride, "status", ""), "category": getattr(ride, "category", ""),
"park": { "status": getattr(ride, "status", ""),
"id": park.id, "park": {
"name": park.name, "id": park.id,
"slug": park.slug, "name": park.name,
}, "slug": park.slug,
"manufacturer": { },
"name": ride.manufacturer.name if ride.manufacturer else "", "manufacturer": (
"slug": getattr(ride.manufacturer, "slug", "") if ride.manufacturer else "", {
} if ride.manufacturer else None, "name": ride.manufacturer.name if ride.manufacturer else "",
}) "slug": getattr(ride.manufacturer, "slug", "") if ride.manufacturer else "",
}
if ride.manufacturer
else None
),
}
)
class ParkComprehensiveDetailAPIView(APIView): class ParkComprehensiveDetailAPIView(APIView):
@@ -243,25 +266,21 @@ class ParkComprehensiveDetailAPIView(APIView):
try: try:
park, is_historical = Park.get_by_slug(park_slug) park, is_historical = Park.get_by_slug(park_slug)
except Park.DoesNotExist: except Park.DoesNotExist:
raise NotFound("Park not found") raise NotFound("Park not found") from None
# Get park with full related data # Get park with full related data
park = Park.objects.select_related( park = (
"operator", "property_owner", "location" Park.objects.select_related("operator", "property_owner", "location")
).prefetch_related( .prefetch_related("areas", "rides", "photos")
"areas", "rides", "photos" .get(pk=park.pk)
).get(pk=park.pk) )
# Get a sample of rides (first 10) for preview # Get a sample of rides (first 10) for preview
rides_sample = Ride.objects.filter(park=park).select_related( rides_sample = Ride.objects.filter(park=park).select_related("manufacturer", "designer", "ride_model")[:10]
"manufacturer", "designer", "ride_model"
)[:10]
if SERIALIZERS_AVAILABLE: if SERIALIZERS_AVAILABLE:
# Get full park details # Get full park details
park_serializer = ParkDetailOutputSerializer( park_serializer = ParkDetailOutputSerializer(park, context={"request": request})
park, context={"request": request}
)
park_data = park_serializer.data park_data = park_serializer.data
# Add rides summary # Add rides summary
@@ -279,25 +298,27 @@ class ParkComprehensiveDetailAPIView(APIView):
return Response(park_data) return Response(park_data)
else: else:
# Fallback serialization # Fallback serialization
return Response({ return Response(
"id": park.id, {
"name": park.name, "id": park.id,
"slug": park.slug, "name": park.name,
"description": getattr(park, "description", ""), "slug": park.slug,
"location": str(getattr(park, "location", "")), "description": getattr(park, "description", ""),
"operator": getattr(park.operator, "name", "") if hasattr(park, "operator") else "", "location": str(getattr(park, "location", "")),
"ride_count": getattr(park, "ride_count", 0), "operator": getattr(park.operator, "name", "") if hasattr(park, "operator") else "",
"rides_summary": { "ride_count": getattr(park, "ride_count", 0),
"total_count": getattr(park, "ride_count", 0), "rides_summary": {
"sample": [ "total_count": getattr(park, "ride_count", 0),
{ "sample": [
"id": ride.id, {
"name": ride.name, "id": ride.id,
"slug": ride.slug, "name": ride.name,
"category": getattr(ride, "category", ""), "slug": ride.slug,
} "category": getattr(ride, "category", ""),
for ride in rides_sample }
], for ride in rides_sample
"full_list_url": f"/api/v1/parks/{park_slug}/rides/", ],
}, "full_list_url": f"/api/v1/parks/{park_slug}/rides/",
}) },
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -116,14 +116,12 @@ class RidePhotoViewSet(ModelViewSet):
def get_permissions(self): def get_permissions(self):
"""Set permissions based on action.""" """Set permissions based on action."""
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated] permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
return [permission() for permission in permission_classes] return [permission() for permission in permission_classes]
def get_queryset(self): def get_queryset(self):
"""Get photos for the current ride with optimized queries.""" """Get photos for the current ride with optimized queries."""
queryset = RidePhoto.objects.select_related( queryset = RidePhoto.objects.select_related("ride", "ride__park", "ride__park__operator", "uploaded_by")
"ride", "ride__park", "ride__park__operator", "uploaded_by"
)
# Filter by park and ride from URL kwargs # Filter by park and ride from URL kwargs
park_slug = self.kwargs.get("park_slug") park_slug = self.kwargs.get("park_slug")
@@ -163,9 +161,9 @@ class RidePhotoViewSet(ModelViewSet):
park, _ = Park.get_by_slug(park_slug) park, _ = Park.get_by_slug(park_slug)
ride, _ = Ride.get_by_slug(ride_slug, park=park) ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist: except Park.DoesNotExist:
raise NotFound("Park not found") raise NotFound("Park not found") from None
except Ride.DoesNotExist: except Ride.DoesNotExist:
raise NotFound("Ride not found at this park") raise NotFound("Ride not found at this park") from None
try: try:
# Use the service to create the photo with proper business logic # Use the service to create the photo with proper business logic
@@ -187,17 +185,14 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error creating ride photo: {e}") logger.error(f"Error creating ride photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}") raise ValidationError(f"Failed to create photo: {str(e)}") from None
def perform_update(self, serializer): def perform_update(self, serializer):
"""Update ride photo with permission checking.""" """Update ride photo with permission checking."""
instance = self.get_object() instance = self.get_object()
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
self.request.user == instance.uploaded_by
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied("You can only edit your own photos or be an admin.") raise PermissionDenied("You can only edit your own photos or be an admin.")
# Handle primary photo logic using service # Handle primary photo logic using service
@@ -209,48 +204,40 @@ class RidePhotoViewSet(ModelViewSet):
del serializer.validated_data["is_primary"] del serializer.validated_data["is_primary"]
except Exception as e: except Exception as e:
logger.error(f"Error setting primary photo: {e}") logger.error(f"Error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}") raise ValidationError(f"Failed to set primary photo: {str(e)}") from None
try: try:
serializer.save() serializer.save()
logger.info(f"Updated ride photo {instance.id} by user {self.request.user.username}") logger.info(f"Updated ride photo {instance.id} by user {self.request.user.username}")
except Exception as e: except Exception as e:
logger.error(f"Error updating ride photo: {e}") logger.error(f"Error updating ride photo: {e}")
raise ValidationError(f"Failed to update photo: {str(e)}") raise ValidationError(f"Failed to update photo: {str(e)}") from None
def perform_destroy(self, instance): def perform_destroy(self, instance):
"""Delete ride photo with permission checking.""" """Delete ride photo with permission checking."""
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
self.request.user == instance.uploaded_by raise PermissionDenied("You can only delete your own photos or be an admin.")
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only delete your own photos or be an admin."
)
try: try:
# Delete from Cloudflare first if image exists # Delete from Cloudflare first if image exists
if instance.image: if instance.image:
try: try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService() service = CloudflareImagesService()
service.delete_image(instance.image) service.delete_image(instance.image)
logger.info( logger.info(f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to delete ride photo from Cloudflare: {str(e)}")
f"Failed to delete ride photo from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails # Continue with database deletion even if Cloudflare deletion fails
RideMediaService.delete_photo( RideMediaService.delete_photo(instance, deleted_by=self.request.user)
instance, deleted_by=self.request.user
)
logger.info(f"Deleted ride photo {instance.id} by user {self.request.user.username}") logger.info(f"Deleted ride photo {instance.id} by user {self.request.user.username}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting ride photo: {e}") logger.error(f"Error deleting ride photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}") raise ValidationError(f"Failed to delete photo: {str(e)}") from None
@extend_schema( @extend_schema(
summary="Set photo as primary", summary="Set photo as primary",
@@ -269,13 +256,8 @@ class RidePhotoViewSet(ModelViewSet):
photo = self.get_object() photo = self.get_object()
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (request.user == photo.uploaded_by or getattr(request.user, "is_staff", False)):
request.user == photo.uploaded_by raise PermissionDenied("You can only modify your own photos or be an admin.")
or getattr(request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only modify your own photos or be an admin."
)
try: try:
success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo) success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo)
@@ -287,21 +269,21 @@ class RidePhotoViewSet(ModelViewSet):
return Response( return Response(
{ {
"message": "Photo set as primary successfully", "detail": "Photo set as primary successfully",
"photo": serializer.data, "photo": serializer.data,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
) )
else: else:
return Response( return Response(
{"error": "Failed to set primary photo"}, {"detail": "Failed to set primary photo"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
except Exception as e: except Exception as e:
logger.error(f"Error setting primary photo: {e}") logger.error(f"Error setting primary photo: {e}")
return Response( return Response(
{"error": f"Failed to set primary photo: {str(e)}"}, {"detail": f"Failed to set primary photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -334,7 +316,7 @@ class RidePhotoViewSet(ModelViewSet):
if photo_ids is None or approve is None: if photo_ids is None or approve is None:
return Response( return Response(
{"error": "Missing required fields: photo_ids and/or approve."}, {"detail": "Missing required fields: photo_ids and/or approve."},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -350,7 +332,7 @@ class RidePhotoViewSet(ModelViewSet):
return Response( return Response(
{ {
"message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", "detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"updated_count": updated_count, "updated_count": updated_count,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -359,7 +341,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error in bulk photo approval: {e}") logger.error(f"Error in bulk photo approval: {e}")
return Response( return Response(
{"error": f"Failed to update photos: {str(e)}"}, {"detail": f"Failed to update photos: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -381,7 +363,7 @@ class RidePhotoViewSet(ModelViewSet):
if not park_slug or not ride_slug: if not park_slug or not ride_slug:
return Response( return Response(
{"error": "Park and ride slugs are required"}, {"detail": "Park and ride slugs are required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -390,12 +372,12 @@ class RidePhotoViewSet(ModelViewSet):
ride, _ = Ride.get_by_slug(ride_slug, park=park) ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist: except Park.DoesNotExist:
return Response( return Response(
{"error": "Park not found"}, {"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
except Ride.DoesNotExist: except Ride.DoesNotExist:
return Response( return Response(
{"error": "Ride not found at this park"}, {"detail": "Ride not found at this park"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
@@ -407,7 +389,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error getting ride photo stats: {e}") logger.error(f"Error getting ride photo stats: {e}")
return Response( return Response(
{"error": f"Failed to get photo statistics: {str(e)}"}, {"detail": f"Failed to get photo statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -431,7 +413,7 @@ class RidePhotoViewSet(ModelViewSet):
if not park_slug or not ride_slug: if not park_slug or not ride_slug:
return Response( return Response(
{"error": "Park and ride slugs are required"}, {"detail": "Park and ride slugs are required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -440,19 +422,19 @@ class RidePhotoViewSet(ModelViewSet):
ride, _ = Ride.get_by_slug(ride_slug, park=park) ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist: except Park.DoesNotExist:
return Response( return Response(
{"error": "Park not found"}, {"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
except Ride.DoesNotExist: except Ride.DoesNotExist:
return Response( return Response(
{"error": "Ride not found at this park"}, {"detail": "Ride not found at this park"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
cloudflare_image_id = request.data.get("cloudflare_image_id") cloudflare_image_id = request.data.get("cloudflare_image_id")
if not cloudflare_image_id: if not cloudflare_image_id:
return Response( return Response(
{"error": "cloudflare_image_id is required"}, {"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -469,27 +451,25 @@ class RidePhotoViewSet(ModelViewSet):
if not image_data: if not image_data:
return Response( return Response(
{"error": "Image not found in Cloudflare"}, {"detail": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
# Try to find existing CloudflareImage record by cloudflare_id # Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None cloudflare_image = None
try: try:
cloudflare_image = CloudflareImage.objects.get( cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare # Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded' cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now() cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {}) cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure # Extract variants from nested result structure
cloudflare_image.variants = image_data.get( cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
'result', {}).get('variants', [])
cloudflare_image.cloudflare_metadata = image_data cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width') cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get('height') cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get('format', '') cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save() cloudflare_image.save()
except CloudflareImage.DoesNotExist: except CloudflareImage.DoesNotExist:
@@ -497,24 +477,23 @@ class RidePhotoViewSet(ModelViewSet):
cloudflare_image = CloudflareImage.objects.create( cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id, cloudflare_id=cloudflare_image_id,
user=request.user, user=request.user,
status='uploaded', status="uploaded",
upload_url='', # Not needed for uploaded images upload_url="", # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(), uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}), metadata=image_data.get("meta", {}),
# Extract variants from nested result structure # Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []), variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data, cloudflare_metadata=image_data,
width=image_data.get('width'), width=image_data.get("width"),
height=image_data.get('height'), height=image_data.get("height"),
format=image_data.get('format', ''), format=image_data.get("format", ""),
) )
except Exception as api_error: except Exception as api_error:
logger.error( logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response( return Response(
{"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, {"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -544,6 +523,6 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error saving ride photo: {e}") logger.error(f"Error saving ride photo: {e}")
return Response( return Response(
{"error": f"Failed to save photo: {str(e)}"}, {"detail": f"Failed to save photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )

View File

@@ -115,14 +115,12 @@ class RideReviewViewSet(ModelViewSet):
def get_permissions(self): def get_permissions(self):
"""Set permissions based on action.""" """Set permissions based on action."""
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated] permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
return [permission() for permission in permission_classes] return [permission() for permission in permission_classes]
def get_queryset(self): def get_queryset(self):
"""Get reviews for the current ride with optimized queries.""" """Get reviews for the current ride with optimized queries."""
queryset = RideReview.objects.select_related( queryset = RideReview.objects.select_related("ride", "ride__park", "user", "user__profile")
"ride", "ride__park", "user", "user__profile"
)
# Filter by park and ride from URL kwargs # Filter by park and ride from URL kwargs
park_slug = self.kwargs.get("park_slug") park_slug = self.kwargs.get("park_slug")
@@ -138,8 +136,7 @@ class RideReviewViewSet(ModelViewSet):
return queryset.none() return queryset.none()
# Filter published reviews for non-staff users # Filter published reviews for non-staff users
if not (hasattr(self.request, 'user') and if not (hasattr(self.request, "user") and getattr(self.request.user, "is_staff", False)):
getattr(self.request.user, 'is_staff', False)):
queryset = queryset.filter(is_published=True) queryset = queryset.filter(is_published=True)
return queryset.order_by("-created_at") return queryset.order_by("-created_at")
@@ -167,9 +164,9 @@ class RideReviewViewSet(ModelViewSet):
park, _ = Park.get_by_slug(park_slug) park, _ = Park.get_by_slug(park_slug)
ride, _ = Ride.get_by_slug(ride_slug, park=park) ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist: except Park.DoesNotExist:
raise NotFound("Park not found") raise NotFound("Park not found") from None
except Ride.DoesNotExist: except Ride.DoesNotExist:
raise NotFound("Ride not found at this park") raise NotFound("Ride not found at this park") from None
# Check if user already has a review for this ride # Check if user already has a review for this ride
if RideReview.objects.filter(ride=ride, user=self.request.user).exists(): if RideReview.objects.filter(ride=ride, user=self.request.user).exists():
@@ -178,26 +175,21 @@ class RideReviewViewSet(ModelViewSet):
try: try:
# Save the review # Save the review
review = serializer.save( review = serializer.save(
ride=ride, ride=ride, user=self.request.user, is_published=True # Auto-publish for now, can add moderation later
user=self.request.user,
is_published=True # Auto-publish for now, can add moderation later
) )
logger.info(f"Created ride review {review.id} for ride {ride.name} by user {self.request.user.username}") logger.info(f"Created ride review {review.id} for ride {ride.name} by user {self.request.user.username}")
except Exception as e: except Exception as e:
logger.error(f"Error creating ride review: {e}") logger.error(f"Error creating ride review: {e}")
raise ValidationError(f"Failed to create review: {str(e)}") raise ValidationError(f"Failed to create review: {str(e)}") from None
def perform_update(self, serializer): def perform_update(self, serializer):
"""Update ride review with permission checking.""" """Update ride review with permission checking."""
instance = self.get_object() instance = self.get_object()
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.user or getattr(self.request.user, "is_staff", False)):
self.request.user == instance.user
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied("You can only edit your own reviews or be an admin.") raise PermissionDenied("You can only edit your own reviews or be an admin.")
try: try:
@@ -205,15 +197,12 @@ class RideReviewViewSet(ModelViewSet):
logger.info(f"Updated ride review {instance.id} by user {self.request.user.username}") logger.info(f"Updated ride review {instance.id} by user {self.request.user.username}")
except Exception as e: except Exception as e:
logger.error(f"Error updating ride review: {e}") logger.error(f"Error updating ride review: {e}")
raise ValidationError(f"Failed to update review: {str(e)}") raise ValidationError(f"Failed to update review: {str(e)}") from None
def perform_destroy(self, instance): def perform_destroy(self, instance):
"""Delete ride review with permission checking.""" """Delete ride review with permission checking."""
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.user or getattr(self.request.user, "is_staff", False)):
self.request.user == instance.user
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied("You can only delete your own reviews or be an admin.") raise PermissionDenied("You can only delete your own reviews or be an admin.")
try: try:
@@ -221,7 +210,7 @@ class RideReviewViewSet(ModelViewSet):
instance.delete() instance.delete()
except Exception as e: except Exception as e:
logger.error(f"Error deleting ride review: {e}") logger.error(f"Error deleting ride review: {e}")
raise ValidationError(f"Failed to delete review: {str(e)}") raise ValidationError(f"Failed to delete review: {str(e)}") from None
@extend_schema( @extend_schema(
summary="Get ride review statistics", summary="Get ride review statistics",
@@ -241,7 +230,7 @@ class RideReviewViewSet(ModelViewSet):
if not park_slug or not ride_slug: if not park_slug or not ride_slug:
return Response( return Response(
{"error": "Park and ride slugs are required"}, {"detail": "Park and ride slugs are required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -250,12 +239,12 @@ class RideReviewViewSet(ModelViewSet):
ride, _ = Ride.get_by_slug(ride_slug, park=park) ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist: except Park.DoesNotExist:
return Response( return Response(
{"error": "Park not found"}, {"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
except Ride.DoesNotExist: except Ride.DoesNotExist:
return Response( return Response(
{"error": "Ride not found at this park"}, {"detail": "Ride not found at this park"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
@@ -268,7 +257,7 @@ class RideReviewViewSet(ModelViewSet):
pending_reviews = RideReview.objects.filter(ride=ride, is_published=False).count() pending_reviews = RideReview.objects.filter(ride=ride, is_published=False).count()
# Calculate average rating # Calculate average rating
avg_rating = reviews.aggregate(avg_rating=Avg('rating'))['avg_rating'] avg_rating = reviews.aggregate(avg_rating=Avg("rating"))["avg_rating"]
# Get rating distribution # Get rating distribution
rating_distribution = {} rating_distribution = {}
@@ -277,6 +266,7 @@ class RideReviewViewSet(ModelViewSet):
# Get recent reviews count (last 30 days) # Get recent reviews count (last 30 days)
from datetime import timedelta from datetime import timedelta
thirty_days_ago = timezone.now() - timedelta(days=30) thirty_days_ago = timezone.now() - timedelta(days=30)
recent_reviews = reviews.filter(created_at__gte=thirty_days_ago).count() recent_reviews = reviews.filter(created_at__gte=thirty_days_ago).count()
@@ -295,7 +285,7 @@ class RideReviewViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error getting ride review stats: {e}") logger.error(f"Error getting ride review stats: {e}")
return Response( return Response(
{"error": f"Failed to get review statistics: {str(e)}"}, {"detail": f"Failed to get review statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -340,7 +330,7 @@ class RideReviewViewSet(ModelViewSet):
is_published=True, is_published=True,
moderated_by=request.user, moderated_by=request.user,
moderated_at=timezone.now(), moderated_at=timezone.now(),
moderation_notes=moderation_notes moderation_notes=moderation_notes,
) )
message = f"Successfully published {updated_count} reviews" message = f"Successfully published {updated_count} reviews"
elif action_type == "unpublish": elif action_type == "unpublish":
@@ -348,7 +338,7 @@ class RideReviewViewSet(ModelViewSet):
is_published=False, is_published=False,
moderated_by=request.user, moderated_by=request.user,
moderated_at=timezone.now(), moderated_at=timezone.now(),
moderation_notes=moderation_notes moderation_notes=moderation_notes,
) )
message = f"Successfully unpublished {updated_count} reviews" message = f"Successfully unpublished {updated_count} reviews"
elif action_type == "delete": elif action_type == "delete":
@@ -357,13 +347,13 @@ class RideReviewViewSet(ModelViewSet):
message = f"Successfully deleted {updated_count} reviews" message = f"Successfully deleted {updated_count} reviews"
else: else:
return Response( return Response(
{"error": "Invalid action type"}, {"detail": "Invalid action type"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
return Response( return Response(
{ {
"message": message, "detail": message,
"updated_count": updated_count, "updated_count": updated_count,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -372,6 +362,6 @@ class RideReviewViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error in bulk review moderation: {e}") logger.error(f"Error in bulk review moderation: {e}")
return Response( return Response(
{"error": f"Failed to moderate reviews: {str(e)}"}, {"detail": f"Failed to moderate reviews: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )

View File

@@ -50,18 +50,14 @@ from apps.parks.models import Park, ParkPhoto
class ParkPhotoOutputSerializer(serializers.ModelSerializer): class ParkPhotoOutputSerializer(serializers.ModelSerializer):
"""Enhanced output serializer for park photos with Cloudflare Images support.""" """Enhanced output serializer for park photos with Cloudflare Images support."""
uploaded_by_username = serializers.CharField( uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
source="uploaded_by.username", read_only=True
)
file_size = serializers.SerializerMethodField() file_size = serializers.SerializerMethodField()
dimensions = serializers.SerializerMethodField() dimensions = serializers.SerializerMethodField()
image_url = serializers.SerializerMethodField() image_url = serializers.SerializerMethodField()
image_variants = serializers.SerializerMethodField() image_variants = serializers.SerializerMethodField()
@extend_schema_field( @extend_schema_field(serializers.IntegerField(allow_null=True, help_text="File size in bytes"))
serializers.IntegerField(allow_null=True, help_text="File size in bytes")
)
def get_file_size(self, obj): def get_file_size(self, obj):
"""Get file size in bytes.""" """Get file size in bytes."""
return obj.file_size return obj.file_size
@@ -79,11 +75,7 @@ class ParkPhotoOutputSerializer(serializers.ModelSerializer):
"""Get image dimensions as [width, height].""" """Get image dimensions as [width, height]."""
return obj.dimensions return obj.dimensions
@extend_schema_field( @extend_schema_field(serializers.URLField(help_text="Full URL to the Cloudflare Images asset", allow_null=True))
serializers.URLField(
help_text="Full URL to the Cloudflare Images asset", allow_null=True
)
)
def get_image_url(self, obj): def get_image_url(self, obj):
"""Get the full Cloudflare Images URL.""" """Get the full Cloudflare Images URL."""
if obj.image: if obj.image:
@@ -175,9 +167,7 @@ class ParkPhotoUpdateInputSerializer(serializers.ModelSerializer):
class ParkPhotoListOutputSerializer(serializers.ModelSerializer): class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
"""Optimized output serializer for park photo lists.""" """Optimized output serializer for park photo lists."""
uploaded_by_username = serializers.CharField( uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
source="uploaded_by.username", read_only=True
)
class Meta: class Meta:
model = ParkPhoto model = ParkPhoto
@@ -196,12 +186,8 @@ class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
class ParkPhotoApprovalInputSerializer(serializers.Serializer): class ParkPhotoApprovalInputSerializer(serializers.Serializer):
"""Input serializer for bulk photo approval operations.""" """Input serializer for bulk photo approval operations."""
photo_ids = serializers.ListField( photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve")
child=serializers.IntegerField(), help_text="List of photo IDs to approve" approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos")
)
approve = serializers.BooleanField(
default=True, help_text="Whether to approve (True) or reject (False) the photos"
)
class ParkPhotoStatsOutputSerializer(serializers.Serializer): class ParkPhotoStatsOutputSerializer(serializers.Serializer):
@@ -261,7 +247,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_city(self, obj): def get_city(self, obj):
"""Get city from related location.""" """Get city from related location."""
try: try:
return obj.location.city if hasattr(obj, 'location') and obj.location else None return obj.location.city if hasattr(obj, "location") and obj.location else None
except AttributeError: except AttributeError:
return None return None
@@ -269,7 +255,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_state(self, obj): def get_state(self, obj):
"""Get state from related location.""" """Get state from related location."""
try: try:
return obj.location.state if hasattr(obj, 'location') and obj.location else None return obj.location.state if hasattr(obj, "location") and obj.location else None
except AttributeError: except AttributeError:
return None return None
@@ -277,7 +263,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_country(self, obj): def get_country(self, obj):
"""Get country from related location.""" """Get country from related location."""
try: try:
return obj.location.country if hasattr(obj, 'location') and obj.location else None return obj.location.country if hasattr(obj, "location") and obj.location else None
except AttributeError: except AttributeError:
return None return None
@@ -285,7 +271,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_continent(self, obj): def get_continent(self, obj):
"""Get continent from related location.""" """Get continent from related location."""
try: try:
return obj.location.continent if hasattr(obj, 'location') and obj.location else None return obj.location.continent if hasattr(obj, "location") and obj.location else None
except AttributeError: except AttributeError:
return None return None
@@ -293,7 +279,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_latitude(self, obj): def get_latitude(self, obj):
"""Get latitude from related location.""" """Get latitude from related location."""
try: try:
if hasattr(obj, 'location') and obj.location and obj.location.coordinates: if hasattr(obj, "location") and obj.location and obj.location.coordinates:
return obj.location.coordinates[1] # PostGIS returns [lon, lat] return obj.location.coordinates[1] # PostGIS returns [lon, lat]
return None return None
except (AttributeError, IndexError, TypeError): except (AttributeError, IndexError, TypeError):
@@ -303,7 +289,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_longitude(self, obj): def get_longitude(self, obj):
"""Get longitude from related location.""" """Get longitude from related location."""
try: try:
if hasattr(obj, 'location') and obj.location and obj.location.coordinates: if hasattr(obj, "location") and obj.location and obj.location.coordinates:
return obj.location.coordinates[0] # PostGIS returns [lon, lat] return obj.location.coordinates[0] # PostGIS returns [lon, lat]
return None return None
except (AttributeError, IndexError, TypeError): except (AttributeError, IndexError, TypeError):
@@ -333,13 +319,11 @@ class HybridParkSerializer(serializers.ModelSerializer):
"description", "description",
"status", "status",
"park_type", "park_type",
# Dates and computed fields # Dates and computed fields
"opening_date", "opening_date",
"closing_date", "closing_date",
"opening_year", "opening_year",
"operating_season", "operating_season",
# Location fields # Location fields
"city", "city",
"state", "state",
@@ -347,28 +331,22 @@ class HybridParkSerializer(serializers.ModelSerializer):
"continent", "continent",
"latitude", "latitude",
"longitude", "longitude",
# Company relationships # Company relationships
"operator_name", "operator_name",
"property_owner_name", "property_owner_name",
# Statistics # Statistics
"size_acres", "size_acres",
"average_rating", "average_rating",
"ride_count", "ride_count",
"coaster_count", "coaster_count",
# Images # Images
"banner_image_url", "banner_image_url",
"card_image_url", "card_image_url",
# URLs # URLs
"website", "website",
"url", "url",
# Computed fields for filtering # Computed fields for filtering
"search_text", "search_text",
# Metadata # Metadata
"created_at", "created_at",
"updated_at", "updated_at",

View File

@@ -46,8 +46,8 @@ ride_photos_router.register(r"", RidePhotoViewSet, basename="ride-photo")
ride_reviews_router = DefaultRouter() ride_reviews_router = DefaultRouter()
ride_reviews_router.register(r"", RideReviewViewSet, basename="ride-review") ride_reviews_router.register(r"", RideReviewViewSet, basename="ride-review")
from .history_views import ParkHistoryViewSet, RideHistoryViewSet from .history_views import ParkHistoryViewSet, RideHistoryViewSet # noqa: E402
from .park_reviews_views import ParkReviewViewSet from .park_reviews_views import ParkReviewViewSet # noqa: E402
# Create routers for nested park endpoints # Create routers for nested park endpoints
reviews_router = DefaultRouter() reviews_router = DefaultRouter()
@@ -59,11 +59,9 @@ app_name = "api_v1_parks"
urlpatterns = [ urlpatterns = [
# Core list/create endpoints # Core list/create endpoints
path("", ParkListCreateAPIView.as_view(), name="park-list-create"), path("", ParkListCreateAPIView.as_view(), name="park-list-create"),
# Hybrid filtering endpoints # Hybrid filtering endpoints
path("hybrid/", HybridParkAPIView.as_view(), name="park-hybrid-list"), path("hybrid/", HybridParkAPIView.as_view(), name="park-hybrid-list"),
path("hybrid/filter-metadata/", ParkFilterMetadataAPIView.as_view(), name="park-hybrid-filter-metadata"), path("hybrid/filter-metadata/", ParkFilterMetadataAPIView.as_view(), name="park-hybrid-filter-metadata"),
# Filter options # Filter options
path("filter-options/", FilterOptionsAPIView.as_view(), name="park-filter-options"), path("filter-options/", FilterOptionsAPIView.as_view(), name="park-filter-options"),
# Autocomplete / suggestion endpoints # Autocomplete / suggestion endpoints
@@ -79,14 +77,11 @@ urlpatterns = [
), ),
# Detail and action endpoints - supports both ID and slug # Detail and action endpoints - supports both ID and slug
path("<str:pk>/", ParkDetailAPIView.as_view(), name="park-detail"), path("<str:pk>/", ParkDetailAPIView.as_view(), name="park-detail"),
# Park rides endpoints # Park rides endpoints
path("<str:park_slug>/rides/", ParkRidesListAPIView.as_view(), name="park-rides-list"), path("<str:park_slug>/rides/", ParkRidesListAPIView.as_view(), name="park-rides-list"),
path("<str:park_slug>/rides/<str:ride_slug>/", ParkRideDetailAPIView.as_view(), name="park-ride-detail"), path("<str:park_slug>/rides/<str:ride_slug>/", ParkRideDetailAPIView.as_view(), name="park-ride-detail"),
# Comprehensive park detail endpoint with rides summary # Comprehensive park detail endpoint with rides summary
path("<str:park_slug>/detail/", ParkComprehensiveDetailAPIView.as_view(), name="park-comprehensive-detail"), path("<str:park_slug>/detail/", ParkComprehensiveDetailAPIView.as_view(), name="park-comprehensive-detail"),
# Park image settings endpoint # Park image settings endpoint
path( path(
"<int:pk>/image-settings/", "<int:pk>/image-settings/",
@@ -95,33 +90,29 @@ urlpatterns = [
), ),
# Park photo endpoints - domain-specific photo management # Park photo endpoints - domain-specific photo management
path("<str:park_pk>/photos/", include(router.urls)), path("<str:park_pk>/photos/", include(router.urls)),
# Nested ride photo endpoints - photos for specific rides within parks # Nested ride photo endpoints - photos for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/photos/", include(ride_photos_router.urls)), path("<str:park_slug>/rides/<str:ride_slug>/photos/", include(ride_photos_router.urls)),
# Nested ride review endpoints - reviews for specific rides within parks # Nested ride review endpoints - reviews for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)), path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
# Nested ride review endpoints - reviews for specific rides within parks # Nested ride review endpoints - reviews for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)), path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
# Ride History # Ride History
path("<str:park_slug>/rides/<str:ride_slug>/history/", RideHistoryViewSet.as_view({'get': 'list'}), name="ride-history"), path(
"<str:park_slug>/rides/<str:ride_slug>/history/",
RideHistoryViewSet.as_view({"get": "list"}),
name="ride-history",
),
# Park Reviews # Park Reviews
path("<str:park_slug>/reviews/", include(reviews_router.urls)), path("<str:park_slug>/reviews/", include(reviews_router.urls)),
# Park History # Park History
path("<str:park_slug>/history/", ParkHistoryViewSet.as_view({'get': 'list'}), name="park-history"), path("<str:park_slug>/history/", ParkHistoryViewSet.as_view({"get": "list"}), name="park-history"),
# Roadtrip API endpoints # Roadtrip API endpoints
path("roadtrip/create/", CreateTripView.as_view(), name="roadtrip-create"), path("roadtrip/create/", CreateTripView.as_view(), name="roadtrip-create"),
path("roadtrip/find-along-route/", FindParksAlongRouteView.as_view(), name="roadtrip-find"), path("roadtrip/find-along-route/", FindParksAlongRouteView.as_view(), name="roadtrip-find"),
path("roadtrip/geocode/", GeocodeAddressView.as_view(), name="roadtrip-geocode"), path("roadtrip/geocode/", GeocodeAddressView.as_view(), name="roadtrip-geocode"),
path("roadtrip/distance/", ParkDistanceCalculatorView.as_view(), name="roadtrip-distance"), path("roadtrip/distance/", ParkDistanceCalculatorView.as_view(), name="roadtrip-distance"),
# Operator endpoints # Operator endpoints
path("operators/", OperatorListAPIView.as_view(), name="operator-list"), path("operators/", OperatorListAPIView.as_view(), name="operator-list"),
# Location search endpoints # Location search endpoints
path("search/location/", location_search, name="location-search"), path("search/location/", location_search, name="location-search"),
path("search/reverse-geocode/", reverse_geocode, name="reverse-geocode"), path("search/reverse-geocode/", reverse_geocode, name="reverse-geocode"),

View File

@@ -134,9 +134,7 @@ class ParkPhotoViewSet(ModelViewSet):
def get_queryset(self): # type: ignore[override] def get_queryset(self): # type: ignore[override]
"""Get photos for the current park with optimized queries.""" """Get photos for the current park with optimized queries."""
queryset = ParkPhoto.objects.select_related( queryset = ParkPhoto.objects.select_related("park", "park__operator", "uploaded_by")
"park", "park__operator", "uploaded_by"
)
# If park_pk is provided in URL kwargs, filter by park # If park_pk is provided in URL kwargs, filter by park
# If park_pk is provided in URL kwargs, filter by park # If park_pk is provided in URL kwargs, filter by park
@@ -172,7 +170,7 @@ class ParkPhotoViewSet(ModelViewSet):
# Use real park ID # Use real park ID
park_id = park.id park_id = park.id
except Park.DoesNotExist: except Park.DoesNotExist:
raise ValidationError("Park not found") raise ValidationError("Park not found") from None
try: try:
# Use the service to create the photo with proper business logic # Use the service to create the photo with proper business logic
@@ -188,48 +186,38 @@ class ParkPhotoViewSet(ModelViewSet):
except (ValidationException, ValidationError) as e: except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error creating park photo: {e}") logger.warning(f"Validation error creating park photo: {e}")
raise ValidationError(str(e)) raise ValidationError(str(e)) from None
except ServiceError as e: except ServiceError as e:
logger.error(f"Service error creating park photo: {e}") logger.error(f"Service error creating park photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}") raise ValidationError(f"Failed to create photo: {str(e)}") from None
def perform_update(self, serializer): def perform_update(self, serializer):
"""Update park photo with permission checking.""" """Update park photo with permission checking."""
instance = self.get_object() instance = self.get_object()
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.uploaded_by or cast(Any, self.request.user).is_staff):
self.request.user == instance.uploaded_by
or cast(Any, self.request.user).is_staff
):
raise PermissionDenied("You can only edit your own photos or be an admin.") raise PermissionDenied("You can only edit your own photos or be an admin.")
# Handle primary photo logic using service # Handle primary photo logic using service
if serializer.validated_data.get("is_primary", False): if serializer.validated_data.get("is_primary", False):
try: try:
ParkMediaService().set_primary_photo( ParkMediaService().set_primary_photo(park_id=instance.park_id, photo_id=instance.id)
park_id=instance.park_id, photo_id=instance.id
)
# Remove is_primary from validated_data since service handles it # Remove is_primary from validated_data since service handles it
if "is_primary" in serializer.validated_data: if "is_primary" in serializer.validated_data:
del serializer.validated_data["is_primary"] del serializer.validated_data["is_primary"]
except (ValidationException, ValidationError) as e: except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error setting primary photo: {e}") logger.warning(f"Validation error setting primary photo: {e}")
raise ValidationError(str(e)) raise ValidationError(str(e)) from None
except ServiceError as e: except ServiceError as e:
logger.error(f"Service error setting primary photo: {e}") logger.error(f"Service error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}") raise ValidationError(f"Failed to set primary photo: {str(e)}") from None
def perform_destroy(self, instance): def perform_destroy(self, instance):
"""Delete park photo with permission checking.""" """Delete park photo with permission checking."""
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.uploaded_by or cast(Any, self.request.user).is_staff):
self.request.user == instance.uploaded_by raise PermissionDenied("You can only delete your own photos or be an admin.")
or cast(Any, self.request.user).is_staff
):
raise PermissionDenied(
"You can only delete your own photos or be an admin."
)
# Delete from Cloudflare first if image exists # Delete from Cloudflare first if image exists
if instance.image: if instance.image:
@@ -240,9 +228,7 @@ class ParkPhotoViewSet(ModelViewSet):
service = CloudflareImagesService() service = CloudflareImagesService()
service.delete_image(instance.image) service.delete_image(instance.image)
logger.info( logger.info(f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}")
f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}"
)
except ImportError: except ImportError:
logger.warning("CloudflareImagesService not available") logger.warning("CloudflareImagesService not available")
except ServiceError as e: except ServiceError as e:
@@ -250,12 +236,10 @@ class ParkPhotoViewSet(ModelViewSet):
# Continue with database deletion even if Cloudflare deletion fails # Continue with database deletion even if Cloudflare deletion fails
try: try:
ParkMediaService().delete_photo( ParkMediaService().delete_photo(instance.id, deleted_by=cast(UserModel, self.request.user))
instance.id, deleted_by=cast(UserModel, self.request.user)
)
except ServiceError as e: except ServiceError as e:
logger.error(f"Service error deleting park photo: {e}") logger.error(f"Service error deleting park photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}") raise ValidationError(f"Failed to delete photo: {str(e)}") from None
@extend_schema( @extend_schema(
summary="Set photo as primary", summary="Set photo as primary",
@@ -275,14 +259,10 @@ class ParkPhotoViewSet(ModelViewSet):
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not (request.user == photo.uploaded_by or cast(Any, request.user).is_staff): if not (request.user == photo.uploaded_by or cast(Any, request.user).is_staff):
raise PermissionDenied( raise PermissionDenied("You can only modify your own photos or be an admin.")
"You can only modify your own photos or be an admin."
)
try: try:
ParkMediaService().set_primary_photo( ParkMediaService().set_primary_photo(park_id=photo.park_id, photo_id=photo.id)
park_id=photo.park_id, photo_id=photo.id
)
# Refresh the photo instance # Refresh the photo instance
photo.refresh_from_db() photo.refresh_from_db()
@@ -290,7 +270,7 @@ class ParkPhotoViewSet(ModelViewSet):
return Response( return Response(
{ {
"message": "Photo set as primary successfully", "detail": "Photo set as primary successfully",
"photo": serializer.data, "photo": serializer.data,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -337,7 +317,7 @@ class ParkPhotoViewSet(ModelViewSet):
if photo_ids is None or approve is None: if photo_ids is None or approve is None:
return Response( return Response(
{"error": "Missing required fields: photo_ids and/or approve."}, {"detail": "Missing required fields: photo_ids and/or approve."},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -354,7 +334,7 @@ class ParkPhotoViewSet(ModelViewSet):
return Response( return Response(
{ {
"message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", "detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"updated_count": updated_count, "updated_count": updated_count,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -430,19 +410,14 @@ class ParkPhotoViewSet(ModelViewSet):
def set_primary_legacy(self, request, id=None): def set_primary_legacy(self, request, id=None):
"""Legacy set primary action for backwards compatibility.""" """Legacy set primary action for backwards compatibility."""
photo = self.get_object() photo = self.get_object()
if not ( if not (request.user == photo.uploaded_by or request.user.has_perm("parks.change_parkphoto")):
request.user == photo.uploaded_by
or request.user.has_perm("parks.change_parkphoto")
):
return Response( return Response(
{"error": "You do not have permission to edit photos for this park."}, {"detail": "You do not have permission to edit photos for this park."},
status=status.HTTP_403_FORBIDDEN, status=status.HTTP_403_FORBIDDEN,
) )
try: try:
ParkMediaService().set_primary_photo( ParkMediaService().set_primary_photo(park_id=photo.park_id, photo_id=photo.id)
park_id=photo.park_id, photo_id=photo.id return Response({"detail": "Photo set as primary successfully."})
)
return Response({"message": "Photo set as primary successfully."})
except (ValidationException, ValidationError) as e: except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error in set_primary_photo: {str(e)}") logger.warning(f"Validation error in set_primary_photo: {str(e)}")
return ErrorHandler.handle_api_error( return ErrorHandler.handle_api_error(
@@ -475,7 +450,7 @@ class ParkPhotoViewSet(ModelViewSet):
park_pk = self.kwargs.get("park_pk") park_pk = self.kwargs.get("park_pk")
if not park_pk: if not park_pk:
return Response( return Response(
{"error": "Park ID is required"}, {"detail": "Park ID is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -483,14 +458,14 @@ class ParkPhotoViewSet(ModelViewSet):
park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk) park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk)
except Park.DoesNotExist: except Park.DoesNotExist:
return Response( return Response(
{"error": "Park not found"}, {"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
cloudflare_image_id = request.data.get("cloudflare_image_id") cloudflare_image_id = request.data.get("cloudflare_image_id")
if not cloudflare_image_id: if not cloudflare_image_id:
return Response( return Response(
{"error": "cloudflare_image_id is required"}, {"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -515,18 +490,14 @@ class ParkPhotoViewSet(ModelViewSet):
# Try to find existing CloudflareImage record by cloudflare_id # Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None cloudflare_image = None
try: try:
cloudflare_image = CloudflareImage.objects.get( cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
cloudflare_id=cloudflare_image_id
)
# Update existing record with latest data from Cloudflare # Update existing record with latest data from Cloudflare
cloudflare_image.status = "uploaded" cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now() cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get("meta", {}) cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure # Extract variants from nested result structure
cloudflare_image.variants = image_data.get("result", {}).get( cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
"variants", []
)
cloudflare_image.cloudflare_metadata = image_data cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get("width") cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get("height") cloudflare_image.height = image_data.get("height")
@@ -540,8 +511,7 @@ class ParkPhotoViewSet(ModelViewSet):
user=request.user, user=request.user,
status="uploaded", status="uploaded",
upload_url="", # Not needed for uploaded images upload_url="", # Not needed for uploaded images
expires_at=timezone.now() expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
+ timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(), uploaded_at=timezone.now(),
metadata=image_data.get("meta", {}), metadata=image_data.get("meta", {}),
# Extract variants from nested result structure # Extract variants from nested result structure
@@ -567,9 +537,7 @@ class ParkPhotoViewSet(ModelViewSet):
# Handle primary photo logic if requested # Handle primary photo logic if requested
if request.data.get("is_primary", False): if request.data.get("is_primary", False):
try: try:
ParkMediaService().set_primary_photo( ParkMediaService().set_primary_photo(park_id=park.id, photo_id=photo.id)
park_id=park.id, photo_id=photo.id
)
except ServiceError as e: except ServiceError as e:
logger.error(f"Error setting primary photo: {e}") logger.error(f"Error setting primary photo: {e}")
# Don't fail the entire operation, just log the error # Don't fail the entire operation, just log the error
@@ -624,12 +592,8 @@ class ParkPhotoViewSet(ModelViewSet):
OpenApiTypes.STR, OpenApiTypes.STR,
description="Filter by state (comma-separated for multiple)", description="Filter by state (comma-separated for multiple)",
), ),
OpenApiParameter( OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"),
"opening_year_min", OpenApiTypes.INT, description="Minimum opening year" OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"),
),
OpenApiParameter(
"opening_year_max", OpenApiTypes.INT, description="Maximum opening year"
),
OpenApiParameter( OpenApiParameter(
"size_min", "size_min",
OpenApiTypes.NUMBER, OpenApiTypes.NUMBER,
@@ -640,18 +604,10 @@ class ParkPhotoViewSet(ModelViewSet):
OpenApiTypes.NUMBER, OpenApiTypes.NUMBER,
description="Maximum park size in acres", description="Maximum park size in acres",
), ),
OpenApiParameter( OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"),
"rating_min", OpenApiTypes.NUMBER, description="Minimum average rating" OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"),
), OpenApiParameter("ride_count_min", OpenApiTypes.INT, description="Minimum ride count"),
OpenApiParameter( OpenApiParameter("ride_count_max", OpenApiTypes.INT, description="Maximum ride count"),
"rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"
),
OpenApiParameter(
"ride_count_min", OpenApiTypes.INT, description="Minimum ride count"
),
OpenApiParameter(
"ride_count_max", OpenApiTypes.INT, description="Maximum ride count"
),
OpenApiParameter( OpenApiParameter(
"coaster_count_min", "coaster_count_min",
OpenApiTypes.INT, OpenApiTypes.INT,
@@ -688,9 +644,7 @@ class ParkPhotoViewSet(ModelViewSet):
"properties": { "properties": {
"parks": { "parks": {
"type": "array", "type": "array",
"items": { "items": {"$ref": "#/components/schemas/HybridParkSerializer"},
"$ref": "#/components/schemas/HybridParkSerializer"
},
}, },
"total_count": {"type": "integer"}, "total_count": {"type": "integer"},
"strategy": { "strategy": {
@@ -808,7 +762,7 @@ class HybridParkAPIView(APIView):
for param in int_params: for param in int_params:
value = query_params.get(param) value = query_params.get(param)
if value: if value:
try: try: # noqa: SIM105
filters[param] = int(value) filters[param] = int(value)
except ValueError: except ValueError:
pass # Skip invalid integer values pass # Skip invalid integer values
@@ -818,7 +772,7 @@ class HybridParkAPIView(APIView):
for param in float_params: for param in float_params:
value = query_params.get(param) value = query_params.get(param)
if value: if value:
try: try: # noqa: SIM105
filters[param] = float(value) filters[param] = float(value)
except ValueError: except ValueError:
pass # Skip invalid float values pass # Skip invalid float values

View File

@@ -0,0 +1,167 @@
"""
Standardized API response helpers for ThrillWiki.
This module provides consistent response formatting across all API endpoints:
Success responses:
- Action completed: {"detail": "Success message"}
- With data: {"detail": "...", "data": {...}}
Error responses:
- Validation: {"field": ["error"]} (DRF default)
- Application: {"detail": "Error message", "code": "ERROR_CODE"}
Usage:
from apps.api.v1.responses import success_response, error_response
# Success
return success_response("Avatar saved successfully")
# Error
return error_response("User not found", code="NOT_FOUND", status_code=404)
"""
from rest_framework import status
from rest_framework.response import Response
# Standard error codes for machine-readable error handling
class ErrorCodes:
"""Standard error codes for API responses."""
# Authentication / Authorization
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
INVALID_CREDENTIALS = "INVALID_CREDENTIALS"
TOKEN_EXPIRED = "TOKEN_EXPIRED"
TOKEN_INVALID = "TOKEN_INVALID"
# Resource errors
NOT_FOUND = "NOT_FOUND"
ALREADY_EXISTS = "ALREADY_EXISTS"
CONFLICT = "CONFLICT"
# Validation errors
VALIDATION_ERROR = "VALIDATION_ERROR"
INVALID_INPUT = "INVALID_INPUT"
MISSING_FIELD = "MISSING_FIELD"
# Operation errors
OPERATION_FAILED = "OPERATION_FAILED"
PERMISSION_DENIED = "PERMISSION_DENIED"
RATE_LIMITED = "RATE_LIMITED"
# User-specific errors
USER_NOT_FOUND = "USER_NOT_FOUND"
USER_INACTIVE = "USER_INACTIVE"
USER_BANNED = "USER_BANNED"
CANNOT_DELETE_SUPERUSER = "CANNOT_DELETE_SUPERUSER"
CANNOT_DELETE_SELF = "CANNOT_DELETE_SELF"
# Verification errors
VERIFICATION_EXPIRED = "VERIFICATION_EXPIRED"
VERIFICATION_INVALID = "VERIFICATION_INVALID"
ALREADY_VERIFIED = "ALREADY_VERIFIED"
# External service errors
EXTERNAL_SERVICE_ERROR = "EXTERNAL_SERVICE_ERROR"
CLOUDFLARE_ERROR = "CLOUDFLARE_ERROR"
def success_response(
detail: str,
data: dict | None = None,
status_code: int = status.HTTP_200_OK,
) -> Response:
"""
Create a standardized success response.
Args:
detail: Human-readable success message
data: Optional additional data to include
status_code: HTTP status code (default 200)
Returns:
DRF Response object
Example:
return success_response("Avatar saved successfully")
return success_response("User created", data={"id": user.id}, status_code=201)
"""
response_data = {"detail": detail}
if data:
response_data.update(data)
return Response(response_data, status=status_code)
def error_response(
detail: str,
code: str | None = None,
status_code: int = status.HTTP_400_BAD_REQUEST,
extra: dict | None = None,
) -> Response:
"""
Create a standardized error response.
Args:
detail: Human-readable error message
code: Machine-readable error code from ErrorCodes
status_code: HTTP status code (default 400)
extra: Optional additional data to include
Returns:
DRF Response object
Example:
return error_response("User not found", code=ErrorCodes.NOT_FOUND, status_code=404)
return error_response("Invalid input", code=ErrorCodes.VALIDATION_ERROR)
"""
response_data = {"detail": detail}
if code:
response_data["code"] = code
if extra:
response_data.update(extra)
return Response(response_data, status=status_code)
def created_response(detail: str, data: dict | None = None) -> Response:
"""Convenience wrapper for 201 Created responses."""
return success_response(detail, data=data, status_code=status.HTTP_201_CREATED)
def not_found_response(detail: str = "Resource not found") -> Response:
"""Convenience wrapper for 404 Not Found responses."""
return error_response(
detail,
code=ErrorCodes.NOT_FOUND,
status_code=status.HTTP_404_NOT_FOUND,
)
def forbidden_response(detail: str = "Permission denied") -> Response:
"""Convenience wrapper for 403 Forbidden responses."""
return error_response(
detail,
code=ErrorCodes.FORBIDDEN,
status_code=status.HTTP_403_FORBIDDEN,
)
def unauthorized_response(detail: str = "Authentication required") -> Response:
"""Convenience wrapper for 401 Unauthorized responses."""
return error_response(
detail,
code=ErrorCodes.UNAUTHORIZED,
status_code=status.HTTP_401_UNAUTHORIZED,
)
__all__ = [
"ErrorCodes",
"success_response",
"error_response",
"created_response",
"not_found_response",
"forbidden_response",
"unauthorized_response",
]

View File

@@ -24,6 +24,7 @@ from apps.api.v1.serializers.companies import (
try: try:
from apps.rides.models.company import Company from apps.rides.models.company import Company
MODELS_AVAILABLE = True MODELS_AVAILABLE = True
except ImportError: except ImportError:
Company = None Company = None
@@ -65,9 +66,7 @@ class CompanyListCreateAPIView(APIView):
# Search filter # Search filter
search = request.query_params.get("search", "") search = request.query_params.get("search", "")
if search: if search:
qs = qs.filter( qs = qs.filter(Q(name__icontains=search) | Q(description__icontains=search))
Q(name__icontains=search) | Q(description__icontains=search)
)
# Role filter # Role filter
role = request.query_params.get("role", "") role = request.query_params.get("role", "")
@@ -120,7 +119,7 @@ class CompanyDetailAPIView(APIView):
try: try:
return Company.objects.get(pk=pk) return Company.objects.get(pk=pk)
except Company.DoesNotExist: except Company.DoesNotExist:
raise NotFound("Company not found") raise NotFound("Company not found") from None
@extend_schema( @extend_schema(
summary="Retrieve a company", summary="Retrieve a company",

View File

@@ -93,18 +93,10 @@ class RideModelListCreateAPIView(APIView):
type=OpenApiTypes.STR, type=OpenApiTypes.STR,
required=True, required=True,
), ),
OpenApiParameter( OpenApiParameter(name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
), OpenApiParameter(name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter( OpenApiParameter(name="category", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT
),
OpenApiParameter(
name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
),
OpenApiParameter(
name="category", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
),
OpenApiParameter( OpenApiParameter(
name="target_market", name="target_market",
location=OpenApiParameter.QUERY, location=OpenApiParameter.QUERY,
@@ -134,7 +126,7 @@ class RideModelListCreateAPIView(APIView):
try: try:
manufacturer = Company.objects.get(slug=manufacturer_slug) manufacturer = Company.objects.get(slug=manufacturer_slug)
except Company.DoesNotExist: except Company.DoesNotExist:
raise NotFound("Manufacturer not found") raise NotFound("Manufacturer not found") from None
qs = ( qs = (
RideModel.objects.filter(manufacturer=manufacturer) RideModel.objects.filter(manufacturer=manufacturer)
@@ -176,13 +168,9 @@ class RideModelListCreateAPIView(APIView):
# Year filters # Year filters
if filters.get("first_installation_year_min"): if filters.get("first_installation_year_min"):
qs = qs.filter( qs = qs.filter(first_installation_year__gte=filters["first_installation_year_min"])
first_installation_year__gte=filters["first_installation_year_min"]
)
if filters.get("first_installation_year_max"): if filters.get("first_installation_year_max"):
qs = qs.filter( qs = qs.filter(first_installation_year__lte=filters["first_installation_year_max"])
first_installation_year__lte=filters["first_installation_year_max"]
)
# Installation count filter # Installation count filter
if filters.get("min_installations"): if filters.get("min_installations"):
@@ -190,23 +178,15 @@ class RideModelListCreateAPIView(APIView):
# Height filters # Height filters
if filters.get("min_height_ft"): if filters.get("min_height_ft"):
qs = qs.filter( qs = qs.filter(typical_height_range_max_ft__gte=filters["min_height_ft"])
typical_height_range_max_ft__gte=filters["min_height_ft"]
)
if filters.get("max_height_ft"): if filters.get("max_height_ft"):
qs = qs.filter( qs = qs.filter(typical_height_range_min_ft__lte=filters["max_height_ft"])
typical_height_range_min_ft__lte=filters["max_height_ft"]
)
# Speed filters # Speed filters
if filters.get("min_speed_mph"): if filters.get("min_speed_mph"):
qs = qs.filter( qs = qs.filter(typical_speed_range_max_mph__gte=filters["min_speed_mph"])
typical_speed_range_max_mph__gte=filters["min_speed_mph"]
)
if filters.get("max_speed_mph"): if filters.get("max_speed_mph"):
qs = qs.filter( qs = qs.filter(typical_speed_range_min_mph__lte=filters["max_speed_mph"])
typical_speed_range_min_mph__lte=filters["max_speed_mph"]
)
# Ordering # Ordering
ordering = filters.get("ordering", "manufacturer__name,name") ordering = filters.get("ordering", "manufacturer__name,name")
@@ -216,9 +196,7 @@ class RideModelListCreateAPIView(APIView):
paginator = StandardResultsSetPagination() paginator = StandardResultsSetPagination()
page = paginator.paginate_queryset(qs, request) page = paginator.paginate_queryset(qs, request)
serializer = RideModelListOutputSerializer( serializer = RideModelListOutputSerializer(page, many=True, context={"request": request})
page, many=True, context={"request": request}
)
return paginator.get_paginated_response(serializer.data) return paginator.get_paginated_response(serializer.data)
@extend_schema( @extend_schema(
@@ -240,9 +218,7 @@ class RideModelListCreateAPIView(APIView):
"""Create a new ride model for a specific manufacturer.""" """Create a new ride model for a specific manufacturer."""
if not MODELS_AVAILABLE: if not MODELS_AVAILABLE:
return Response( return Response(
{ {"detail": "Ride model creation is not available because domain models are not imported."},
"detail": "Ride model creation is not available because domain models are not imported."
},
status=status.HTTP_501_NOT_IMPLEMENTED, status=status.HTTP_501_NOT_IMPLEMENTED,
) )
@@ -250,7 +226,7 @@ class RideModelListCreateAPIView(APIView):
try: try:
manufacturer = Company.objects.get(slug=manufacturer_slug) manufacturer = Company.objects.get(slug=manufacturer_slug)
except Company.DoesNotExist: except Company.DoesNotExist:
raise NotFound("Manufacturer not found") raise NotFound("Manufacturer not found") from None
serializer_in = RideModelCreateInputSerializer(data=request.data) serializer_in = RideModelCreateInputSerializer(data=request.data)
serializer_in.is_valid(raise_exception=True) serializer_in.is_valid(raise_exception=True)
@@ -279,18 +255,14 @@ class RideModelListCreateAPIView(APIView):
target_market=validated.get("target_market", ""), target_market=validated.get("target_market", ""),
) )
out_serializer = RideModelDetailOutputSerializer( out_serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request})
ride_model, context={"request": request}
)
return Response(out_serializer.data, status=status.HTTP_201_CREATED) return Response(out_serializer.data, status=status.HTTP_201_CREATED)
class RideModelDetailAPIView(APIView): class RideModelDetailAPIView(APIView):
permission_classes = [permissions.AllowAny] permission_classes = [permissions.AllowAny]
def _get_ride_model_or_404( def _get_ride_model_or_404(self, manufacturer_slug: str, ride_model_slug: str) -> Any:
self, manufacturer_slug: str, ride_model_slug: str
) -> Any:
if not MODELS_AVAILABLE: if not MODELS_AVAILABLE:
raise NotFound("Ride model models not available") raise NotFound("Ride model models not available")
try: try:
@@ -300,7 +272,7 @@ class RideModelDetailAPIView(APIView):
.get(manufacturer__slug=manufacturer_slug, slug=ride_model_slug) .get(manufacturer__slug=manufacturer_slug, slug=ride_model_slug)
) )
except RideModel.DoesNotExist: except RideModel.DoesNotExist:
raise NotFound("Ride model not found") raise NotFound("Ride model not found") from None
@extend_schema( @extend_schema(
summary="Retrieve a ride model", summary="Retrieve a ride model",
@@ -322,13 +294,9 @@ class RideModelDetailAPIView(APIView):
responses={200: RideModelDetailOutputSerializer()}, responses={200: RideModelDetailOutputSerializer()},
tags=["Ride Models"], tags=["Ride Models"],
) )
def get( def get(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug) ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug)
serializer = RideModelDetailOutputSerializer( serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request})
ride_model, context={"request": request}
)
return Response(serializer.data) return Response(serializer.data)
@extend_schema( @extend_schema(
@@ -352,9 +320,7 @@ class RideModelDetailAPIView(APIView):
responses={200: RideModelDetailOutputSerializer()}, responses={200: RideModelDetailOutputSerializer()},
tags=["Ride Models"], tags=["Ride Models"],
) )
def patch( def patch(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug) ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug)
serializer_in = RideModelUpdateInputSerializer(data=request.data, partial=True) serializer_in = RideModelUpdateInputSerializer(data=request.data, partial=True)
serializer_in.is_valid(raise_exception=True) serializer_in.is_valid(raise_exception=True)
@@ -366,20 +332,16 @@ class RideModelDetailAPIView(APIView):
manufacturer = Company.objects.get(id=value) manufacturer = Company.objects.get(id=value)
ride_model.manufacturer = manufacturer ride_model.manufacturer = manufacturer
except Company.DoesNotExist: except Company.DoesNotExist:
raise ValidationError({"manufacturer_id": "Manufacturer not found"}) raise ValidationError({"manufacturer_id": "Manufacturer not found"}) from None
else: else:
setattr(ride_model, field, value) setattr(ride_model, field, value)
ride_model.save() ride_model.save()
serializer = RideModelDetailOutputSerializer( serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request})
ride_model, context={"request": request}
)
return Response(serializer.data) return Response(serializer.data)
def put( def put(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
# Full replace - reuse patch behavior for simplicity # Full replace - reuse patch behavior for simplicity
return self.patch(request, manufacturer_slug, ride_model_slug) return self.patch(request, manufacturer_slug, ride_model_slug)
@@ -403,9 +365,7 @@ class RideModelDetailAPIView(APIView):
responses={204: None}, responses={204: None},
tags=["Ride Models"], tags=["Ride Models"],
) )
def delete( def delete(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug) ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug)
ride_model.delete() ride_model.delete()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)
@@ -449,9 +409,7 @@ class RideModelSearchAPIView(APIView):
) )
qs = RideModel.objects.filter( qs = RideModel.objects.filter(
Q(name__icontains=q) Q(name__icontains=q) | Q(description__icontains=q) | Q(manufacturer__name__icontains=q)
| Q(description__icontains=q)
| Q(manufacturer__name__icontains=q)
).select_related("manufacturer")[:20] ).select_related("manufacturer")[:20]
results = [ results = [
@@ -491,8 +449,8 @@ class RideModelFilterOptionsAPIView(APIView):
# Use Rich Choice Objects for fallback options # Use Rich Choice Objects for fallback options
try: try:
# Get rich choice objects from registry # Get rich choice objects from registry
categories = get_choices('categories', 'rides') categories = get_choices("categories", "rides")
target_markets = get_choices('target_markets', 'rides') target_markets = get_choices("target_markets", "rides")
# Convert Rich Choice Objects to frontend format with metadata # Convert Rich Choice Objects to frontend format with metadata
categories_data = [ categories_data = [
@@ -500,10 +458,10 @@ class RideModelFilterOptionsAPIView(APIView):
"value": choice.value, "value": choice.value,
"label": choice.label, "label": choice.label,
"description": choice.description, "description": choice.description,
"color": choice.metadata.get('color'), "color": choice.metadata.get("color"),
"icon": choice.metadata.get('icon'), "icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get('css_class'), "css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get('sort_order', 0) "sort_order": choice.metadata.get("sort_order", 0),
} }
for choice in categories for choice in categories
] ]
@@ -513,10 +471,10 @@ class RideModelFilterOptionsAPIView(APIView):
"value": choice.value, "value": choice.value,
"label": choice.label, "label": choice.label,
"description": choice.description, "description": choice.description,
"color": choice.metadata.get('color'), "color": choice.metadata.get("color"),
"icon": choice.metadata.get('icon'), "icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get('css_class'), "css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get('sort_order', 0) "sort_order": choice.metadata.get("sort_order", 0),
} }
for choice in target_markets for choice in target_markets
] ]
@@ -524,25 +482,173 @@ class RideModelFilterOptionsAPIView(APIView):
except Exception: except Exception:
# Ultimate fallback with basic structure # Ultimate fallback with basic structure
categories_data = [ categories_data = [
{"value": "RC", "label": "Roller Coaster", "description": "High-speed thrill rides with tracks", "color": "red", "icon": "roller-coaster", "css_class": "bg-red-100 text-red-800", "sort_order": 1}, {
{"value": "DR", "label": "Dark Ride", "description": "Indoor themed experiences", "color": "purple", "icon": "dark-ride", "css_class": "bg-purple-100 text-purple-800", "sort_order": 2}, "value": "RC",
{"value": "FR", "label": "Flat Ride", "description": "Spinning and rotating attractions", "color": "blue", "icon": "flat-ride", "css_class": "bg-blue-100 text-blue-800", "sort_order": 3}, "label": "Roller Coaster",
{"value": "WR", "label": "Water Ride", "description": "Water-based attractions and slides", "color": "cyan", "icon": "water-ride", "css_class": "bg-cyan-100 text-cyan-800", "sort_order": 4}, "description": "High-speed thrill rides with tracks",
{"value": "TR", "label": "Transport", "description": "Transportation systems within parks", "color": "green", "icon": "transport", "css_class": "bg-green-100 text-green-800", "sort_order": 5}, "color": "red",
{"value": "OT", "label": "Other", "description": "Miscellaneous attractions", "color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 6}, "icon": "roller-coaster",
"css_class": "bg-red-100 text-red-800",
"sort_order": 1,
},
{
"value": "DR",
"label": "Dark Ride",
"description": "Indoor themed experiences",
"color": "purple",
"icon": "dark-ride",
"css_class": "bg-purple-100 text-purple-800",
"sort_order": 2,
},
{
"value": "FR",
"label": "Flat Ride",
"description": "Spinning and rotating attractions",
"color": "blue",
"icon": "flat-ride",
"css_class": "bg-blue-100 text-blue-800",
"sort_order": 3,
},
{
"value": "WR",
"label": "Water Ride",
"description": "Water-based attractions and slides",
"color": "cyan",
"icon": "water-ride",
"css_class": "bg-cyan-100 text-cyan-800",
"sort_order": 4,
},
{
"value": "TR",
"label": "Transport",
"description": "Transportation systems within parks",
"color": "green",
"icon": "transport",
"css_class": "bg-green-100 text-green-800",
"sort_order": 5,
},
{
"value": "OT",
"label": "Other",
"description": "Miscellaneous attractions",
"color": "gray",
"icon": "other",
"css_class": "bg-gray-100 text-gray-800",
"sort_order": 6,
},
] ]
target_markets_data = [ target_markets_data = [
{"value": "FAMILY", "label": "Family", "description": "Suitable for all family members", "color": "green", "icon": "family", "css_class": "bg-green-100 text-green-800", "sort_order": 1}, {
{"value": "THRILL", "label": "Thrill", "description": "High-intensity thrill experience", "color": "orange", "icon": "thrill", "css_class": "bg-orange-100 text-orange-800", "sort_order": 2}, "value": "FAMILY",
{"value": "EXTREME", "label": "Extreme", "description": "Maximum intensity experience", "color": "red", "icon": "extreme", "css_class": "bg-red-100 text-red-800", "sort_order": 3}, "label": "Family",
{"value": "KIDDIE", "label": "Kiddie", "description": "Designed for young children", "color": "pink", "icon": "kiddie", "css_class": "bg-pink-100 text-pink-800", "sort_order": 4}, "description": "Suitable for all family members",
{"value": "ALL_AGES", "label": "All Ages", "description": "Enjoyable for all age groups", "color": "blue", "icon": "all-ages", "css_class": "bg-blue-100 text-blue-800", "sort_order": 5}, "color": "green",
"icon": "family",
"css_class": "bg-green-100 text-green-800",
"sort_order": 1,
},
{
"value": "THRILL",
"label": "Thrill",
"description": "High-intensity thrill experience",
"color": "orange",
"icon": "thrill",
"css_class": "bg-orange-100 text-orange-800",
"sort_order": 2,
},
{
"value": "EXTREME",
"label": "Extreme",
"description": "Maximum intensity experience",
"color": "red",
"icon": "extreme",
"css_class": "bg-red-100 text-red-800",
"sort_order": 3,
},
{
"value": "KIDDIE",
"label": "Kiddie",
"description": "Designed for young children",
"color": "pink",
"icon": "kiddie",
"css_class": "bg-pink-100 text-pink-800",
"sort_order": 4,
},
{
"value": "ALL_AGES",
"label": "All Ages",
"description": "Enjoyable for all age groups",
"color": "blue",
"icon": "all-ages",
"css_class": "bg-blue-100 text-blue-800",
"sort_order": 5,
},
] ]
return Response({ return Response(
{
"categories": categories_data,
"target_markets": target_markets_data,
"manufacturers": [{"id": 1, "name": "Bolliger & Mabillard", "slug": "bolliger-mabillard"}],
"ordering_options": [
{"value": "name", "label": "Name A-Z"},
{"value": "-name", "label": "Name Z-A"},
{"value": "manufacturer__name", "label": "Manufacturer A-Z"},
{"value": "-manufacturer__name", "label": "Manufacturer Z-A"},
{"value": "first_installation_year", "label": "Oldest First"},
{"value": "-first_installation_year", "label": "Newest First"},
{"value": "total_installations", "label": "Fewest Installations"},
{"value": "-total_installations", "label": "Most Installations"},
],
}
)
# Get static choice definitions from Rich Choice Objects (primary source)
# Get dynamic data from database queries
# Get rich choice objects from registry
categories = get_choices("categories", "rides")
target_markets = get_choices("target_markets", "rides")
# Convert Rich Choice Objects to frontend format with metadata
categories_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get("color"),
"icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get("sort_order", 0),
}
for choice in categories
]
target_markets_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get("color"),
"icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get("sort_order", 0),
}
for choice in target_markets
]
# Get actual data from database
manufacturers = (
Company.objects.filter(roles__contains=["MANUFACTURER"], ride_models__isnull=False)
.distinct()
.values("id", "name", "slug")
)
return Response(
{
"categories": categories_data, "categories": categories_data,
"target_markets": target_markets_data, "target_markets": target_markets_data,
"manufacturers": [{"id": 1, "name": "Bolliger & Mabillard", "slug": "bolliger-mabillard"}], "manufacturers": list(manufacturers),
"ordering_options": [ "ordering_options": [
{"value": "name", "label": "Name A-Z"}, {"value": "name", "label": "Name A-Z"},
{"value": "-name", "label": "Name Z-A"}, {"value": "-name", "label": "Name Z-A"},
@@ -553,68 +659,9 @@ class RideModelFilterOptionsAPIView(APIView):
{"value": "total_installations", "label": "Fewest Installations"}, {"value": "total_installations", "label": "Fewest Installations"},
{"value": "-total_installations", "label": "Most Installations"}, {"value": "-total_installations", "label": "Most Installations"},
], ],
})
# Get static choice definitions from Rich Choice Objects (primary source)
# Get dynamic data from database queries
# Get rich choice objects from registry
categories = get_choices('categories', 'rides')
target_markets = get_choices('target_markets', 'rides')
# Convert Rich Choice Objects to frontend format with metadata
categories_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get('color'),
"icon": choice.metadata.get('icon'),
"css_class": choice.metadata.get('css_class'),
"sort_order": choice.metadata.get('sort_order', 0)
} }
for choice in categories
]
target_markets_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get('color'),
"icon": choice.metadata.get('icon'),
"css_class": choice.metadata.get('css_class'),
"sort_order": choice.metadata.get('sort_order', 0)
}
for choice in target_markets
]
# Get actual data from database
manufacturers = (
Company.objects.filter(
roles__contains=["MANUFACTURER"], ride_models__isnull=False
)
.distinct()
.values("id", "name", "slug")
) )
return Response({
"categories": categories_data,
"target_markets": target_markets_data,
"manufacturers": list(manufacturers),
"ordering_options": [
{"value": "name", "label": "Name A-Z"},
{"value": "-name", "label": "Name Z-A"},
{"value": "manufacturer__name", "label": "Manufacturer A-Z"},
{"value": "-manufacturer__name", "label": "Manufacturer Z-A"},
{"value": "first_installation_year", "label": "Oldest First"},
{"value": "-first_installation_year", "label": "Newest First"},
{"value": "total_installations", "label": "Fewest Installations"},
{"value": "-total_installations", "label": "Most Installations"},
],
})
# === RIDE MODEL STATISTICS === # === RIDE MODEL STATISTICS ===
@@ -646,37 +693,23 @@ class RideModelStatsAPIView(APIView):
# Calculate statistics # Calculate statistics
total_models = RideModel.objects.count() total_models = RideModel.objects.count()
total_installations = ( total_installations = RideModel.objects.aggregate(total=Count("rides"))["total"] or 0
RideModel.objects.aggregate(total=Count("rides"))["total"] or 0
)
active_manufacturers = ( active_manufacturers = (
Company.objects.filter( Company.objects.filter(roles__contains=["MANUFACTURER"], ride_models__isnull=False).distinct().count()
roles__contains=["MANUFACTURER"], ride_models__isnull=False
)
.distinct()
.count()
) )
discontinued_models = RideModel.objects.filter(is_discontinued=True).count() discontinued_models = RideModel.objects.filter(is_discontinued=True).count()
# Category breakdown # Category breakdown
by_category = {} by_category = {}
category_counts = ( category_counts = RideModel.objects.exclude(category="").values("category").annotate(count=Count("id"))
RideModel.objects.exclude(category="")
.values("category")
.annotate(count=Count("id"))
)
for item in category_counts: for item in category_counts:
by_category[item["category"]] = item["count"] by_category[item["category"]] = item["count"]
# Target market breakdown # Target market breakdown
by_target_market = {} by_target_market = {}
market_counts = ( market_counts = RideModel.objects.exclude(target_market="").values("target_market").annotate(count=Count("id"))
RideModel.objects.exclude(target_market="")
.values("target_market")
.annotate(count=Count("id"))
)
for item in market_counts: for item in market_counts:
by_target_market[item["target_market"]] = item["count"] by_target_market[item["target_market"]] = item["count"]
@@ -693,9 +726,7 @@ class RideModelStatsAPIView(APIView):
# Recent models (last 30 days) # Recent models (last 30 days)
thirty_days_ago = timezone.now() - timedelta(days=30) thirty_days_ago = timezone.now() - timedelta(days=30)
recent_models = RideModel.objects.filter( recent_models = RideModel.objects.filter(created_at__gte=thirty_days_ago).count()
created_at__gte=thirty_days_ago
).count()
return Response( return Response(
{ {
@@ -730,7 +761,7 @@ class RideModelVariantListCreateAPIView(APIView):
try: try:
ride_model = RideModel.objects.get(pk=ride_model_pk) ride_model = RideModel.objects.get(pk=ride_model_pk)
except RideModel.DoesNotExist: except RideModel.DoesNotExist:
raise NotFound("Ride model not found") raise NotFound("Ride model not found") from None
variants = RideModelVariant.objects.filter(ride_model=ride_model) variants = RideModelVariant.objects.filter(ride_model=ride_model)
serializer = RideModelVariantOutputSerializer(variants, many=True) serializer = RideModelVariantOutputSerializer(variants, many=True)
@@ -753,7 +784,7 @@ class RideModelVariantListCreateAPIView(APIView):
try: try:
ride_model = RideModel.objects.get(pk=ride_model_pk) ride_model = RideModel.objects.get(pk=ride_model_pk)
except RideModel.DoesNotExist: except RideModel.DoesNotExist:
raise NotFound("Ride model not found") raise NotFound("Ride model not found") from None
# Override ride_model_id in the data # Override ride_model_id in the data
data = request.data.copy() data = request.data.copy()
@@ -787,7 +818,7 @@ class RideModelVariantDetailAPIView(APIView):
try: try:
return RideModelVariant.objects.get(ride_model_id=ride_model_pk, pk=pk) return RideModelVariant.objects.get(ride_model_id=ride_model_pk, pk=pk)
except RideModelVariant.DoesNotExist: except RideModelVariant.DoesNotExist:
raise NotFound("Variant not found") raise NotFound("Variant not found") from None
@extend_schema( @extend_schema(
summary="Get a ride model variant", summary="Get a ride model variant",
@@ -807,9 +838,7 @@ class RideModelVariantDetailAPIView(APIView):
) )
def patch(self, request: Request, ride_model_pk: int, pk: int) -> Response: def patch(self, request: Request, ride_model_pk: int, pk: int) -> Response:
variant = self._get_variant_or_404(ride_model_pk, pk) variant = self._get_variant_or_404(ride_model_pk, pk)
serializer_in = RideModelVariantUpdateInputSerializer( serializer_in = RideModelVariantUpdateInputSerializer(data=request.data, partial=True)
data=request.data, partial=True
)
serializer_in.is_valid(raise_exception=True) serializer_in.is_valid(raise_exception=True)
for field, value in serializer_in.validated_data.items(): for field, value in serializer_in.validated_data.items():

View File

@@ -118,9 +118,7 @@ class RidePhotoViewSet(ModelViewSet):
def get_queryset(self): # type: ignore[override] def get_queryset(self): # type: ignore[override]
"""Get photos for the current ride with optimized queries.""" """Get photos for the current ride with optimized queries."""
queryset = RidePhoto.objects.select_related( queryset = RidePhoto.objects.select_related("ride", "ride__park", "ride__park__operator", "uploaded_by")
"ride", "ride__park", "ride__park__operator", "uploaded_by"
)
# If ride_pk is provided in URL kwargs, filter by ride # If ride_pk is provided in URL kwargs, filter by ride
ride_pk = self.kwargs.get("ride_pk") ride_pk = self.kwargs.get("ride_pk")
@@ -149,7 +147,7 @@ class RidePhotoViewSet(ModelViewSet):
try: try:
ride = Ride.objects.get(pk=ride_id) ride = Ride.objects.get(pk=ride_id)
except Ride.DoesNotExist: except Ride.DoesNotExist:
raise ValidationError("Ride not found") raise ValidationError("Ride not found") from None
try: try:
# Use the service to create the photo with proper business logic # Use the service to create the photo with proper business logic
@@ -169,17 +167,14 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error creating ride photo: {e}") logger.error(f"Error creating ride photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}") raise ValidationError(f"Failed to create photo: {str(e)}") from None
def perform_update(self, serializer): def perform_update(self, serializer):
"""Update ride photo with permission checking.""" """Update ride photo with permission checking."""
instance = self.get_object() instance = self.get_object()
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
self.request.user == instance.uploaded_by
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied("You can only edit your own photos or be an admin.") raise PermissionDenied("You can only edit your own photos or be an admin.")
# Handle primary photo logic using service # Handle primary photo logic using service
@@ -191,39 +186,31 @@ class RidePhotoViewSet(ModelViewSet):
del serializer.validated_data["is_primary"] del serializer.validated_data["is_primary"]
except Exception as e: except Exception as e:
logger.error(f"Error setting primary photo: {e}") logger.error(f"Error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}") raise ValidationError(f"Failed to set primary photo: {str(e)}") from None
def perform_destroy(self, instance): def perform_destroy(self, instance):
"""Delete ride photo with permission checking.""" """Delete ride photo with permission checking."""
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
self.request.user == instance.uploaded_by raise PermissionDenied("You can only delete your own photos or be an admin.")
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only delete your own photos or be an admin."
)
try: try:
# Delete from Cloudflare first if image exists # Delete from Cloudflare first if image exists
if instance.image: if instance.image:
try: try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService() service = CloudflareImagesService()
service.delete_image(instance.image) service.delete_image(instance.image)
logger.info( logger.info(f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to delete ride photo from Cloudflare: {str(e)}")
f"Failed to delete ride photo from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails # Continue with database deletion even if Cloudflare deletion fails
RideMediaService.delete_photo( RideMediaService.delete_photo(instance, deleted_by=self.request.user) # type: ignore
instance, deleted_by=self.request.user # type: ignore
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting ride photo: {e}") logger.error(f"Error deleting ride photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}") raise ValidationError(f"Failed to delete photo: {str(e)}") from None
@extend_schema( @extend_schema(
summary="Set photo as primary", summary="Set photo as primary",
@@ -242,13 +229,8 @@ class RidePhotoViewSet(ModelViewSet):
photo = self.get_object() photo = self.get_object()
# Check permissions - allow owner or staff # Check permissions - allow owner or staff
if not ( if not (request.user == photo.uploaded_by or getattr(request.user, "is_staff", False)):
request.user == photo.uploaded_by raise PermissionDenied("You can only modify your own photos or be an admin.")
or getattr(request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only modify your own photos or be an admin."
)
try: try:
success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo) success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo)
@@ -260,21 +242,21 @@ class RidePhotoViewSet(ModelViewSet):
return Response( return Response(
{ {
"message": "Photo set as primary successfully", "detail": "Photo set as primary successfully",
"photo": serializer.data, "photo": serializer.data,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
) )
else: else:
return Response( return Response(
{"error": "Failed to set primary photo"}, {"detail": "Failed to set primary photo"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
except Exception as e: except Exception as e:
logger.error(f"Error setting primary photo: {e}") logger.error(f"Error setting primary photo: {e}")
return Response( return Response(
{"error": f"Failed to set primary photo: {str(e)}"}, {"detail": f"Failed to set primary photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -305,7 +287,7 @@ class RidePhotoViewSet(ModelViewSet):
if photo_ids is None or approve is None: if photo_ids is None or approve is None:
return Response( return Response(
{"error": "Missing required fields: photo_ids and/or approve."}, {"detail": "Missing required fields: photo_ids and/or approve."},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -319,7 +301,7 @@ class RidePhotoViewSet(ModelViewSet):
return Response( return Response(
{ {
"message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", "detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"updated_count": updated_count, "updated_count": updated_count,
}, },
status=status.HTTP_200_OK, status=status.HTTP_200_OK,
@@ -328,7 +310,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error in bulk photo approval: {e}") logger.error(f"Error in bulk photo approval: {e}")
return Response( return Response(
{"error": f"Failed to update photos: {str(e)}"}, {"detail": f"Failed to update photos: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -352,7 +334,7 @@ class RidePhotoViewSet(ModelViewSet):
ride = Ride.objects.get(pk=ride_pk) ride = Ride.objects.get(pk=ride_pk)
except Ride.DoesNotExist: except Ride.DoesNotExist:
return Response( return Response(
{"error": "Ride not found."}, {"detail": "Ride not found."},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
@@ -363,16 +345,10 @@ class RidePhotoViewSet(ModelViewSet):
# Global stats across all rides # Global stats across all rides
stats = { stats = {
"total_photos": RidePhoto.objects.count(), "total_photos": RidePhoto.objects.count(),
"approved_photos": RidePhoto.objects.filter( "approved_photos": RidePhoto.objects.filter(is_approved=True).count(),
is_approved=True "pending_photos": RidePhoto.objects.filter(is_approved=False).count(),
).count(),
"pending_photos": RidePhoto.objects.filter(
is_approved=False
).count(),
"has_primary": False, # Not applicable for global stats "has_primary": False, # Not applicable for global stats
"recent_uploads": RidePhoto.objects.order_by("-created_at")[ "recent_uploads": RidePhoto.objects.order_by("-created_at")[:5].count(),
:5
].count(),
"by_type": {}, "by_type": {},
} }
@@ -382,7 +358,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error getting ride photo stats: {e}") logger.error(f"Error getting ride photo stats: {e}")
return Response( return Response(
{"error": f"Failed to get photo statistics: {str(e)}"}, {"detail": f"Failed to get photo statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -401,26 +377,23 @@ class RidePhotoViewSet(ModelViewSet):
def set_primary_legacy(self, request, id=None): def set_primary_legacy(self, request, id=None):
"""Legacy set primary action for backwards compatibility.""" """Legacy set primary action for backwards compatibility."""
photo = self.get_object() photo = self.get_object()
if not ( if not (request.user == photo.uploaded_by or request.user.has_perm("rides.change_ridephoto")):
request.user == photo.uploaded_by
or request.user.has_perm("rides.change_ridephoto")
):
return Response( return Response(
{"error": "You do not have permission to edit photos for this ride."}, {"detail": "You do not have permission to edit photos for this ride."},
status=status.HTTP_403_FORBIDDEN, status=status.HTTP_403_FORBIDDEN,
) )
try: try:
success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo) success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo)
if success: if success:
return Response({"message": "Photo set as primary successfully."}) return Response({"detail": "Photo set as primary successfully."})
else: else:
return Response( return Response(
{"error": "Failed to set primary photo"}, {"detail": "Failed to set primary photo"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
except Exception as e: except Exception as e:
logger.error(f"Error in set_primary_photo: {str(e)}", exc_info=True) logger.error(f"Error in set_primary_photo: {str(e)}", exc_info=True)
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
@extend_schema( @extend_schema(
summary="Save Cloudflare image as ride photo", summary="Save Cloudflare image as ride photo",
@@ -440,7 +413,7 @@ class RidePhotoViewSet(ModelViewSet):
ride_pk = self.kwargs.get("ride_pk") ride_pk = self.kwargs.get("ride_pk")
if not ride_pk: if not ride_pk:
return Response( return Response(
{"error": "Ride ID is required"}, {"detail": "Ride ID is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -448,14 +421,14 @@ class RidePhotoViewSet(ModelViewSet):
ride = Ride.objects.get(pk=ride_pk) ride = Ride.objects.get(pk=ride_pk)
except Ride.DoesNotExist: except Ride.DoesNotExist:
return Response( return Response(
{"error": "Ride not found"}, {"detail": "Ride not found"},
status=status.HTTP_404_NOT_FOUND, status=status.HTTP_404_NOT_FOUND,
) )
cloudflare_image_id = request.data.get("cloudflare_image_id") cloudflare_image_id = request.data.get("cloudflare_image_id")
if not cloudflare_image_id: if not cloudflare_image_id:
return Response( return Response(
{"error": "cloudflare_image_id is required"}, {"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -473,27 +446,25 @@ class RidePhotoViewSet(ModelViewSet):
if not image_data: if not image_data:
return Response( return Response(
{"error": "Image not found in Cloudflare"}, {"detail": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
# Try to find existing CloudflareImage record by cloudflare_id # Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None cloudflare_image = None
try: try:
cloudflare_image = CloudflareImage.objects.get( cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare # Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded' cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now() cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {}) cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure # Extract variants from nested result structure
cloudflare_image.variants = image_data.get( cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
'result', {}).get('variants', [])
cloudflare_image.cloudflare_metadata = image_data cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width') cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get('height') cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get('format', '') cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save() cloudflare_image.save()
except CloudflareImage.DoesNotExist: except CloudflareImage.DoesNotExist:
@@ -501,24 +472,23 @@ class RidePhotoViewSet(ModelViewSet):
cloudflare_image = CloudflareImage.objects.create( cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id, cloudflare_id=cloudflare_image_id,
user=request.user, user=request.user,
status='uploaded', status="uploaded",
upload_url='', # Not needed for uploaded images upload_url="", # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(), uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}), metadata=image_data.get("meta", {}),
# Extract variants from nested result structure # Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []), variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data, cloudflare_metadata=image_data,
width=image_data.get('width'), width=image_data.get("width"),
height=image_data.get('height'), height=image_data.get("height"),
format=image_data.get('format', ''), format=image_data.get("format", ""),
) )
except Exception as api_error: except Exception as api_error:
logger.error( logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response( return Response(
{"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, {"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
@@ -548,6 +518,6 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e: except Exception as e:
logger.error(f"Error saving ride photo: {e}") logger.error(f"Error saving ride photo: {e}")
return Response( return Response(
{"error": f"Failed to save photo: {str(e)}"}, {"detail": f"Failed to save photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )

View File

@@ -52,18 +52,14 @@ from apps.rides.models import Ride, RidePhoto
class RidePhotoOutputSerializer(serializers.ModelSerializer): class RidePhotoOutputSerializer(serializers.ModelSerializer):
"""Output serializer for ride photos with Cloudflare Images support.""" """Output serializer for ride photos with Cloudflare Images support."""
uploaded_by_username = serializers.CharField( uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
source="uploaded_by.username", read_only=True
)
file_size = serializers.SerializerMethodField() file_size = serializers.SerializerMethodField()
dimensions = serializers.SerializerMethodField() dimensions = serializers.SerializerMethodField()
image_url = serializers.SerializerMethodField() image_url = serializers.SerializerMethodField()
image_variants = serializers.SerializerMethodField() image_variants = serializers.SerializerMethodField()
@extend_schema_field( @extend_schema_field(serializers.IntegerField(allow_null=True, help_text="File size in bytes"))
serializers.IntegerField(allow_null=True, help_text="File size in bytes")
)
def get_file_size(self, obj): def get_file_size(self, obj):
"""Get file size in bytes.""" """Get file size in bytes."""
return obj.file_size return obj.file_size
@@ -81,11 +77,7 @@ class RidePhotoOutputSerializer(serializers.ModelSerializer):
"""Get image dimensions as [width, height].""" """Get image dimensions as [width, height]."""
return obj.dimensions return obj.dimensions
@extend_schema_field( @extend_schema_field(serializers.URLField(help_text="Full URL to the Cloudflare Images asset", allow_null=True))
serializers.URLField(
help_text="Full URL to the Cloudflare Images asset", allow_null=True
)
)
def get_image_url(self, obj): def get_image_url(self, obj):
"""Get the full Cloudflare Images URL.""" """Get the full Cloudflare Images URL."""
if obj.image: if obj.image:
@@ -186,9 +178,7 @@ class RidePhotoUpdateInputSerializer(serializers.ModelSerializer):
class RidePhotoListOutputSerializer(serializers.ModelSerializer): class RidePhotoListOutputSerializer(serializers.ModelSerializer):
"""Simplified output serializer for ride photo lists.""" """Simplified output serializer for ride photo lists."""
uploaded_by_username = serializers.CharField( uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
source="uploaded_by.username", read_only=True
)
class Meta: class Meta:
model = RidePhoto model = RidePhoto
@@ -208,12 +198,8 @@ class RidePhotoListOutputSerializer(serializers.ModelSerializer):
class RidePhotoApprovalInputSerializer(serializers.Serializer): class RidePhotoApprovalInputSerializer(serializers.Serializer):
"""Input serializer for photo approval operations.""" """Input serializer for photo approval operations."""
photo_ids = serializers.ListField( photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve")
child=serializers.IntegerField(), help_text="List of photo IDs to approve" approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos")
)
approve = serializers.BooleanField(
default=True, help_text="Whether to approve (True) or reject (False) the photos"
)
class RidePhotoStatsOutputSerializer(serializers.Serializer): class RidePhotoStatsOutputSerializer(serializers.Serializer):
@@ -224,9 +210,7 @@ class RidePhotoStatsOutputSerializer(serializers.Serializer):
pending_photos = serializers.IntegerField() pending_photos = serializers.IntegerField()
has_primary = serializers.BooleanField() has_primary = serializers.BooleanField()
recent_uploads = serializers.IntegerField() recent_uploads = serializers.IntegerField()
by_type = serializers.DictField( by_type = serializers.DictField(child=serializers.IntegerField(), help_text="Photo counts by type")
child=serializers.IntegerField(), help_text="Photo counts by type"
)
class RidePhotoTypeFilterSerializer(serializers.Serializer): class RidePhotoTypeFilterSerializer(serializers.Serializer):
@@ -292,8 +276,12 @@ class HybridRideSerializer(serializers.ModelSerializer):
ride_model_name = serializers.CharField(source="ride_model.name", read_only=True, allow_null=True) ride_model_name = serializers.CharField(source="ride_model.name", read_only=True, allow_null=True)
ride_model_slug = serializers.CharField(source="ride_model.slug", read_only=True, allow_null=True) ride_model_slug = serializers.CharField(source="ride_model.slug", read_only=True, allow_null=True)
ride_model_category = serializers.CharField(source="ride_model.category", read_only=True, allow_null=True) ride_model_category = serializers.CharField(source="ride_model.category", read_only=True, allow_null=True)
ride_model_manufacturer_name = serializers.CharField(source="ride_model.manufacturer.name", read_only=True, allow_null=True) ride_model_manufacturer_name = serializers.CharField(
ride_model_manufacturer_slug = serializers.CharField(source="ride_model.manufacturer.slug", read_only=True, allow_null=True) source="ride_model.manufacturer.name", read_only=True, allow_null=True
)
ride_model_manufacturer_slug = serializers.CharField(
source="ride_model.manufacturer.slug", read_only=True, allow_null=True
)
# Roller coaster stats fields # Roller coaster stats fields
coaster_height_ft = serializers.SerializerMethodField() coaster_height_ft = serializers.SerializerMethodField()
@@ -323,7 +311,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_park_city(self, obj): def get_park_city(self, obj):
"""Get city from park location.""" """Get city from park location."""
try: try:
if obj.park and hasattr(obj.park, 'location') and obj.park.location: if obj.park and hasattr(obj.park, "location") and obj.park.location:
return obj.park.location.city return obj.park.location.city
return None return None
except AttributeError: except AttributeError:
@@ -333,7 +321,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_park_state(self, obj): def get_park_state(self, obj):
"""Get state from park location.""" """Get state from park location."""
try: try:
if obj.park and hasattr(obj.park, 'location') and obj.park.location: if obj.park and hasattr(obj.park, "location") and obj.park.location:
return obj.park.location.state return obj.park.location.state
return None return None
except AttributeError: except AttributeError:
@@ -343,7 +331,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_park_country(self, obj): def get_park_country(self, obj):
"""Get country from park location.""" """Get country from park location."""
try: try:
if obj.park and hasattr(obj.park, 'location') and obj.park.location: if obj.park and hasattr(obj.park, "location") and obj.park.location:
return obj.park.location.country return obj.park.location.country
return None return None
except AttributeError: except AttributeError:
@@ -353,7 +341,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_height_ft(self, obj): def get_coaster_height_ft(self, obj):
"""Get roller coaster height.""" """Get roller coaster height."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.height_ft) if obj.coaster_stats.height_ft else None return float(obj.coaster_stats.height_ft) if obj.coaster_stats.height_ft else None
return None return None
except (AttributeError, TypeError): except (AttributeError, TypeError):
@@ -363,7 +351,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_length_ft(self, obj): def get_coaster_length_ft(self, obj):
"""Get roller coaster length.""" """Get roller coaster length."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.length_ft) if obj.coaster_stats.length_ft else None return float(obj.coaster_stats.length_ft) if obj.coaster_stats.length_ft else None
return None return None
except (AttributeError, TypeError): except (AttributeError, TypeError):
@@ -373,7 +361,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_speed_mph(self, obj): def get_coaster_speed_mph(self, obj):
"""Get roller coaster speed.""" """Get roller coaster speed."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.speed_mph) if obj.coaster_stats.speed_mph else None return float(obj.coaster_stats.speed_mph) if obj.coaster_stats.speed_mph else None
return None return None
except (AttributeError, TypeError): except (AttributeError, TypeError):
@@ -383,7 +371,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_inversions(self, obj): def get_coaster_inversions(self, obj):
"""Get roller coaster inversions.""" """Get roller coaster inversions."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.inversions return obj.coaster_stats.inversions
return None return None
except AttributeError: except AttributeError:
@@ -393,7 +381,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_ride_time_seconds(self, obj): def get_coaster_ride_time_seconds(self, obj):
"""Get roller coaster ride time.""" """Get roller coaster ride time."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.ride_time_seconds return obj.coaster_stats.ride_time_seconds
return None return None
except AttributeError: except AttributeError:
@@ -403,7 +391,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_track_type(self, obj): def get_coaster_track_type(self, obj):
"""Get roller coaster track type.""" """Get roller coaster track type."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.track_type return obj.coaster_stats.track_type
return None return None
except AttributeError: except AttributeError:
@@ -413,7 +401,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_track_material(self, obj): def get_coaster_track_material(self, obj):
"""Get roller coaster track material.""" """Get roller coaster track material."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.track_material return obj.coaster_stats.track_material
return None return None
except AttributeError: except AttributeError:
@@ -423,7 +411,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_roller_coaster_type(self, obj): def get_coaster_roller_coaster_type(self, obj):
"""Get roller coaster type.""" """Get roller coaster type."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.roller_coaster_type return obj.coaster_stats.roller_coaster_type
return None return None
except AttributeError: except AttributeError:
@@ -433,7 +421,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_max_drop_height_ft(self, obj): def get_coaster_max_drop_height_ft(self, obj):
"""Get roller coaster max drop height.""" """Get roller coaster max drop height."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.max_drop_height_ft) if obj.coaster_stats.max_drop_height_ft else None return float(obj.coaster_stats.max_drop_height_ft) if obj.coaster_stats.max_drop_height_ft else None
return None return None
except (AttributeError, TypeError): except (AttributeError, TypeError):
@@ -443,7 +431,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_propulsion_system(self, obj): def get_coaster_propulsion_system(self, obj):
"""Get roller coaster propulsion system.""" """Get roller coaster propulsion system."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.propulsion_system return obj.coaster_stats.propulsion_system
return None return None
except AttributeError: except AttributeError:
@@ -453,7 +441,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_train_style(self, obj): def get_coaster_train_style(self, obj):
"""Get roller coaster train style.""" """Get roller coaster train style."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.train_style return obj.coaster_stats.train_style
return None return None
except AttributeError: except AttributeError:
@@ -463,7 +451,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_trains_count(self, obj): def get_coaster_trains_count(self, obj):
"""Get roller coaster trains count.""" """Get roller coaster trains count."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.trains_count return obj.coaster_stats.trains_count
return None return None
except AttributeError: except AttributeError:
@@ -473,7 +461,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_cars_per_train(self, obj): def get_coaster_cars_per_train(self, obj):
"""Get roller coaster cars per train.""" """Get roller coaster cars per train."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.cars_per_train return obj.coaster_stats.cars_per_train
return None return None
except AttributeError: except AttributeError:
@@ -483,7 +471,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_seats_per_car(self, obj): def get_coaster_seats_per_car(self, obj):
"""Get roller coaster seats per car.""" """Get roller coaster seats per car."""
try: try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats: if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.seats_per_car return obj.coaster_stats.seats_per_car
return None return None
except AttributeError: except AttributeError:
@@ -514,44 +502,37 @@ class HybridRideSerializer(serializers.ModelSerializer):
"category", "category",
"status", "status",
"post_closing_status", "post_closing_status",
# Dates and computed fields # Dates and computed fields
"opening_date", "opening_date",
"closing_date", "closing_date",
"status_since", "status_since",
"opening_year", "opening_year",
# Park fields # Park fields
"park_name", "park_name",
"park_slug", "park_slug",
"park_city", "park_city",
"park_state", "park_state",
"park_country", "park_country",
# Park area fields # Park area fields
"park_area_name", "park_area_name",
"park_area_slug", "park_area_slug",
# Company fields # Company fields
"manufacturer_name", "manufacturer_name",
"manufacturer_slug", "manufacturer_slug",
"designer_name", "designer_name",
"designer_slug", "designer_slug",
# Ride model fields # Ride model fields
"ride_model_name", "ride_model_name",
"ride_model_slug", "ride_model_slug",
"ride_model_category", "ride_model_category",
"ride_model_manufacturer_name", "ride_model_manufacturer_name",
"ride_model_manufacturer_slug", "ride_model_manufacturer_slug",
# Ride specifications # Ride specifications
"min_height_in", "min_height_in",
"max_height_in", "max_height_in",
"capacity_per_hour", "capacity_per_hour",
"ride_duration_seconds", "ride_duration_seconds",
"average_rating", "average_rating",
# Roller coaster stats # Roller coaster stats
"coaster_height_ft", "coaster_height_ft",
"coaster_length_ft", "coaster_length_ft",
@@ -567,18 +548,14 @@ class HybridRideSerializer(serializers.ModelSerializer):
"coaster_trains_count", "coaster_trains_count",
"coaster_cars_per_train", "coaster_cars_per_train",
"coaster_seats_per_car", "coaster_seats_per_car",
# Images # Images
"banner_image_url", "banner_image_url",
"card_image_url", "card_image_url",
# URLs # URLs
"url", "url",
"park_url", "park_url",
# Computed fields for filtering # Computed fields for filtering
"search_text", "search_text",
# Metadata # Metadata
"created_at", "created_at",
"updated_at", "updated_at",

View File

@@ -35,11 +35,9 @@ app_name = "api_v1_rides"
urlpatterns = [ urlpatterns = [
# Core list/create endpoints # Core list/create endpoints
path("", RideListCreateAPIView.as_view(), name="ride-list-create"), path("", RideListCreateAPIView.as_view(), name="ride-list-create"),
# Hybrid filtering endpoints # Hybrid filtering endpoints
path("hybrid/", HybridRideAPIView.as_view(), name="ride-hybrid-filtering"), path("hybrid/", HybridRideAPIView.as_view(), name="ride-hybrid-filtering"),
path("hybrid/filter-metadata/", RideFilterMetadataAPIView.as_view(), name="ride-hybrid-filter-metadata"), path("hybrid/filter-metadata/", RideFilterMetadataAPIView.as_view(), name="ride-hybrid-filter-metadata"),
# Filter options # Filter options
path("filter-options/", FilterOptionsAPIView.as_view(), name="ride-filter-options"), path("filter-options/", FilterOptionsAPIView.as_view(), name="ride-filter-options"),
# Autocomplete / suggestion endpoints # Autocomplete / suggestion endpoints
@@ -61,7 +59,6 @@ urlpatterns = [
# Manufacturer and Designer endpoints # Manufacturer and Designer endpoints
path("manufacturers/", ManufacturerListAPIView.as_view(), name="manufacturer-list"), path("manufacturers/", ManufacturerListAPIView.as_view(), name="manufacturer-list"),
path("designers/", DesignerListAPIView.as_view(), name="designer-list"), path("designers/", DesignerListAPIView.as_view(), name="designer-list"),
# Ride model management endpoints - nested under rides/manufacturers # Ride model management endpoints - nested under rides/manufacturers
path( path(
"manufacturers/<slug:manufacturer_slug>/", "manufacturers/<slug:manufacturer_slug>/",

View File

@@ -28,6 +28,7 @@ import logging
from typing import Any from typing import Any
from django.db import models from django.db import models
from django.db.models import Count
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import permissions, status from rest_framework import permissions, status
@@ -333,9 +334,7 @@ class RideListCreateAPIView(APIView):
paginator = StandardResultsSetPagination() paginator = StandardResultsSetPagination()
page = paginator.paginate_queryset(qs, request) page = paginator.paginate_queryset(qs, request)
serializer = RideListOutputSerializer( serializer = RideListOutputSerializer(page, many=True, context={"request": request})
page, many=True, context={"request": request}
)
return paginator.get_paginated_response(serializer.data) return paginator.get_paginated_response(serializer.data)
def _apply_filters(self, qs, params): def _apply_filters(self, qs, params):
@@ -567,9 +566,9 @@ class RideListCreateAPIView(APIView):
if ordering in valid_orderings: if ordering in valid_orderings:
if ordering in ["height_ft", "-height_ft", "speed_mph", "-speed_mph"]: if ordering in ["height_ft", "-height_ft", "speed_mph", "-speed_mph"]:
# For coaster stats ordering, we need to join and order by the stats # For coaster stats ordering, we need to join and order by the stats
ordering_field = ordering.replace( ordering_field = ordering.replace("height_ft", "coaster_stats__height_ft").replace(
"height_ft", "coaster_stats__height_ft" "speed_mph", "coaster_stats__speed_mph"
).replace("speed_mph", "coaster_stats__speed_mph") )
qs = qs.order_by(ordering_field) qs = qs.order_by(ordering_field)
else: else:
qs = qs.order_by(ordering) qs = qs.order_by(ordering)
@@ -602,7 +601,7 @@ class RideListCreateAPIView(APIView):
try: try:
park = Park.objects.get(id=validated["park_id"]) # type: ignore park = Park.objects.get(id=validated["park_id"]) # type: ignore
except Park.DoesNotExist: # type: ignore except Park.DoesNotExist: # type: ignore
raise NotFound("Park not found") raise NotFound("Park not found") from None
ride = Ride.objects.create( # type: ignore ride = Ride.objects.create( # type: ignore
name=validated["name"], name=validated["name"],
@@ -658,7 +657,7 @@ class RideDetailAPIView(APIView):
try: try:
return Ride.objects.select_related("park").get(pk=pk) # type: ignore return Ride.objects.select_related("park").get(pk=pk) # type: ignore
except Ride.DoesNotExist: # type: ignore except Ride.DoesNotExist: # type: ignore
raise NotFound("Ride not found") raise NotFound("Ride not found") from None
@cache_api_response(timeout=1800, key_prefix="ride_detail") @cache_api_response(timeout=1800, key_prefix="ride_detail")
def get(self, request: Request, pk: int) -> Response: def get(self, request: Request, pk: int) -> Response:
@@ -672,9 +671,7 @@ class RideDetailAPIView(APIView):
serializer_in.is_valid(raise_exception=True) serializer_in.is_valid(raise_exception=True)
if not MODELS_AVAILABLE: if not MODELS_AVAILABLE:
return Response( return Response(
{ {"detail": "Ride update is not available because domain models are not imported."},
"detail": "Ride update is not available because domain models are not imported."
},
status=status.HTTP_501_NOT_IMPLEMENTED, status=status.HTTP_501_NOT_IMPLEMENTED,
) )
@@ -690,7 +687,7 @@ class RideDetailAPIView(APIView):
# Use the move_to_park method for proper handling # Use the move_to_park method for proper handling
park_change_info = ride.move_to_park(new_park) park_change_info = ride.move_to_park(new_park)
except Park.DoesNotExist: # type: ignore except Park.DoesNotExist: # type: ignore
raise NotFound("Target park not found") raise NotFound("Target park not found") from None
# Apply other field updates # Apply other field updates
for key, value in validated_data.items(): for key, value in validated_data.items():
@@ -715,9 +712,7 @@ class RideDetailAPIView(APIView):
def delete(self, request: Request, pk: int) -> Response: def delete(self, request: Request, pk: int) -> Response:
if not MODELS_AVAILABLE: if not MODELS_AVAILABLE:
return Response( return Response(
{ {"detail": "Ride delete is not available because domain models are not imported."},
"detail": "Ride delete is not available because domain models are not imported."
},
status=status.HTTP_501_NOT_IMPLEMENTED, status=status.HTTP_501_NOT_IMPLEMENTED,
) )
ride = self._get_ride_or_404(pk) ride = self._get_ride_or_404(pk)
@@ -1491,16 +1486,12 @@ class FilterOptionsAPIView(APIView):
# Get manufacturers (companies with MANUFACTURER role) # Get manufacturers (companies with MANUFACTURER role)
manufacturers = list( manufacturers = list(
Company.objects.filter(roles__contains=["MANUFACTURER"]) Company.objects.filter(roles__contains=["MANUFACTURER"]).values("id", "name", "slug").order_by("name")
.values("id", "name", "slug")
.order_by("name")
) )
# Get designers (companies with DESIGNER role) # Get designers (companies with DESIGNER role)
designers = list( designers = list(
Company.objects.filter(roles__contains=["DESIGNER"]) Company.objects.filter(roles__contains=["DESIGNER"]).values("id", "name", "slug").order_by("name")
.values("id", "name", "slug")
.order_by("name")
) )
# Get ride models data from database # Get ride models data from database
@@ -1722,11 +1713,7 @@ class FilterOptionsAPIView(APIView):
# --- Company search (autocomplete) ----------------------------------------- # --- Company search (autocomplete) -----------------------------------------
@extend_schema( @extend_schema(
summary="Search companies (manufacturers/designers) for autocomplete", summary="Search companies (manufacturers/designers) for autocomplete",
parameters=[ parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)],
OpenApiParameter(
name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
)
],
responses={200: OpenApiTypes.OBJECT}, responses={200: OpenApiTypes.OBJECT},
tags=["Rides"], tags=["Rides"],
) )
@@ -1753,20 +1740,14 @@ class CompanySearchAPIView(APIView):
) )
qs = Company.objects.filter(name__icontains=q)[:20] # type: ignore qs = Company.objects.filter(name__icontains=q)[:20] # type: ignore
results = [ results = [{"id": c.id, "name": c.name, "slug": getattr(c, "slug", "")} for c in qs]
{"id": c.id, "name": c.name, "slug": getattr(c, "slug", "")} for c in qs
]
return Response(results) return Response(results)
# --- Ride model search (autocomplete) -------------------------------------- # --- Ride model search (autocomplete) --------------------------------------
@extend_schema( @extend_schema(
summary="Search ride models for autocomplete", summary="Search ride models for autocomplete",
parameters=[ parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)],
OpenApiParameter(
name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
)
],
tags=["Rides"], tags=["Rides"],
) )
class RideModelSearchAPIView(APIView): class RideModelSearchAPIView(APIView):
@@ -1795,21 +1776,14 @@ class RideModelSearchAPIView(APIView):
) )
qs = RideModel.objects.filter(name__icontains=q)[:20] # type: ignore qs = RideModel.objects.filter(name__icontains=q)[:20] # type: ignore
results = [ results = [{"id": m.id, "name": m.name, "category": getattr(m, "category", "")} for m in qs]
{"id": m.id, "name": m.name, "category": getattr(m, "category", "")}
for m in qs
]
return Response(results) return Response(results)
# --- Search suggestions ----------------------------------------------------- # --- Search suggestions -----------------------------------------------------
@extend_schema( @extend_schema(
summary="Search suggestions for ride search box", summary="Search suggestions for ride search box",
parameters=[ parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)],
OpenApiParameter(
name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
)
],
tags=["Rides"], tags=["Rides"],
) )
class RideSearchSuggestionsAPIView(APIView): class RideSearchSuggestionsAPIView(APIView):
@@ -1827,9 +1801,7 @@ class RideSearchSuggestionsAPIView(APIView):
# Very small suggestion implementation: look in ride names if available # Very small suggestion implementation: look in ride names if available
if MODELS_AVAILABLE and Ride is not None: if MODELS_AVAILABLE and Ride is not None:
qs = Ride.objects.filter(name__icontains=q).values_list("name", flat=True)[ qs = Ride.objects.filter(name__icontains=q).values_list("name", flat=True)[:10] # type: ignore
:10
] # type: ignore
return Response([{"suggestion": name} for name in qs]) return Response([{"suggestion": name} for name in qs])
# Fallback suggestions # Fallback suggestions
@@ -1862,7 +1834,7 @@ class RideImageSettingsAPIView(APIView):
try: try:
return Ride.objects.get(pk=pk) # type: ignore return Ride.objects.get(pk=pk) # type: ignore
except Ride.DoesNotExist: # type: ignore except Ride.DoesNotExist: # type: ignore
raise NotFound("Ride not found") raise NotFound("Ride not found") from None
def patch(self, request: Request, pk: int) -> Response: def patch(self, request: Request, pk: int) -> Response:
"""Set banner and card images for the ride.""" """Set banner and card images for the ride."""
@@ -1878,9 +1850,7 @@ class RideImageSettingsAPIView(APIView):
ride.save() ride.save()
# Return updated ride data # Return updated ride data
output_serializer = RideDetailOutputSerializer( output_serializer = RideDetailOutputSerializer(ride, context={"request": request})
ride, context={"request": request}
)
return Response(output_serializer.data) return Response(output_serializer.data)
@@ -1902,12 +1872,8 @@ class RideImageSettingsAPIView(APIView):
OpenApiTypes.STR, OpenApiTypes.STR,
description="Filter by ride status (comma-separated for multiple)", description="Filter by ride status (comma-separated for multiple)",
), ),
OpenApiParameter( OpenApiParameter("park_slug", OpenApiTypes.STR, description="Filter by park slug"),
"park_slug", OpenApiTypes.STR, description="Filter by park slug" OpenApiParameter("park_id", OpenApiTypes.INT, description="Filter by park ID"),
),
OpenApiParameter(
"park_id", OpenApiTypes.INT, description="Filter by park ID"
),
OpenApiParameter( OpenApiParameter(
"manufacturer", "manufacturer",
OpenApiTypes.STR, OpenApiTypes.STR,
@@ -1923,18 +1889,10 @@ class RideImageSettingsAPIView(APIView):
OpenApiTypes.STR, OpenApiTypes.STR,
description="Filter by ride model slug (comma-separated for multiple)", description="Filter by ride model slug (comma-separated for multiple)",
), ),
OpenApiParameter( OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"),
"opening_year_min", OpenApiTypes.INT, description="Minimum opening year" OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"),
), OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"),
OpenApiParameter( OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"),
"opening_year_max", OpenApiTypes.INT, description="Maximum opening year"
),
OpenApiParameter(
"rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"
),
OpenApiParameter(
"rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"
),
OpenApiParameter( OpenApiParameter(
"height_requirement_min", "height_requirement_min",
OpenApiTypes.INT, OpenApiTypes.INT,
@@ -1945,12 +1903,8 @@ class RideImageSettingsAPIView(APIView):
OpenApiTypes.INT, OpenApiTypes.INT,
description="Maximum height requirement in inches", description="Maximum height requirement in inches",
), ),
OpenApiParameter( OpenApiParameter("capacity_min", OpenApiTypes.INT, description="Minimum hourly capacity"),
"capacity_min", OpenApiTypes.INT, description="Minimum hourly capacity" OpenApiParameter("capacity_max", OpenApiTypes.INT, description="Maximum hourly capacity"),
),
OpenApiParameter(
"capacity_max", OpenApiTypes.INT, description="Maximum hourly capacity"
),
OpenApiParameter( OpenApiParameter(
"roller_coaster_type", "roller_coaster_type",
OpenApiTypes.STR, OpenApiTypes.STR,
@@ -2022,9 +1976,7 @@ class RideImageSettingsAPIView(APIView):
"properties": { "properties": {
"rides": { "rides": {
"type": "array", "type": "array",
"items": { "items": {"$ref": "#/components/schemas/HybridRideSerializer"},
"$ref": "#/components/schemas/HybridRideSerializer"
},
}, },
"total_count": {"type": "integer"}, "total_count": {"type": "integer"},
"strategy": { "strategy": {
@@ -2084,7 +2036,7 @@ class HybridRideAPIView(APIView):
data = smart_ride_loader.get_progressive_load(offset, filters) data = smart_ride_loader.get_progressive_load(offset, filters)
except ValueError: except ValueError:
return Response( return Response(
{"error": "Invalid offset parameter"}, {"detail": "Invalid offset parameter"},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
else: else:
@@ -2109,7 +2061,7 @@ class HybridRideAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in HybridRideAPIView: {e}") logger.error(f"Error in HybridRideAPIView: {e}")
return Response( return Response(
{"error": "Internal server error"}, {"detail": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -2158,7 +2110,7 @@ class HybridRideAPIView(APIView):
for param in int_params: for param in int_params:
value = query_params.get(param) value = query_params.get(param)
if value: if value:
try: try: # noqa: SIM105
filters[param] = int(value) filters[param] = int(value)
except ValueError: except ValueError:
pass # Skip invalid integer values pass # Skip invalid integer values
@@ -2175,7 +2127,7 @@ class HybridRideAPIView(APIView):
for param in float_params: for param in float_params:
value = query_params.get(param) value = query_params.get(param)
if value: if value:
try: try: # noqa: SIM105
filters[param] = float(value) filters[param] = float(value)
except ValueError: except ValueError:
pass # Skip invalid float values pass # Skip invalid float values
@@ -2408,7 +2360,7 @@ class RideFilterMetadataAPIView(APIView):
except Exception as e: except Exception as e:
logger.error(f"Error in RideFilterMetadataAPIView: {e}") logger.error(f"Error in RideFilterMetadataAPIView: {e}")
return Response( return Response(
{"error": "Internal server error"}, {"detail": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@@ -2417,18 +2369,18 @@ class RideFilterMetadataAPIView(APIView):
# Reuse the same filter extraction logic # Reuse the same filter extraction logic
view = HybridRideAPIView() view = HybridRideAPIView()
return view._extract_filters(query_params) return view._extract_filters(query_params)
# === MANUFACTURER & DESIGNER LISTS === # === MANUFACTURER & DESIGNER LISTS ===
class BaseCompanyListAPIView(APIView): class BaseCompanyListAPIView(APIView):
permission_classes = [permissions.AllowAny] permission_classes = [permissions.AllowAny]
role = None role = None
def get(self, request: Request) -> Response: def get(self, request: Request) -> Response:
if not MODELS_AVAILABLE: if not MODELS_AVAILABLE:
return Response( return Response({"detail": "Models not available"}, status=status.HTTP_501_NOT_IMPLEMENTED)
{"detail": "Models not available"},
status=status.HTTP_501_NOT_IMPLEMENTED
)
companies = ( companies = (
Company.objects.filter(roles__contains=[self.role]) Company.objects.filter(roles__contains=[self.role])
@@ -2448,10 +2400,8 @@ class BaseCompanyListAPIView(APIView):
for c in companies for c in companies
] ]
return Response({ return Response({"results": data, "count": len(data)})
"results": data,
"count": len(data)
})
@extend_schema( @extend_schema(
summary="List manufacturers", summary="List manufacturers",
@@ -2462,6 +2412,7 @@ class BaseCompanyListAPIView(APIView):
class ManufacturerListAPIView(BaseCompanyListAPIView): class ManufacturerListAPIView(BaseCompanyListAPIView):
role = "MANUFACTURER" role = "MANUFACTURER"
@extend_schema( @extend_schema(
summary="List designers", summary="List designers",
description="List all companies with DESIGNER role.", description="List all companies with DESIGNER role.",

View File

@@ -49,5 +49,4 @@ __all__ = (
"UserProfileCreateInputSerializer", "UserProfileCreateInputSerializer",
"UserProfileUpdateInputSerializer", "UserProfileUpdateInputSerializer",
"UserProfileOutputSerializer", "UserProfileOutputSerializer",
) )

View File

@@ -90,7 +90,6 @@ _ACCOUNTS_SYMBOLS: list[str] = [
"UserProfileOutputSerializer", "UserProfileOutputSerializer",
"UserProfileCreateInputSerializer", "UserProfileCreateInputSerializer",
"UserProfileUpdateInputSerializer", "UserProfileUpdateInputSerializer",
"UserOutputSerializer", "UserOutputSerializer",
"LoginInputSerializer", "LoginInputSerializer",
"LoginOutputSerializer", "LoginOutputSerializer",

View File

@@ -187,6 +187,7 @@ class PublicUserSerializer(serializers.ModelSerializer):
Public user serializer for viewing other users' profiles. Public user serializer for viewing other users' profiles.
Only exposes public information. Only exposes public information.
""" """
profile = UserProfileSerializer(read_only=True) profile = UserProfileSerializer(read_only=True)
class Meta: class Meta:
@@ -228,37 +229,21 @@ class UserPreferencesSerializer(serializers.Serializer):
"""Serializer for user preferences and settings.""" """Serializer for user preferences and settings."""
theme_preference = RichChoiceFieldSerializer( theme_preference = RichChoiceFieldSerializer(
choice_group="theme_preferences", choice_group="theme_preferences", domain="accounts", help_text="User's theme preference"
domain="accounts",
help_text="User's theme preference"
)
email_notifications = serializers.BooleanField(
default=True, help_text="Whether to receive email notifications"
)
push_notifications = serializers.BooleanField(
default=False, help_text="Whether to receive push notifications"
) )
email_notifications = serializers.BooleanField(default=True, help_text="Whether to receive email notifications")
push_notifications = serializers.BooleanField(default=False, help_text="Whether to receive push notifications")
privacy_level = RichChoiceFieldSerializer( privacy_level = RichChoiceFieldSerializer(
choice_group="privacy_levels", choice_group="privacy_levels",
domain="accounts", domain="accounts",
default="public", default="public",
help_text="Profile visibility level", help_text="Profile visibility level",
) )
show_email = serializers.BooleanField( show_email = serializers.BooleanField(default=False, help_text="Whether to show email on profile")
default=False, help_text="Whether to show email on profile" show_real_name = serializers.BooleanField(default=True, help_text="Whether to show real name on profile")
) show_statistics = serializers.BooleanField(default=True, help_text="Whether to show ride statistics on profile")
show_real_name = serializers.BooleanField( allow_friend_requests = serializers.BooleanField(default=True, help_text="Whether to allow friend requests")
default=True, help_text="Whether to show real name on profile" allow_messages = serializers.BooleanField(default=True, help_text="Whether to allow direct messages")
)
show_statistics = serializers.BooleanField(
default=True, help_text="Whether to show ride statistics on profile"
)
allow_friend_requests = serializers.BooleanField(
default=True, help_text="Whether to allow friend requests"
)
allow_messages = serializers.BooleanField(
default=True, help_text="Whether to allow direct messages"
)
# === NOTIFICATION SETTINGS SERIALIZERS === # === NOTIFICATION SETTINGS SERIALIZERS ===
@@ -363,39 +348,17 @@ class PrivacySettingsSerializer(serializers.Serializer):
default="public", default="public",
help_text="Overall profile visibility", help_text="Overall profile visibility",
) )
show_email = serializers.BooleanField( show_email = serializers.BooleanField(default=False, help_text="Show email address on profile")
default=False, help_text="Show email address on profile" show_real_name = serializers.BooleanField(default=True, help_text="Show real name on profile")
) show_join_date = serializers.BooleanField(default=True, help_text="Show join date on profile")
show_real_name = serializers.BooleanField( show_statistics = serializers.BooleanField(default=True, help_text="Show ride statistics on profile")
default=True, help_text="Show real name on profile" show_reviews = serializers.BooleanField(default=True, help_text="Show reviews on profile")
) show_photos = serializers.BooleanField(default=True, help_text="Show uploaded photos on profile")
show_join_date = serializers.BooleanField( show_top_lists = serializers.BooleanField(default=True, help_text="Show top lists on profile")
default=True, help_text="Show join date on profile" allow_friend_requests = serializers.BooleanField(default=True, help_text="Allow others to send friend requests")
) allow_messages = serializers.BooleanField(default=True, help_text="Allow others to send direct messages")
show_statistics = serializers.BooleanField( allow_profile_comments = serializers.BooleanField(default=False, help_text="Allow others to comment on profile")
default=True, help_text="Show ride statistics on profile" search_visibility = serializers.BooleanField(default=True, help_text="Allow profile to appear in search results")
)
show_reviews = serializers.BooleanField(
default=True, help_text="Show reviews on profile"
)
show_photos = serializers.BooleanField(
default=True, help_text="Show uploaded photos on profile"
)
show_top_lists = serializers.BooleanField(
default=True, help_text="Show top lists on profile"
)
allow_friend_requests = serializers.BooleanField(
default=True, help_text="Allow others to send friend requests"
)
allow_messages = serializers.BooleanField(
default=True, help_text="Allow others to send direct messages"
)
allow_profile_comments = serializers.BooleanField(
default=False, help_text="Allow others to comment on profile"
)
search_visibility = serializers.BooleanField(
default=True, help_text="Allow profile to appear in search results"
)
activity_visibility = RichChoiceFieldSerializer( activity_visibility = RichChoiceFieldSerializer(
choice_group="privacy_levels", choice_group="privacy_levels",
domain="accounts", domain="accounts",
@@ -431,21 +394,13 @@ class SecuritySettingsSerializer(serializers.Serializer):
two_factor_enabled = serializers.BooleanField( two_factor_enabled = serializers.BooleanField(
default=False, help_text="Whether two-factor authentication is enabled" default=False, help_text="Whether two-factor authentication is enabled"
) )
login_notifications = serializers.BooleanField( login_notifications = serializers.BooleanField(default=True, help_text="Send notifications for new logins")
default=True, help_text="Send notifications for new logins"
)
session_timeout = serializers.IntegerField( session_timeout = serializers.IntegerField(
default=30, min_value=5, max_value=180, help_text="Session timeout in days" default=30, min_value=5, max_value=180, help_text="Session timeout in days"
) )
require_password_change = serializers.BooleanField( require_password_change = serializers.BooleanField(default=False, help_text="Whether password change is required")
default=False, help_text="Whether password change is required" last_password_change = serializers.DateTimeField(read_only=True, help_text="When password was last changed")
) active_sessions = serializers.IntegerField(read_only=True, help_text="Number of active sessions")
last_password_change = serializers.DateTimeField(
read_only=True, help_text="When password was last changed"
)
active_sessions = serializers.IntegerField(
read_only=True, help_text="Number of active sessions"
)
login_history_retention = serializers.IntegerField( login_history_retention = serializers.IntegerField(
default=90, default=90,
min_value=30, min_value=30,
@@ -699,7 +654,7 @@ class ThemePreferenceSerializer(serializers.ModelSerializer):
"id": 1, "id": 1,
"notification_type": "submission_approved", "notification_type": "submission_approved",
"title": "Your submission has been approved!", "title": "Your submission has been approved!",
"message": "Your photo submission for Cedar Point has been approved and is now live on the site.", "detail": "Your photo submission for Cedar Point has been approved and is now live on the site.",
"priority": "normal", "priority": "normal",
"is_read": False, "is_read": False,
"read_at": None, "read_at": None,
@@ -866,15 +821,11 @@ class MarkNotificationsReadSerializer(serializers.Serializer):
def validate_notification_ids(self, value): def validate_notification_ids(self, value):
"""Validate that all notification IDs belong to the requesting user.""" """Validate that all notification IDs belong to the requesting user."""
user = self.context["request"].user user = self.context["request"].user
valid_ids = UserNotification.objects.filter( valid_ids = UserNotification.objects.filter(id__in=value, user=user).values_list("id", flat=True)
id__in=value, user=user
).values_list("id", flat=True)
invalid_ids = set(value) - set(valid_ids) invalid_ids = set(value) - set(valid_ids)
if invalid_ids: if invalid_ids:
raise serializers.ValidationError( raise serializers.ValidationError(f"Invalid notification IDs: {list(invalid_ids)}")
f"Invalid notification IDs: {list(invalid_ids)}"
)
return value return value
@@ -901,9 +852,8 @@ class AvatarUploadSerializer(serializers.Serializer):
raise serializers.ValidationError("No file provided") raise serializers.ValidationError("No file provided")
# Check file size constraints (max 10MB for Cloudflare Images) # Check file size constraints (max 10MB for Cloudflare Images)
if hasattr(value, 'size') and value.size > 10 * 1024 * 1024: if hasattr(value, "size") and value.size > 10 * 1024 * 1024:
raise serializers.ValidationError( raise serializers.ValidationError("Image file too large. Maximum size is 10MB.")
"Image file too large. Maximum size is 10MB.")
# Try to validate with PIL # Try to validate with PIL
try: try:
@@ -926,13 +876,13 @@ class AvatarUploadSerializer(serializers.Serializer):
# Check image dimensions (max 12,000x12,000 for Cloudflare Images) # Check image dimensions (max 12,000x12,000 for Cloudflare Images)
if image.size[0] > 12000 or image.size[1] > 12000: if image.size[0] > 12000 or image.size[1] > 12000:
raise serializers.ValidationError( raise serializers.ValidationError("Image dimensions too large. Maximum is 12,000x12,000 pixels.")
"Image dimensions too large. Maximum is 12,000x12,000 pixels.")
# Check if it's a supported format # Check if it's a supported format
if image.format not in ['JPEG', 'PNG', 'GIF', 'WEBP']: if image.format not in ["JPEG", "PNG", "GIF", "WEBP"]:
raise serializers.ValidationError( raise serializers.ValidationError(
f"Unsupported image format: {image.format}. Supported formats: JPEG, PNG, GIF, WebP.") f"Unsupported image format: {image.format}. Supported formats: JPEG, PNG, GIF, WebP."
)
except serializers.ValidationError: except serializers.ValidationError:
raise # Re-raise validation errors raise # Re-raise validation errors

View File

@@ -97,7 +97,7 @@ class LoginInputSerializer(serializers.Serializer):
password=password, password=password,
) )
if not user: if not user: # noqa: SIM102
# Try email-based authentication if username failed # Try email-based authentication if username failed
if "@" in username: if "@" in username:
try: try:
@@ -138,7 +138,7 @@ class LoginInputSerializer(serializers.Serializer):
"first_name": "John", "first_name": "John",
"last_name": "Doe", "last_name": "Doe",
}, },
"message": "Login successful", "detail": "Login successful",
}, },
) )
] ]
@@ -213,7 +213,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
try: try:
validate_password(value) validate_password(value)
except DjangoValidationError as e: except DjangoValidationError as e:
raise serializers.ValidationError(list(e.messages)) raise serializers.ValidationError(list(e.messages)) from None
return value return value
def validate(self, attrs): def validate(self, attrs):
@@ -253,7 +253,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
"first_name": "Jane", "first_name": "Jane",
"last_name": "Smith", "last_name": "Smith",
}, },
"message": "Registration successful", "detail": "Registration successful",
}, },
) )
] ]
@@ -276,7 +276,7 @@ class SignupOutputSerializer(serializers.Serializer):
summary="Example logout response", summary="Example logout response",
description="Successful logout response", description="Successful logout response",
value={ value={
"message": "Logout successful", "detail": "Logout successful",
}, },
) )
] ]
@@ -318,9 +318,9 @@ class PasswordResetInputSerializer(serializers.Serializer):
"""Send password reset email.""" """Send password reset email."""
email = self.validated_data["email"] # type: ignore[index] email = self.validated_data["email"] # type: ignore[index]
try: try:
_user = UserModel.objects.get(email=email) # Check if email exists (but don't reveal the result for security)
UserModel.objects.get(email=email)
# Here you would typically send a password reset email # Here you would typically send a password reset email
# For now, we'll just pass
pass pass
except UserModel.DoesNotExist: except UserModel.DoesNotExist:
# Don't reveal if email exists for security # Don't reveal if email exists for security
@@ -393,7 +393,7 @@ class PasswordChangeInputSerializer(serializers.Serializer):
try: try:
validate_password(value, user=self.context["request"].user) validate_password(value, user=self.context["request"].user)
except DjangoValidationError as e: except DjangoValidationError as e:
raise serializers.ValidationError(list(e.messages)) raise serializers.ValidationError(list(e.messages)) from None
return value return value
def validate(self, attrs): def validate(self, attrs):
@@ -492,6 +492,4 @@ class AuthStatusOutputSerializer(serializers.Serializer):
"""Output serializer for authentication status.""" """Output serializer for authentication status."""
authenticated = serializers.BooleanField(help_text="Whether user is authenticated") authenticated = serializers.BooleanField(help_text="Whether user is authenticated")
user = UserOutputSerializer( user = UserOutputSerializer(allow_null=True, help_text="User information if authenticated")
allow_null=True, help_text="User information if authenticated"
)

View File

@@ -112,10 +112,7 @@ class RideModelDetailOutputSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
name = serializers.CharField() name = serializers.CharField()
description = serializers.CharField() description = serializers.CharField()
category = RichChoiceFieldSerializer( category = RichChoiceFieldSerializer(choice_group="categories", domain="rides")
choice_group="categories",
domain="rides"
)
# Manufacturer info # Manufacturer info
manufacturer = serializers.SerializerMethodField() manufacturer = serializers.SerializerMethodField()

View File

@@ -99,9 +99,7 @@ class ParkHistoryOutputSerializer(serializers.Serializer):
"slug": park.slug, "slug": park.slug,
"status": park.status, "status": park.status,
"opening_date": ( "opening_date": (
park.opening_date.isoformat() park.opening_date.isoformat() if hasattr(park, "opening_date") and park.opening_date else None
if hasattr(park, "opening_date") and park.opening_date
else None
), ),
"coaster_count": getattr(park, "coaster_count", 0), "coaster_count": getattr(park, "coaster_count", 0),
"ride_count": getattr(park, "ride_count", 0), "ride_count": getattr(park, "ride_count", 0),
@@ -143,9 +141,7 @@ class RideHistoryOutputSerializer(serializers.Serializer):
"park_name": ride.park.name if hasattr(ride, "park") else None, "park_name": ride.park.name if hasattr(ride, "park") else None,
"status": getattr(ride, "status", "UNKNOWN"), "status": getattr(ride, "status", "UNKNOWN"),
"opening_date": ( "opening_date": (
ride.opening_date.isoformat() ride.opening_date.isoformat() if hasattr(ride, "opening_date") and ride.opening_date else None
if hasattr(ride, "opening_date") and ride.opening_date
else None
), ),
"ride_type": getattr(ride, "ride_type", "Unknown"), "ride_type": getattr(ride, "ride_type", "Unknown"),
} }

View File

@@ -79,16 +79,12 @@ class MapLocationSerializer(serializers.Serializer):
return { return {
"coaster_count": obj.coaster_count or 0, "coaster_count": obj.coaster_count or 0,
"ride_count": obj.ride_count or 0, "ride_count": obj.ride_count or 0,
"average_rating": ( "average_rating": (float(obj.average_rating) if obj.average_rating else None),
float(obj.average_rating) if obj.average_rating else None
),
} }
elif obj._meta.model_name == "ride": elif obj._meta.model_name == "ride":
return { return {
"category": obj.get_category_display() if obj.category else None, "category": obj.get_category_display() if obj.category else None,
"average_rating": ( "average_rating": (float(obj.average_rating) if obj.average_rating else None),
float(obj.average_rating) if obj.average_rating else None
),
"park_name": obj.park.name if obj.park else None, "park_name": obj.park.name if obj.park else None,
} }
return {} return {}
@@ -339,24 +335,16 @@ class MapLocationDetailSerializer(serializers.Serializer):
return { return {
"coaster_count": obj.coaster_count or 0, "coaster_count": obj.coaster_count or 0,
"ride_count": obj.ride_count or 0, "ride_count": obj.ride_count or 0,
"average_rating": ( "average_rating": (float(obj.average_rating) if obj.average_rating else None),
float(obj.average_rating) if obj.average_rating else None
),
"size_acres": float(obj.size_acres) if obj.size_acres else None, "size_acres": float(obj.size_acres) if obj.size_acres else None,
"opening_date": ( "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
obj.opening_date.isoformat() if obj.opening_date else None
),
} }
elif obj._meta.model_name == "ride": elif obj._meta.model_name == "ride":
return { return {
"category": obj.get_category_display() if obj.category else None, "category": obj.get_category_display() if obj.category else None,
"average_rating": ( "average_rating": (float(obj.average_rating) if obj.average_rating else None),
float(obj.average_rating) if obj.average_rating else None
),
"park_name": obj.park.name if obj.park else None, "park_name": obj.park.name if obj.park else None,
"opening_date": ( "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
obj.opening_date.isoformat() if obj.opening_date else None
),
"manufacturer": obj.manufacturer.name if obj.manufacturer else None, "manufacturer": obj.manufacturer.name if obj.manufacturer else None,
} }
return {} return {}
@@ -382,9 +370,7 @@ class MapBoundsInputSerializer(serializers.Serializer):
def validate(self, attrs): def validate(self, attrs):
"""Validate that bounds make geographic sense.""" """Validate that bounds make geographic sense."""
if attrs["north"] <= attrs["south"]: if attrs["north"] <= attrs["south"]:
raise serializers.ValidationError( raise serializers.ValidationError("North bound must be greater than south bound")
"North bound must be greater than south bound"
)
# Handle longitude wraparound (e.g., crossing the international date line) # Handle longitude wraparound (e.g., crossing the international date line)
# For now, we'll require west < east for simplicity # For now, we'll require west < east for simplicity

View File

@@ -31,9 +31,7 @@ class PhotoUploadInputSerializer(serializers.Serializer):
allow_blank=True, allow_blank=True,
help_text="Alt text for accessibility", help_text="Alt text for accessibility",
) )
is_primary = serializers.BooleanField( is_primary = serializers.BooleanField(default=False, help_text="Whether this should be the primary photo")
default=False, help_text="Whether this should be the primary photo"
)
@extend_schema_serializer( @extend_schema_serializer(
@@ -89,9 +87,7 @@ class PhotoDetailOutputSerializer(serializers.Serializer):
return { return {
"id": obj.uploaded_by.id, "id": obj.uploaded_by.id,
"username": obj.uploaded_by.username, "username": obj.uploaded_by.username,
"display_name": getattr( "display_name": getattr(obj.uploaded_by, "get_display_name", lambda: obj.uploaded_by.username)(),
obj.uploaded_by, "get_display_name", lambda: obj.uploaded_by.username
)(),
} }

View File

@@ -24,12 +24,8 @@ class ParkStatsOutputSerializer(serializers.Serializer):
under_construction = serializers.IntegerField() under_construction = serializers.IntegerField()
# Averages # Averages
average_rating = serializers.DecimalField( average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
max_digits=3, decimal_places=2, allow_null=True average_coaster_count = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
)
average_coaster_count = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
# Top countries # Top countries
top_countries = serializers.ListField(child=serializers.DictField()) top_countries = serializers.ListField(child=serializers.DictField())
@@ -50,12 +46,8 @@ class RideStatsOutputSerializer(serializers.Serializer):
rides_by_category = serializers.DictField() rides_by_category = serializers.DictField()
# Averages # Averages
average_rating = serializers.DecimalField( average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
max_digits=3, decimal_places=2, allow_null=True average_capacity = serializers.DecimalField(max_digits=8, decimal_places=2, allow_null=True)
)
average_capacity = serializers.DecimalField(
max_digits=8, decimal_places=2, allow_null=True
)
# Top manufacturers # Top manufacturers
top_manufacturers = serializers.ListField(child=serializers.DictField()) top_manufacturers = serializers.ListField(child=serializers.DictField())
@@ -91,10 +83,7 @@ class ParkReviewOutputSerializer(serializers.Serializer):
class HealthCheckOutputSerializer(serializers.Serializer): class HealthCheckOutputSerializer(serializers.Serializer):
"""Output serializer for health check responses.""" """Output serializer for health check responses."""
status = RichChoiceFieldSerializer( status = RichChoiceFieldSerializer(choice_group="health_statuses", domain="core")
choice_group="health_statuses",
domain="core"
)
timestamp = serializers.DateTimeField() timestamp = serializers.DateTimeField()
version = serializers.CharField() version = serializers.CharField()
environment = serializers.CharField() environment = serializers.CharField()
@@ -115,9 +104,6 @@ class PerformanceMetricsOutputSerializer(serializers.Serializer):
class SimpleHealthOutputSerializer(serializers.Serializer): class SimpleHealthOutputSerializer(serializers.Serializer):
"""Output serializer for simple health check.""" """Output serializer for simple health check."""
status = RichChoiceFieldSerializer( status = RichChoiceFieldSerializer(choice_group="simple_health_statuses", domain="core")
choice_group="simple_health_statuses",
domain="core"
)
timestamp = serializers.DateTimeField() timestamp = serializers.DateTimeField()
error = serializers.CharField(required=False) error = serializers.CharField(required=False)

View File

@@ -29,14 +29,10 @@ from apps.parks.models.reviews import ParkReview
"user": { "user": {
"username": "park_fan", "username": "park_fan",
"display_name": "Park Fan", "display_name": "Park Fan",
"avatar_url": "https://example.com/avatar.jpg" "avatar_url": "https://example.com/avatar.jpg",
}, },
"park": { "park": {"id": 101, "name": "Cedar Point", "slug": "cedar-point"},
"id": 101, },
"name": "Cedar Point",
"slug": "cedar-point"
}
}
) )
] ]
) )
@@ -145,8 +141,7 @@ class ParkReviewStatsOutputSerializer(serializers.Serializer):
pending_reviews = serializers.IntegerField() pending_reviews = serializers.IntegerField()
average_rating = serializers.FloatField(allow_null=True) average_rating = serializers.FloatField(allow_null=True)
rating_distribution = serializers.DictField( rating_distribution = serializers.DictField(
child=serializers.IntegerField(), child=serializers.IntegerField(), help_text="Count of reviews by rating (1-10)"
help_text="Count of reviews by rating (1-10)"
) )
recent_reviews = serializers.IntegerField() recent_reviews = serializers.IntegerField()
@@ -154,20 +149,15 @@ class ParkReviewStatsOutputSerializer(serializers.Serializer):
class ParkReviewModerationInputSerializer(serializers.Serializer): class ParkReviewModerationInputSerializer(serializers.Serializer):
"""Input serializer for review moderation operations.""" """Input serializer for review moderation operations."""
review_ids = serializers.ListField( review_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of review IDs to moderate")
child=serializers.IntegerField(),
help_text="List of review IDs to moderate"
)
action = serializers.ChoiceField( action = serializers.ChoiceField(
choices=[ choices=[
("publish", "Publish"), ("publish", "Publish"),
("unpublish", "Unpublish"), ("unpublish", "Unpublish"),
("delete", "Delete"), ("delete", "Delete"),
], ],
help_text="Moderation action to perform" help_text="Moderation action to perform",
) )
moderation_notes = serializers.CharField( moderation_notes = serializers.CharField(
required=False, required=False, allow_blank=True, help_text="Optional notes about the moderation action"
allow_blank=True,
help_text="Optional notes about the moderation action"
) )

View File

@@ -52,16 +52,11 @@ class ParkListOutputSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
name = serializers.CharField() name = serializers.CharField()
slug = serializers.CharField() slug = serializers.CharField()
status = RichChoiceFieldSerializer( status = RichChoiceFieldSerializer(choice_group="statuses", domain="parks")
choice_group="statuses",
domain="parks"
)
description = serializers.CharField() description = serializers.CharField()
# Statistics # Statistics
average_rating = serializers.DecimalField( average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
max_digits=3, decimal_places=2, allow_null=True
)
coaster_count = serializers.IntegerField(allow_null=True) coaster_count = serializers.IntegerField(allow_null=True)
ride_count = serializers.IntegerField(allow_null=True) ride_count = serializers.IntegerField(allow_null=True)
@@ -145,25 +140,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
name = serializers.CharField() name = serializers.CharField()
slug = serializers.CharField() slug = serializers.CharField()
status = RichChoiceFieldSerializer( status = RichChoiceFieldSerializer(choice_group="statuses", domain="parks")
choice_group="statuses",
domain="parks"
)
description = serializers.CharField() description = serializers.CharField()
# Details # Details
opening_date = serializers.DateField(allow_null=True) opening_date = serializers.DateField(allow_null=True)
closing_date = serializers.DateField(allow_null=True) closing_date = serializers.DateField(allow_null=True)
operating_season = serializers.CharField() operating_season = serializers.CharField()
size_acres = serializers.DecimalField( size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, allow_null=True)
max_digits=10, decimal_places=2, allow_null=True
)
website = serializers.URLField() website = serializers.URLField()
# Statistics # Statistics
average_rating = serializers.DecimalField( average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
max_digits=3, decimal_places=2, allow_null=True
)
coaster_count = serializers.IntegerField(allow_null=True) coaster_count = serializers.IntegerField(allow_null=True)
ride_count = serializers.IntegerField(allow_null=True) ride_count = serializers.IntegerField(allow_null=True)
@@ -211,9 +199,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"""Get all approved photos for this park.""" """Get all approved photos for this park."""
from apps.parks.models import ParkPhoto from apps.parks.models import ParkPhoto
photos = ParkPhoto.objects.filter(park=obj, is_approved=True).order_by( photos = ParkPhoto.objects.filter(park=obj, is_approved=True).order_by("-is_primary", "-created_at")[
"-is_primary", "-created_at"
)[
:10 :10
] # Limit to 10 photos ] # Limit to 10 photos
@@ -228,7 +214,9 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"), "public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"),
}, },
"friendly_urls": { "friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "thumbnail"), "thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, photo.caption, photo.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"), "medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"), "large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"), "public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"),
@@ -246,9 +234,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
from apps.parks.models import ParkPhoto from apps.parks.models import ParkPhoto
try: try:
photo = ParkPhoto.objects.filter( photo = ParkPhoto.objects.filter(park=obj, is_primary=True, is_approved=True).first()
park=obj, is_primary=True, is_approved=True
).first()
if photo and photo.image: if photo and photo.image:
return { return {
@@ -261,7 +247,9 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"), "public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"),
}, },
"friendly_urls": { "friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "thumbnail"), "thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, photo.caption, photo.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"), "medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"), "large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"), "public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"),
@@ -289,10 +277,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(obj.banner_image.image, "public"), "public": MediaURLService.get_cloudflare_url_with_fallback(obj.banner_image.image, "public"),
}, },
"friendly_urls": { "friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "thumbnail"), "thumbnail": MediaURLService.generate_park_photo_url(
"medium": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "medium"), obj.slug, obj.banner_image.caption, obj.banner_image.pk, "thumbnail"
"large": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "large"), ),
"public": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "public"), "medium": MediaURLService.generate_park_photo_url(
obj.slug, obj.banner_image.caption, obj.banner_image.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, obj.banner_image.caption, obj.banner_image.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, obj.banner_image.caption, obj.banner_image.pk, "public"
),
}, },
"caption": obj.banner_image.caption, "caption": obj.banner_image.caption,
"alt_text": obj.banner_image.alt_text, "alt_text": obj.banner_image.alt_text,
@@ -303,9 +299,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
try: try:
latest_photo = ( latest_photo = (
ParkPhoto.objects.filter( ParkPhoto.objects.filter(park=obj, is_approved=True, image__isnull=False)
park=obj, is_approved=True, image__isnull=False
)
.order_by("-created_at") .order_by("-created_at")
.first() .first()
) )
@@ -321,10 +315,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"), "public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"),
}, },
"friendly_urls": { "friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"), "thumbnail": MediaURLService.generate_park_photo_url(
"medium": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "medium"), obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"
"large": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "large"), ),
"public": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "public"), "medium": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "public"
),
}, },
"caption": latest_photo.caption, "caption": latest_photo.caption,
"alt_text": latest_photo.alt_text, "alt_text": latest_photo.alt_text,
@@ -350,10 +352,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(obj.card_image.image, "public"), "public": MediaURLService.get_cloudflare_url_with_fallback(obj.card_image.image, "public"),
}, },
"friendly_urls": { "friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "thumbnail"), "thumbnail": MediaURLService.generate_park_photo_url(
"medium": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "medium"), obj.slug, obj.card_image.caption, obj.card_image.pk, "thumbnail"
"large": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "large"), ),
"public": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "public"), "medium": MediaURLService.generate_park_photo_url(
obj.slug, obj.card_image.caption, obj.card_image.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, obj.card_image.caption, obj.card_image.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, obj.card_image.caption, obj.card_image.pk, "public"
),
}, },
"caption": obj.card_image.caption, "caption": obj.card_image.caption,
"alt_text": obj.card_image.alt_text, "alt_text": obj.card_image.alt_text,
@@ -364,9 +374,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
try: try:
latest_photo = ( latest_photo = (
ParkPhoto.objects.filter( ParkPhoto.objects.filter(park=obj, is_approved=True, image__isnull=False)
park=obj, is_approved=True, image__isnull=False
)
.order_by("-created_at") .order_by("-created_at")
.first() .first()
) )
@@ -382,10 +390,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"), "public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"),
}, },
"friendly_urls": { "friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"), "thumbnail": MediaURLService.generate_park_photo_url(
"medium": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "medium"), obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"
"large": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "large"), ),
"public": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "public"), "medium": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "public"
),
}, },
"caption": latest_photo.caption, "caption": latest_photo.caption,
"alt_text": latest_photo.alt_text, "alt_text": latest_photo.alt_text,
@@ -417,7 +433,7 @@ class ParkImageSettingsInputSerializer(serializers.Serializer):
# The park will be validated in the view # The park will be validated in the view
return value return value
except ParkPhoto.DoesNotExist: except ParkPhoto.DoesNotExist:
raise serializers.ValidationError("Photo not found") raise serializers.ValidationError("Photo not found") from None
return value return value
def validate_card_image_id(self, value): def validate_card_image_id(self, value):
@@ -430,7 +446,7 @@ class ParkImageSettingsInputSerializer(serializers.Serializer):
# The park will be validated in the view # The park will be validated in the view
return value return value
except ParkPhoto.DoesNotExist: except ParkPhoto.DoesNotExist:
raise serializers.ValidationError("Photo not found") raise serializers.ValidationError("Photo not found") from None
return value return value
@@ -439,19 +455,13 @@ class ParkCreateInputSerializer(serializers.Serializer):
name = serializers.CharField(max_length=255) name = serializers.CharField(max_length=255)
description = serializers.CharField(allow_blank=True, default="") description = serializers.CharField(allow_blank=True, default="")
status = serializers.ChoiceField( status = serializers.ChoiceField(choices=ModelChoices.get_park_status_choices(), default="OPERATING")
choices=ModelChoices.get_park_status_choices(), default="OPERATING"
)
# Optional details # Optional details
opening_date = serializers.DateField(required=False, allow_null=True) opening_date = serializers.DateField(required=False, allow_null=True)
closing_date = serializers.DateField(required=False, allow_null=True) closing_date = serializers.DateField(required=False, allow_null=True)
operating_season = serializers.CharField( operating_season = serializers.CharField(max_length=255, required=False, allow_blank=True)
max_length=255, required=False, allow_blank=True size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, allow_null=True)
)
size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, required=False, allow_null=True
)
website = serializers.URLField(required=False, allow_blank=True) website = serializers.URLField(required=False, allow_blank=True)
# Required operator # Required operator
@@ -466,9 +476,7 @@ class ParkCreateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date") closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date: if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError( raise serializers.ValidationError("Closing date cannot be before opening date")
"Closing date cannot be before opening date"
)
return attrs return attrs
@@ -478,19 +486,13 @@ class ParkUpdateInputSerializer(serializers.Serializer):
name = serializers.CharField(max_length=255, required=False) name = serializers.CharField(max_length=255, required=False)
description = serializers.CharField(allow_blank=True, required=False) description = serializers.CharField(allow_blank=True, required=False)
status = serializers.ChoiceField( status = serializers.ChoiceField(choices=ModelChoices.get_park_status_choices(), required=False)
choices=ModelChoices.get_park_status_choices(), required=False
)
# Optional details # Optional details
opening_date = serializers.DateField(required=False, allow_null=True) opening_date = serializers.DateField(required=False, allow_null=True)
closing_date = serializers.DateField(required=False, allow_null=True) closing_date = serializers.DateField(required=False, allow_null=True)
operating_season = serializers.CharField( operating_season = serializers.CharField(max_length=255, required=False, allow_blank=True)
max_length=255, required=False, allow_blank=True size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, allow_null=True)
)
size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, required=False, allow_null=True
)
website = serializers.URLField(required=False, allow_blank=True) website = serializers.URLField(required=False, allow_blank=True)
# Companies # Companies
@@ -503,9 +505,7 @@ class ParkUpdateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date") closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date: if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError( raise serializers.ValidationError("Closing date cannot be before opening date")
"Closing date cannot be before opening date"
)
return attrs return attrs
@@ -537,12 +537,8 @@ class ParkFilterInputSerializer(serializers.Serializer):
) )
# Size filter # Size filter
min_size_acres = serializers.DecimalField( min_size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, min_value=0)
max_digits=10, decimal_places=2, required=False, min_value=0 max_size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, min_value=0)
)
max_size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, required=False, min_value=0
)
# Company filters # Company filters
operator_id = serializers.IntegerField(required=False) operator_id = serializers.IntegerField(required=False)
@@ -625,9 +621,7 @@ class ParkAreaCreateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date") closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date: if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError( raise serializers.ValidationError("Closing date cannot be before opening date")
"Closing date cannot be before opening date"
)
return attrs return attrs
@@ -646,9 +640,7 @@ class ParkAreaUpdateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date") closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date: if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError( raise serializers.ValidationError("Closing date cannot be before opening date")
"Closing date cannot be before opening date"
)
return attrs return attrs

View File

@@ -12,9 +12,7 @@ from apps.parks.models import ParkPhoto
class ParkPhotoOutputSerializer(serializers.ModelSerializer): class ParkPhotoOutputSerializer(serializers.ModelSerializer):
"""Output serializer for park photos.""" """Output serializer for park photos."""
uploaded_by_username = serializers.CharField( uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
source="uploaded_by.username", read_only=True
)
file_size = serializers.ReadOnlyField() file_size = serializers.ReadOnlyField()
dimensions = serializers.ReadOnlyField() dimensions = serializers.ReadOnlyField()
park_slug = serializers.CharField(source="park.slug", read_only=True) park_slug = serializers.CharField(source="park.slug", read_only=True)
@@ -78,9 +76,7 @@ class ParkPhotoUpdateInputSerializer(serializers.ModelSerializer):
class ParkPhotoListOutputSerializer(serializers.ModelSerializer): class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
"""Simplified output serializer for park photo lists.""" """Simplified output serializer for park photo lists."""
uploaded_by_username = serializers.CharField( uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
source="uploaded_by.username", read_only=True
)
class Meta: class Meta:
model = ParkPhoto model = ParkPhoto
@@ -99,12 +95,8 @@ class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
class ParkPhotoApprovalInputSerializer(serializers.Serializer): class ParkPhotoApprovalInputSerializer(serializers.Serializer):
"""Input serializer for photo approval operations.""" """Input serializer for photo approval operations."""
photo_ids = serializers.ListField( photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve")
child=serializers.IntegerField(), help_text="List of photo IDs to approve" approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos")
)
approve = serializers.BooleanField(
default=True, help_text="Whether to approve (True) or reject (False) the photos"
)
class ParkPhotoStatsOutputSerializer(serializers.Serializer): class ParkPhotoStatsOutputSerializer(serializers.Serializer):

View File

@@ -8,35 +8,33 @@ from apps.rides.models.credits import RideCredit
class RideCreditSerializer(serializers.ModelSerializer): class RideCreditSerializer(serializers.ModelSerializer):
"""Serializer for user ride credits.""" """Serializer for user ride credits."""
ride_id = serializers.PrimaryKeyRelatedField( ride_id = serializers.PrimaryKeyRelatedField(queryset=Ride.objects.all(), source="ride", write_only=True)
queryset=Ride.objects.all(), source='ride', write_only=True
)
ride = RideListOutputSerializer(read_only=True) ride = RideListOutputSerializer(read_only=True)
class Meta: class Meta:
model = RideCredit model = RideCredit
fields = [ fields = [
'id', "id",
'ride', "ride",
'ride_id', "ride_id",
'count', "count",
'rating', "rating",
'first_ridden_at', "first_ridden_at",
'last_ridden_at', "last_ridden_at",
'notes', "notes",
'display_order', "display_order",
'created_at', "created_at",
'updated_at', "updated_at",
] ]
read_only_fields = ['id', 'created_at', 'updated_at'] read_only_fields = ["id", "created_at", "updated_at"]
def validate(self, attrs): def validate(self, attrs):
""" """
Validate data. Validate data.
""" """
# Ensure dates make sense # Ensure dates make sense
first = attrs.get('first_ridden_at') first = attrs.get("first_ridden_at")
last = attrs.get('last_ridden_at') last = attrs.get("last_ridden_at")
if first and last and last < first: if first and last and last < first:
raise serializers.ValidationError("Last ridden date cannot be before first ridden date.") raise serializers.ValidationError("Last ridden date cannot be before first ridden date.")
@@ -44,6 +42,6 @@ class RideCreditSerializer(serializers.ModelSerializer):
def create(self, validated_data): def create(self, validated_data):
"""Create a new ride credit.""" """Create a new ride credit."""
user = self.context['request'].user user = self.context["request"].user
validated_data['user'] = user validated_data["user"] = user
return super().create(validated_data) return super().create(validated_data)

View File

@@ -80,18 +80,10 @@ class RideModelVariantOutputSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
name = serializers.CharField() name = serializers.CharField()
description = serializers.CharField() description = serializers.CharField()
min_height_ft = serializers.DecimalField( min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
max_digits=6, decimal_places=2, allow_null=True max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
) min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
max_height_ft = serializers.DecimalField( max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
max_digits=6, decimal_places=2, allow_null=True
)
min_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
distinguishing_features = serializers.CharField() distinguishing_features = serializers.CharField()
@@ -134,20 +126,14 @@ class RideModelListOutputSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
name = serializers.CharField() name = serializers.CharField()
slug = serializers.CharField() slug = serializers.CharField()
category = RichChoiceFieldSerializer( category = RichChoiceFieldSerializer(choice_group="categories", domain="rides")
choice_group="categories",
domain="rides"
)
description = serializers.CharField() description = serializers.CharField()
# Manufacturer info # Manufacturer info
manufacturer = RideModelManufacturerOutputSerializer(allow_null=True) manufacturer = RideModelManufacturerOutputSerializer(allow_null=True)
# Market info # Market info
target_market = RichChoiceFieldSerializer( target_market = RichChoiceFieldSerializer(choice_group="target_markets", domain="rides")
choice_group="target_markets",
domain="rides"
)
is_discontinued = serializers.BooleanField() is_discontinued = serializers.BooleanField()
total_installations = serializers.IntegerField() total_installations = serializers.IntegerField()
first_installation_year = serializers.IntegerField(allow_null=True) first_installation_year = serializers.IntegerField(allow_null=True)
@@ -258,18 +244,10 @@ class RideModelDetailOutputSerializer(serializers.Serializer):
manufacturer = RideModelManufacturerOutputSerializer(allow_null=True) manufacturer = RideModelManufacturerOutputSerializer(allow_null=True)
# Technical specifications # Technical specifications
typical_height_range_min_ft = serializers.DecimalField( typical_height_range_min_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
max_digits=6, decimal_places=2, allow_null=True typical_height_range_max_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
) typical_speed_range_min_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
typical_height_range_max_ft = serializers.DecimalField( typical_speed_range_max_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
max_digits=6, decimal_places=2, allow_null=True
)
typical_speed_range_min_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
typical_speed_range_max_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
typical_capacity_range_min = serializers.IntegerField(allow_null=True) typical_capacity_range_min = serializers.IntegerField(allow_null=True)
typical_capacity_range_max = serializers.IntegerField(allow_null=True) typical_capacity_range_max = serializers.IntegerField(allow_null=True)
@@ -343,9 +321,7 @@ class RideModelCreateInputSerializer(serializers.Serializer):
name = serializers.CharField(max_length=255) name = serializers.CharField(max_length=255)
description = serializers.CharField(allow_blank=True, default="") description = serializers.CharField(allow_blank=True, default="")
category = serializers.ChoiceField( category = serializers.ChoiceField(choices=ModelChoices.get_ride_category_choices(), allow_blank=True, default="")
choices=ModelChoices.get_ride_category_choices(), allow_blank=True, default=""
)
# Required manufacturer # Required manufacturer
manufacturer_id = serializers.IntegerField() manufacturer_id = serializers.IntegerField()
@@ -363,32 +339,18 @@ class RideModelCreateInputSerializer(serializers.Serializer):
typical_speed_range_max_mph = serializers.DecimalField( typical_speed_range_max_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True max_digits=5, decimal_places=2, required=False, allow_null=True
) )
typical_capacity_range_min = serializers.IntegerField( typical_capacity_range_min = serializers.IntegerField(required=False, allow_null=True, min_value=1)
required=False, allow_null=True, min_value=1 typical_capacity_range_max = serializers.IntegerField(required=False, allow_null=True, min_value=1)
)
typical_capacity_range_max = serializers.IntegerField(
required=False, allow_null=True, min_value=1
)
# Design characteristics # Design characteristics
track_type = serializers.CharField(max_length=100, allow_blank=True, default="") track_type = serializers.CharField(max_length=100, allow_blank=True, default="")
support_structure = serializers.CharField( support_structure = serializers.CharField(max_length=100, allow_blank=True, default="")
max_length=100, allow_blank=True, default="" train_configuration = serializers.CharField(max_length=200, allow_blank=True, default="")
) restraint_system = serializers.CharField(max_length=100, allow_blank=True, default="")
train_configuration = serializers.CharField(
max_length=200, allow_blank=True, default=""
)
restraint_system = serializers.CharField(
max_length=100, allow_blank=True, default=""
)
# Market information # Market information
first_installation_year = serializers.IntegerField( first_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
required=False, allow_null=True, min_value=1800, max_value=2100 last_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
)
last_installation_year = serializers.IntegerField(
required=False, allow_null=True, min_value=1800, max_value=2100
)
is_discontinued = serializers.BooleanField(default=False) is_discontinued = serializers.BooleanField(default=False)
# Design features # Design features
@@ -406,36 +368,28 @@ class RideModelCreateInputSerializer(serializers.Serializer):
max_height = attrs.get("typical_height_range_max_ft") max_height = attrs.get("typical_height_range_max_ft")
if min_height and max_height and min_height > max_height: if min_height and max_height and min_height > max_height:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
"Minimum height cannot be greater than maximum height"
)
# Speed range validation # Speed range validation
min_speed = attrs.get("typical_speed_range_min_mph") min_speed = attrs.get("typical_speed_range_min_mph")
max_speed = attrs.get("typical_speed_range_max_mph") max_speed = attrs.get("typical_speed_range_max_mph")
if min_speed and max_speed and min_speed > max_speed: if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
"Minimum speed cannot be greater than maximum speed"
)
# Capacity range validation # Capacity range validation
min_capacity = attrs.get("typical_capacity_range_min") min_capacity = attrs.get("typical_capacity_range_min")
max_capacity = attrs.get("typical_capacity_range_max") max_capacity = attrs.get("typical_capacity_range_max")
if min_capacity and max_capacity and min_capacity > max_capacity: if min_capacity and max_capacity and min_capacity > max_capacity:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum capacity cannot be greater than maximum capacity")
"Minimum capacity cannot be greater than maximum capacity"
)
# Installation years validation # Installation years validation
first_year = attrs.get("first_installation_year") first_year = attrs.get("first_installation_year")
last_year = attrs.get("last_installation_year") last_year = attrs.get("last_installation_year")
if first_year and last_year and first_year > last_year: if first_year and last_year and first_year > last_year:
raise serializers.ValidationError( raise serializers.ValidationError("First installation year cannot be after last installation year")
"First installation year cannot be after last installation year"
)
return attrs return attrs
@@ -467,32 +421,18 @@ class RideModelUpdateInputSerializer(serializers.Serializer):
typical_speed_range_max_mph = serializers.DecimalField( typical_speed_range_max_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True max_digits=5, decimal_places=2, required=False, allow_null=True
) )
typical_capacity_range_min = serializers.IntegerField( typical_capacity_range_min = serializers.IntegerField(required=False, allow_null=True, min_value=1)
required=False, allow_null=True, min_value=1 typical_capacity_range_max = serializers.IntegerField(required=False, allow_null=True, min_value=1)
)
typical_capacity_range_max = serializers.IntegerField(
required=False, allow_null=True, min_value=1
)
# Design characteristics # Design characteristics
track_type = serializers.CharField(max_length=100, allow_blank=True, required=False) track_type = serializers.CharField(max_length=100, allow_blank=True, required=False)
support_structure = serializers.CharField( support_structure = serializers.CharField(max_length=100, allow_blank=True, required=False)
max_length=100, allow_blank=True, required=False train_configuration = serializers.CharField(max_length=200, allow_blank=True, required=False)
) restraint_system = serializers.CharField(max_length=100, allow_blank=True, required=False)
train_configuration = serializers.CharField(
max_length=200, allow_blank=True, required=False
)
restraint_system = serializers.CharField(
max_length=100, allow_blank=True, required=False
)
# Market information # Market information
first_installation_year = serializers.IntegerField( first_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
required=False, allow_null=True, min_value=1800, max_value=2100 last_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
)
last_installation_year = serializers.IntegerField(
required=False, allow_null=True, min_value=1800, max_value=2100
)
is_discontinued = serializers.BooleanField(required=False) is_discontinued = serializers.BooleanField(required=False)
# Design features # Design features
@@ -510,36 +450,28 @@ class RideModelUpdateInputSerializer(serializers.Serializer):
max_height = attrs.get("typical_height_range_max_ft") max_height = attrs.get("typical_height_range_max_ft")
if min_height and max_height and min_height > max_height: if min_height and max_height and min_height > max_height:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
"Minimum height cannot be greater than maximum height"
)
# Speed range validation # Speed range validation
min_speed = attrs.get("typical_speed_range_min_mph") min_speed = attrs.get("typical_speed_range_min_mph")
max_speed = attrs.get("typical_speed_range_max_mph") max_speed = attrs.get("typical_speed_range_max_mph")
if min_speed and max_speed and min_speed > max_speed: if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
"Minimum speed cannot be greater than maximum speed"
)
# Capacity range validation # Capacity range validation
min_capacity = attrs.get("typical_capacity_range_min") min_capacity = attrs.get("typical_capacity_range_min")
max_capacity = attrs.get("typical_capacity_range_max") max_capacity = attrs.get("typical_capacity_range_max")
if min_capacity and max_capacity and min_capacity > max_capacity: if min_capacity and max_capacity and min_capacity > max_capacity:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum capacity cannot be greater than maximum capacity")
"Minimum capacity cannot be greater than maximum capacity"
)
# Installation years validation # Installation years validation
first_year = attrs.get("first_installation_year") first_year = attrs.get("first_installation_year")
last_year = attrs.get("last_installation_year") last_year = attrs.get("last_installation_year")
if first_year and last_year and first_year > last_year: if first_year and last_year and first_year > last_year:
raise serializers.ValidationError( raise serializers.ValidationError("First installation year cannot be after last installation year")
"First installation year cannot be after last installation year"
)
return attrs return attrs
@@ -551,9 +483,7 @@ class RideModelFilterInputSerializer(serializers.Serializer):
search = serializers.CharField(required=False, allow_blank=True) search = serializers.CharField(required=False, allow_blank=True)
# Category filter # Category filter
category = serializers.MultipleChoiceField( category = serializers.MultipleChoiceField(choices=ModelChoices.get_ride_category_choices(), required=False)
choices=ModelChoices.get_ride_category_choices(), required=False
)
# Manufacturer filter # Manufacturer filter
manufacturer_id = serializers.IntegerField(required=False) manufacturer_id = serializers.IntegerField(required=False)
@@ -576,20 +506,12 @@ class RideModelFilterInputSerializer(serializers.Serializer):
min_installations = serializers.IntegerField(required=False, min_value=0) min_installations = serializers.IntegerField(required=False, min_value=0)
# Height filters # Height filters
min_height_ft = serializers.DecimalField( min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False)
max_digits=6, decimal_places=2, required=False max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False)
)
max_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, required=False
)
# Speed filters # Speed filters
min_speed_mph = serializers.DecimalField( min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False)
max_digits=5, decimal_places=2, required=False max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False)
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False
)
# Ordering # Ordering
ordering = serializers.ChoiceField( ordering = serializers.ChoiceField(
@@ -621,18 +543,10 @@ class RideModelVariantCreateInputSerializer(serializers.Serializer):
description = serializers.CharField(allow_blank=True, default="") description = serializers.CharField(allow_blank=True, default="")
# Variant-specific specifications # Variant-specific specifications
min_height_ft = serializers.DecimalField( min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
max_digits=6, decimal_places=2, required=False, allow_null=True max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
) min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
max_height_ft = serializers.DecimalField( max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
max_digits=6, decimal_places=2, required=False, allow_null=True
)
min_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
# Distinguishing features # Distinguishing features
distinguishing_features = serializers.CharField(allow_blank=True, default="") distinguishing_features = serializers.CharField(allow_blank=True, default="")
@@ -644,18 +558,14 @@ class RideModelVariantCreateInputSerializer(serializers.Serializer):
max_height = attrs.get("max_height_ft") max_height = attrs.get("max_height_ft")
if min_height and max_height and min_height > max_height: if min_height and max_height and min_height > max_height:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
"Minimum height cannot be greater than maximum height"
)
# Speed range validation # Speed range validation
min_speed = attrs.get("min_speed_mph") min_speed = attrs.get("min_speed_mph")
max_speed = attrs.get("max_speed_mph") max_speed = attrs.get("max_speed_mph")
if min_speed and max_speed and min_speed > max_speed: if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
"Minimum speed cannot be greater than maximum speed"
)
return attrs return attrs
@@ -667,18 +577,10 @@ class RideModelVariantUpdateInputSerializer(serializers.Serializer):
description = serializers.CharField(allow_blank=True, required=False) description = serializers.CharField(allow_blank=True, required=False)
# Variant-specific specifications # Variant-specific specifications
min_height_ft = serializers.DecimalField( min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
max_digits=6, decimal_places=2, required=False, allow_null=True max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
) min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
max_height_ft = serializers.DecimalField( max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
max_digits=6, decimal_places=2, required=False, allow_null=True
)
min_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
# Distinguishing features # Distinguishing features
distinguishing_features = serializers.CharField(allow_blank=True, required=False) distinguishing_features = serializers.CharField(allow_blank=True, required=False)
@@ -690,18 +592,14 @@ class RideModelVariantUpdateInputSerializer(serializers.Serializer):
max_height = attrs.get("max_height_ft") max_height = attrs.get("max_height_ft")
if min_height and max_height and min_height > max_height: if min_height and max_height and min_height > max_height:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
"Minimum height cannot be greater than maximum height"
)
# Speed range validation # Speed range validation
min_speed = attrs.get("min_speed_mph") min_speed = attrs.get("min_speed_mph")
max_speed = attrs.get("max_speed_mph") max_speed = attrs.get("max_speed_mph")
if min_speed and max_speed and min_speed > max_speed: if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError( raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
"Minimum speed cannot be greater than maximum speed"
)
return attrs return attrs
@@ -713,9 +611,7 @@ class RideModelTechnicalSpecCreateInputSerializer(serializers.Serializer):
"""Input serializer for creating ride model technical specifications.""" """Input serializer for creating ride model technical specifications."""
ride_model_id = serializers.IntegerField() ride_model_id = serializers.IntegerField()
spec_category = serializers.ChoiceField( spec_category = serializers.ChoiceField(choices=ModelChoices.get_technical_spec_category_choices())
choices=ModelChoices.get_technical_spec_category_choices()
)
spec_name = serializers.CharField(max_length=100) spec_name = serializers.CharField(max_length=100)
spec_value = serializers.CharField(max_length=255) spec_value = serializers.CharField(max_length=255)
spec_unit = serializers.CharField(max_length=20, allow_blank=True, default="") spec_unit = serializers.CharField(max_length=20, allow_blank=True, default="")
@@ -765,13 +661,9 @@ class RideModelPhotoUpdateInputSerializer(serializers.Serializer):
required=False, required=False,
) )
is_primary = serializers.BooleanField(required=False) is_primary = serializers.BooleanField(required=False)
photographer = serializers.CharField( photographer = serializers.CharField(max_length=255, allow_blank=True, required=False)
max_length=255, allow_blank=True, required=False
)
source = serializers.CharField(max_length=255, allow_blank=True, required=False) source = serializers.CharField(max_length=255, allow_blank=True, required=False)
copyright_info = serializers.CharField( copyright_info = serializers.CharField(max_length=255, allow_blank=True, required=False)
max_length=255, allow_blank=True, required=False
)
# === RIDE MODEL STATS SERIALIZERS === # === RIDE MODEL STATS SERIALIZERS ===
@@ -784,15 +676,9 @@ class RideModelStatsOutputSerializer(serializers.Serializer):
total_installations = serializers.IntegerField() total_installations = serializers.IntegerField()
active_manufacturers = serializers.IntegerField() active_manufacturers = serializers.IntegerField()
discontinued_models = serializers.IntegerField() discontinued_models = serializers.IntegerField()
by_category = serializers.DictField( by_category = serializers.DictField(child=serializers.IntegerField(), help_text="Model counts by category")
child=serializers.IntegerField(), help_text="Model counts by category"
)
by_target_market = serializers.DictField( by_target_market = serializers.DictField(
child=serializers.IntegerField(), help_text="Model counts by target market" child=serializers.IntegerField(), help_text="Model counts by target market"
) )
by_manufacturer = serializers.DictField( by_manufacturer = serializers.DictField(child=serializers.IntegerField(), help_text="Model counts by manufacturer")
child=serializers.IntegerField(), help_text="Model counts by manufacturer" recent_models = serializers.IntegerField(help_text="Models created in the last 30 days")
)
recent_models = serializers.IntegerField(
help_text="Models created in the last 30 days"
)

View File

@@ -54,19 +54,11 @@ class ReviewUserSerializer(serializers.ModelSerializer):
"id": 456, "id": 456,
"username": "coaster_fan", "username": "coaster_fan",
"display_name": "Coaster Fan", "display_name": "Coaster Fan",
"avatar_url": "https://example.com/avatar.jpg" "avatar_url": "https://example.com/avatar.jpg",
}, },
"ride": { "ride": {"id": 789, "name": "Steel Vengeance", "slug": "steel-vengeance"},
"id": 789, "park": {"id": 101, "name": "Cedar Point", "slug": "cedar-point"},
"name": "Steel Vengeance", },
"slug": "steel-vengeance"
},
"park": {
"id": 101,
"name": "Cedar Point",
"slug": "cedar-point"
}
}
) )
] ]
) )
@@ -191,8 +183,7 @@ class RideReviewStatsOutputSerializer(serializers.Serializer):
pending_reviews = serializers.IntegerField() pending_reviews = serializers.IntegerField()
average_rating = serializers.FloatField(allow_null=True) average_rating = serializers.FloatField(allow_null=True)
rating_distribution = serializers.DictField( rating_distribution = serializers.DictField(
child=serializers.IntegerField(), child=serializers.IntegerField(), help_text="Count of reviews by rating (1-10)"
help_text="Count of reviews by rating (1-10)"
) )
recent_reviews = serializers.IntegerField() recent_reviews = serializers.IntegerField()
@@ -200,20 +191,15 @@ class RideReviewStatsOutputSerializer(serializers.Serializer):
class RideReviewModerationInputSerializer(serializers.Serializer): class RideReviewModerationInputSerializer(serializers.Serializer):
"""Input serializer for review moderation operations.""" """Input serializer for review moderation operations."""
review_ids = serializers.ListField( review_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of review IDs to moderate")
child=serializers.IntegerField(),
help_text="List of review IDs to moderate"
)
action = serializers.ChoiceField( action = serializers.ChoiceField(
choices=[ choices=[
("publish", "Publish"), ("publish", "Publish"),
("unpublish", "Unpublish"), ("unpublish", "Unpublish"),
("delete", "Delete"), ("delete", "Delete"),
], ],
help_text="Moderation action to perform" help_text="Moderation action to perform",
) )
moderation_notes = serializers.CharField( moderation_notes = serializers.CharField(
required=False, required=False, allow_blank=True, help_text="Optional notes about the moderation action"
allow_blank=True,
help_text="Optional notes about the moderation action"
) )

Some files were not shown because too many files have changed in this diff Show More