From b243b17af7fd24e2dc76cc45544657b0dc0f858b Mon Sep 17 00:00:00 2001 From: pacnpal <183241239+pacnpal@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:13:01 -0500 Subject: [PATCH] feat: Implement initial schema and add various API, service, and management command enhancements across the application. --- backend/.flake8 | 37 + backend/apps/accounts/__init__.py | 2 +- backend/apps/accounts/admin.py | 32 +- backend/apps/accounts/choices.py | 102 +- backend/apps/accounts/export_service.py | 37 +- backend/apps/accounts/login_history.py | 4 +- .../commands/check_all_social_tables.py | 12 +- .../management/commands/check_social_apps.py | 4 +- .../commands/cleanup_social_auth.py | 9 +- .../management/commands/cleanup_test_data.py | 12 +- .../management/commands/create_social_apps.py | 10 +- .../management/commands/create_test_users.py | 10 +- .../management/commands/delete_user.py | 58 +- .../commands/fix_migration_history.py | 11 +- .../management/commands/fix_social_apps.py | 4 +- .../commands/generate_letter_avatars.py | 4 +- .../management/commands/regenerate_avatars.py | 4 +- .../accounts/management/commands/reset_db.py | 14 +- .../management/commands/reset_social_apps.py | 4 +- .../management/commands/reset_social_auth.py | 12 +- .../management/commands/setup_groups.py | 4 +- .../management/commands/setup_site.py | 4 +- .../management/commands/setup_social_auth.py | 36 +- .../commands/setup_social_providers.py | 4 +- .../management/commands/test_discord_auth.py | 4 +- .../commands/update_social_apps_sites.py | 4 +- .../commands/verify_discord_settings.py | 8 +- .../apps/accounts/migrations/0001_initial.py | 20 +- ...t_passwordresetevent_userevent_and_more.py | 28 +- ...quest_userdeletionrequestevent_and_more.py | 28 +- ...sert_remove_user_update_update_and_more.py | 8 +- ...ce_notificationpreferenceevent_and_more.py | 12 +- .../migrations/0010_auto_20250830_1657.py | 34 +- ...0011_fix_userprofile_event_avatar_field.py | 9 +- .../migrations/0013_add_user_query_indexes.py | 16 +- ...er_remove_toplistitem_top_list_and_more.py | 309 ++---- backend/apps/accounts/mixins.py | 15 +- backend/apps/accounts/models.py | 211 +--- backend/apps/accounts/selectors.py | 61 +- backend/apps/accounts/serializers.py | 36 +- backend/apps/accounts/services.py | 140 +-- backend/apps/accounts/services/__init__.py | 2 +- .../accounts/services/notification_service.py | 28 +- .../services/social_provider_service.py | 88 +- .../services/user_deletion_service.py | 104 +- backend/apps/accounts/signals.py | 20 +- backend/apps/accounts/tests.py | 12 +- backend/apps/accounts/tests/test_admin.py | 3 - .../accounts/tests/test_model_constraints.py | 6 +- .../apps/accounts/tests/test_user_deletion.py | 18 +- backend/apps/accounts/views.py | 44 +- .../apps/api/management/commands/seed_data.py | 986 +++++++++++------- backend/apps/api/v1/accounts/serializers.py | 1 + backend/apps/api/v1/accounts/urls.py | 13 +- backend/apps/api/v1/accounts/views.py | 282 ++--- backend/apps/api/v1/accounts/views_credits.py | 69 +- .../apps/api/v1/accounts/views_magic_link.py | 129 +-- backend/apps/api/v1/auth/mfa.py | 95 +- backend/apps/api/v1/auth/serializers.py | 54 +- .../v1/auth/serializers_package/__init__.py | 18 +- .../api/v1/auth/serializers_package/social.py | 160 +-- backend/apps/api/v1/auth/urls.py | 7 - backend/apps/api/v1/auth/views.py | 174 ++-- backend/apps/api/v1/core/views.py | 39 +- backend/apps/api/v1/email/views.py | 8 +- backend/apps/api/v1/history/views.py | 52 +- backend/apps/api/v1/images/views.py | 16 +- backend/apps/api/v1/maps/views.py | 262 ++--- backend/apps/api/v1/middleware.py | 123 +-- backend/apps/api/v1/parks/history_views.py | 16 +- .../apps/api/v1/parks/park_reviews_views.py | 21 +- backend/apps/api/v1/parks/park_rides_views.py | 181 ++-- backend/apps/api/v1/parks/park_views.py | 793 +++++++------- .../apps/api/v1/parks/ride_photos_views.py | 119 +-- .../apps/api/v1/parks/ride_reviews_views.py | 54 +- backend/apps/api/v1/parks/serializers.py | 46 +- backend/apps/api/v1/parks/urls.py | 25 +- backend/apps/api/v1/parks/views.py | 122 +-- backend/apps/api/v1/responses.py | 167 +++ backend/apps/api/v1/rides/company_views.py | 7 +- .../apps/api/v1/rides/manufacturers/views.py | 381 +++---- backend/apps/api/v1/rides/photo_views.py | 132 +-- backend/apps/api/v1/rides/serializers.py | 83 +- backend/apps/api/v1/rides/urls.py | 3 - backend/apps/api/v1/rides/views.py | 131 +-- backend/apps/api/v1/serializers.py | 1 - backend/apps/api/v1/serializers/__init__.py | 1 - backend/apps/api/v1/serializers/accounts.py | 116 +-- backend/apps/api/v1/serializers/auth.py | 20 +- backend/apps/api/v1/serializers/companies.py | 5 +- backend/apps/api/v1/serializers/history.py | 8 +- backend/apps/api/v1/serializers/maps.py | 28 +- backend/apps/api/v1/serializers/media.py | 8 +- backend/apps/api/v1/serializers/other.py | 26 +- .../apps/api/v1/serializers/park_reviews.py | 24 +- backend/apps/api/v1/serializers/parks.py | 162 ++- .../apps/api/v1/serializers/parks_media.py | 16 +- .../apps/api/v1/serializers/ride_credits.py | 36 +- .../apps/api/v1/serializers/ride_models.py | 226 +--- .../apps/api/v1/serializers/ride_reviews.py | 30 +- backend/apps/api/v1/serializers/rides.py | 212 +--- .../apps/api/v1/serializers/rides_media.py | 20 +- backend/apps/api/v1/serializers/search.py | 13 +- backend/apps/api/v1/serializers/services.py | 10 +- backend/apps/api/v1/serializers/shared.py | 338 ++---- backend/apps/api/v1/serializers/stats.py | 128 +-- backend/apps/api/v1/serializers_rankings.py | 34 +- backend/apps/api/v1/tests/test_contracts.py | 275 +++-- backend/apps/api/v1/views/auth.py | 51 +- backend/apps/api/v1/views/base.py | 167 ++- backend/apps/api/v1/views/discovery.py | 19 +- backend/apps/api/v1/views/health.py | 34 +- backend/apps/api/v1/views/leaderboard.py | 135 ++- backend/apps/api/v1/views/stats.py | 32 +- backend/apps/api/v1/views/trending.py | 16 +- backend/apps/api/v1/viewsets.py | 4 +- backend/apps/api/v1/viewsets_rankings.py | 44 +- backend/apps/blog/models.py | 9 +- backend/apps/blog/serializers.py | 9 +- backend/apps/blog/views.py | 4 +- .../versions/2025_06_17_initial_schema.py | 4 +- backend/apps/core/__init__.py | 2 +- backend/apps/core/admin.py | 4 +- backend/apps/core/admin/mixins.py | 16 +- backend/apps/core/analytics.py | 16 +- backend/apps/core/api/exceptions.py | 13 +- backend/apps/core/api/mixins.py | 16 +- backend/apps/core/checks.py | 185 ++-- backend/apps/core/choices/__init__.py | 20 +- backend/apps/core/choices/base.py | 41 +- backend/apps/core/choices/core_choices.py | 102 +- backend/apps/core/choices/fields.py | 59 +- backend/apps/core/choices/registry.py | 19 +- backend/apps/core/choices/serializers.py | 146 ++- backend/apps/core/choices/utils.py | 112 +- .../apps/core/decorators/cache_decorators.py | 84 +- backend/apps/core/forms.py | 4 +- backend/apps/core/forms/htmx_forms.py | 6 +- backend/apps/core/forms/search.py | 17 +- .../apps/core/health_checks/custom_checks.py | 33 +- backend/apps/core/history.py | 4 +- backend/apps/core/logging.py | 32 +- .../commands/calculate_new_content.py | 36 +- .../management/commands/calculate_trending.py | 105 +- .../core/management/commands/clear_cache.py | 160 +-- .../commands/list_transition_callbacks.py | 132 ++- .../management/commands/optimize_static.py | 51 +- .../apps/core/management/commands/rundev.py | 18 +- .../management/commands/security_audit.py | 169 ++- .../core/management/commands/setup_dev.py | 58 +- .../commands/test_transition_callbacks.py | 146 ++- .../core/management/commands/test_trending.py | 60 +- .../core/management/commands/warm_cache.py | 64 +- backend/apps/core/managers.py | 20 +- backend/apps/core/middleware/analytics.py | 10 +- .../core/middleware/htmx_error_middleware.py | 1 + backend/apps/core/middleware/nextjs.py | 8 +- .../core/middleware/performance_middleware.py | 42 +- backend/apps/core/middleware/rate_limiting.py | 103 +- .../apps/core/middleware/request_logging.py | 119 +-- .../apps/core/middleware/security_headers.py | 68 +- backend/apps/core/middleware/view_tracking.py | 50 +- ...0004_alter_slughistory_options_and_more.py | 12 +- backend/apps/core/mixins/__init__.py | 5 +- backend/apps/core/models.py | 6 +- backend/apps/core/permissions.py | 3 +- backend/apps/core/selectors.py | 28 +- .../apps/core/services/clustering_service.py | 31 +- backend/apps/core/services/data_structures.py | 10 +- .../core/services/enhanced_cache_service.py | 26 +- .../core/services/entity_fuzzy_matching.py | 47 +- .../apps/core/services/location_adapters.py | 119 +-- backend/apps/core/services/location_search.py | 110 +- .../apps/core/services/map_cache_service.py | 34 +- backend/apps/core/services/map_service.py | 73 +- backend/apps/core/services/media_service.py | 8 +- .../apps/core/services/media_url_service.py | 21 +- .../core/services/performance_monitoring.py | 47 +- .../apps/core/services/trending_service.py | 148 +-- backend/apps/core/state_machine/__init__.py | 1 + backend/apps/core/state_machine/builder.py | 20 +- .../apps/core/state_machine/callback_base.py | 54 +- .../core/state_machine/callbacks/cache.py | 79 +- .../state_machine/callbacks/notifications.py | 176 ++-- .../callbacks/related_updates.py | 112 +- backend/apps/core/state_machine/config.py | 190 ++-- backend/apps/core/state_machine/decorators.py | 67 +- backend/apps/core/state_machine/exceptions.py | 24 +- backend/apps/core/state_machine/fields.py | 13 +- backend/apps/core/state_machine/guards.py | 14 +- .../apps/core/state_machine/integration.py | 55 +- backend/apps/core/state_machine/mixins.py | 111 +- backend/apps/core/state_machine/monitoring.py | 76 +- backend/apps/core/state_machine/registry.py | 77 +- backend/apps/core/state_machine/signals.py | 79 +- .../apps/core/state_machine/tests/fixtures.py | 126 +-- .../apps/core/state_machine/tests/helpers.py | 100 +- .../core/state_machine/tests/test_builder.py | 1 + .../state_machine/tests/test_callbacks.py | 341 +++--- .../state_machine/tests/test_decorators.py | 23 +- .../core/state_machine/tests/test_guards.py | 267 ++--- .../state_machine/tests/test_integration.py | 17 +- .../core/state_machine/tests/test_registry.py | 61 +- .../state_machine/tests/test_validators.py | 25 +- backend/apps/core/state_machine/validators.py | 35 +- backend/apps/core/tasks/trending.py | 107 +- .../apps/core/templatetags/common_filters.py | 79 +- backend/apps/core/templatetags/fsm_tags.py | 126 ++- backend/apps/core/templatetags/safe_html.py | 107 +- backend/apps/core/tests/test_history.py | 10 +- backend/apps/core/urls/__init__.py | 4 +- backend/apps/core/utils/cloudflare.py | 9 +- backend/apps/core/utils/file_scanner.py | 122 +-- backend/apps/core/utils/html_sanitizer.py | 244 +++-- backend/apps/core/utils/query_optimization.py | 88 +- backend/apps/core/utils/turnstile.py | 41 +- backend/apps/core/views/base.py | 1 - backend/apps/core/views/entity_search.py | 29 +- backend/apps/core/views/map_views.py | 34 +- backend/apps/core/views/maps.py | 30 +- .../apps/core/views/performance_dashboard.py | 41 +- backend/apps/core/views/search.py | 10 +- backend/apps/core/views/views.py | 74 +- backend/apps/lists/admin.py | 4 + backend/apps/lists/views.py | 4 +- .../apps/media/commands/download_photos.py | 8 +- .../apps/media/commands/fix_photo_paths.py | 8 +- backend/apps/media/commands/move_photos.py | 20 +- backend/apps/media/models.py | 10 +- backend/apps/media/serializers.py | 5 +- backend/apps/moderation/admin.py | 24 +- backend/apps/moderation/apps.py | 85 +- backend/apps/moderation/choices.py | 982 ++++++++--------- backend/apps/moderation/context_processors.py | 11 +- backend/apps/moderation/filters.py | 68 +- .../commands/analyze_transitions.py | 226 ++-- .../management/commands/seed_submissions.py | 18 +- .../commands/validate_state_machines.py | 55 +- .../moderation/migrations/0001_initial.py | 8 +- ...perationevent_moderationaction_and_more.py | 36 +- ..._alter_moderationqueue_options_and_more.py | 176 +--- ...08_alter_bulkoperation_options_and_more.py | 128 +-- backend/apps/moderation/mixins.py | 8 +- backend/apps/moderation/models.py | 353 ++----- backend/apps/moderation/permissions.py | 8 +- backend/apps/moderation/selectors.py | 45 +- backend/apps/moderation/serializers.py | 148 +-- backend/apps/moderation/services.py | 218 ++-- backend/apps/moderation/signals.py | 169 +-- backend/apps/moderation/sse.py | 31 +- .../templatetags/moderation_tags.py | 8 +- backend/apps/moderation/tests.py | 447 ++++---- backend/apps/moderation/tests/test_admin.py | 17 +- .../apps/moderation/tests/test_workflows.py | 241 ++--- backend/apps/moderation/urls.py | 4 + backend/apps/moderation/views.py | 157 +-- backend/apps/parks/admin.py | 23 +- backend/apps/parks/apps.py | 29 +- backend/apps/parks/choices.py | 249 ++--- backend/apps/parks/filters.py | 6 +- backend/apps/parks/forms.py | 22 +- .../management/commands/create_sample_data.py | 20 +- .../management/commands/fix_migrations.py | 4 +- .../management/commands/seed_initial_data.py | 16 +- .../management/commands/seed_sample_data.py | 90 +- .../management/commands/test_location.py | 8 +- .../management/commands/update_park_counts.py | 8 +- backend/apps/parks/managers.py | 53 +- backend/apps/parks/migrations/0001_initial.py | 20 +- ...uartersevent_parklocationevent_and_more.py | 8 +- .../0008_parkphoto_parkphotoevent_and_more.py | 12 +- .../0015_populate_hybrid_filtering_fields.py | 22 +- .../0016_add_hybrid_filtering_indexes.py | 31 +- .../migrations/0019_fix_pghistory_timezone.py | 2 +- .../0020_fix_pghistory_update_timezone.py | 2 +- .../0023_add_company_roles_gin_index.py | 2 +- .../migrations/0024_add_timezone_default.py | 8 +- ...any_options_alter_park_options_and_more.py | 132 +-- backend/apps/parks/models/__init__.py | 2 +- backend/apps/parks/models/areas.py | 12 +- backend/apps/parks/models/companies.py | 20 +- backend/apps/parks/models/location.py | 9 +- backend/apps/parks/models/media.py | 28 +- backend/apps/parks/models/parks.py | 44 +- backend/apps/parks/models/reviews.py | 17 +- backend/apps/parks/querysets.py | 4 +- backend/apps/parks/selectors.py | 30 +- backend/apps/parks/services.py | 6 +- backend/apps/parks/services/filter_service.py | 33 +- backend/apps/parks/services/hybrid_loader.py | 312 +++--- .../apps/parks/services/location_service.py | 42 +- backend/apps/parks/services/media_service.py | 12 +- .../apps/parks/services/park_management.py | 10 +- backend/apps/parks/services/roadtrip.py | 77 +- backend/apps/parks/signals.py | 43 +- backend/apps/parks/templatetags/park_tags.py | 62 +- backend/apps/parks/tests.py | 242 ++--- .../apps/parks/tests/test_park_workflows.py | 262 ++--- .../parks/tests/test_query_optimization.py | 20 +- .../apps/parks/tests_disabled/test_filters.py | 8 +- .../apps/parks/tests_disabled/test_models.py | 24 +- backend/apps/parks/views.py | 116 +-- backend/apps/parks/views_roadtrip.py | 62 +- backend/apps/reviews/models.py | 10 +- backend/apps/reviews/signals.py | 11 +- backend/apps/rides/__init__.py | 2 +- backend/apps/rides/admin.py | 14 +- backend/apps/rides/apps.py | 44 +- backend/apps/rides/choices.py | 618 ++++------- backend/apps/rides/events.py | 2 - backend/apps/rides/forms.py | 4 +- backend/apps/rides/forms/base.py | 4 +- backend/apps/rides/forms/search.py | 85 +- .../commands/update_ride_rankings.py | 6 +- backend/apps/rides/managers.py | 39 +- backend/apps/rides/migrations/0001_initial.py | 24 +- ...levent_rollercoasterstatsevent_and_more.py | 20 +- .../migrations/0006_add_ride_rankings.py | 20 +- .../0007_ridephoto_ridephotoevent_and_more.py | 16 +- ...010_add_comprehensive_ride_model_system.py | 72 +- .../0012_make_ride_model_slug_unique.py | 4 +- .../0014_update_ride_model_slugs_data.py | 12 +- ...5_remove_company_insert_insert_and_more.py | 16 +- ...sert_remove_ride_update_update_and_more.py | 8 +- .../0019_populate_hybrid_filtering_fields.py | 26 +- .../0020_add_hybrid_filtering_indexes.py | 83 +- ..._convert_unique_together_to_constraints.py | 20 +- ..._alter_rankingsnapshot_options_and_more.py | 112 +- backend/apps/rides/mixins.py | 11 +- backend/apps/rides/models/company.py | 31 +- backend/apps/rides/models/credits.py | 20 +- backend/apps/rides/models/location.py | 8 +- backend/apps/rides/models/media.py | 14 +- backend/apps/rides/models/rankings.py | 50 +- backend/apps/rides/models/reviews.py | 17 +- backend/apps/rides/models/rides.py | 253 ++--- backend/apps/rides/selectors.py | 26 +- backend/apps/rides/services/__init__.py | 1 - backend/apps/rides/services/hybrid_loader.py | 727 +++++++------ .../apps/rides/services/location_service.py | 33 +- backend/apps/rides/services/media_service.py | 18 +- .../apps/rides/services/ranking_service.py | 106 +- backend/apps/rides/services/search.py | 127 +-- backend/apps/rides/services/status_service.py | 13 +- backend/apps/rides/services_core.py | 35 +- backend/apps/rides/signals.py | 82 +- backend/apps/rides/tasks.py | 8 +- backend/apps/rides/tests.py | 506 ++++----- .../apps/rides/tests/test_ride_workflows.py | 510 ++++----- backend/apps/rides/views.py | 43 +- backend/apps/support/models.py | 48 +- backend/apps/support/serializers.py | 9 +- backend/apps/support/views.py | 5 +- backend/config/django/base.py | 8 +- backend/config/django/local.py | 5 +- backend/config/django/production.py | 25 +- backend/config/settings/cache.py | 36 +- backend/config/settings/database.py | 15 +- backend/config/settings/email.py | 15 +- backend/config/settings/local.py | 0 backend/config/settings/logging.py | 5 +- backend/config/settings/rest_framework.py | 34 +- backend/config/settings/secrets.py | 39 +- backend/config/settings/security.py | 32 +- backend/config/settings/storage.py | 49 +- backend/config/settings/third_party.py | 48 +- backend/config/settings/validation.py | 50 +- backend/ensure_admin.py | 8 +- backend/scripts/benchmark_queries.py | 40 +- backend/stubs/environ.pyi | 1 + backend/templates/base/base.html | 4 +- backend/test_avatar_upload.py | 18 +- .../accessibility/test_wcag_compliance.py | 184 ++-- backend/tests/api/test_auth_api.py | 207 ++-- backend/tests/api/test_error_handling.py | 26 +- backend/tests/api/test_filters.py | 53 +- backend/tests/api/test_pagination.py | 5 +- backend/tests/api/test_parks_api.py | 183 ++-- backend/tests/api/test_response_format.py | 2 +- backend/tests/api/test_rides_api.py | 304 +++--- backend/tests/conftest.py | 8 +- backend/tests/e2e/conftest.py | 476 +++++++-- backend/tests/e2e/test_fsm_error_handling.py | 214 ++-- backend/tests/e2e/test_fsm_permissions.py | 154 +-- backend/tests/e2e/test_moderation_fsm.py | 108 +- backend/tests/e2e/test_park_browsing.py | 28 +- backend/tests/e2e/test_park_ride_fsm.py | 269 ++--- backend/tests/e2e/test_review_submission.py | 104 +- backend/tests/e2e/test_user_registration.py | 48 +- backend/tests/factories.py | 16 +- backend/tests/forms/test_park_forms.py | 1 - backend/tests/forms/test_ride_forms.py | 1 - .../integration/test_fsm_transition_view.py | 216 ++-- .../test_fsm_transition_workflow.py | 3 +- .../test_park_creation_workflow.py | 4 +- .../integration/test_photo_upload_workflow.py | 4 +- backend/tests/managers/test_core_managers.py | 11 +- backend/tests/managers/test_park_managers.py | 1 - .../test_contract_validation_middleware.py | 77 +- .../serializers/test_account_serializers.py | 3 - .../serializers/test_park_serializers.py | 12 +- .../serializers/test_ride_serializers.py | 10 +- .../tests/services/test_park_media_service.py | 8 +- backend/tests/services/test_ride_service.py | 16 +- backend/tests/test_factories.py | 4 +- backend/tests/test_parks_api.py | 12 +- backend/tests/test_utils.py | 8 +- backend/tests/utils/fsm_test_helpers.py | 161 +-- backend/tests/ux/test_breadcrumbs.py | 4 +- backend/tests/ux/test_messages.py | 7 +- backend/thrillwiki/views.py | 64 +- backend/verify_backend.py | 31 +- backend/verify_no_tuple_fallbacks.py | 40 +- 413 files changed, 11164 insertions(+), 17433 deletions(-) create mode 100644 backend/.flake8 create mode 100644 backend/apps/api/v1/responses.py create mode 100644 backend/config/settings/local.py diff --git a/backend/.flake8 b/backend/.flake8 new file mode 100644 index 00000000..9288cef7 --- /dev/null +++ b/backend/.flake8 @@ -0,0 +1,37 @@ +[flake8] +# Match Black and Ruff line length +max-line-length = 120 + +# Ignore rules that conflict with Black formatting or are handled by other tools +ignore = + # E203: whitespace before ':' - Black intentionally does this + E203, + # E501: line too long - handled by Black/Ruff + E501, + # W503: line break before binary operator - conflicts with Black + W503, + # E226: missing whitespace around arithmetic operator - Black style + E226, + # W391: blank line at end of file - not critical + W391, + # C901: function is too complex - these are intentional for complex business logic + C901, + # F401: imported but unused - star imports for choice registration are intentional + F401 + +# Exclude common directories +exclude = + .git, + __pycache__, + migrations, + .venv, + venv, + build, + dist, + *.egg-info, + node_modules, + htmlcov, + .pytest_cache + +# Complexity threshold - set high since we have intentional complex functions +max-complexity = 50 diff --git a/backend/apps/accounts/__init__.py b/backend/apps/accounts/__init__.py index e2210ac0..0eab0094 100644 --- a/backend/apps/accounts/__init__.py +++ b/backend/apps/accounts/__init__.py @@ -1,2 +1,2 @@ # Import choices to trigger registration -from .choices import * +from .choices import * # noqa: F403 diff --git a/backend/apps/accounts/admin.py b/backend/apps/accounts/admin.py index c40d4f03..bf17412f 100644 --- a/backend/apps/accounts/admin.py +++ b/backend/apps/accounts/admin.py @@ -77,8 +77,6 @@ class UserProfileInline(admin.StackedInline): ) - - @admin.register(User) class CustomUserAdmin(QueryOptimizationMixin, ExportActionMixin, UserAdmin): """ @@ -332,8 +330,9 @@ class CustomUserAdmin(QueryOptimizationMixin, ExportActionMixin, UserAdmin): try: profile = user.profile # Credits would be recalculated from ride history here - profile.save(update_fields=["coaster_credits", "dark_ride_credits", - "flat_ride_credits", "water_ride_credits"]) + profile.save( + update_fields=["coaster_credits", "dark_ride_credits", "flat_ride_credits", "water_ride_credits"] + ) count += 1 except UserProfile.DoesNotExist: pass @@ -442,12 +441,14 @@ class UserProfileAdmin(QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin @admin.display(description="Completeness") def profile_completeness(self, obj): """Display profile completeness indicator.""" - fields_filled = sum([ - bool(obj.display_name), - bool(obj.avatar), - bool(obj.bio), - bool(obj.twitter or obj.instagram or obj.youtube or obj.discord), - ]) + fields_filled = sum( + [ + bool(obj.display_name), + bool(obj.avatar), + bool(obj.bio), + bool(obj.twitter or obj.instagram or obj.youtube or obj.discord), + ] + ) percentage = (fields_filled / 4) * 100 color = "green" if percentage >= 75 else "orange" if percentage >= 50 else "red" return format_html( @@ -529,12 +530,8 @@ class EmailVerificationAdmin(QueryOptimizationMixin, BaseModelAdmin): def expiration_status(self, obj): """Display expiration status with color coding.""" if timezone.now() - obj.last_sent > timedelta(days=1): - return format_html( - 'Expired' - ) - return format_html( - 'Valid' - ) + return format_html('Expired') + return format_html('Valid') @admin.display(description="Can Resend", boolean=True) def can_resend(self, obj): @@ -665,6 +662,3 @@ class PasswordResetAdmin(ReadOnlyAdminMixin, BaseModelAdmin): "Cleanup old tokens", ) return actions - - - diff --git a/backend/apps/accounts/choices.py b/backend/apps/accounts/choices.py index 83fb5c42..4e377dd3 100644 --- a/backend/apps/accounts/choices.py +++ b/backend/apps/accounts/choices.py @@ -26,7 +26,7 @@ user_roles = ChoiceGroup( "css_class": "text-blue-600 bg-blue-50", "permissions": ["create_content", "create_reviews", "create_lists"], "sort_order": 1, - } + }, ), RichChoice( value="MODERATOR", @@ -38,7 +38,7 @@ user_roles = ChoiceGroup( "css_class": "text-green-600 bg-green-50", "permissions": ["moderate_content", "review_submissions", "manage_reports"], "sort_order": 2, - } + }, ), RichChoice( value="ADMIN", @@ -50,7 +50,7 @@ user_roles = ChoiceGroup( "css_class": "text-purple-600 bg-purple-50", "permissions": ["manage_users", "site_configuration", "advanced_moderation"], "sort_order": 3, - } + }, ), RichChoice( value="SUPERUSER", @@ -62,9 +62,9 @@ user_roles = ChoiceGroup( "css_class": "text-red-600 bg-red-50", "permissions": ["full_access", "system_administration", "database_access"], "sort_order": 4, - } + }, ), - ] + ], ) @@ -83,13 +83,9 @@ theme_preferences = ChoiceGroup( "color": "yellow", "icon": "sun", "css_class": "text-yellow-600 bg-yellow-50", - "preview_colors": { - "background": "#ffffff", - "text": "#1f2937", - "accent": "#3b82f6" - }, + "preview_colors": {"background": "#ffffff", "text": "#1f2937", "accent": "#3b82f6"}, "sort_order": 1, - } + }, ), RichChoice( value="dark", @@ -99,15 +95,11 @@ theme_preferences = ChoiceGroup( "color": "gray", "icon": "moon", "css_class": "text-gray-600 bg-gray-50", - "preview_colors": { - "background": "#1f2937", - "text": "#f9fafb", - "accent": "#60a5fa" - }, + "preview_colors": {"background": "#1f2937", "text": "#f9fafb", "accent": "#60a5fa"}, "sort_order": 2, - } + }, ), - ] + ], ) @@ -133,7 +125,7 @@ unit_systems = ChoiceGroup( "large_distance": "km", }, "sort_order": 1, - } + }, ), RichChoice( value="imperial", @@ -150,9 +142,9 @@ unit_systems = ChoiceGroup( "large_distance": "mi", }, "sort_order": 2, - } + }, ), - ] + ], ) @@ -177,10 +169,10 @@ privacy_levels = ChoiceGroup( "Profile visible to all users", "Activity appears in public feeds", "Searchable by search engines", - "Can be found by username search" + "Can be found by username search", ], "sort_order": 1, - } + }, ), RichChoice( value="friends", @@ -196,10 +188,10 @@ privacy_levels = ChoiceGroup( "Profile visible only to friends", "Activity hidden from public feeds", "Not searchable by search engines", - "Requires friend request approval" + "Requires friend request approval", ], "sort_order": 2, - } + }, ), RichChoice( value="private", @@ -215,12 +207,12 @@ privacy_levels = ChoiceGroup( "Profile completely hidden", "No activity in any feeds", "Not discoverable by other users", - "Maximum privacy protection" + "Maximum privacy protection", ], "sort_order": 3, - } + }, ), - ] + ], ) @@ -242,7 +234,7 @@ top_list_categories = ChoiceGroup( "ride_category": "roller_coaster", "typical_list_size": 10, "sort_order": 1, - } + }, ), RichChoice( value="DR", @@ -255,7 +247,7 @@ top_list_categories = ChoiceGroup( "ride_category": "dark_ride", "typical_list_size": 10, "sort_order": 2, - } + }, ), RichChoice( value="FR", @@ -268,7 +260,7 @@ top_list_categories = ChoiceGroup( "ride_category": "flat_ride", "typical_list_size": 10, "sort_order": 3, - } + }, ), RichChoice( value="WR", @@ -281,7 +273,7 @@ top_list_categories = ChoiceGroup( "ride_category": "water_ride", "typical_list_size": 10, "sort_order": 4, - } + }, ), RichChoice( value="PK", @@ -294,9 +286,9 @@ top_list_categories = ChoiceGroup( "entity_type": "park", "typical_list_size": 10, "sort_order": 5, - } + }, ), - ] + ], ) @@ -320,7 +312,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "push", "inapp"], "priority": "normal", "sort_order": 1, - } + }, ), RichChoice( value="submission_rejected", @@ -334,7 +326,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "push", "inapp"], "priority": "normal", "sort_order": 2, - } + }, ), RichChoice( value="submission_pending", @@ -348,7 +340,7 @@ notification_types = ChoiceGroup( "default_channels": ["inapp"], "priority": "low", "sort_order": 3, - } + }, ), # Review related RichChoice( @@ -363,7 +355,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "push", "inapp"], "priority": "normal", "sort_order": 4, - } + }, ), RichChoice( value="review_helpful", @@ -377,7 +369,7 @@ notification_types = ChoiceGroup( "default_channels": ["push", "inapp"], "priority": "low", "sort_order": 5, - } + }, ), # Social related RichChoice( @@ -392,7 +384,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "push", "inapp"], "priority": "normal", "sort_order": 6, - } + }, ), RichChoice( value="friend_accepted", @@ -406,7 +398,7 @@ notification_types = ChoiceGroup( "default_channels": ["push", "inapp"], "priority": "low", "sort_order": 7, - } + }, ), RichChoice( value="message_received", @@ -420,7 +412,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "push", "inapp"], "priority": "normal", "sort_order": 8, - } + }, ), RichChoice( value="profile_comment", @@ -434,7 +426,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "push", "inapp"], "priority": "normal", "sort_order": 9, - } + }, ), # System related RichChoice( @@ -449,7 +441,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "inapp"], "priority": "normal", "sort_order": 10, - } + }, ), RichChoice( value="account_security", @@ -463,7 +455,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "push", "inapp"], "priority": "high", "sort_order": 11, - } + }, ), RichChoice( value="feature_update", @@ -477,7 +469,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "inapp"], "priority": "low", "sort_order": 12, - } + }, ), RichChoice( value="maintenance", @@ -491,7 +483,7 @@ notification_types = ChoiceGroup( "default_channels": ["email", "inapp"], "priority": "normal", "sort_order": 13, - } + }, ), # Achievement related RichChoice( @@ -506,7 +498,7 @@ notification_types = ChoiceGroup( "default_channels": ["push", "inapp"], "priority": "low", "sort_order": 14, - } + }, ), RichChoice( value="milestone_reached", @@ -520,9 +512,9 @@ notification_types = ChoiceGroup( "default_channels": ["push", "inapp"], "priority": "low", "sort_order": 15, - } + }, ), - ] + ], ) @@ -545,7 +537,7 @@ notification_priorities = ChoiceGroup( "batch_eligible": True, "delay_minutes": 60, "sort_order": 1, - } + }, ), RichChoice( value="normal", @@ -559,7 +551,7 @@ notification_priorities = ChoiceGroup( "batch_eligible": True, "delay_minutes": 15, "sort_order": 2, - } + }, ), RichChoice( value="high", @@ -573,7 +565,7 @@ notification_priorities = ChoiceGroup( "batch_eligible": False, "delay_minutes": 0, "sort_order": 3, - } + }, ), RichChoice( value="urgent", @@ -588,9 +580,9 @@ notification_priorities = ChoiceGroup( "delay_minutes": 0, "bypass_preferences": True, "sort_order": 4, - } + }, ), - ] + ], ) diff --git a/backend/apps/accounts/export_service.py b/backend/apps/accounts/export_service.py index 66b363f5..21bdbe15 100644 --- a/backend/apps/accounts/export_service.py +++ b/backend/apps/accounts/export_service.py @@ -53,28 +53,34 @@ class UserExportService: "dark_ride": profile.dark_ride_credits, "flat_ride": profile.flat_ride_credits, "water_ride": profile.water_ride_credits, - } + }, } # Reviews - park_reviews = list(ParkReview.objects.filter(user=user).values( - "park__name", "rating", "review", "created_at", "updated_at", "is_published" - )) + park_reviews = list( + ParkReview.objects.filter(user=user).values( + "park__name", "rating", "review", "created_at", "updated_at", "is_published" + ) + ) - ride_reviews = list(RideReview.objects.filter(user=user).values( - "ride__name", "rating", "review", "created_at", "updated_at", "is_published" - )) + ride_reviews = list( + RideReview.objects.filter(user=user).values( + "ride__name", "rating", "review", "created_at", "updated_at", "is_published" + ) + ) # Lists user_lists = [] for user_list in UserList.objects.filter(user=user): items = list(user_list.items.values("order", "content_type__model", "object_id", "comment")) - user_lists.append({ - "title": user_list.title, - "description": user_list.description, - "created_at": user_list.created_at, - "items": items - }) + user_lists.append( + { + "title": user_list.title, + "description": user_list.description, + "created_at": user_list.created_at, + "items": items, + } + ) export_data = { "account": user_data, @@ -85,10 +91,7 @@ class UserExportService: "ride_reviews": ride_reviews, "lists": user_lists, }, - "export_info": { - "generated_at": timezone.now(), - "version": "1.0" - } + "export_info": {"generated_at": timezone.now(), "version": "1.0"}, } return export_data diff --git a/backend/apps/accounts/login_history.py b/backend/apps/accounts/login_history.py index 2d914c19..0fe1c0c7 100644 --- a/backend/apps/accounts/login_history.py +++ b/backend/apps/accounts/login_history.py @@ -99,8 +99,6 @@ class LoginHistory(models.Model): # Default cleanup for entries older than the specified days cutoff = timezone.now() - timedelta(days=days) - deleted_count, _ = cls.objects.filter( - login_timestamp__lt=cutoff - ).delete() + deleted_count, _ = cls.objects.filter(login_timestamp__lt=cutoff).delete() return deleted_count diff --git a/backend/apps/accounts/management/commands/check_all_social_tables.py b/backend/apps/accounts/management/commands/check_all_social_tables.py index 79ec30d4..689b75c4 100644 --- a/backend/apps/accounts/management/commands/check_all_social_tables.py +++ b/backend/apps/accounts/management/commands/check_all_social_tables.py @@ -22,20 +22,14 @@ class Command(BaseCommand): # Check SocialAccount self.stdout.write("\nChecking SocialAccount table:") for account in SocialAccount.objects.all(): - self.stdout.write( - f"ID: {account.pk}, Provider: {account.provider}, UID: {account.uid}" - ) + self.stdout.write(f"ID: {account.pk}, Provider: {account.provider}, UID: {account.uid}") # Check SocialToken self.stdout.write("\nChecking SocialToken table:") for token in SocialToken.objects.all(): - self.stdout.write( - f"ID: {token.pk}, Account: {token.account}, App: {token.app}" - ) + self.stdout.write(f"ID: {token.pk}, Account: {token.account}, App: {token.app}") # Check Site self.stdout.write("\nChecking Site table:") for site in Site.objects.all(): - self.stdout.write( - f"ID: {site.pk}, Domain: {site.domain}, Name: {site.name}" - ) + self.stdout.write(f"ID: {site.pk}, Domain: {site.domain}, Name: {site.name}") diff --git a/backend/apps/accounts/management/commands/check_social_apps.py b/backend/apps/accounts/management/commands/check_social_apps.py index 4a3980b8..d3db002a 100644 --- a/backend/apps/accounts/management/commands/check_social_apps.py +++ b/backend/apps/accounts/management/commands/check_social_apps.py @@ -17,6 +17,4 @@ class Command(BaseCommand): self.stdout.write(f"Name: {app.name}") self.stdout.write(f"Client ID: {app.client_id}") self.stdout.write(f"Secret: {app.secret}") - self.stdout.write( - f"Sites: {', '.join(str(site.domain) for site in app.sites.all())}" - ) + self.stdout.write(f"Sites: {', '.join(str(site.domain) for site in app.sites.all())}") diff --git a/backend/apps/accounts/management/commands/cleanup_social_auth.py b/backend/apps/accounts/management/commands/cleanup_social_auth.py index 56e7d8fb..16238bc4 100644 --- a/backend/apps/accounts/management/commands/cleanup_social_auth.py +++ b/backend/apps/accounts/management/commands/cleanup_social_auth.py @@ -15,14 +15,9 @@ class Command(BaseCommand): # Remove migration records cursor.execute("DELETE FROM django_migrations WHERE app='socialaccount'") - cursor.execute( - "DELETE FROM django_migrations WHERE app='accounts' " - "AND name LIKE '%social%'" - ) + cursor.execute("DELETE FROM django_migrations WHERE app='accounts' " "AND name LIKE '%social%'") # Reset sequences cursor.execute("DELETE FROM sqlite_sequence WHERE name LIKE '%social%'") - self.stdout.write( - self.style.SUCCESS("Successfully cleaned up social auth configuration") - ) + self.stdout.write(self.style.SUCCESS("Successfully cleaned up social auth configuration")) diff --git a/backend/apps/accounts/management/commands/cleanup_test_data.py b/backend/apps/accounts/management/commands/cleanup_test_data.py index b819752e..995eeec5 100644 --- a/backend/apps/accounts/management/commands/cleanup_test_data.py +++ b/backend/apps/accounts/management/commands/cleanup_test_data.py @@ -18,24 +18,18 @@ class Command(BaseCommand): self.stdout.write(self.style.SUCCESS(f"Deleted {count} test users")) # Delete test reviews - reviews = ParkReview.objects.filter( - user__username__in=["testuser", "moderator"] - ) + reviews = ParkReview.objects.filter(user__username__in=["testuser", "moderator"]) count = reviews.count() reviews.delete() self.stdout.write(self.style.SUCCESS(f"Deleted {count} test reviews")) # Delete test photos - both park and ride photos - park_photos = ParkPhoto.objects.filter( - uploader__username__in=["testuser", "moderator"] - ) + park_photos = ParkPhoto.objects.filter(uploader__username__in=["testuser", "moderator"]) park_count = park_photos.count() park_photos.delete() self.stdout.write(self.style.SUCCESS(f"Deleted {park_count} test park photos")) - ride_photos = RidePhoto.objects.filter( - uploader__username__in=["testuser", "moderator"] - ) + ride_photos = RidePhoto.objects.filter(uploader__username__in=["testuser", "moderator"]) ride_count = ride_photos.count() ride_photos.delete() self.stdout.write(self.style.SUCCESS(f"Deleted {ride_count} test ride photos")) diff --git a/backend/apps/accounts/management/commands/create_social_apps.py b/backend/apps/accounts/management/commands/create_social_apps.py index 4e678581..ed307f92 100644 --- a/backend/apps/accounts/management/commands/create_social_apps.py +++ b/backend/apps/accounts/management/commands/create_social_apps.py @@ -37,18 +37,12 @@ class Command(BaseCommand): provider="google", defaults={ "name": "Google", - "client_id": ( - "135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." - "apps.googleusercontent.com" - ), + "client_id": ("135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." "apps.googleusercontent.com"), "secret": "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue", }, ) if not created: - google_app.client_id = ( - "135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." - "apps.googleusercontent.com" - ) + google_app.client_id = "135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." "apps.googleusercontent.com" google_app.secret = "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue" google_app.save() google_app.sites.add(site) diff --git a/backend/apps/accounts/management/commands/create_test_users.py b/backend/apps/accounts/management/commands/create_test_users.py index 70c445ad..21ee200a 100644 --- a/backend/apps/accounts/management/commands/create_test_users.py +++ b/backend/apps/accounts/management/commands/create_test_users.py @@ -14,9 +14,7 @@ class Command(BaseCommand): ) user.set_password("testpass123") user.save() - self.stdout.write( - self.style.SUCCESS(f"Created test user: {user.get_username()}") - ) + self.stdout.write(self.style.SUCCESS(f"Created test user: {user.get_username()}")) else: self.stdout.write(self.style.WARNING("Test user already exists")) @@ -47,11 +45,7 @@ class Command(BaseCommand): # Add user to moderator group moderator.groups.add(moderator_group) - self.stdout.write( - self.style.SUCCESS( - f"Created moderator user: {moderator.get_username()}" - ) - ) + self.stdout.write(self.style.SUCCESS(f"Created moderator user: {moderator.get_username()}")) else: self.stdout.write(self.style.WARNING("Moderator user already exists")) diff --git a/backend/apps/accounts/management/commands/delete_user.py b/backend/apps/accounts/management/commands/delete_user.py index 46aeed4e..d266c870 100644 --- a/backend/apps/accounts/management/commands/delete_user.py +++ b/backend/apps/accounts/management/commands/delete_user.py @@ -17,9 +17,7 @@ class Command(BaseCommand): help = "Delete a user while preserving all their submissions" def add_arguments(self, parser): - parser.add_argument( - "username", nargs="?", type=str, help="Username of the user to delete" - ) + parser.add_argument("username", nargs="?", type=str, help="Username of the user to delete") parser.add_argument( "--user-id", type=str, @@ -30,9 +28,7 @@ class Command(BaseCommand): action="store_true", help="Show what would be deleted without actually deleting", ) - parser.add_argument( - "--force", action="store_true", help="Skip confirmation prompt" - ) + parser.add_argument("--force", action="store_true", help="Skip confirmation prompt") def handle(self, *args, **options): username = options.get("username") @@ -52,7 +48,7 @@ class Command(BaseCommand): user = User.objects.get(username=username) if username else User.objects.get(user_id=user_id) except User.DoesNotExist: identifier = username or user_id - raise CommandError(f'User "{identifier}" does not exist') + raise CommandError(f'User "{identifier}" does not exist') from None # Check if user can be deleted can_delete, reason = UserDeletionService.can_delete_user(user) @@ -61,27 +57,13 @@ class Command(BaseCommand): # Count submissions submission_counts = { - "park_reviews": getattr( - user, "park_reviews", user.__class__.objects.none() - ).count(), - "ride_reviews": getattr( - user, "ride_reviews", user.__class__.objects.none() - ).count(), - "uploaded_park_photos": getattr( - user, "uploaded_park_photos", user.__class__.objects.none() - ).count(), - "uploaded_ride_photos": getattr( - user, "uploaded_ride_photos", user.__class__.objects.none() - ).count(), - "top_lists": getattr( - user, "top_lists", user.__class__.objects.none() - ).count(), - "edit_submissions": getattr( - user, "edit_submissions", user.__class__.objects.none() - ).count(), - "photo_submissions": getattr( - user, "photo_submissions", user.__class__.objects.none() - ).count(), + "park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(), + "ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(), + "uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(), + "uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(), + "top_lists": getattr(user, "top_lists", user.__class__.objects.none()).count(), + "edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(), + "photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(), } total_submissions = sum(submission_counts.values()) @@ -98,9 +80,7 @@ class Command(BaseCommand): self.stdout.write(self.style.WARNING("\nSubmissions to preserve:")) for submission_type, count in submission_counts.items(): if count > 0: - self.stdout.write( - f' {submission_type.replace("_", " ").title()}: {count}' - ) + self.stdout.write(f' {submission_type.replace("_", " ").title()}: {count}') self.stdout.write(f"\nTotal submissions: {total_submissions}") @@ -111,9 +91,7 @@ class Command(BaseCommand): ) ) else: - self.stdout.write( - self.style.WARNING("\nNo submissions found for this user.") - ) + self.stdout.write(self.style.WARNING("\nNo submissions found for this user.")) if dry_run: self.stdout.write(self.style.SUCCESS("\n[DRY RUN] No changes were made.")) @@ -136,11 +114,7 @@ class Command(BaseCommand): try: result = UserDeletionService.delete_user_preserve_submissions(user) - self.stdout.write( - self.style.SUCCESS( - f'\nSuccessfully deleted user "{result["deleted_user"]["username"]}"' - ) - ) + self.stdout.write(self.style.SUCCESS(f'\nSuccessfully deleted user "{result["deleted_user"]["username"]}"')) preserved_count = sum(result["preserved_submissions"].values()) if preserved_count > 0: @@ -154,9 +128,7 @@ class Command(BaseCommand): self.stdout.write(self.style.WARNING("\nPreservation Summary:")) for submission_type, count in result["preserved_submissions"].items(): if count > 0: - self.stdout.write( - f' {submission_type.replace("_", " ").title()}: {count}' - ) + self.stdout.write(f' {submission_type.replace("_", " ").title()}: {count}') except Exception as e: - raise CommandError(f"Error deleting user: {str(e)}") + raise CommandError(f"Error deleting user: {str(e)}") from None diff --git a/backend/apps/accounts/management/commands/fix_migration_history.py b/backend/apps/accounts/management/commands/fix_migration_history.py index 3a8eafe1..390260b1 100644 --- a/backend/apps/accounts/management/commands/fix_migration_history.py +++ b/backend/apps/accounts/management/commands/fix_migration_history.py @@ -7,12 +7,5 @@ class Command(BaseCommand): def handle(self, *args, **kwargs): with connection.cursor() as cursor: - cursor.execute( - "DELETE FROM django_migrations WHERE app='rides' " - "AND name='0001_initial';" - ) - self.stdout.write( - self.style.SUCCESS( - "Successfully removed rides.0001_initial from migration history" - ) - ) + cursor.execute("DELETE FROM django_migrations WHERE app='rides' " "AND name='0001_initial';") + self.stdout.write(self.style.SUCCESS("Successfully removed rides.0001_initial from migration history")) diff --git a/backend/apps/accounts/management/commands/fix_social_apps.py b/backend/apps/accounts/management/commands/fix_social_apps.py index ab3d5397..5dee4552 100644 --- a/backend/apps/accounts/management/commands/fix_social_apps.py +++ b/backend/apps/accounts/management/commands/fix_social_apps.py @@ -34,6 +34,4 @@ class Command(BaseCommand): secret=os.getenv("DISCORD_CLIENT_SECRET"), ) discord_app.sites.add(site) - self.stdout.write( - f"Created Discord app with client_id: {discord_app.client_id}" - ) + self.stdout.write(f"Created Discord app with client_id: {discord_app.client_id}") diff --git a/backend/apps/accounts/management/commands/generate_letter_avatars.py b/backend/apps/accounts/management/commands/generate_letter_avatars.py index 9a4c4fee..cb645f7a 100644 --- a/backend/apps/accounts/management/commands/generate_letter_avatars.py +++ b/backend/apps/accounts/management/commands/generate_letter_avatars.py @@ -47,9 +47,7 @@ class Command(BaseCommand): help = "Generate avatars for letters A-Z and numbers 0-9" def handle(self, *args, **kwargs): - characters = [chr(i) for i in range(65, 91)] + [ - str(i) for i in range(10) - ] # A-Z and 0-9 + characters = [chr(i) for i in range(65, 91)] + [str(i) for i in range(10)] # A-Z and 0-9 for char in characters: generate_avatar(char) self.stdout.write(self.style.SUCCESS(f"Generated avatar for {char}")) diff --git a/backend/apps/accounts/management/commands/regenerate_avatars.py b/backend/apps/accounts/management/commands/regenerate_avatars.py index 95c6cabd..c23e7b92 100644 --- a/backend/apps/accounts/management/commands/regenerate_avatars.py +++ b/backend/apps/accounts/management/commands/regenerate_avatars.py @@ -11,6 +11,4 @@ class Command(BaseCommand): for profile in profiles: # This will trigger the avatar generation logic in the save method profile.save() - self.stdout.write( - self.style.SUCCESS(f"Regenerated avatar for {profile.user.username}") - ) + self.stdout.write(self.style.SUCCESS(f"Regenerated avatar for {profile.user.username}")) diff --git a/backend/apps/accounts/management/commands/reset_db.py b/backend/apps/accounts/management/commands/reset_db.py index cd3657f0..a5f1f222 100644 --- a/backend/apps/accounts/management/commands/reset_db.py +++ b/backend/apps/accounts/management/commands/reset_db.py @@ -69,18 +69,18 @@ class Command(BaseCommand): # Security: Using Django ORM instead of raw SQL for user creation user = User.objects.create_superuser( - username='admin', - email='admin@thrillwiki.com', - password='admin', - role='SUPERUSER', + username="admin", + email="admin@thrillwiki.com", + password="admin", + role="SUPERUSER", ) # Create profile using ORM UserProfile.objects.create( user=user, - display_name='Admin', - pronouns='they/them', - bio='ThrillWiki Administrator', + display_name="Admin", + pronouns="they/them", + bio="ThrillWiki Administrator", ) self.stdout.write("Superuser created.") diff --git a/backend/apps/accounts/management/commands/reset_social_apps.py b/backend/apps/accounts/management/commands/reset_social_apps.py index 40ba7b7e..4459b9d1 100644 --- a/backend/apps/accounts/management/commands/reset_social_apps.py +++ b/backend/apps/accounts/management/commands/reset_social_apps.py @@ -30,9 +30,7 @@ class Command(BaseCommand): google_app = SocialApp.objects.create( provider="google", name="Google", - client_id=( - "135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2.apps.googleusercontent.com" - ), + client_id=("135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2.apps.googleusercontent.com"), secret="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", ) google_app.sites.add(site) diff --git a/backend/apps/accounts/management/commands/reset_social_auth.py b/backend/apps/accounts/management/commands/reset_social_auth.py index 5dbc7707..4a5f77aa 100644 --- a/backend/apps/accounts/management/commands/reset_social_auth.py +++ b/backend/apps/accounts/management/commands/reset_social_auth.py @@ -12,13 +12,7 @@ class Command(BaseCommand): cursor.execute("DELETE FROM socialaccount_socialapp_sites") # Reset sequences - cursor.execute( - "DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp'" - ) - cursor.execute( - "DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp_sites'" - ) + cursor.execute("DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp'") + cursor.execute("DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp_sites'") - self.stdout.write( - self.style.SUCCESS("Successfully reset social auth configuration") - ) + self.stdout.write(self.style.SUCCESS("Successfully reset social auth configuration")) diff --git a/backend/apps/accounts/management/commands/setup_groups.py b/backend/apps/accounts/management/commands/setup_groups.py index 12ab3051..3f7bc243 100644 --- a/backend/apps/accounts/management/commands/setup_groups.py +++ b/backend/apps/accounts/management/commands/setup_groups.py @@ -30,9 +30,7 @@ class Command(BaseCommand): user.is_staff = True user.save() - self.stdout.write( - self.style.SUCCESS("Successfully set up groups and permissions") - ) + self.stdout.write(self.style.SUCCESS("Successfully set up groups and permissions")) # Print summary for group in Group.objects.all(): diff --git a/backend/apps/accounts/management/commands/setup_site.py b/backend/apps/accounts/management/commands/setup_site.py index 00f2b491..efc3c399 100644 --- a/backend/apps/accounts/management/commands/setup_site.py +++ b/backend/apps/accounts/management/commands/setup_site.py @@ -10,7 +10,5 @@ class Command(BaseCommand): Site.objects.all().delete() # Create default site - site = Site.objects.create( - id=1, domain="localhost:8000", name="ThrillWiki Development" - ) + site = Site.objects.create(id=1, domain="localhost:8000", name="ThrillWiki Development") self.stdout.write(self.style.SUCCESS(f"Created site: {site.domain}")) diff --git a/backend/apps/accounts/management/commands/setup_social_auth.py b/backend/apps/accounts/management/commands/setup_social_auth.py index 3763718a..b6c5bb5f 100644 --- a/backend/apps/accounts/management/commands/setup_social_auth.py +++ b/backend/apps/accounts/management/commands/setup_social_auth.py @@ -49,27 +49,15 @@ class Command(BaseCommand): discord_client_secret, ] ): - self.stdout.write( - self.style.ERROR("Missing required environment variables") - ) - self.stdout.write( - f"DEBUG: google_client_id is None: {google_client_id is None}" - ) - self.stdout.write( - f"DEBUG: google_client_secret is None: {google_client_secret is None}" - ) - self.stdout.write( - f"DEBUG: discord_client_id is None: {discord_client_id is None}" - ) - self.stdout.write( - f"DEBUG: discord_client_secret is None: {discord_client_secret is None}" - ) + self.stdout.write(self.style.ERROR("Missing required environment variables")) + self.stdout.write(f"DEBUG: google_client_id is None: {google_client_id is None}") + self.stdout.write(f"DEBUG: google_client_secret is None: {google_client_secret is None}") + self.stdout.write(f"DEBUG: discord_client_id is None: {discord_client_id is None}") + self.stdout.write(f"DEBUG: discord_client_secret is None: {discord_client_secret is None}") return # Get or create the default site - site, _ = Site.objects.get_or_create( - id=1, defaults={"domain": "localhost:8000", "name": "localhost"} - ) + site, _ = Site.objects.get_or_create(id=1, defaults={"domain": "localhost:8000", "name": "localhost"}) # Set up Google google_app, created = SocialApp.objects.get_or_create( @@ -92,11 +80,7 @@ class Command(BaseCommand): google_app.save() self.stdout.write("DEBUG: Successfully updated Google app") else: - self.stdout.write( - self.style.ERROR( - "Google client_id or secret is None, skipping update." - ) - ) + self.stdout.write(self.style.ERROR("Google client_id or secret is None, skipping update.")) google_app.sites.add(site) # Set up Discord @@ -120,11 +104,7 @@ class Command(BaseCommand): discord_app.save() self.stdout.write("DEBUG: Successfully updated Discord app") else: - self.stdout.write( - self.style.ERROR( - "Discord client_id or secret is None, skipping update." - ) - ) + self.stdout.write(self.style.ERROR("Discord client_id or secret is None, skipping update.")) discord_app.sites.add(site) self.stdout.write(self.style.SUCCESS("Successfully set up social auth apps")) diff --git a/backend/apps/accounts/management/commands/setup_social_providers.py b/backend/apps/accounts/management/commands/setup_social_providers.py index 08cf2bdf..3be88d57 100644 --- a/backend/apps/accounts/management/commands/setup_social_providers.py +++ b/backend/apps/accounts/management/commands/setup_social_providers.py @@ -42,6 +42,4 @@ class Command(BaseCommand): for app in SocialApp.objects.all(): self.stdout.write(f"- {app.name} ({app.provider}): {app.client_id}") - self.stdout.write( - self.style.SUCCESS(f"\nTotal social apps: {SocialApp.objects.count()}") - ) + self.stdout.write(self.style.SUCCESS(f"\nTotal social apps: {SocialApp.objects.count()}")) diff --git a/backend/apps/accounts/management/commands/test_discord_auth.py b/backend/apps/accounts/management/commands/test_discord_auth.py index 30428530..6ec73785 100644 --- a/backend/apps/accounts/management/commands/test_discord_auth.py +++ b/backend/apps/accounts/management/commands/test_discord_auth.py @@ -40,9 +40,7 @@ class Command(BaseCommand): # Show callback URL callback_url = "http://localhost:8000/accounts/discord/login/callback/" - self.stdout.write( - "\nCallback URL to configure in Discord Developer Portal:" - ) + self.stdout.write("\nCallback URL to configure in Discord Developer Portal:") self.stdout.write(callback_url) # Show frontend login URL diff --git a/backend/apps/accounts/management/commands/update_social_apps_sites.py b/backend/apps/accounts/management/commands/update_social_apps_sites.py index 9b1b9dd1..2f7b151c 100644 --- a/backend/apps/accounts/management/commands/update_social_apps_sites.py +++ b/backend/apps/accounts/management/commands/update_social_apps_sites.py @@ -18,6 +18,4 @@ class Command(BaseCommand): # Add all sites for site in sites: app.sites.add(site) - self.stdout.write( - f"Added sites: {', '.join(site.domain for site in sites)}" - ) + self.stdout.write(f"Added sites: {', '.join(site.domain for site in sites)}") diff --git a/backend/apps/accounts/management/commands/verify_discord_settings.py b/backend/apps/accounts/management/commands/verify_discord_settings.py index 80892249..a453c237 100644 --- a/backend/apps/accounts/management/commands/verify_discord_settings.py +++ b/backend/apps/accounts/management/commands/verify_discord_settings.py @@ -22,17 +22,13 @@ class Command(BaseCommand): # Show callback URL callback_url = "http://localhost:8000/accounts/discord/login/callback/" - self.stdout.write( - "\nCallback URL to configure in Discord Developer Portal:" - ) + self.stdout.write("\nCallback URL to configure in Discord Developer Portal:") self.stdout.write(callback_url) # Show OAuth2 settings self.stdout.write("\nOAuth2 settings in settings.py:") discord_settings = settings.SOCIALACCOUNT_PROVIDERS.get("discord", {}) - self.stdout.write( - f"PKCE Enabled: {discord_settings.get('OAUTH_PKCE_ENABLED', False)}" - ) + self.stdout.write(f"PKCE Enabled: {discord_settings.get('OAUTH_PKCE_ENABLED', False)}") self.stdout.write(f"Scopes: {discord_settings.get('SCOPE', [])}") except SocialApp.DoesNotExist: diff --git a/backend/apps/accounts/migrations/0001_initial.py b/backend/apps/accounts/migrations/0001_initial.py index 544048ac..9e9a946a 100644 --- a/backend/apps/accounts/migrations/0001_initial.py +++ b/backend/apps/accounts/migrations/0001_initial.py @@ -38,9 +38,7 @@ class Migration(migrations.Migration): ), ( "last_login", - models.DateTimeField( - blank=True, null=True, verbose_name="last login" - ), + models.DateTimeField(blank=True, null=True, verbose_name="last login"), ), ( "is_superuser", @@ -53,29 +51,21 @@ class Migration(migrations.Migration): ( "username", models.CharField( - error_messages={ - "unique": "A user with that username already exists." - }, + error_messages={"unique": "A user with that username already exists."}, help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", max_length=150, unique=True, - validators=[ - django.contrib.auth.validators.UnicodeUsernameValidator() - ], + validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name="username", ), ), ( "first_name", - models.CharField( - blank=True, max_length=150, verbose_name="first name" - ), + models.CharField(blank=True, max_length=150, verbose_name="first name"), ), ( "last_name", - models.CharField( - blank=True, max_length=150, verbose_name="last name" - ), + models.CharField(blank=True, max_length=150, verbose_name="last name"), ), ( "email", diff --git a/backend/apps/accounts/migrations/0003_emailverificationevent_passwordresetevent_userevent_and_more.py b/backend/apps/accounts/migrations/0003_emailverificationevent_passwordresetevent_userevent_and_more.py index 96c31203..13d37d2b 100644 --- a/backend/apps/accounts/migrations/0003_emailverificationevent_passwordresetevent_userevent_and_more.py +++ b/backend/apps/accounts/migrations/0003_emailverificationevent_passwordresetevent_userevent_and_more.py @@ -57,9 +57,7 @@ class Migration(migrations.Migration): ("password", models.CharField(max_length=128, verbose_name="password")), ( "last_login", - models.DateTimeField( - blank=True, null=True, verbose_name="last login" - ), + models.DateTimeField(blank=True, null=True, verbose_name="last login"), ), ( "is_superuser", @@ -72,34 +70,24 @@ class Migration(migrations.Migration): ( "username", models.CharField( - error_messages={ - "unique": "A user with that username already exists." - }, + error_messages={"unique": "A user with that username already exists."}, help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", max_length=150, - validators=[ - django.contrib.auth.validators.UnicodeUsernameValidator() - ], + validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name="username", ), ), ( "first_name", - models.CharField( - blank=True, max_length=150, verbose_name="first name" - ), + models.CharField(blank=True, max_length=150, verbose_name="first name"), ), ( "last_name", - models.CharField( - blank=True, max_length=150, verbose_name="last name" - ), + models.CharField(blank=True, max_length=150, verbose_name="last name"), ), ( "email", - models.EmailField( - blank=True, max_length=254, verbose_name="email address" - ), + models.EmailField(blank=True, max_length=254, verbose_name="email address"), ), ( "is_staff", @@ -119,9 +107,7 @@ class Migration(migrations.Migration): ), ( "date_joined", - models.DateTimeField( - default=django.utils.timezone.now, verbose_name="date joined" - ), + models.DateTimeField(default=django.utils.timezone.now, verbose_name="date joined"), ), ( "user_id", diff --git a/backend/apps/accounts/migrations/0004_userdeletionrequest_userdeletionrequestevent_and_more.py b/backend/apps/accounts/migrations/0004_userdeletionrequest_userdeletionrequestevent_and_more.py index 43e83fdf..f74d5d22 100644 --- a/backend/apps/accounts/migrations/0004_userdeletionrequest_userdeletionrequestevent_and_more.py +++ b/backend/apps/accounts/migrations/0004_userdeletionrequest_userdeletionrequestevent_and_more.py @@ -41,9 +41,7 @@ class Migration(migrations.Migration): ("created_at", models.DateTimeField(auto_now_add=True)), ( "expires_at", - models.DateTimeField( - help_text="When this deletion request expires" - ), + models.DateTimeField(help_text="When this deletion request expires"), ), ( "email_sent_at", @@ -55,9 +53,7 @@ class Migration(migrations.Migration): ), ( "attempts", - models.PositiveIntegerField( - default=0, help_text="Number of verification attempts made" - ), + models.PositiveIntegerField(default=0, help_text="Number of verification attempts made"), ), ( "max_attempts", @@ -103,9 +99,7 @@ class Migration(migrations.Migration): ("created_at", models.DateTimeField(auto_now_add=True)), ( "expires_at", - models.DateTimeField( - help_text="When this deletion request expires" - ), + models.DateTimeField(help_text="When this deletion request expires"), ), ( "email_sent_at", @@ -117,9 +111,7 @@ class Migration(migrations.Migration): ), ( "attempts", - models.PositiveIntegerField( - default=0, help_text="Number of verification attempts made" - ), + models.PositiveIntegerField(default=0, help_text="Number of verification attempts made"), ), ( "max_attempts", @@ -171,21 +163,15 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="userdeletionrequest", - index=models.Index( - fields=["verification_code"], name="accounts_us_verific_94460d_idx" - ), + index=models.Index(fields=["verification_code"], name="accounts_us_verific_94460d_idx"), ), migrations.AddIndex( model_name="userdeletionrequest", - index=models.Index( - fields=["expires_at"], name="accounts_us_expires_1d1dca_idx" - ), + index=models.Index(fields=["expires_at"], name="accounts_us_expires_1d1dca_idx"), ), migrations.AddIndex( model_name="userdeletionrequest", - index=models.Index( - fields=["user", "is_used"], name="accounts_us_user_id_1ce18a_idx" - ), + index=models.Index(fields=["user", "is_used"], name="accounts_us_user_id_1ce18a_idx"), ), pgtrigger.migrations.AddTrigger( model_name="userdeletionrequest", diff --git a/backend/apps/accounts/migrations/0005_remove_user_insert_insert_remove_user_update_update_and_more.py b/backend/apps/accounts/migrations/0005_remove_user_insert_insert_remove_user_update_update_and_more.py index 2d7f5d6a..ccc9ed35 100644 --- a/backend/apps/accounts/migrations/0005_remove_user_insert_insert_remove_user_update_update_and_more.py +++ b/backend/apps/accounts/migrations/0005_remove_user_insert_insert_remove_user_update_update_and_more.py @@ -57,9 +57,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="user", name="last_password_change", - field=models.DateTimeField( - auto_now_add=True, default=django.utils.timezone.now - ), + field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), preserve_default=False, ), migrations.AddField( @@ -185,9 +183,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="userevent", name="last_password_change", - field=models.DateTimeField( - auto_now_add=True, default=django.utils.timezone.now - ), + field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), preserve_default=False, ), migrations.AddField( diff --git a/backend/apps/accounts/migrations/0009_notificationpreference_notificationpreferenceevent_and_more.py b/backend/apps/accounts/migrations/0009_notificationpreference_notificationpreferenceevent_and_more.py index 344b81bd..2a108493 100644 --- a/backend/apps/accounts/migrations/0009_notificationpreference_notificationpreferenceevent_and_more.py +++ b/backend/apps/accounts/migrations/0009_notificationpreference_notificationpreferenceevent_and_more.py @@ -454,9 +454,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="usernotification", - index=models.Index( - fields=["user", "is_read"], name="accounts_us_user_id_785929_idx" - ), + index=models.Index(fields=["user", "is_read"], name="accounts_us_user_id_785929_idx"), ), migrations.AddIndex( model_name="usernotification", @@ -467,15 +465,11 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="usernotification", - index=models.Index( - fields=["created_at"], name="accounts_us_created_a62f54_idx" - ), + index=models.Index(fields=["created_at"], name="accounts_us_created_a62f54_idx"), ), migrations.AddIndex( model_name="usernotification", - index=models.Index( - fields=["expires_at"], name="accounts_us_expires_f267b1_idx" - ), + index=models.Index(fields=["expires_at"], name="accounts_us_expires_f267b1_idx"), ), pgtrigger.migrations.AddTrigger( model_name="usernotification", diff --git a/backend/apps/accounts/migrations/0010_auto_20250830_1657.py b/backend/apps/accounts/migrations/0010_auto_20250830_1657.py index 89a56fa3..f346caed 100644 --- a/backend/apps/accounts/migrations/0010_auto_20250830_1657.py +++ b/backend/apps/accounts/migrations/0010_auto_20250830_1657.py @@ -26,25 +26,24 @@ def safe_add_avatar_field(apps, schema_editor): """ # Check if the column already exists with schema_editor.connection.cursor() as cursor: - cursor.execute(""" + cursor.execute( + """ SELECT column_name FROM information_schema.columns WHERE table_name='accounts_userprofile' AND column_name='avatar_id' - """) + """ + ) column_exists = cursor.fetchone() is not None if not column_exists: # Column doesn't exist, add it - UserProfile = apps.get_model('accounts', 'UserProfile') + UserProfile = apps.get_model("accounts", "UserProfile") field = models.ForeignKey( - 'django_cloudflareimages_toolkit.CloudflareImage', - on_delete=models.SET_NULL, - null=True, - blank=True + "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.SET_NULL, null=True, blank=True ) - field.set_attributes_from_name('avatar') + field.set_attributes_from_name("avatar") schema_editor.add_field(UserProfile, field) @@ -54,24 +53,23 @@ def reverse_safe_add_avatar_field(apps, schema_editor): """ # Check if the column exists and remove it with schema_editor.connection.cursor() as cursor: - cursor.execute(""" + cursor.execute( + """ SELECT column_name FROM information_schema.columns WHERE table_name='accounts_userprofile' AND column_name='avatar_id' - """) + """ + ) column_exists = cursor.fetchone() is not None if column_exists: - UserProfile = apps.get_model('accounts', 'UserProfile') + UserProfile = apps.get_model("accounts", "UserProfile") field = models.ForeignKey( - 'django_cloudflareimages_toolkit.CloudflareImage', - on_delete=models.SET_NULL, - null=True, - blank=True + "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.SET_NULL, null=True, blank=True ) - field.set_attributes_from_name('avatar') + field.set_attributes_from_name("avatar") schema_editor.remove_field(UserProfile, field) @@ -89,15 +87,13 @@ class Migration(migrations.Migration): # First, remove the old avatar column (CloudflareImageField) migrations.RunSQL( "ALTER TABLE accounts_userprofile DROP COLUMN IF EXISTS avatar;", - reverse_sql="-- Cannot reverse this operation" + reverse_sql="-- Cannot reverse this operation", ), - # Safely add the new avatar_id column for ForeignKey migrations.RunPython( safe_add_avatar_field, reverse_safe_add_avatar_field, ), - # Run the data migration migrations.RunPython( migrate_avatar_data, diff --git a/backend/apps/accounts/migrations/0011_fix_userprofile_event_avatar_field.py b/backend/apps/accounts/migrations/0011_fix_userprofile_event_avatar_field.py index 3e558466..8ced0a45 100644 --- a/backend/apps/accounts/migrations/0011_fix_userprofile_event_avatar_field.py +++ b/backend/apps/accounts/migrations/0011_fix_userprofile_event_avatar_field.py @@ -6,17 +6,16 @@ from django.db import migrations class Migration(migrations.Migration): dependencies = [ - ('accounts', '0010_auto_20250830_1657'), - ('django_cloudflareimages_toolkit', '0001_initial'), + ("accounts", "0010_auto_20250830_1657"), + ("django_cloudflareimages_toolkit", "0001_initial"), ] operations = [ # Remove the old avatar field from the event table migrations.RunSQL( "ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar;", - reverse_sql="-- Cannot reverse this operation" + reverse_sql="-- Cannot reverse this operation", ), - # Add the new avatar_id field to match the main table (only if it doesn't exist) migrations.RunSQL( """ @@ -32,6 +31,6 @@ class Migration(migrations.Migration): END IF; END $$; """, - reverse_sql="ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar_id;" + reverse_sql="ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar_id;", ), ] diff --git a/backend/apps/accounts/migrations/0013_add_user_query_indexes.py b/backend/apps/accounts/migrations/0013_add_user_query_indexes.py index 749cb233..d1a02d81 100644 --- a/backend/apps/accounts/migrations/0013_add_user_query_indexes.py +++ b/backend/apps/accounts/migrations/0013_add_user_query_indexes.py @@ -13,28 +13,28 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('accounts', '0012_alter_toplist_category_and_more'), + ("accounts", "0012_alter_toplist_category_and_more"), ] operations = [ # Add db_index to is_banned field migrations.AlterField( - model_name='user', - name='is_banned', + model_name="user", + name="is_banned", field=models.BooleanField(default=False, db_index=True), ), # Add composite index for common query patterns migrations.AddIndex( - model_name='user', - index=models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'), + model_name="user", + index=models.Index(fields=["is_banned", "role"], name="accounts_user_banned_role_idx"), ), # Add CheckConstraint for ban consistency migrations.AddConstraint( - model_name='user', + model_name="user", constraint=models.CheckConstraint( - name='user_ban_consistency', + name="user_ban_consistency", check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False), - violation_error_message='Banned users must have a ban_date set' + violation_error_message="Banned users must have a ban_date set", ), ), ] diff --git a/backend/apps/accounts/migrations/0014_remove_toplist_user_remove_toplistitem_top_list_and_more.py b/backend/apps/accounts/migrations/0014_remove_toplist_user_remove_toplistitem_top_list_and_more.py index 97f288be..22b20d5c 100644 --- a/backend/apps/accounts/migrations/0014_remove_toplist_user_remove_toplistitem_top_list_and_more.py +++ b/backend/apps/accounts/migrations/0014_remove_toplist_user_remove_toplistitem_top_list_and_more.py @@ -18,7 +18,6 @@ class Migration(migrations.Migration): ] operations = [ - migrations.AlterModelOptions( name="user", options={"verbose_name": "User", "verbose_name_plural": "Users"}, @@ -58,9 +57,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="userprofile", name="location", - field=models.CharField( - blank=True, help_text="User's location (City, Country)", max_length=100 - ), + field=models.CharField(blank=True, help_text="User's location (City, Country)", max_length=100), ), migrations.AddField( model_name="userprofile", @@ -78,9 +75,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="userprofileevent", name="location", - field=models.CharField( - blank=True, help_text="User's location (City, Country)", max_length=100 - ), + field=models.CharField(blank=True, help_text="User's location (City, Country)", max_length=100), ), migrations.AddField( model_name="userprofileevent", @@ -98,23 +93,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="emailverification", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this verification was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this verification was created"), ), migrations.AlterField( model_name="emailverification", name="last_sent", - field=models.DateTimeField( - auto_now_add=True, help_text="When the verification email was last sent" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent"), ), migrations.AlterField( model_name="emailverification", name="token", - field=models.CharField( - help_text="Verification token", max_length=64, unique=True - ), + field=models.CharField(help_text="Verification token", max_length=64, unique=True), ), migrations.AlterField( model_name="emailverification", @@ -128,16 +117,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="emailverificationevent", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this verification was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this verification was created"), ), migrations.AlterField( model_name="emailverificationevent", name="last_sent", - field=models.DateTimeField( - auto_now_add=True, help_text="When the verification email was last sent" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent"), ), migrations.AlterField( model_name="emailverificationevent", @@ -181,9 +166,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="passwordreset", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this reset was requested" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this reset was requested"), ), migrations.AlterField( model_name="passwordreset", @@ -198,9 +181,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="passwordreset", name="used", - field=models.BooleanField( - default=False, help_text="Whether this token has been used" - ), + field=models.BooleanField(default=False, help_text="Whether this token has been used"), ), migrations.AlterField( model_name="passwordreset", @@ -214,9 +195,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="passwordresetevent", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this reset was requested" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this reset was requested"), ), migrations.AlterField( model_name="passwordresetevent", @@ -231,9 +210,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="passwordresetevent", name="used", - field=models.BooleanField( - default=False, help_text="Whether this token has been used" - ), + field=models.BooleanField(default=False, help_text="Whether this token has been used"), ), migrations.AlterField( model_name="passwordresetevent", @@ -267,30 +244,22 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="allow_friend_requests", - field=models.BooleanField( - default=True, help_text="Whether to allow friend requests" - ), + field=models.BooleanField(default=True, help_text="Whether to allow friend requests"), ), migrations.AlterField( model_name="user", name="allow_messages", - field=models.BooleanField( - default=True, help_text="Whether to allow direct messages" - ), + field=models.BooleanField(default=True, help_text="Whether to allow direct messages"), ), migrations.AlterField( model_name="user", name="allow_profile_comments", - field=models.BooleanField( - default=False, help_text="Whether to allow profile comments" - ), + field=models.BooleanField(default=False, help_text="Whether to allow profile comments"), ), migrations.AlterField( model_name="user", name="ban_date", - field=models.DateTimeField( - blank=True, help_text="Date the user was banned", null=True - ), + field=models.DateTimeField(blank=True, help_text="Date the user was banned", null=True), ), migrations.AlterField( model_name="user", @@ -300,37 +269,27 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="email_notifications", - field=models.BooleanField( - default=True, help_text="Whether to send email notifications" - ), + field=models.BooleanField(default=True, help_text="Whether to send email notifications"), ), migrations.AlterField( model_name="user", name="is_banned", - field=models.BooleanField( - db_index=True, default=False, help_text="Whether this user is banned" - ), + field=models.BooleanField(db_index=True, default=False, help_text="Whether this user is banned"), ), migrations.AlterField( model_name="user", name="last_password_change", - field=models.DateTimeField( - auto_now_add=True, help_text="When the password was last changed" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When the password was last changed"), ), migrations.AlterField( model_name="user", name="login_history_retention", - field=models.IntegerField( - default=90, help_text="How long to retain login history (days)" - ), + field=models.IntegerField(default=90, help_text="How long to retain login history (days)"), ), migrations.AlterField( model_name="user", name="login_notifications", - field=models.BooleanField( - default=True, help_text="Whether to send login notifications" - ), + field=models.BooleanField(default=True, help_text="Whether to send login notifications"), ), migrations.AlterField( model_name="user", @@ -352,9 +311,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="push_notifications", - field=models.BooleanField( - default=False, help_text="Whether to send push notifications" - ), + field=models.BooleanField(default=False, help_text="Whether to send push notifications"), ), migrations.AlterField( model_name="user", @@ -378,9 +335,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="search_visibility", - field=models.BooleanField( - default=True, help_text="Whether profile appears in search results" - ), + field=models.BooleanField(default=True, help_text="Whether profile appears in search results"), ), migrations.AlterField( model_name="user", @@ -390,51 +345,37 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="show_email", - field=models.BooleanField( - default=False, help_text="Whether to show email on profile" - ), + field=models.BooleanField(default=False, help_text="Whether to show email on profile"), ), migrations.AlterField( model_name="user", name="show_join_date", - field=models.BooleanField( - default=True, help_text="Whether to show join date on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show join date on profile"), ), migrations.AlterField( model_name="user", name="show_photos", - field=models.BooleanField( - default=True, help_text="Whether to show photos on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show photos on profile"), ), migrations.AlterField( model_name="user", name="show_real_name", - field=models.BooleanField( - default=True, help_text="Whether to show real name on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show real name on profile"), ), migrations.AlterField( model_name="user", name="show_reviews", - field=models.BooleanField( - default=True, help_text="Whether to show reviews on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show reviews on profile"), ), migrations.AlterField( model_name="user", name="show_statistics", - field=models.BooleanField( - default=True, help_text="Whether to show statistics on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show statistics on profile"), ), migrations.AlterField( model_name="user", name="show_top_lists", - field=models.BooleanField( - default=True, help_text="Whether to show top lists on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show top lists on profile"), ), migrations.AlterField( model_name="user", @@ -452,9 +393,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="two_factor_enabled", - field=models.BooleanField( - default=False, help_text="Whether two-factor authentication is enabled" - ), + field=models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled"), ), migrations.AlterField( model_name="userevent", @@ -476,30 +415,22 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userevent", name="allow_friend_requests", - field=models.BooleanField( - default=True, help_text="Whether to allow friend requests" - ), + field=models.BooleanField(default=True, help_text="Whether to allow friend requests"), ), migrations.AlterField( model_name="userevent", name="allow_messages", - field=models.BooleanField( - default=True, help_text="Whether to allow direct messages" - ), + field=models.BooleanField(default=True, help_text="Whether to allow direct messages"), ), migrations.AlterField( model_name="userevent", name="allow_profile_comments", - field=models.BooleanField( - default=False, help_text="Whether to allow profile comments" - ), + field=models.BooleanField(default=False, help_text="Whether to allow profile comments"), ), migrations.AlterField( model_name="userevent", name="ban_date", - field=models.DateTimeField( - blank=True, help_text="Date the user was banned", null=True - ), + field=models.DateTimeField(blank=True, help_text="Date the user was banned", null=True), ), migrations.AlterField( model_name="userevent", @@ -509,37 +440,27 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userevent", name="email_notifications", - field=models.BooleanField( - default=True, help_text="Whether to send email notifications" - ), + field=models.BooleanField(default=True, help_text="Whether to send email notifications"), ), migrations.AlterField( model_name="userevent", name="is_banned", - field=models.BooleanField( - default=False, help_text="Whether this user is banned" - ), + field=models.BooleanField(default=False, help_text="Whether this user is banned"), ), migrations.AlterField( model_name="userevent", name="last_password_change", - field=models.DateTimeField( - auto_now_add=True, help_text="When the password was last changed" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When the password was last changed"), ), migrations.AlterField( model_name="userevent", name="login_history_retention", - field=models.IntegerField( - default=90, help_text="How long to retain login history (days)" - ), + field=models.IntegerField(default=90, help_text="How long to retain login history (days)"), ), migrations.AlterField( model_name="userevent", name="login_notifications", - field=models.BooleanField( - default=True, help_text="Whether to send login notifications" - ), + field=models.BooleanField(default=True, help_text="Whether to send login notifications"), ), migrations.AlterField( model_name="userevent", @@ -561,9 +482,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userevent", name="push_notifications", - field=models.BooleanField( - default=False, help_text="Whether to send push notifications" - ), + field=models.BooleanField(default=False, help_text="Whether to send push notifications"), ), migrations.AlterField( model_name="userevent", @@ -586,9 +505,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userevent", name="search_visibility", - field=models.BooleanField( - default=True, help_text="Whether profile appears in search results" - ), + field=models.BooleanField(default=True, help_text="Whether profile appears in search results"), ), migrations.AlterField( model_name="userevent", @@ -598,51 +515,37 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userevent", name="show_email", - field=models.BooleanField( - default=False, help_text="Whether to show email on profile" - ), + field=models.BooleanField(default=False, help_text="Whether to show email on profile"), ), migrations.AlterField( model_name="userevent", name="show_join_date", - field=models.BooleanField( - default=True, help_text="Whether to show join date on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show join date on profile"), ), migrations.AlterField( model_name="userevent", name="show_photos", - field=models.BooleanField( - default=True, help_text="Whether to show photos on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show photos on profile"), ), migrations.AlterField( model_name="userevent", name="show_real_name", - field=models.BooleanField( - default=True, help_text="Whether to show real name on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show real name on profile"), ), migrations.AlterField( model_name="userevent", name="show_reviews", - field=models.BooleanField( - default=True, help_text="Whether to show reviews on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show reviews on profile"), ), migrations.AlterField( model_name="userevent", name="show_statistics", - field=models.BooleanField( - default=True, help_text="Whether to show statistics on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show statistics on profile"), ), migrations.AlterField( model_name="userevent", name="show_top_lists", - field=models.BooleanField( - default=True, help_text="Whether to show top lists on profile" - ), + field=models.BooleanField(default=True, help_text="Whether to show top lists on profile"), ), migrations.AlterField( model_name="userevent", @@ -660,9 +563,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userevent", name="two_factor_enabled", - field=models.BooleanField( - default=False, help_text="Whether two-factor authentication is enabled" - ), + field=models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled"), ), migrations.AlterField( model_name="usernotification", @@ -678,23 +579,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="usernotification", name="email_sent", - field=models.BooleanField( - default=False, help_text="Whether email was sent" - ), + field=models.BooleanField(default=False, help_text="Whether email was sent"), ), migrations.AlterField( model_name="usernotification", name="email_sent_at", - field=models.DateTimeField( - blank=True, help_text="When email was sent", null=True - ), + field=models.DateTimeField(blank=True, help_text="When email was sent", null=True), ), migrations.AlterField( model_name="usernotification", name="is_read", - field=models.BooleanField( - default=False, help_text="Whether this notification has been read" - ), + field=models.BooleanField(default=False, help_text="Whether this notification has been read"), ), migrations.AlterField( model_name="usernotification", @@ -704,30 +599,22 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="usernotification", name="object_id", - field=models.PositiveIntegerField( - blank=True, help_text="ID of related object", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="ID of related object", null=True), ), migrations.AlterField( model_name="usernotification", name="push_sent", - field=models.BooleanField( - default=False, help_text="Whether push notification was sent" - ), + field=models.BooleanField(default=False, help_text="Whether push notification was sent"), ), migrations.AlterField( model_name="usernotification", name="push_sent_at", - field=models.DateTimeField( - blank=True, help_text="When push notification was sent", null=True - ), + field=models.DateTimeField(blank=True, help_text="When push notification was sent", null=True), ), migrations.AlterField( model_name="usernotification", name="read_at", - field=models.DateTimeField( - blank=True, help_text="When this notification was read", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this notification was read", null=True), ), migrations.AlterField( model_name="usernotification", @@ -761,23 +648,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="usernotificationevent", name="email_sent", - field=models.BooleanField( - default=False, help_text="Whether email was sent" - ), + field=models.BooleanField(default=False, help_text="Whether email was sent"), ), migrations.AlterField( model_name="usernotificationevent", name="email_sent_at", - field=models.DateTimeField( - blank=True, help_text="When email was sent", null=True - ), + field=models.DateTimeField(blank=True, help_text="When email was sent", null=True), ), migrations.AlterField( model_name="usernotificationevent", name="is_read", - field=models.BooleanField( - default=False, help_text="Whether this notification has been read" - ), + field=models.BooleanField(default=False, help_text="Whether this notification has been read"), ), migrations.AlterField( model_name="usernotificationevent", @@ -787,30 +668,22 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="usernotificationevent", name="object_id", - field=models.PositiveIntegerField( - blank=True, help_text="ID of related object", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="ID of related object", null=True), ), migrations.AlterField( model_name="usernotificationevent", name="push_sent", - field=models.BooleanField( - default=False, help_text="Whether push notification was sent" - ), + field=models.BooleanField(default=False, help_text="Whether push notification was sent"), ), migrations.AlterField( model_name="usernotificationevent", name="push_sent_at", - field=models.DateTimeField( - blank=True, help_text="When push notification was sent", null=True - ), + field=models.DateTimeField(blank=True, help_text="When push notification was sent", null=True), ), migrations.AlterField( model_name="usernotificationevent", name="read_at", - field=models.DateTimeField( - blank=True, help_text="When this notification was read", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this notification was read", null=True), ), migrations.AlterField( model_name="usernotificationevent", @@ -844,37 +717,27 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userprofile", name="bio", - field=models.TextField( - blank=True, help_text="User biography", max_length=500 - ), + field=models.TextField(blank=True, help_text="User biography", max_length=500), ), migrations.AlterField( model_name="userprofile", name="coaster_credits", - field=models.IntegerField( - default=0, help_text="Number of roller coasters ridden" - ), + field=models.IntegerField(default=0, help_text="Number of roller coasters ridden"), ), migrations.AlterField( model_name="userprofile", name="dark_ride_credits", - field=models.IntegerField( - default=0, help_text="Number of dark rides ridden" - ), + field=models.IntegerField(default=0, help_text="Number of dark rides ridden"), ), migrations.AlterField( model_name="userprofile", name="discord", - field=models.CharField( - blank=True, help_text="Discord username", max_length=100 - ), + field=models.CharField(blank=True, help_text="Discord username", max_length=100), ), migrations.AlterField( model_name="userprofile", name="flat_ride_credits", - field=models.IntegerField( - default=0, help_text="Number of flat rides ridden" - ), + field=models.IntegerField(default=0, help_text="Number of flat rides ridden"), ), migrations.AlterField( model_name="userprofile", @@ -884,9 +747,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userprofile", name="pronouns", - field=models.CharField( - blank=True, help_text="User's preferred pronouns", max_length=50 - ), + field=models.CharField(blank=True, help_text="User's preferred pronouns", max_length=50), ), migrations.AlterField( model_name="userprofile", @@ -906,9 +767,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userprofile", name="water_ride_credits", - field=models.IntegerField( - default=0, help_text="Number of water rides ridden" - ), + field=models.IntegerField(default=0, help_text="Number of water rides ridden"), ), migrations.AlterField( model_name="userprofile", @@ -932,37 +791,27 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userprofileevent", name="bio", - field=models.TextField( - blank=True, help_text="User biography", max_length=500 - ), + field=models.TextField(blank=True, help_text="User biography", max_length=500), ), migrations.AlterField( model_name="userprofileevent", name="coaster_credits", - field=models.IntegerField( - default=0, help_text="Number of roller coasters ridden" - ), + field=models.IntegerField(default=0, help_text="Number of roller coasters ridden"), ), migrations.AlterField( model_name="userprofileevent", name="dark_ride_credits", - field=models.IntegerField( - default=0, help_text="Number of dark rides ridden" - ), + field=models.IntegerField(default=0, help_text="Number of dark rides ridden"), ), migrations.AlterField( model_name="userprofileevent", name="discord", - field=models.CharField( - blank=True, help_text="Discord username", max_length=100 - ), + field=models.CharField(blank=True, help_text="Discord username", max_length=100), ), migrations.AlterField( model_name="userprofileevent", name="flat_ride_credits", - field=models.IntegerField( - default=0, help_text="Number of flat rides ridden" - ), + field=models.IntegerField(default=0, help_text="Number of flat rides ridden"), ), migrations.AlterField( model_name="userprofileevent", @@ -972,9 +821,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userprofileevent", name="pronouns", - field=models.CharField( - blank=True, help_text="User's preferred pronouns", max_length=50 - ), + field=models.CharField(blank=True, help_text="User's preferred pronouns", max_length=50), ), migrations.AlterField( model_name="userprofileevent", @@ -996,9 +843,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="userprofileevent", name="water_ride_credits", - field=models.IntegerField( - default=0, help_text="Number of water rides ridden" - ), + field=models.IntegerField(default=0, help_text="Number of water rides ridden"), ), migrations.AlterField( model_name="userprofileevent", diff --git a/backend/apps/accounts/mixins.py b/backend/apps/accounts/mixins.py index c4be1ec4..17990b51 100644 --- a/backend/apps/accounts/mixins.py +++ b/backend/apps/accounts/mixins.py @@ -1,6 +1,7 @@ """ Mixins for authentication views. """ + from django.core.exceptions import ValidationError from apps.core.utils.turnstile import get_client_ip, validate_turnstile_token @@ -24,14 +25,14 @@ class TurnstileMixin: token = None # Check POST data (form submissions) - if hasattr(request, 'POST'): + if hasattr(request, "POST"): token = request.POST.get("cf-turnstile-response") # Check JSON body (API requests) - if not token and hasattr(request, 'data'): - data = getattr(request, 'data', {}) - if hasattr(data, 'get'): - token = data.get('turnstile_token') or data.get('cf-turnstile-response') + if not token and hasattr(request, "data"): + data = getattr(request, "data", {}) + if hasattr(data, "get"): + token = data.get("turnstile_token") or data.get("cf-turnstile-response") # Get client IP ip = get_client_ip(request) @@ -39,6 +40,6 @@ class TurnstileMixin: # Validate the token result = validate_turnstile_token(token, ip) - if not result.get('success'): - error_msg = result.get('error', 'Captcha verification failed. Please try again.') + if not result.get("success"): + error_msg = result.get("error", "Captcha verification failed. Please try again.") raise ValidationError(error_msg) diff --git a/backend/apps/accounts/models.py b/backend/apps/accounts/models.py index 8a3e36ae..b5f9507b 100644 --- a/backend/apps/accounts/models.py +++ b/backend/apps/accounts/models.py @@ -41,10 +41,7 @@ class User(AbstractUser): max_length=10, unique=True, editable=False, - help_text=( - "Unique identifier for this user that remains constant even if the " - "username changes" - ), + help_text=("Unique identifier for this user that remains constant even if the " "username changes"), ) role = RichChoiceField( @@ -55,13 +52,9 @@ class User(AbstractUser): db_index=True, help_text="User role (user, moderator, admin)", ) - is_banned = models.BooleanField( - default=False, db_index=True, help_text="Whether this user is banned" - ) + is_banned = models.BooleanField(default=False, db_index=True, help_text="Whether this user is banned") ban_reason = models.TextField(blank=True, help_text="Reason for ban") - ban_date = models.DateTimeField( - null=True, blank=True, help_text="Date the user was banned" - ) + ban_date = models.DateTimeField(null=True, blank=True, help_text="Date the user was banned") pending_email = models.EmailField(blank=True, null=True) theme_preference = RichChoiceField( choice_group="theme_preferences", @@ -72,12 +65,8 @@ class User(AbstractUser): ) # Notification preferences - email_notifications = models.BooleanField( - default=True, help_text="Whether to send email notifications" - ) - push_notifications = models.BooleanField( - default=False, help_text="Whether to send push notifications" - ) + email_notifications = models.BooleanField(default=True, help_text="Whether to send email notifications") + push_notifications = models.BooleanField(default=False, help_text="Whether to send push notifications") # Privacy settings privacy_level = RichChoiceField( @@ -87,39 +76,17 @@ class User(AbstractUser): default="public", help_text="Overall privacy level", ) - show_email = models.BooleanField( - default=False, help_text="Whether to show email on profile" - ) - show_real_name = models.BooleanField( - default=True, help_text="Whether to show real name on profile" - ) - show_join_date = models.BooleanField( - default=True, help_text="Whether to show join date on profile" - ) - show_statistics = models.BooleanField( - default=True, help_text="Whether to show statistics on profile" - ) - show_reviews = models.BooleanField( - default=True, help_text="Whether to show reviews on profile" - ) - show_photos = models.BooleanField( - default=True, help_text="Whether to show photos on profile" - ) - show_top_lists = models.BooleanField( - default=True, help_text="Whether to show top lists on profile" - ) - allow_friend_requests = models.BooleanField( - default=True, help_text="Whether to allow friend requests" - ) - allow_messages = models.BooleanField( - default=True, help_text="Whether to allow direct messages" - ) - allow_profile_comments = models.BooleanField( - default=False, help_text="Whether to allow profile comments" - ) - search_visibility = models.BooleanField( - default=True, help_text="Whether profile appears in search results" - ) + show_email = models.BooleanField(default=False, help_text="Whether to show email on profile") + show_real_name = models.BooleanField(default=True, help_text="Whether to show real name on profile") + show_join_date = models.BooleanField(default=True, help_text="Whether to show join date on profile") + show_statistics = models.BooleanField(default=True, help_text="Whether to show statistics on profile") + show_reviews = models.BooleanField(default=True, help_text="Whether to show reviews on profile") + show_photos = models.BooleanField(default=True, help_text="Whether to show photos on profile") + show_top_lists = models.BooleanField(default=True, help_text="Whether to show top lists on profile") + allow_friend_requests = models.BooleanField(default=True, help_text="Whether to allow friend requests") + allow_messages = models.BooleanField(default=True, help_text="Whether to allow direct messages") + allow_profile_comments = models.BooleanField(default=False, help_text="Whether to allow profile comments") + search_visibility = models.BooleanField(default=True, help_text="Whether profile appears in search results") activity_visibility = RichChoiceField( choice_group="privacy_levels", domain="accounts", @@ -129,21 +96,11 @@ class User(AbstractUser): ) # Security settings - two_factor_enabled = models.BooleanField( - default=False, help_text="Whether two-factor authentication is enabled" - ) - login_notifications = models.BooleanField( - default=True, help_text="Whether to send login notifications" - ) - session_timeout = models.IntegerField( - default=30, help_text="Session timeout in days" - ) - login_history_retention = models.IntegerField( - default=90, help_text="How long to retain login history (days)" - ) - last_password_change = models.DateTimeField( - auto_now_add=True, help_text="When the password was last changed" - ) + two_factor_enabled = models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled") + login_notifications = models.BooleanField(default=True, help_text="Whether to send login notifications") + session_timeout = models.IntegerField(default=30, help_text="Session timeout in days") + login_history_retention = models.IntegerField(default=90, help_text="How long to retain login history (days)") + last_password_change = models.DateTimeField(auto_now_add=True, help_text="When the password was last changed") # Display name - core user data for better performance display_name = models.CharField( @@ -179,13 +136,13 @@ class User(AbstractUser): verbose_name = "User" verbose_name_plural = "Users" indexes = [ - models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'), + models.Index(fields=["is_banned", "role"], name="accounts_user_banned_role_idx"), ] constraints = [ models.CheckConstraint( - name='user_ban_consistency', + name="user_ban_consistency", check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False), - violation_error_message='Banned users must have a ban_date set' + violation_error_message="Banned users must have a ban_date set", ), ] @@ -224,14 +181,10 @@ class UserProfile(models.Model): related_name="user_profiles", help_text="User's avatar image", ) - pronouns = models.CharField( - max_length=50, blank=True, help_text="User's preferred pronouns" - ) + pronouns = models.CharField(max_length=50, blank=True, help_text="User's preferred pronouns") bio = models.TextField(max_length=500, blank=True, help_text="User biography") - location = models.CharField( - max_length=100, blank=True, help_text="User's location (City, Country)" - ) + location = models.CharField(max_length=100, blank=True, help_text="User's location (City, Country)") unit_system = RichChoiceField( choice_group="unit_systems", domain="accounts", @@ -247,18 +200,10 @@ class UserProfile(models.Model): discord = models.CharField(max_length=100, blank=True, help_text="Discord username") # Ride statistics - coaster_credits = models.IntegerField( - default=0, help_text="Number of roller coasters ridden" - ) - dark_ride_credits = models.IntegerField( - default=0, help_text="Number of dark rides ridden" - ) - flat_ride_credits = models.IntegerField( - default=0, help_text="Number of flat rides ridden" - ) - water_ride_credits = models.IntegerField( - default=0, help_text="Number of water rides ridden" - ) + coaster_credits = models.IntegerField(default=0, help_text="Number of roller coasters ridden") + dark_ride_credits = models.IntegerField(default=0, help_text="Number of dark rides ridden") + flat_ride_credits = models.IntegerField(default=0, help_text="Number of flat rides ridden") + water_ride_credits = models.IntegerField(default=0, help_text="Number of water rides ridden") def get_avatar_url(self): """ @@ -266,12 +211,12 @@ class UserProfile(models.Model): """ if self.avatar and self.avatar.is_uploaded: # Try to get avatar variant first, fallback to public - avatar_url = self.avatar.get_url('avatar') + avatar_url = self.avatar.get_url("avatar") if avatar_url: return avatar_url # Fallback to public variant - public_url = self.avatar.get_url('public') + public_url = self.avatar.get_url("public") if public_url: return public_url @@ -298,10 +243,10 @@ class UserProfile(models.Model): variants = {} # Try to get specific variants - thumbnail_url = self.avatar.get_url('thumbnail') - avatar_url = self.avatar.get_url('avatar') - large_url = self.avatar.get_url('large') - public_url = self.avatar.get_url('public') + thumbnail_url = self.avatar.get_url("thumbnail") + avatar_url = self.avatar.get_url("avatar") + large_url = self.avatar.get_url("large") + public_url = self.avatar.get_url("public") # Use specific variants if available, otherwise fallback to public or first available fallback_url = public_url @@ -354,18 +299,10 @@ class EmailVerification(models.Model): on_delete=models.CASCADE, help_text="User this verification belongs to", ) - token = models.CharField( - max_length=64, unique=True, help_text="Verification token" - ) - created_at = models.DateTimeField( - auto_now_add=True, help_text="When this verification was created" - ) - updated_at = models.DateTimeField( - auto_now=True, help_text="When this verification was last updated" - ) - last_sent = models.DateTimeField( - auto_now_add=True, help_text="When the verification email was last sent" - ) + token = models.CharField(max_length=64, unique=True, help_text="Verification token") + created_at = models.DateTimeField(auto_now_add=True, help_text="When this verification was created") + updated_at = models.DateTimeField(auto_now=True, help_text="When this verification was last updated") + last_sent = models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent") def __str__(self): return f"Email verification for {self.user.username}" @@ -383,9 +320,7 @@ class PasswordReset(models.Model): help_text="User requesting password reset", ) token = models.CharField(max_length=64, help_text="Reset token") - created_at = models.DateTimeField( - auto_now_add=True, help_text="When this reset was requested" - ) + created_at = models.DateTimeField(auto_now_add=True, help_text="When this reset was requested") expires_at = models.DateTimeField(help_text="When this reset token expires") used = models.BooleanField(default=False, help_text="Whether this token has been used") @@ -397,8 +332,6 @@ class PasswordReset(models.Model): verbose_name_plural = "Password Resets" - - @pghistory.track() class UserDeletionRequest(models.Model): """ @@ -409,9 +342,7 @@ class UserDeletionRequest(models.Model): provide the correct code. """ - user = models.OneToOneField( - User, on_delete=models.CASCADE, related_name="deletion_request" - ) + user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="deletion_request") verification_code = models.CharField( max_length=32, @@ -422,21 +353,13 @@ class UserDeletionRequest(models.Model): created_at = models.DateTimeField(auto_now_add=True) expires_at = models.DateTimeField(help_text="When this deletion request expires") - email_sent_at = models.DateTimeField( - null=True, blank=True, help_text="When the verification email was sent" - ) + email_sent_at = models.DateTimeField(null=True, blank=True, help_text="When the verification email was sent") - attempts = models.PositiveIntegerField( - default=0, help_text="Number of verification attempts made" - ) + attempts = models.PositiveIntegerField(default=0, help_text="Number of verification attempts made") - max_attempts = models.PositiveIntegerField( - default=5, help_text="Maximum number of verification attempts allowed" - ) + max_attempts = models.PositiveIntegerField(default=5, help_text="Maximum number of verification attempts allowed") - is_used = models.BooleanField( - default=False, help_text="Whether this deletion request has been used" - ) + is_used = models.BooleanField(default=False, help_text="Whether this deletion request has been used") class Meta: verbose_name = "User Deletion Request" @@ -466,9 +389,7 @@ class UserDeletionRequest(models.Model): """Generate a unique 8-character verification code.""" while True: # Generate a random 8-character alphanumeric code - code = "".join( - secrets.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") for _ in range(8) - ) + code = "".join(secrets.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") for _ in range(8)) # Ensure it's unique if not UserDeletionRequest.objects.filter(verification_code=code).exists(): @@ -480,11 +401,7 @@ class UserDeletionRequest(models.Model): def is_valid(self): """Check if this deletion request is still valid.""" - return ( - not self.is_used - and not self.is_expired() - and self.attempts < self.max_attempts - ) + return not self.is_used and not self.is_expired() and self.attempts < self.max_attempts def increment_attempts(self): """Increment the number of verification attempts.""" @@ -499,9 +416,7 @@ class UserDeletionRequest(models.Model): @classmethod def cleanup_expired(cls): """Remove expired deletion requests.""" - expired_requests = cls.objects.filter( - expires_at__lt=timezone.now(), is_used=False - ) + expired_requests = cls.objects.filter(expires_at__lt=timezone.now(), is_used=False) count = expired_requests.count() expired_requests.delete() return count @@ -541,9 +456,7 @@ class UserNotification(TrackedModel): blank=True, help_text="Type of related object", ) - object_id = models.PositiveIntegerField( - null=True, blank=True, help_text="ID of related object" - ) + object_id = models.PositiveIntegerField(null=True, blank=True, help_text="ID of related object") related_object = GenericForeignKey("content_type", "object_id") # Metadata @@ -555,24 +468,14 @@ class UserNotification(TrackedModel): ) # Status tracking - is_read = models.BooleanField( - default=False, help_text="Whether this notification has been read" - ) - read_at = models.DateTimeField( - null=True, blank=True, help_text="When this notification was read" - ) + is_read = models.BooleanField(default=False, help_text="Whether this notification has been read") + read_at = models.DateTimeField(null=True, blank=True, help_text="When this notification was read") # Delivery tracking email_sent = models.BooleanField(default=False, help_text="Whether email was sent") - email_sent_at = models.DateTimeField( - null=True, blank=True, help_text="When email was sent" - ) - push_sent = models.BooleanField( - default=False, help_text="Whether push notification was sent" - ) - push_sent_at = models.DateTimeField( - null=True, blank=True, help_text="When push notification was sent" - ) + email_sent_at = models.DateTimeField(null=True, blank=True, help_text="When email was sent") + push_sent = models.BooleanField(default=False, help_text="Whether push notification was sent") + push_sent_at = models.DateTimeField(null=True, blank=True, help_text="When push notification was sent") # Additional data (JSON field for flexibility) extra_data = models.JSONField(default=dict, blank=True) @@ -619,9 +522,7 @@ class UserNotification(TrackedModel): @classmethod def mark_all_read_for_user(cls, user): """Mark all notifications as read for a specific user.""" - return cls.objects.filter(user=user, is_read=False).update( - is_read=True, read_at=timezone.now() - ) + return cls.objects.filter(user=user, is_read=False).update(is_read=True, read_at=timezone.now()) @pghistory.track() diff --git a/backend/apps/accounts/selectors.py b/backend/apps/accounts/selectors.py index 6a1f0d58..19fccfa0 100644 --- a/backend/apps/accounts/selectors.py +++ b/backend/apps/accounts/selectors.py @@ -27,16 +27,10 @@ def user_profile_optimized(*, user_id: int) -> Any: User.DoesNotExist: If user doesn't exist """ return ( - User.objects.prefetch_related( - "park_reviews", "ride_reviews", "socialaccount_set" - ) + User.objects.prefetch_related("park_reviews", "ride_reviews", "socialaccount_set") .annotate( - park_review_count=Count( - "park_reviews", filter=Q(park_reviews__is_published=True) - ), - ride_review_count=Count( - "ride_reviews", filter=Q(ride_reviews__is_published=True) - ), + park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)), + ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)), total_review_count=F("park_review_count") + F("ride_review_count"), ) .get(id=user_id) @@ -53,12 +47,8 @@ def active_users_with_stats() -> QuerySet: return ( User.objects.filter(is_active=True) .annotate( - park_review_count=Count( - "park_reviews", filter=Q(park_reviews__is_published=True) - ), - ride_review_count=Count( - "ride_reviews", filter=Q(ride_reviews__is_published=True) - ), + park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)), + ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)), total_review_count=F("park_review_count") + F("ride_review_count"), ) .order_by("-total_review_count") @@ -112,12 +102,8 @@ def top_reviewers(*, limit: int = 10) -> QuerySet: return ( User.objects.filter(is_active=True) .annotate( - park_review_count=Count( - "park_reviews", filter=Q(park_reviews__is_published=True) - ), - ride_review_count=Count( - "ride_reviews", filter=Q(ride_reviews__is_published=True) - ), + park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)), + ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)), total_review_count=F("park_review_count") + F("ride_review_count"), ) .filter(total_review_count__gt=0) @@ -159,9 +145,9 @@ def users_by_registration_date(*, start_date, end_date) -> QuerySet: Returns: QuerySet of users registered in the date range """ - return User.objects.filter( - date_joined__date__gte=start_date, date_joined__date__lte=end_date - ).order_by("-date_joined") + return User.objects.filter(date_joined__date__gte=start_date, date_joined__date__lte=end_date).order_by( + "-date_joined" + ) def user_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet: @@ -176,8 +162,7 @@ def user_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet: QuerySet of matching users for autocomplete """ return User.objects.filter( - Q(username__icontains=query) - | Q(display_name__icontains=query), + Q(username__icontains=query) | Q(display_name__icontains=query), is_active=True, ).order_by("username")[:limit] @@ -210,11 +195,7 @@ def user_statistics_summary() -> dict[str, Any]: # Users with reviews users_with_reviews = ( - User.objects.filter( - Q(park_reviews__isnull=False) | Q(ride_reviews__isnull=False) - ) - .distinct() - .count() + User.objects.filter(Q(park_reviews__isnull=False) | Q(ride_reviews__isnull=False)).distinct().count() ) # Recent registrations (last 30 days) @@ -228,9 +209,7 @@ def user_statistics_summary() -> dict[str, Any]: "staff_users": staff_users, "users_with_reviews": users_with_reviews, "recent_registrations": recent_registrations, - "review_participation_rate": ( - (users_with_reviews / total_users * 100) if total_users > 0 else 0 - ), + "review_participation_rate": ((users_with_reviews / total_users * 100) if total_users > 0 else 0), } @@ -241,11 +220,7 @@ def users_needing_email_verification() -> QuerySet: Returns: QuerySet of users with unverified emails """ - return ( - User.objects.filter(is_active=True, emailaddress__verified=False) - .distinct() - .order_by("date_joined") - ) + return User.objects.filter(is_active=True, emailaddress__verified=False).distinct().order_by("date_joined") def users_by_review_activity(*, min_reviews: int = 1) -> QuerySet: @@ -260,12 +235,8 @@ def users_by_review_activity(*, min_reviews: int = 1) -> QuerySet: """ return ( User.objects.annotate( - park_review_count=Count( - "park_reviews", filter=Q(park_reviews__is_published=True) - ), - ride_review_count=Count( - "ride_reviews", filter=Q(ride_reviews__is_published=True) - ), + park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)), + ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)), total_review_count=F("park_review_count") + F("ride_review_count"), ) .filter(total_review_count__gte=min_reviews) diff --git a/backend/apps/accounts/serializers.py b/backend/apps/accounts/serializers.py index b09f58a0..39e8414a 100644 --- a/backend/apps/accounts/serializers.py +++ b/backend/apps/accounts/serializers.py @@ -62,12 +62,8 @@ class LoginSerializer(serializers.Serializer): Serializer for user login """ - username = serializers.CharField( - max_length=254, help_text="Username or email address" - ) - password = serializers.CharField( - max_length=128, style={"input_type": "password"}, trim_whitespace=False - ) + username = serializers.CharField(max_length=254, help_text="Username or email address") + password = serializers.CharField(max_length=128, style={"input_type": "password"}, trim_whitespace=False) def validate(self, attrs): username = attrs.get("username") @@ -89,9 +85,7 @@ class SignupSerializer(serializers.ModelSerializer): validators=[validate_password], style={"input_type": "password"}, ) - password_confirm = serializers.CharField( - write_only=True, style={"input_type": "password"} - ) + password_confirm = serializers.CharField(write_only=True, style={"input_type": "password"}) class Meta: model = User @@ -118,9 +112,7 @@ class SignupSerializer(serializers.ModelSerializer): def validate_username(self, value): """Validate username is unique""" if UserModel.objects.filter(username=value).exists(): - raise serializers.ValidationError( - "A user with this username already exists." - ) + raise serializers.ValidationError("A user with this username already exists.") return value def validate(self, attrs): @@ -129,9 +121,7 @@ class SignupSerializer(serializers.ModelSerializer): password_confirm = attrs.get("password_confirm") if password != password_confirm: - raise serializers.ValidationError( - {"password_confirm": "Passwords do not match."} - ) + raise serializers.ValidationError({"password_confirm": "Passwords do not match."}) return attrs @@ -194,9 +184,7 @@ class PasswordResetSerializer(serializers.Serializer): "site_name": site.name, } - email_html = render_to_string( - "accounts/email/password_reset.html", context - ) + email_html = render_to_string("accounts/email/password_reset.html", context) # Narrow and validate email type for the static checker email = getattr(self.user, "email", None) @@ -218,15 +206,11 @@ class PasswordChangeSerializer(serializers.Serializer): Serializer for password change """ - old_password = serializers.CharField( - max_length=128, style={"input_type": "password"} - ) + old_password = serializers.CharField(max_length=128, style={"input_type": "password"}) new_password = serializers.CharField( max_length=128, validators=[validate_password], style={"input_type": "password"} ) - new_password_confirm = serializers.CharField( - max_length=128, style={"input_type": "password"} - ) + new_password_confirm = serializers.CharField(max_length=128, style={"input_type": "password"}) def validate_old_password(self, value): """Validate old password is correct""" @@ -241,9 +225,7 @@ class PasswordChangeSerializer(serializers.Serializer): new_password_confirm = attrs.get("new_password_confirm") if new_password != new_password_confirm: - raise serializers.ValidationError( - {"new_password_confirm": "New passwords do not match."} - ) + raise serializers.ValidationError({"new_password_confirm": "New passwords do not match."}) return attrs diff --git a/backend/apps/accounts/services.py b/backend/apps/accounts/services.py index 779a3d9c..9b098489 100644 --- a/backend/apps/accounts/services.py +++ b/backend/apps/accounts/services.py @@ -81,21 +81,15 @@ class AccountService: """ # Verify old password if not user.check_password(old_password): - logger.warning( - f"Password change failed: incorrect current password for user {user.id}" - ) - return { - 'success': False, - 'message': "Current password is incorrect", - 'redirect_url': None - } + logger.warning(f"Password change failed: incorrect current password for user {user.id}") + return {"success": False, "message": "Current password is incorrect", "redirect_url": None} # Validate new password if not AccountService.validate_password(new_password): return { - 'success': False, - 'message': "Password must be at least 8 characters and contain uppercase, lowercase, and numbers", - 'redirect_url': None + "success": False, + "message": "Password must be at least 8 characters and contain uppercase, lowercase, and numbers", + "redirect_url": None, } # Update password @@ -111,9 +105,9 @@ class AccountService: logger.info(f"Password changed successfully for user {user.id}") return { - 'success': True, - 'message': "Password changed successfully. Please check your email for confirmation.", - 'redirect_url': None + "success": True, + "message": "Password changed successfully. Please check your email for confirmation.", + "redirect_url": None, } @staticmethod @@ -125,9 +119,7 @@ class AccountService: "site_name": site.name, } - email_html = render_to_string( - "accounts/email/password_change_confirmation.html", context - ) + email_html = render_to_string("accounts/email/password_change_confirmation.html", context) try: EmailService.send_email( @@ -166,26 +158,17 @@ class AccountService: } """ if not new_email: - return { - 'success': False, - 'message': "New email is required" - } + return {"success": False, "message": "New email is required"} # Check if email is already in use if User.objects.filter(email=new_email).exclude(id=user.id).exists(): - return { - 'success': False, - 'message': "This email address is already in use" - } + return {"success": False, "message": "This email address is already in use"} # Generate verification token token = get_random_string(64) # Create or update email verification record - EmailVerification.objects.update_or_create( - user=user, - defaults={"token": token} - ) + EmailVerification.objects.update_or_create(user=user, defaults={"token": token}) # Store pending email user.pending_email = new_email @@ -196,18 +179,10 @@ class AccountService: logger.info(f"Email change initiated for user {user.id} to {new_email}") - return { - 'success': True, - 'message': "Verification email sent to your new email address" - } + return {"success": True, "message": "Verification email sent to your new email address"} @staticmethod - def _send_email_verification( - request: HttpRequest, - user: User, - new_email: str, - token: str - ) -> None: + def _send_email_verification(request: HttpRequest, user: User, new_email: str, token: str) -> None: """Send email verification for email change.""" from django.urls import reverse @@ -245,22 +220,14 @@ class AccountService: Dictionary with success status and message """ try: - verification = EmailVerification.objects.select_related("user").get( - token=token - ) + verification = EmailVerification.objects.select_related("user").get(token=token) except EmailVerification.DoesNotExist: - return { - 'success': False, - 'message': "Invalid or expired verification token" - } + return {"success": False, "message": "Invalid or expired verification token"} user = verification.user if not user.pending_email: - return { - 'success': False, - 'message': "No pending email change found" - } + return {"success": False, "message": "No pending email change found"} # Update email old_email = user.email @@ -273,10 +240,7 @@ class AccountService: logger.info(f"Email changed for user {user.id} from {old_email} to {user.email}") - return { - 'success': True, - 'message': "Email address updated successfully" - } + return {"success": True, "message": "Email address updated successfully"} class UserDeletionService: @@ -337,39 +301,17 @@ class UserDeletionService: # Count submissions before transfer submission_counts = { - "park_reviews": getattr( - user, "park_reviews", user.__class__.objects.none() - ).count(), - "ride_reviews": getattr( - user, "ride_reviews", user.__class__.objects.none() - ).count(), - "uploaded_park_photos": getattr( - user, "uploaded_park_photos", user.__class__.objects.none() - ).count(), - "uploaded_ride_photos": getattr( - user, "uploaded_ride_photos", user.__class__.objects.none() - ).count(), - "top_lists": getattr( - user, "top_lists", user.__class__.objects.none() - ).count(), - "edit_submissions": getattr( - user, "edit_submissions", user.__class__.objects.none() - ).count(), - "photo_submissions": getattr( - user, "photo_submissions", user.__class__.objects.none() - ).count(), - "moderated_park_reviews": getattr( - user, "moderated_park_reviews", user.__class__.objects.none() - ).count(), - "moderated_ride_reviews": getattr( - user, "moderated_ride_reviews", user.__class__.objects.none() - ).count(), - "handled_submissions": getattr( - user, "handled_submissions", user.__class__.objects.none() - ).count(), - "handled_photos": getattr( - user, "handled_photos", user.__class__.objects.none() - ).count(), + "park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(), + "ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(), + "uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(), + "uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(), + "top_lists": getattr(user, "top_lists", user.__class__.objects.none()).count(), + "edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(), + "photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(), + "moderated_park_reviews": getattr(user, "moderated_park_reviews", user.__class__.objects.none()).count(), + "moderated_ride_reviews": getattr(user, "moderated_ride_reviews", user.__class__.objects.none()).count(), + "handled_submissions": getattr(user, "handled_submissions", user.__class__.objects.none()).count(), + "handled_photos": getattr(user, "handled_photos", user.__class__.objects.none()).count(), } # Transfer all submissions to deleted user @@ -440,11 +382,17 @@ class UserDeletionService: return False, "Cannot delete the system deleted user placeholder" if user.is_superuser: - return False, "Superuser accounts cannot be deleted for security reasons. Please contact system administrator or remove superuser privileges first." + return ( + False, + "Superuser accounts cannot be deleted for security reasons. Please contact system administrator or remove superuser privileges first.", + ) # Check if user has critical admin role if user.role == User.Roles.ADMIN and user.is_staff: - return False, "Admin accounts with staff privileges cannot be deleted. Please remove admin privileges first or contact system administrator." + return ( + False, + "Admin accounts with staff privileges cannot be deleted. Please remove admin privileges first or contact system administrator.", + ) # Add any other business rules here @@ -492,9 +440,7 @@ class UserDeletionService: site = Site.objects.get_current() except Site.DoesNotExist: # Fallback to default site - site = Site.objects.get_or_create( - id=1, defaults={"domain": "localhost:8000", "name": "localhost:8000"} - )[0] + site = Site.objects.get_or_create(id=1, defaults={"domain": "localhost:8000", "name": "localhost:8000"})[0] # Prepare email context context = { @@ -502,9 +448,7 @@ class UserDeletionService: "verification_code": deletion_request.verification_code, "expires_at": deletion_request.expires_at, "site_name": getattr(settings, "SITE_NAME", "ThrillWiki"), - "frontend_domain": getattr( - settings, "FRONTEND_DOMAIN", "http://localhost:3000" - ), + "frontend_domain": getattr(settings, "FRONTEND_DOMAIN", "http://localhost:3000"), } # Render email content @@ -564,11 +508,9 @@ The ThrillWiki Team ValueError: If verification fails """ try: - deletion_request = UserDeletionRequest.objects.get( - verification_code=verification_code - ) + deletion_request = UserDeletionRequest.objects.get(verification_code=verification_code) except UserDeletionRequest.DoesNotExist: - raise ValueError("Invalid verification code") + raise ValueError("Invalid verification code") from None # Check if request is still valid if not deletion_request.is_valid(): diff --git a/backend/apps/accounts/services/__init__.py b/backend/apps/accounts/services/__init__.py index 0134fad4..e4ce3c19 100644 --- a/backend/apps/accounts/services/__init__.py +++ b/backend/apps/accounts/services/__init__.py @@ -8,4 +8,4 @@ including social provider management, user authentication, and profile services. from .social_provider_service import SocialProviderService from .user_deletion_service import UserDeletionService -__all__ = ['SocialProviderService', 'UserDeletionService'] +__all__ = ["SocialProviderService", "UserDeletionService"] diff --git a/backend/apps/accounts/services/notification_service.py b/backend/apps/accounts/services/notification_service.py index 0b4e7a25..30de90e0 100644 --- a/backend/apps/accounts/services/notification_service.py +++ b/backend/apps/accounts/services/notification_service.py @@ -139,7 +139,9 @@ class NotificationService: UserNotification: The created notification """ title = f"Your {submission_type} needs attention" - message = f"Your {submission_type} submission has been reviewed and needs some changes before it can be approved." + message = ( + f"Your {submission_type} submission has been reviewed and needs some changes before it can be approved." + ) message += f"\n\nReason: {rejection_reason}" if additional_message: @@ -216,9 +218,7 @@ class NotificationService: preferences = NotificationPreference.objects.create(user=user) # Send email notification if enabled - if preferences.should_send_notification( - notification.notification_type, "email" - ): + if preferences.should_send_notification(notification.notification_type, "email"): NotificationService._send_email_notification(notification) # Toast notifications are always created (the notification object itself) @@ -261,14 +261,10 @@ class NotificationService: notification.email_sent_at = timezone.now() notification.save(update_fields=["email_sent", "email_sent_at"]) - logger.info( - f"Email notification sent to {user.email} for notification {notification.id}" - ) + logger.info(f"Email notification sent to {user.email} for notification {notification.id}") except Exception as e: - logger.error( - f"Failed to send email notification {notification.id}: {str(e)}" - ) + logger.error(f"Failed to send email notification {notification.id}: {str(e)}") @staticmethod def get_user_notifications( @@ -298,9 +294,7 @@ class NotificationService: queryset = queryset.filter(notification_type__in=notification_types) # Exclude expired notifications - queryset = queryset.filter( - models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=timezone.now()) - ) + queryset = queryset.filter(models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=timezone.now())) if limit: queryset = queryset[:limit] @@ -308,9 +302,7 @@ class NotificationService: return list(queryset) @staticmethod - def mark_notifications_read( - user: User, notification_ids: list[int] | None = None - ) -> int: + def mark_notifications_read(user: User, notification_ids: list[int] | None = None) -> int: """ Mark notifications as read for a user. @@ -341,9 +333,7 @@ class NotificationService: """ cutoff_date = timezone.now() - timedelta(days=days) - old_notifications = UserNotification.objects.filter( - is_read=True, read_at__lt=cutoff_date - ) + old_notifications = UserNotification.objects.filter(is_read=True, read_at__lt=cutoff_date) count = old_notifications.count() old_notifications.delete() diff --git a/backend/apps/accounts/services/social_provider_service.py b/backend/apps/accounts/services/social_provider_service.py index b64063e4..c7bee3a9 100644 --- a/backend/apps/accounts/services/social_provider_service.py +++ b/backend/apps/accounts/services/social_provider_service.py @@ -40,23 +40,20 @@ class SocialProviderService: """ try: # Count remaining social accounts after disconnection - remaining_social_accounts = user.socialaccount_set.exclude( - provider=provider - ).count() + remaining_social_accounts = user.socialaccount_set.exclude(provider=provider).count() # Check if user has email/password auth - has_password_auth = ( - user.email and - user.has_usable_password() and - bool(user.password) # Not empty/unusable - ) + has_password_auth = user.email and user.has_usable_password() and bool(user.password) # Not empty/unusable # Allow disconnection only if alternative auth exists can_disconnect = remaining_social_accounts > 0 or has_password_auth if not can_disconnect: if remaining_social_accounts == 0 and not has_password_auth: - return False, "Cannot disconnect your only authentication method. Please set up a password or connect another social provider first." + return ( + False, + "Cannot disconnect your only authentication method. Please set up a password or connect another social provider first.", + ) elif not has_password_auth: return False, "Please set up email/password authentication before disconnecting this provider." else: @@ -65,8 +62,7 @@ class SocialProviderService: return True, "Provider can be safely disconnected." except Exception as e: - logger.error( - f"Error checking disconnect permission for user {user.id}, provider {provider}: {e}") + logger.error(f"Error checking disconnect permission for user {user.id}, provider {provider}: {e}") return False, "Unable to verify disconnection safety. Please try again." @staticmethod @@ -84,18 +80,16 @@ class SocialProviderService: connected_providers = [] for social_account in user.socialaccount_set.all(): - can_disconnect, reason = SocialProviderService.can_disconnect_provider( - user, social_account.provider - ) + can_disconnect, reason = SocialProviderService.can_disconnect_provider(user, social_account.provider) provider_info = { - 'provider': social_account.provider, - 'provider_name': social_account.get_provider().name, - 'uid': social_account.uid, - 'date_joined': social_account.date_joined, - 'can_disconnect': can_disconnect, - 'disconnect_reason': reason if not can_disconnect else None, - 'extra_data': social_account.extra_data + "provider": social_account.provider, + "provider_name": social_account.get_provider().name, + "uid": social_account.uid, + "date_joined": social_account.date_joined, + "can_disconnect": can_disconnect, + "disconnect_reason": reason if not can_disconnect else None, + "extra_data": social_account.extra_data, } connected_providers.append(provider_info) @@ -122,28 +116,25 @@ class SocialProviderService: available_providers = [] # Get all social apps configured for this site - social_apps = SocialApp.objects.filter(sites=site).order_by('provider') + social_apps = SocialApp.objects.filter(sites=site).order_by("provider") for social_app in social_apps: try: provider = registry.by_id(social_app.provider) provider_info = { - 'id': social_app.provider, - 'name': provider.name, - 'auth_url': request.build_absolute_uri( - f'/accounts/{social_app.provider}/login/' + "id": social_app.provider, + "name": provider.name, + "auth_url": request.build_absolute_uri(f"/accounts/{social_app.provider}/login/"), + "connect_url": request.build_absolute_uri( + f"/api/v1/auth/social/connect/{social_app.provider}/" ), - 'connect_url': request.build_absolute_uri( - f'/api/v1/auth/social/connect/{social_app.provider}/' - ) } available_providers.append(provider_info) except Exception as e: - logger.warning( - f"Error processing provider {social_app.provider}: {e}") + logger.warning(f"Error processing provider {social_app.provider}: {e}") continue return available_providers @@ -166,8 +157,7 @@ class SocialProviderService: """ try: # First check if disconnection is allowed - can_disconnect, reason = SocialProviderService.can_disconnect_provider( - user, provider) + can_disconnect, reason = SocialProviderService.can_disconnect_provider(user, provider) if not can_disconnect: return False, reason @@ -182,8 +172,7 @@ class SocialProviderService: deleted_count = social_accounts.count() social_accounts.delete() - logger.info( - f"User {user.id} disconnected {deleted_count} {provider} account(s)") + logger.info(f"User {user.id} disconnected {deleted_count} {provider} account(s)") return True, f"{provider.title()} account disconnected successfully." @@ -205,31 +194,24 @@ class SocialProviderService: try: connected_providers = SocialProviderService.get_connected_providers(user) - has_password_auth = ( - user.email and - user.has_usable_password() and - bool(user.password) - ) + has_password_auth = user.email and user.has_usable_password() and bool(user.password) - auth_methods_count = len(connected_providers) + \ - (1 if has_password_auth else 0) + auth_methods_count = len(connected_providers) + (1 if has_password_auth else 0) return { - 'user_id': user.id, - 'username': user.username, - 'email': user.email, - 'has_password_auth': has_password_auth, - 'connected_providers': connected_providers, - 'total_auth_methods': auth_methods_count, - 'can_disconnect_any': auth_methods_count > 1, - 'requires_password_setup': not has_password_auth and len(connected_providers) == 1 + "user_id": user.id, + "username": user.username, + "email": user.email, + "has_password_auth": has_password_auth, + "connected_providers": connected_providers, + "total_auth_methods": auth_methods_count, + "can_disconnect_any": auth_methods_count > 1, + "requires_password_setup": not has_password_auth and len(connected_providers) == 1, } except Exception as e: logger.error(f"Error getting auth status for user {user.id}: {e}") - return { - 'error': 'Unable to retrieve authentication status' - } + return {"error": "Unable to retrieve authentication status"} @staticmethod def validate_provider_exists(provider: str) -> tuple[bool, str]: diff --git a/backend/apps/accounts/services/user_deletion_service.py b/backend/apps/accounts/services/user_deletion_service.py index 389d331e..75ec16d8 100644 --- a/backend/apps/accounts/services/user_deletion_service.py +++ b/backend/apps/accounts/services/user_deletion_service.py @@ -59,7 +59,7 @@ class UserDeletionService: return False, "Cannot delete staff accounts" # Check for system users (if you have any special system accounts) - if hasattr(user, 'role') and user.role in ['ADMIN', 'MODERATOR']: + if hasattr(user, "role") and user.role in ["ADMIN", "MODERATOR"]: return False, "Cannot delete admin or moderator accounts" return True, None @@ -84,8 +84,7 @@ class UserDeletionService: raise ValueError(reason) # Generate verification code - verification_code = ''.join(secrets.choice( - string.ascii_uppercase + string.digits) for _ in range(8)) + verification_code = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(8)) # Set expiration (24 hours from now) expires_at = timezone.now() + timezone.timedelta(hours=24) @@ -97,8 +96,7 @@ class UserDeletionService: UserDeletionService._deletion_requests[verification_code] = deletion_request # Send verification email - UserDeletionService._send_deletion_verification_email( - user, verification_code, expires_at) + UserDeletionService._send_deletion_verification_email(user, verification_code, expires_at) return deletion_request @@ -136,10 +134,10 @@ class UserDeletionService: del UserDeletionService._deletion_requests[verification_code] # Add verification info to result - result['deletion_request'] = { - 'verification_code': verification_code, - 'created_at': deletion_request.created_at, - 'verified_at': timezone.now(), + result["deletion_request"] = { + "verification_code": verification_code, + "created_at": deletion_request.created_at, + "verified_at": timezone.now(), } return result @@ -180,13 +178,13 @@ class UserDeletionService: """ # Get or create the "deleted_user" placeholder deleted_user_placeholder, created = User.objects.get_or_create( - username='deleted_user', + username="deleted_user", defaults={ - 'email': 'deleted@thrillwiki.com', - 'first_name': 'Deleted', - 'last_name': 'User', - 'is_active': False, - } + "email": "deleted@thrillwiki.com", + "first_name": "Deleted", + "last_name": "User", + "is_active": False, + }, ) # Count submissions before transfer @@ -197,22 +195,22 @@ class UserDeletionService: # Store user info before deletion deleted_user_info = { - 'username': user.username, - 'user_id': getattr(user, 'user_id', user.id), - 'email': user.email, - 'date_joined': user.date_joined, + "username": user.username, + "user_id": getattr(user, "user_id", user.id), + "email": user.email, + "date_joined": user.date_joined, } # Delete the user account user.delete() return { - 'deleted_user': deleted_user_info, - 'preserved_submissions': submission_counts, - 'transferred_to': { - 'username': deleted_user_placeholder.username, - 'user_id': getattr(deleted_user_placeholder, 'user_id', deleted_user_placeholder.id), - } + "deleted_user": deleted_user_info, + "preserved_submissions": submission_counts, + "transferred_to": { + "username": deleted_user_placeholder.username, + "user_id": getattr(deleted_user_placeholder, "user_id", deleted_user_placeholder.id), + }, } @staticmethod @@ -222,20 +220,13 @@ class UserDeletionService: # Count different types of submissions # Note: These are placeholder counts - adjust based on your actual models - counts['park_reviews'] = getattr( - user, 'park_reviews', user.__class__.objects.none()).count() - counts['ride_reviews'] = getattr( - user, 'ride_reviews', user.__class__.objects.none()).count() - counts['uploaded_park_photos'] = getattr( - user, 'uploaded_park_photos', user.__class__.objects.none()).count() - counts['uploaded_ride_photos'] = getattr( - user, 'uploaded_ride_photos', user.__class__.objects.none()).count() - counts['top_lists'] = getattr( - user, 'top_lists', user.__class__.objects.none()).count() - counts['edit_submissions'] = getattr( - user, 'edit_submissions', user.__class__.objects.none()).count() - counts['photo_submissions'] = getattr( - user, 'photo_submissions', user.__class__.objects.none()).count() + counts["park_reviews"] = getattr(user, "park_reviews", user.__class__.objects.none()).count() + counts["ride_reviews"] = getattr(user, "ride_reviews", user.__class__.objects.none()).count() + counts["uploaded_park_photos"] = getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count() + counts["uploaded_ride_photos"] = getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count() + counts["top_lists"] = getattr(user, "top_lists", user.__class__.objects.none()).count() + counts["edit_submissions"] = getattr(user, "edit_submissions", user.__class__.objects.none()).count() + counts["photo_submissions"] = getattr(user, "photo_submissions", user.__class__.objects.none()).count() return counts @@ -247,30 +238,30 @@ class UserDeletionService: # Note: Adjust these based on your actual model relationships # Park reviews - if hasattr(user, 'park_reviews'): + if hasattr(user, "park_reviews"): user.park_reviews.all().update(user=placeholder_user) # Ride reviews - if hasattr(user, 'ride_reviews'): + if hasattr(user, "ride_reviews"): user.ride_reviews.all().update(user=placeholder_user) # Uploaded photos - if hasattr(user, 'uploaded_park_photos'): + if hasattr(user, "uploaded_park_photos"): user.uploaded_park_photos.all().update(user=placeholder_user) - if hasattr(user, 'uploaded_ride_photos'): + if hasattr(user, "uploaded_ride_photos"): user.uploaded_ride_photos.all().update(user=placeholder_user) # Top lists - if hasattr(user, 'top_lists'): + if hasattr(user, "top_lists"): user.top_lists.all().update(user=placeholder_user) # Edit submissions - if hasattr(user, 'edit_submissions'): + if hasattr(user, "edit_submissions"): user.edit_submissions.all().update(user=placeholder_user) # Photo submissions - if hasattr(user, 'photo_submissions'): + if hasattr(user, "photo_submissions"): user.photo_submissions.all().update(user=placeholder_user) @staticmethod @@ -278,18 +269,16 @@ class UserDeletionService: """Send verification email for account deletion.""" try: context = { - 'user': user, - 'verification_code': verification_code, - 'expires_at': expires_at, - 'site_name': 'ThrillWiki', - 'site_url': getattr(settings, 'SITE_URL', 'https://thrillwiki.com'), + "user": user, + "verification_code": verification_code, + "expires_at": expires_at, + "site_name": "ThrillWiki", + "site_url": getattr(settings, "SITE_URL", "https://thrillwiki.com"), } - subject = 'ThrillWiki: Confirm Account Deletion' - html_message = render_to_string( - 'emails/account_deletion_verification.html', context) - plain_message = render_to_string( - 'emails/account_deletion_verification.txt', context) + subject = "ThrillWiki: Confirm Account Deletion" + html_message = render_to_string("emails/account_deletion_verification.html", context) + plain_message = render_to_string("emails/account_deletion_verification.txt", context) send_mail( subject=subject, @@ -303,6 +292,5 @@ class UserDeletionService: logger.info(f"Deletion verification email sent to {user.email}") except Exception as e: - logger.error( - f"Failed to send deletion verification email to {user.email}: {str(e)}") + logger.error(f"Failed to send deletion verification email to {user.email}: {str(e)}") raise diff --git a/backend/apps/accounts/signals.py b/backend/apps/accounts/signals.py index 4367cd7d..864f4177 100644 --- a/backend/apps/accounts/signals.py +++ b/backend/apps/accounts/signals.py @@ -108,7 +108,7 @@ def sync_user_role_with_groups(sender, instance, **kwargs): User.Roles.MODERATOR, ]: instance.is_staff = True - elif old_instance.role in [ + elif old_instance.role in [ # noqa: SIM102 User.Roles.ADMIN, User.Roles.MODERATOR, ]: @@ -119,9 +119,7 @@ def sync_user_role_with_groups(sender, instance, **kwargs): except User.DoesNotExist: pass except Exception as e: - print( - f"Error syncing role with groups for user {instance.username}: {str(e)}" - ) + print(f"Error syncing role with groups for user {instance.username}: {str(e)}") def create_default_groups(): @@ -200,19 +198,19 @@ def log_successful_login(sender, user, request, **kwargs): """ try: # Get IP address - x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') - ip_address = x_forwarded_for.split(',')[0].strip() if x_forwarded_for else request.META.get('REMOTE_ADDR') + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") + ip_address = x_forwarded_for.split(",")[0].strip() if x_forwarded_for else request.META.get("REMOTE_ADDR") # Get user agent - user_agent = request.META.get('HTTP_USER_AGENT', '')[:500] + user_agent = request.META.get("HTTP_USER_AGENT", "")[:500] # Determine login method from session or request - login_method = 'PASSWORD' - if hasattr(request, 'session'): - sociallogin = getattr(request, '_sociallogin', None) + login_method = "PASSWORD" + if hasattr(request, "session"): + sociallogin = getattr(request, "_sociallogin", None) if sociallogin: provider = sociallogin.account.provider.upper() - if provider in ['GOOGLE', 'DISCORD']: + if provider in ["GOOGLE", "DISCORD"]: login_method = provider # Create login history entry diff --git a/backend/apps/accounts/tests.py b/backend/apps/accounts/tests.py index b54eac7e..acaa876c 100644 --- a/backend/apps/accounts/tests.py +++ b/backend/apps/accounts/tests.py @@ -113,16 +113,10 @@ class SignalsTestCase(TestCase): moderator_group = Group.objects.get(name=User.Roles.MODERATOR) self.assertIsNotNone(moderator_group) - self.assertTrue( - moderator_group.permissions.filter(codename="change_review").exists() - ) - self.assertFalse( - moderator_group.permissions.filter(codename="change_user").exists() - ) + self.assertTrue(moderator_group.permissions.filter(codename="change_review").exists()) + self.assertFalse(moderator_group.permissions.filter(codename="change_user").exists()) admin_group = Group.objects.get(name=User.Roles.ADMIN) self.assertIsNotNone(admin_group) - self.assertTrue( - admin_group.permissions.filter(codename="change_review").exists() - ) + self.assertTrue(admin_group.permissions.filter(codename="change_review").exists()) self.assertTrue(admin_group.permissions.filter(codename="change_user").exists()) diff --git a/backend/apps/accounts/tests/test_admin.py b/backend/apps/accounts/tests/test_admin.py index d9971c65..3092c1f3 100644 --- a/backend/apps/accounts/tests/test_admin.py +++ b/backend/apps/accounts/tests/test_admin.py @@ -150,6 +150,3 @@ class TestPasswordResetAdmin(TestCase): request.user = UserModel(is_superuser=True) actions = self.admin.get_actions(request) assert "cleanup_old_tokens" in actions - - - diff --git a/backend/apps/accounts/tests/test_model_constraints.py b/backend/apps/accounts/tests/test_model_constraints.py index 92f6dff7..db160c01 100644 --- a/backend/apps/accounts/tests/test_model_constraints.py +++ b/backend/apps/accounts/tests/test_model_constraints.py @@ -85,16 +85,16 @@ class UserIndexTests(TestCase): def test_is_banned_field_is_indexed(self): """Verify is_banned field has db_index=True.""" - field = User._meta.get_field('is_banned') + field = User._meta.get_field("is_banned") self.assertTrue(field.db_index) def test_role_field_is_indexed(self): """Verify role field has db_index=True.""" - field = User._meta.get_field('role') + field = User._meta.get_field("role") self.assertTrue(field.db_index) def test_composite_index_exists(self): """Verify composite index on (is_banned, role) exists.""" indexes = User._meta.indexes index_names = [idx.name for idx in indexes] - self.assertIn('accounts_user_banned_role_idx', index_names) + self.assertIn("accounts_user_banned_role_idx", index_names) diff --git a/backend/apps/accounts/tests/test_user_deletion.py b/backend/apps/accounts/tests/test_user_deletion.py index 5a8910d8..5def6d2b 100644 --- a/backend/apps/accounts/tests/test_user_deletion.py +++ b/backend/apps/accounts/tests/test_user_deletion.py @@ -15,9 +15,7 @@ class UserDeletionServiceTest(TestCase): def setUp(self): """Set up test data.""" # Create test users - self.user = User.objects.create_user( - username="testuser", email="test@example.com", password="testpass123" - ) + self.user = User.objects.create_user(username="testuser", email="test@example.com", password="testpass123") self.admin_user = User.objects.create_user( username="admin", @@ -27,13 +25,9 @@ class UserDeletionServiceTest(TestCase): ) # Create user profiles - UserProfile.objects.create( - user=self.user, display_name="Test User", bio="Test bio" - ) + UserProfile.objects.create(user=self.user, display_name="Test User", bio="Test bio") - UserProfile.objects.create( - user=self.admin_user, display_name="Admin User", bio="Admin bio" - ) + UserProfile.objects.create(user=self.admin_user, display_name="Admin User", bio="Admin bio") def test_get_or_create_deleted_user(self): """Test that deleted user placeholder is created correctly.""" @@ -108,9 +102,7 @@ class UserDeletionServiceTest(TestCase): with self.assertRaises(ValueError) as context: UserDeletionService.delete_user_preserve_submissions(deleted_user) - self.assertIn( - "Cannot delete the system deleted user placeholder", str(context.exception) - ) + self.assertIn("Cannot delete the system deleted user placeholder", str(context.exception)) def test_delete_user_with_submissions_transfers_correctly(self): """Test that user submissions are transferred to deleted user placeholder.""" @@ -141,7 +133,7 @@ class UserDeletionServiceTest(TestCase): original_user_count = User.objects.count() # Mock a failure during the deletion process - with self.assertRaises(Exception), transaction.atomic(): + with self.assertRaises(Exception), transaction.atomic(): # noqa: B017 # Start the deletion process UserDeletionService.get_or_create_deleted_user() diff --git a/backend/apps/accounts/views.py b/backend/apps/accounts/views.py index 6c1c0408..bdd6703e 100644 --- a/backend/apps/accounts/views.py +++ b/backend/apps/accounts/views.py @@ -61,11 +61,7 @@ class CustomLoginView(TurnstileMixin, LoginView): context={"user_id": user.id, "username": user.username}, request=self.request, ) - return ( - HttpResponseClientRefresh() - if getattr(self.request, "htmx", False) - else response - ) + return HttpResponseClientRefresh() if getattr(self.request, "htmx", False) else response def form_invalid(self, form): log_security_event( @@ -116,11 +112,7 @@ class CustomSignupView(TurnstileMixin, SignupView): }, request=self.request, ) - return ( - HttpResponseClientRefresh() - if getattr(self.request, "htmx", False) - else response - ) + return HttpResponseClientRefresh() if getattr(self.request, "htmx", False) else response def form_invalid(self, form): if getattr(self.request, "htmx", False): @@ -260,9 +252,7 @@ class SettingsView(LoginRequiredMixin, TemplateView): and bool(re.search(r"[0-9]", password)) ) - def _send_password_change_confirmation( - self, request: HttpRequest, user: User - ) -> None: + def _send_password_change_confirmation(self, request: HttpRequest, user: User) -> None: """Send password change confirmation email.""" site = get_current_site(request) context = { @@ -270,9 +260,7 @@ class SettingsView(LoginRequiredMixin, TemplateView): "site_name": site.name, } - email_html = render_to_string( - "accounts/email/password_change_confirmation.html", context - ) + email_html = render_to_string("accounts/email/password_change_confirmation.html", context) EmailService.send_email( to=user.email, @@ -282,9 +270,7 @@ class SettingsView(LoginRequiredMixin, TemplateView): html=email_html, ) - def _handle_password_change( - self, request: HttpRequest - ) -> HttpResponseRedirect | None: + def _handle_password_change(self, request: HttpRequest) -> HttpResponseRedirect | None: user = cast(User, request.user) old_password = request.POST.get("old_password", "") new_password = request.POST.get("new_password", "") @@ -327,9 +313,7 @@ class SettingsView(LoginRequiredMixin, TemplateView): def _handle_email_change(self, request: HttpRequest) -> None: if new_email := request.POST.get("new_email"): self._send_email_verification(request, new_email) - messages.success( - request, "Verification email sent to your new email address" - ) + messages.success(request, "Verification email sent to your new email address") else: messages.error(request, "New email is required") @@ -385,9 +369,7 @@ def create_password_reset_token(user: User) -> str: return token -def send_password_reset_email( - user: User, site: Site | RequestSite, token: str -) -> None: +def send_password_reset_email(user: User, site: Site | RequestSite, token: str) -> None: reset_url = reverse("password_reset_confirm", kwargs={"token": token}) context = { "user": user, @@ -457,16 +439,12 @@ def handle_password_reset( messages.success(request, "Password reset successfully") -def send_password_reset_confirmation( - user: User, site: Site | RequestSite -) -> None: +def send_password_reset_confirmation(user: User, site: Site | RequestSite) -> None: context = { "user": user, "site_name": site.name, } - email_html = render_to_string( - "accounts/email/password_reset_complete.html", context - ) + email_html = render_to_string("accounts/email/password_reset_complete.html", context) EmailService.send_email( to=user.email, @@ -479,9 +457,7 @@ def send_password_reset_confirmation( def reset_password(request: HttpRequest, token: str) -> HttpResponse: try: - reset = PasswordReset.objects.select_related("user").get( - token=token, expires_at__gt=timezone.now(), used=False - ) + reset = PasswordReset.objects.select_related("user").get(token=token, expires_at__gt=timezone.now(), used=False) if request.method == "POST": if new_password := request.POST.get("new_password"): diff --git a/backend/apps/api/management/commands/seed_data.py b/backend/apps/api/management/commands/seed_data.py index ff0f338b..2068e84d 100644 --- a/backend/apps/api/management/commands/seed_data.py +++ b/backend/apps/api/management/commands/seed_data.py @@ -58,69 +58,67 @@ User = get_user_model() class Command(BaseCommand): - help = 'Seed the database with comprehensive test data for all models' + help = "Seed the database with comprehensive test data for all models" def add_arguments(self, parser): parser.add_argument( - '--clear', - action='store_true', - help='Clear existing data before seeding', + "--clear", + action="store_true", + help="Clear existing data before seeding", ) parser.add_argument( - '--users', + "--users", type=int, default=25, - help='Number of users to create (default: 25)', + help="Number of users to create (default: 25)", ) parser.add_argument( - '--companies', + "--companies", type=int, default=15, - help='Number of companies to create (default: 15)', + help="Number of companies to create (default: 15)", ) parser.add_argument( - '--parks', + "--parks", type=int, default=10, - help='Number of parks to create (default: 10)', + help="Number of parks to create (default: 10)", ) parser.add_argument( - '--rides', + "--rides", type=int, default=50, - help='Number of rides to create (default: 50)', + help="Number of rides to create (default: 50)", ) parser.add_argument( - '--ride-models', + "--ride-models", type=int, default=20, - help='Number of ride models to create (default: 20)', + help="Number of ride models to create (default: 20)", ) parser.add_argument( - '--reviews', + "--reviews", type=int, default=100, - help='Number of reviews to create (default: 100)', + help="Number of reviews to create (default: 100)", ) def handle(self, *args, **options): - self.stdout.write( - self.style.SUCCESS('Starting comprehensive data seeding...') - ) + self.stdout.write(self.style.SUCCESS("Starting comprehensive data seeding...")) - if options['clear']: + if options["clear"]: self.clear_data() with transaction.atomic(): # Create data in dependency order - users = self.create_users(options['users']) - companies = self.create_companies(options['companies']) - ride_models = self.create_ride_models(options['ride_models'], companies) - parks = self.create_parks(options['parks'], companies) - rides = self.create_rides(options['rides'], parks, companies, ride_models) + users = self.create_users(options["users"]) + companies = self.create_companies(options["companies"]) + ride_models = self.create_ride_models(options["ride_models"], companies) + parks = self.create_parks(options["parks"], companies) + rides = self.create_rides(options["rides"], parks, companies, ride_models) # Create content and interactions - self.create_reviews(options['reviews'], users, parks, rides) + self.create_reviews(options["reviews"], users, parks, rides) self.create_notifications(users) self.create_moderation_data(users, parks, rides) @@ -131,30 +129,39 @@ class Command(BaseCommand): # Create rankings and statistics self.create_rankings(rides) - self.stdout.write( - self.style.SUCCESS('✅ Data seeding completed successfully!') - ) + self.stdout.write(self.style.SUCCESS("✅ Data seeding completed successfully!")) self.print_summary() def clear_data(self): """Clear existing data in reverse dependency order""" - self.stdout.write('🗑️ Clearing existing data...') + self.stdout.write("🗑️ Clearing existing data...") models_to_clear = [ # Content and interactions (clear first) - UserNotification, NotificationPreference, - ParkReview, RideReview, ModerationAction, ModerationQueue, - + UserNotification, + NotificationPreference, + ParkReview, + RideReview, + ModerationAction, + ModerationQueue, # Media - ParkPhoto, RidePhoto, CloudflareImage, - + ParkPhoto, + RidePhoto, + CloudflareImage, # Core entities - RollerCoasterStats, Ride, ParkArea, Park, ParkLocation, - RideModel, CompanyHeadquarters, ParkCompany, RideCompany, - + RollerCoasterStats, + Ride, + ParkArea, + Park, + ParkLocation, + RideModel, + CompanyHeadquarters, + ParkCompany, + RideCompany, # Users (clear last due to foreign key dependencies) - UserDeletionRequest, UserProfile, User, - + UserDeletionRequest, + UserProfile, + User, # History HistoricalSlug, ] @@ -178,71 +185,113 @@ class Command(BaseCommand): count = model.objects.count() if count > 0: model.objects.all().delete() - self.stdout.write(f' Cleared {count} {model._meta.verbose_name_plural}') + self.stdout.write(f" Cleared {count} {model._meta.verbose_name_plural}") except Exception as e: self.stdout.write( - self.style.WARNING(f' ⚠️ Could not clear {model._meta.verbose_name_plural}: {str(e)}') + self.style.WARNING(f" ⚠️ Could not clear {model._meta.verbose_name_plural}: {str(e)}") ) # Continue with other models continue def create_users(self, count: int) -> list[User]: """Create diverse users with comprehensive profiles""" - self.stdout.write(f'👥 Creating {count} users...') + self.stdout.write(f"👥 Creating {count} users...") users = [] # Create admin user if it doesn't exist admin, created = User.objects.get_or_create( - username='admin', + username="admin", defaults={ - 'email': 'admin@thrillwiki.com', - 'role': 'ADMIN', - 'is_staff': True, - 'is_superuser': True, - 'display_name': 'ThrillWiki Admin', - 'theme_preference': 'dark', - 'privacy_level': 'public', - } + "email": "admin@thrillwiki.com", + "role": "ADMIN", + "is_staff": True, + "is_superuser": True, + "display_name": "ThrillWiki Admin", + "theme_preference": "dark", + "privacy_level": "public", + }, ) if created: - admin.set_password('admin123') + admin.set_password("admin123") admin.save() users.append(admin) # Create moderator if it doesn't exist moderator, created = User.objects.get_or_create( - username='moderator', + username="moderator", defaults={ - 'email': 'mod@thrillwiki.com', - 'role': 'MODERATOR', - 'is_staff': True, - 'display_name': 'Site Moderator', - 'theme_preference': 'light', - 'privacy_level': 'public', - } + "email": "mod@thrillwiki.com", + "role": "MODERATOR", + "is_staff": True, + "display_name": "Site Moderator", + "theme_preference": "light", + "privacy_level": "public", + }, ) if created: - moderator.set_password('mod123') + moderator.set_password("mod123") moderator.save() users.append(moderator) # Sample user data first_names = [ - 'Alex', 'Jordan', 'Taylor', 'Casey', 'Morgan', 'Riley', 'Avery', 'Quinn', - 'Blake', 'Cameron', 'Drew', 'Emery', 'Finley', 'Harper', 'Hayden', - 'Jamie', 'Kendall', 'Logan', 'Parker', 'Peyton', 'Reese', 'Sage', - 'Skyler', 'Sydney', 'Tanner' + "Alex", + "Jordan", + "Taylor", + "Casey", + "Morgan", + "Riley", + "Avery", + "Quinn", + "Blake", + "Cameron", + "Drew", + "Emery", + "Finley", + "Harper", + "Hayden", + "Jamie", + "Kendall", + "Logan", + "Parker", + "Peyton", + "Reese", + "Sage", + "Skyler", + "Sydney", + "Tanner", ] last_names = [ - 'Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Garcia', 'Miller', - 'Davis', 'Rodriguez', 'Martinez', 'Hernandez', 'Lopez', 'Gonzalez', - 'Wilson', 'Anderson', 'Thomas', 'Taylor', 'Moore', 'Jackson', 'Martin', - 'Lee', 'Perez', 'Thompson', 'White', 'Harris' + "Smith", + "Johnson", + "Williams", + "Brown", + "Jones", + "Garcia", + "Miller", + "Davis", + "Rodriguez", + "Martinez", + "Hernandez", + "Lopez", + "Gonzalez", + "Wilson", + "Anderson", + "Thomas", + "Taylor", + "Moore", + "Jackson", + "Martin", + "Lee", + "Perez", + "Thompson", + "White", + "Harris", ] - domains = ['gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com', 'icloud.com'] + domains = ["gmail.com", "yahoo.com", "hotmail.com", "outlook.com", "icloud.com"] # Create regular users for _i in range(count - 2): # -2 for admin and moderator @@ -254,11 +303,11 @@ class Command(BaseCommand): user = User.objects.create_user( username=username, email=email, - password='password123', + password="password123", display_name=f"{first_name} {last_name}", - role=random.choice(['USER'] * 8 + ['MODERATOR']), - theme_preference=random.choice(['light', 'dark']), - privacy_level=random.choice(['public', 'friends', 'private']), + role=random.choice(["USER"] * 8 + ["MODERATOR"]), + theme_preference=random.choice(["light", "dark"]), + privacy_level=random.choice(["public", "friends", "private"]), email_notifications=random.choice([True, False]), push_notifications=random.choice([True, False]), show_email=random.choice([True, False]), @@ -271,28 +320,28 @@ class Command(BaseCommand): # Create detailed notification preferences user.notification_preferences = { - 'email': { - 'reviews': random.choice([True, False]), - 'submissions': random.choice([True, False]), - 'social': random.choice([True, False]), - 'system': random.choice([True, False]), + "email": { + "reviews": random.choice([True, False]), + "submissions": random.choice([True, False]), + "social": random.choice([True, False]), + "system": random.choice([True, False]), }, - 'push': { - 'reviews': random.choice([True, False]), - 'submissions': random.choice([True, False]), - 'social': random.choice([True, False]), - 'achievements': random.choice([True, False]), + "push": { + "reviews": random.choice([True, False]), + "submissions": random.choice([True, False]), + "social": random.choice([True, False]), + "achievements": random.choice([True, False]), + }, + "in_app": { + "all": random.choice([True, False]), }, - 'in_app': { - 'all': random.choice([True, False]), - } } user.save() # Create user profile with ride credits profile = UserProfile.objects.get(user=user) profile.bio = f"Thrill seeker from {random.choice(['California', 'Florida', 'Ohio', 'Pennsylvania', 'Texas'])}. Love roller coasters!" - profile.pronouns = random.choice(['he/him', 'she/her', 'they/them', '']) + profile.pronouns = random.choice(["he/him", "she/her", "they/them", ""]) profile.coaster_credits = random.randint(0, 500) profile.dark_ride_credits = random.randint(0, 100) profile.flat_ride_credits = random.randint(0, 200) @@ -309,122 +358,122 @@ class Command(BaseCommand): profile.save() users.append(user) - self.stdout.write(f' ✅ Created {len(users)} users') + self.stdout.write(f" ✅ Created {len(users)} users") return users def create_companies(self, count: int) -> list: """Create companies with different roles""" - self.stdout.write(f'🏢 Creating {count} companies...') + self.stdout.write(f"🏢 Creating {count} companies...") companies = [] # Major theme park operators operators_data = [ - ('Walt Disney Company', ['OPERATOR', 'PROPERTY_OWNER'], 1923, 'Burbank, CA, USA'), - ('Universal Parks & Resorts', ['OPERATOR', 'PROPERTY_OWNER'], 1964, 'Orlando, FL, USA'), - ('Six Flags Entertainment', ['OPERATOR'], 1961, 'Arlington, TX, USA'), - ('Cedar Fair', ['OPERATOR'], 1983, 'Sandusky, OH, USA'), - ('SeaWorld Parks', ['OPERATOR'], 1964, 'Orlando, FL, USA'), - ('Busch Gardens', ['OPERATOR'], 1959, 'Tampa, FL, USA'), - ('Knott\'s Berry Farm', ['OPERATOR'], 1920, 'Buena Park, CA, USA'), + ("Walt Disney Company", ["OPERATOR", "PROPERTY_OWNER"], 1923, "Burbank, CA, USA"), + ("Universal Parks & Resorts", ["OPERATOR", "PROPERTY_OWNER"], 1964, "Orlando, FL, USA"), + ("Six Flags Entertainment", ["OPERATOR"], 1961, "Arlington, TX, USA"), + ("Cedar Fair", ["OPERATOR"], 1983, "Sandusky, OH, USA"), + ("SeaWorld Parks", ["OPERATOR"], 1964, "Orlando, FL, USA"), + ("Busch Gardens", ["OPERATOR"], 1959, "Tampa, FL, USA"), + ("Knott's Berry Farm", ["OPERATOR"], 1920, "Buena Park, CA, USA"), ] # Major ride manufacturers manufacturers_data = [ - ('Bolliger & Mabillard', ['MANUFACTURER'], 1988, 'Monthey, Switzerland'), - ('Intamin', ['MANUFACTURER'], 1967, 'Schaan, Liechtenstein'), - ('Vekoma', ['MANUFACTURER'], 1926, 'Vlodrop, Netherlands'), - ('Rocky Mountain Construction', ['MANUFACTURER'], 2001, 'Hayden, ID, USA'), - ('Mack Rides', ['MANUFACTURER'], 1780, 'Waldkirch, Germany'), - ('Gerstlauer', ['MANUFACTURER'], 1982, 'Münsterhausen, Germany'), - ('Premier Rides', ['MANUFACTURER'], 1994, 'Baltimore, MD, USA'), - ('S&S Worldwide', ['MANUFACTURER'], 1994, 'Logan, UT, USA'), + ("Bolliger & Mabillard", ["MANUFACTURER"], 1988, "Monthey, Switzerland"), + ("Intamin", ["MANUFACTURER"], 1967, "Schaan, Liechtenstein"), + ("Vekoma", ["MANUFACTURER"], 1926, "Vlodrop, Netherlands"), + ("Rocky Mountain Construction", ["MANUFACTURER"], 2001, "Hayden, ID, USA"), + ("Mack Rides", ["MANUFACTURER"], 1780, "Waldkirch, Germany"), + ("Gerstlauer", ["MANUFACTURER"], 1982, "Münsterhausen, Germany"), + ("Premier Rides", ["MANUFACTURER"], 1994, "Baltimore, MD, USA"), + ("S&S Worldwide", ["MANUFACTURER"], 1994, "Logan, UT, USA"), ] # Ride designers designers_data = [ - ('Werner Stengel', ['DESIGNER'], 1965, 'Munich, Germany'), - ('Alan Schilke', ['DESIGNER'], 1990, 'Hayden, ID, USA'), - ('John Wardley', ['DESIGNER'], 1970, 'London, UK'), + ("Werner Stengel", ["DESIGNER"], 1965, "Munich, Germany"), + ("Alan Schilke", ["DESIGNER"], 1990, "Hayden, ID, USA"), + ("John Wardley", ["DESIGNER"], 1970, "London, UK"), ] all_company_data = operators_data + manufacturers_data + designers_data for name, roles, founded_year, location in all_company_data: # Determine which Company model to use based on roles - if 'OPERATOR' in roles or 'PROPERTY_OWNER' in roles: + if "OPERATOR" in roles or "PROPERTY_OWNER" in roles: # Use ParkCompany for park operators and property owners company, created = ParkCompany.objects.get_or_create( name=name, defaults={ - 'slug': slugify(name), - 'roles': roles, - 'founded_year': founded_year, - 'description': f"{name} is a leading {'park operator' if 'OPERATOR' in roles else 'property owner'} in the theme park industry.", - 'website': f"https://{slugify(name).replace('-', '')}.com", - 'parks_count': random.randint(1, 20) if 'OPERATOR' in roles else 0, - 'rides_count': random.randint(10, 500) if 'MANUFACTURER' in roles else 0, - } + "slug": slugify(name), + "roles": roles, + "founded_year": founded_year, + "description": f"{name} is a leading {'park operator' if 'OPERATOR' in roles else 'property owner'} in the theme park industry.", + "website": f"https://{slugify(name).replace('-', '')}.com", + "parks_count": random.randint(1, 20) if "OPERATOR" in roles else 0, + "rides_count": random.randint(10, 500) if "MANUFACTURER" in roles else 0, + }, ) else: # Use RideCompany for manufacturers and designers company, created = RideCompany.objects.get_or_create( name=name, defaults={ - 'slug': slugify(name), - 'roles': roles, - 'founded_date': date(founded_year, 1, 1) if founded_year else None, - 'description': f"{name} is a leading {'ride manufacturer' if 'MANUFACTURER' in roles else 'ride designer'} in the theme park industry.", - 'website': f"https://{slugify(name).replace('-', '')}.com", - 'rides_count': random.randint(10, 500) if 'MANUFACTURER' in roles else 0, - 'coasters_count': random.randint(5, 100) if 'MANUFACTURER' in roles else 0, - } + "slug": slugify(name), + "roles": roles, + "founded_date": date(founded_year, 1, 1) if founded_year else None, + "description": f"{name} is a leading {'ride manufacturer' if 'MANUFACTURER' in roles else 'ride designer'} in the theme park industry.", + "website": f"https://{slugify(name).replace('-', '')}.com", + "rides_count": random.randint(10, 500) if "MANUFACTURER" in roles else 0, + "coasters_count": random.randint(5, 100) if "MANUFACTURER" in roles else 0, + }, ) # Create headquarters if company was created and is a ParkCompany if created and isinstance(company, ParkCompany): - city, state_country = location.rsplit(', ', 1) - if ', ' in city: - city, state = city.split(', ') + city, state_country = location.rsplit(", ", 1) + if ", " in city: + city, state = city.split(", ") country = state_country else: - state = '' + state = "" country = state_country CompanyHeadquarters.objects.get_or_create( company=company, defaults={ - 'city': city, - 'state_province': state, - 'country': country, - 'street_address': f"{random.randint(100, 9999)} {random.choice(['Main', 'Park', 'Industry', 'Corporate'])} {random.choice(['St', 'Ave', 'Blvd', 'Dr'])}", - 'postal_code': f"{random.randint(10000, 99999)}" if country == 'USA' else '', - } + "city": city, + "state_province": state, + "country": country, + "street_address": f"{random.randint(100, 9999)} {random.choice(['Main', 'Park', 'Industry', 'Corporate'])} {random.choice(['St', 'Ave', 'Blvd', 'Dr'])}", + "postal_code": f"{random.randint(10000, 99999)}" if country == "USA" else "", + }, ) companies.append(company) # Create additional random companies to reach the target count - company_types = ['Theme Parks', 'Amusements', 'Entertainment', 'Rides', 'Design', 'Engineering'] + company_types = ["Theme Parks", "Amusements", "Entertainment", "Rides", "Design", "Engineering"] for _i in range(len(all_company_data), count): company_type = random.choice(company_types) name = f"{random.choice(['Global', 'International', 'Premier', 'Elite', 'Advanced', 'Creative'])} {company_type} {'Group' if random.random() < 0.5 else 'Corporation'}" roles = [] - if 'Theme Parks' in name or 'Amusements' in name: - roles = ['OPERATOR'] + if "Theme Parks" in name or "Amusements" in name: + roles = ["OPERATOR"] if random.random() < 0.5: - roles.append('PROPERTY_OWNER') - elif 'Rides' in name or 'Engineering' in name: - roles = ['MANUFACTURER'] - elif 'Design' in name: - roles = ['DESIGNER'] + roles.append("PROPERTY_OWNER") + elif "Rides" in name or "Engineering" in name: + roles = ["MANUFACTURER"] + elif "Design" in name: + roles = ["DESIGNER"] else: - roles = [random.choice(['OPERATOR', 'MANUFACTURER', 'DESIGNER'])] + roles = [random.choice(["OPERATOR", "MANUFACTURER", "DESIGNER"])] # Use appropriate company model based on roles - if 'OPERATOR' in roles or 'PROPERTY_OWNER' in roles: + if "OPERATOR" in roles or "PROPERTY_OWNER" in roles: company = ParkCompany.objects.create( name=name, slug=slugify(name), @@ -432,8 +481,8 @@ class Command(BaseCommand): founded_year=random.randint(1950, 2020), description=f"{name} specializes in {'theme park operations' if 'OPERATOR' in roles else 'property ownership'}.", website=f"https://{slugify(name).replace('-', '')}.com", - parks_count=random.randint(1, 10) if 'OPERATOR' in roles else 0, - rides_count=random.randint(5, 100) if 'MANUFACTURER' in roles else 0, + parks_count=random.randint(1, 10) if "OPERATOR" in roles else 0, + rides_count=random.randint(5, 100) if "MANUFACTURER" in roles else 0, ) else: company = RideCompany.objects.create( @@ -443,62 +492,73 @@ class Command(BaseCommand): founded_date=date(random.randint(1950, 2020), 1, 1), description=f"{name} specializes in {'ride manufacturing' if 'MANUFACTURER' in roles else 'ride design'}.", website=f"https://{slugify(name).replace('-', '')}.com", - rides_count=random.randint(5, 100) if 'MANUFACTURER' in roles else 0, - coasters_count=random.randint(2, 50) if 'MANUFACTURER' in roles else 0, + rides_count=random.randint(5, 100) if "MANUFACTURER" in roles else 0, + coasters_count=random.randint(2, 50) if "MANUFACTURER" in roles else 0, ) # Create headquarters - cities = ['Los Angeles', 'New York', 'Chicago', 'Houston', 'Phoenix', 'Philadelphia', 'San Antonio', 'San Diego', 'Dallas', 'San Jose'] - states = ['CA', 'NY', 'IL', 'TX', 'AZ', 'PA', 'TX', 'CA', 'TX', 'CA'] + cities = [ + "Los Angeles", + "New York", + "Chicago", + "Houston", + "Phoenix", + "Philadelphia", + "San Antonio", + "San Diego", + "Dallas", + "San Jose", + ] + states = ["CA", "NY", "IL", "TX", "AZ", "PA", "TX", "CA", "TX", "CA"] city_state = random.choice(list(zip(cities, states, strict=False))) CompanyHeadquarters.objects.create( company=company, city=city_state[0], state_province=city_state[1], - country='USA', + country="USA", street_address=f"{random.randint(100, 9999)} {random.choice(['Business', 'Corporate', 'Industry', 'Commerce'])} {random.choice(['Pkwy', 'Blvd', 'Dr', 'Way'])}", postal_code=f"{random.randint(10000, 99999)}", ) companies.append(company) - self.stdout.write(f' ✅ Created {len(companies)} companies') + self.stdout.write(f" ✅ Created {len(companies)} companies") return companies def create_ride_models(self, count: int, companies: list) -> list[RideModel]: """Create ride models from manufacturers""" - self.stdout.write(f'🎢 Creating {count} ride models...') + self.stdout.write(f"🎢 Creating {count} ride models...") - manufacturers = [c for c in companies if 'MANUFACTURER' in c.roles] + manufacturers = [c for c in companies if "MANUFACTURER" in c.roles] if not manufacturers: - self.stdout.write(' ⚠️ No manufacturers found, skipping ride models') + self.stdout.write(" ⚠️ No manufacturers found, skipping ride models") return [] ride_models = [] # Famous ride models famous_models = [ - ('Dive Coaster', 'RC', 'Bolliger & Mabillard', 'Vertical drop roller coaster with holding brake'), - ('Hyper Coaster', 'RC', 'Bolliger & Mabillard', 'High-speed out-and-back roller coaster'), - ('Wing Coaster', 'RC', 'Bolliger & Mabillard', 'Seats positioned on sides of track'), - ('Accelerator Coaster', 'RC', 'Intamin', 'Hydraulic launch roller coaster'), - ('Mega Coaster', 'RC', 'Intamin', 'High-speed steel roller coaster'), - ('Boomerang', 'RC', 'Vekoma', 'Shuttle roller coaster with inversions'), - ('SLC', 'RC', 'Vekoma', 'Suspended Looping Coaster'), - ('I-Box Track', 'RC', 'Rocky Mountain Construction', 'Steel track on wooden structure'), - ('Raptor Track', 'RC', 'Rocky Mountain Construction', 'Single rail roller coaster'), - ('BigDipper', 'RC', 'Mack Rides', 'Family roller coaster'), - ('Launched Coaster', 'RC', 'Mack Rides', 'LSM launched roller coaster'), - ('Infinity Coaster', 'RC', 'Gerstlauer', 'Compact looping roller coaster'), - ('Sky Rocket II', 'RC', 'Premier Rides', 'Compact launched roller coaster'), - ('Air Coaster', 'RC', 'S&S Worldwide', 'Compressed air launched coaster'), - ('Dark Ride System', 'DR', 'Mack Rides', 'Trackless dark ride vehicles'), - ('Omnimover', 'DR', 'Walt Disney Company', 'Continuous loading dark ride system'), - ('Log Flume', 'WR', 'Mack Rides', 'Water ride with drops'), - ('Rapids Ride', 'WR', 'Intamin', 'Whitewater rafting experience'), - ('Drop Tower', 'FR', 'Intamin', 'Vertical drop ride'), - ('Gyro Drop', 'FR', 'Intamin', 'Tilting drop tower'), + ("Dive Coaster", "RC", "Bolliger & Mabillard", "Vertical drop roller coaster with holding brake"), + ("Hyper Coaster", "RC", "Bolliger & Mabillard", "High-speed out-and-back roller coaster"), + ("Wing Coaster", "RC", "Bolliger & Mabillard", "Seats positioned on sides of track"), + ("Accelerator Coaster", "RC", "Intamin", "Hydraulic launch roller coaster"), + ("Mega Coaster", "RC", "Intamin", "High-speed steel roller coaster"), + ("Boomerang", "RC", "Vekoma", "Shuttle roller coaster with inversions"), + ("SLC", "RC", "Vekoma", "Suspended Looping Coaster"), + ("I-Box Track", "RC", "Rocky Mountain Construction", "Steel track on wooden structure"), + ("Raptor Track", "RC", "Rocky Mountain Construction", "Single rail roller coaster"), + ("BigDipper", "RC", "Mack Rides", "Family roller coaster"), + ("Launched Coaster", "RC", "Mack Rides", "LSM launched roller coaster"), + ("Infinity Coaster", "RC", "Gerstlauer", "Compact looping roller coaster"), + ("Sky Rocket II", "RC", "Premier Rides", "Compact launched roller coaster"), + ("Air Coaster", "RC", "S&S Worldwide", "Compressed air launched coaster"), + ("Dark Ride System", "DR", "Mack Rides", "Trackless dark ride vehicles"), + ("Omnimover", "DR", "Walt Disney Company", "Continuous loading dark ride system"), + ("Log Flume", "WR", "Mack Rides", "Water ride with drops"), + ("Rapids Ride", "WR", "Intamin", "Whitewater rafting experience"), + ("Drop Tower", "FR", "Intamin", "Vertical drop ride"), + ("Gyro Drop", "FR", "Intamin", "Tilting drop tower"), ] for model_name, category, manufacturer_name, description in famous_models: @@ -510,33 +570,41 @@ class Command(BaseCommand): name=model_name, manufacturer=manufacturer, defaults={ - 'description': description, - 'category': category, - 'first_installation_year': random.randint(1980, 2020), - 'is_discontinued': random.choice([True, False]), - 'target_market': random.choice(['FAMILY', 'THRILL', 'EXTREME']), - 'typical_height_range_min_ft': random.randint(50, 200) if category == 'RC' else random.randint(20, 100), - 'typical_height_range_max_ft': random.randint(200, 400) if category == 'RC' else random.randint(100, 200), - 'typical_speed_range_min_mph': random.randint(20, 60) if category == 'RC' else random.randint(5, 30), - 'typical_speed_range_max_mph': random.randint(60, 120) if category == 'RC' else random.randint(30, 60), - 'typical_capacity_range_min': random.randint(500, 1000), - 'typical_capacity_range_max': random.randint(1000, 2000), - 'track_type': random.choice(['Steel', 'Wood', 'Hybrid']) if category == 'RC' else 'Steel', - 'support_structure': random.choice(['Steel', 'Wood', 'Concrete']), - 'train_configuration': f"{random.randint(2, 4)} trains, {random.randint(6, 8)} cars per train", - 'restraint_system': random.choice(['Over-shoulder', 'Lap bar', 'Vest', 'None']), - 'notable_features': 'High-speed elements, smooth ride experience', - 'total_installations': random.randint(1, 50), - } + "description": description, + "category": category, + "first_installation_year": random.randint(1980, 2020), + "is_discontinued": random.choice([True, False]), + "target_market": random.choice(["FAMILY", "THRILL", "EXTREME"]), + "typical_height_range_min_ft": ( + random.randint(50, 200) if category == "RC" else random.randint(20, 100) + ), + "typical_height_range_max_ft": ( + random.randint(200, 400) if category == "RC" else random.randint(100, 200) + ), + "typical_speed_range_min_mph": ( + random.randint(20, 60) if category == "RC" else random.randint(5, 30) + ), + "typical_speed_range_max_mph": ( + random.randint(60, 120) if category == "RC" else random.randint(30, 60) + ), + "typical_capacity_range_min": random.randint(500, 1000), + "typical_capacity_range_max": random.randint(1000, 2000), + "track_type": random.choice(["Steel", "Wood", "Hybrid"]) if category == "RC" else "Steel", + "support_structure": random.choice(["Steel", "Wood", "Concrete"]), + "train_configuration": f"{random.randint(2, 4)} trains, {random.randint(6, 8)} cars per train", + "restraint_system": random.choice(["Over-shoulder", "Lap bar", "Vest", "None"]), + "notable_features": "High-speed elements, smooth ride experience", + "total_installations": random.randint(1, 50), + }, ) # Create technical specs if model exists - if category == 'RC' and RideModelTechnicalSpec: + if category == "RC" and RideModelTechnicalSpec: specs = [ - ('DIMENSIONS', 'Track Length', f"{random.randint(2000, 8000)}", 'ft'), - ('PERFORMANCE', 'Max G-Force', f"{random.uniform(3.0, 5.0):.1f}", 'G'), - ('CAPACITY', 'Riders per Train', f"{random.randint(20, 32)}", 'people'), - ('SAFETY', 'Block Zones', f"{random.randint(4, 8)}", 'zones'), + ("DIMENSIONS", "Track Length", f"{random.randint(2000, 8000)}", "ft"), + ("PERFORMANCE", "Max G-Force", f"{random.uniform(3.0, 5.0):.1f}", "G"), + ("CAPACITY", "Riders per Train", f"{random.randint(20, 32)}", "people"), + ("SAFETY", "Block Zones", f"{random.randint(4, 8)}", "zones"), ] for spec_category, spec_name, spec_value, spec_unit in specs: @@ -550,7 +618,7 @@ class Command(BaseCommand): # Create variants for some models if model exists if random.random() < 0.3 and RideModelVariant: - variant_names = ['Compact', 'Extended', 'Family', 'Extreme', 'Custom'] + variant_names = ["Compact", "Extended", "Family", "Extreme", "Custom"] variant_name = random.choice(variant_names) RideModelVariant.objects.create( @@ -563,12 +631,12 @@ class Command(BaseCommand): ride_models.append(ride_model) # Create additional random models - model_types = ['Coaster', 'Ride', 'System', 'Experience', 'Adventure'] - prefixes = ['Mega', 'Super', 'Ultra', 'Hyper', 'Giga', 'Extreme', 'Family', 'Junior'] + model_types = ["Coaster", "Ride", "System", "Experience", "Adventure"] + prefixes = ["Mega", "Super", "Ultra", "Hyper", "Giga", "Extreme", "Family", "Junior"] for _i in range(len(famous_models), count): manufacturer = random.choice(manufacturers) - category = random.choice(['RC', 'DR', 'FR', 'WR', 'TR']) + category = random.choice(["RC", "DR", "FR", "WR", "TR"]) model_name = f"{random.choice(prefixes)} {random.choice(model_types)}" @@ -579,85 +647,196 @@ class Command(BaseCommand): category=category, first_installation_year=random.randint(1990, 2023), is_discontinued=random.choice([True, False]), - target_market=random.choice(['FAMILY', 'THRILL', 'EXTREME', 'ALL_AGES']), + target_market=random.choice(["FAMILY", "THRILL", "EXTREME", "ALL_AGES"]), typical_height_range_min_ft=random.randint(20, 150), typical_height_range_max_ft=random.randint(150, 350), typical_speed_range_min_mph=random.randint(10, 50), typical_speed_range_max_mph=random.randint(50, 100), typical_capacity_range_min=random.randint(400, 800), typical_capacity_range_max=random.randint(800, 1800), - track_type=random.choice(['Steel', 'Wood', 'Hybrid', 'Trackless']), - support_structure=random.choice(['Steel', 'Wood', 'Concrete', 'Hybrid']), + track_type=random.choice(["Steel", "Wood", "Hybrid", "Trackless"]), + support_structure=random.choice(["Steel", "Wood", "Concrete", "Hybrid"]), train_configuration=f"{random.randint(1, 3)} trains, {random.randint(4, 12)} cars per train", - restraint_system=random.choice(['Over-shoulder', 'Lap bar', 'Vest', 'Seatbelt', 'None']), - notable_features=random.choice([ - 'Smooth ride experience', - 'High-speed elements', - 'Family-friendly design', - 'Innovative technology', - 'Compact footprint' - ]), + restraint_system=random.choice(["Over-shoulder", "Lap bar", "Vest", "Seatbelt", "None"]), + notable_features=random.choice( + [ + "Smooth ride experience", + "High-speed elements", + "Family-friendly design", + "Innovative technology", + "Compact footprint", + ] + ), total_installations=random.randint(0, 25), ) ride_models.append(ride_model) - self.stdout.write(f' ✅ Created {len(ride_models)} ride models') + self.stdout.write(f" ✅ Created {len(ride_models)} ride models") return ride_models def create_parks(self, count: int, companies: list) -> list[Park]: """Create parks with locations and areas""" - self.stdout.write(f'🏰 Creating {count} parks...') + self.stdout.write(f"🏰 Creating {count} parks...") if count == 0: - self.stdout.write(' ℹ️ Skipping park creation (count = 0)') + self.stdout.write(" ℹ️ Skipping park creation (count = 0)") return [] - operators = [c for c in companies if 'OPERATOR' in c.roles] - property_owners = [c for c in companies if 'PROPERTY_OWNER' in c.roles] + operators = [c for c in companies if "OPERATOR" in c.roles] + property_owners = [c for c in companies if "PROPERTY_OWNER" in c.roles] if not operators: - raise CommandError('No operators found. Create companies first.') + raise CommandError("No operators found. Create companies first.") parks = [] # Famous theme parks with timezone information famous_parks = [ - ('Magic Kingdom', 'Walt Disney World\'s flagship theme park', 'THEME_PARK', 'OPERATING', - date(1971, 10, 1), 107, 'Orlando', 'FL', 'USA', 28.4177, -81.5812, 'America/New_York'), - ('Disneyland', 'The original Disney theme park', 'THEME_PARK', 'OPERATING', - date(1955, 7, 17), 85, 'Anaheim', 'CA', 'USA', 33.8121, -117.9190, 'America/Los_Angeles'), - ('Universal Studios Hollywood', 'Movie studio and theme park', 'THEME_PARK', 'OPERATING', - date(1964, 7, 15), 415, 'Universal City', 'CA', 'USA', 34.1381, -118.3534, 'America/Los_Angeles'), - ('Cedar Point', 'Roller coaster capital of the world', 'AMUSEMENT_PARK', 'OPERATING', - date(1870, 5, 30), 364, 'Sandusky', 'OH', 'USA', 41.4814, -82.6838, 'America/New_York'), - ('Six Flags Magic Mountain', 'Thrill capital of the world', 'THEME_PARK', 'OPERATING', - date(1971, 5, 29), 262, 'Valencia', 'CA', 'USA', 34.4244, -118.5969, 'America/Los_Angeles'), - ('Knott\'s Berry Farm', 'America\'s first theme park', 'THEME_PARK', 'OPERATING', - date(1920, 6, 1), 57, 'Buena Park', 'CA', 'USA', 33.8442, -117.9981, 'America/Los_Angeles'), - ('Busch Gardens Tampa', 'African-themed adventure park', 'THEME_PARK', 'OPERATING', - date(1959, 3, 31), 335, 'Tampa', 'FL', 'USA', 28.0373, -82.4194, 'America/New_York'), - ('SeaWorld Orlando', 'Marine life theme park', 'THEME_PARK', 'OPERATING', - date(1973, 12, 15), 200, 'Orlando', 'FL', 'USA', 28.4110, -81.4610, 'America/New_York'), + ( + "Magic Kingdom", + "Walt Disney World's flagship theme park", + "THEME_PARK", + "OPERATING", + date(1971, 10, 1), + 107, + "Orlando", + "FL", + "USA", + 28.4177, + -81.5812, + "America/New_York", + ), + ( + "Disneyland", + "The original Disney theme park", + "THEME_PARK", + "OPERATING", + date(1955, 7, 17), + 85, + "Anaheim", + "CA", + "USA", + 33.8121, + -117.9190, + "America/Los_Angeles", + ), + ( + "Universal Studios Hollywood", + "Movie studio and theme park", + "THEME_PARK", + "OPERATING", + date(1964, 7, 15), + 415, + "Universal City", + "CA", + "USA", + 34.1381, + -118.3534, + "America/Los_Angeles", + ), + ( + "Cedar Point", + "Roller coaster capital of the world", + "AMUSEMENT_PARK", + "OPERATING", + date(1870, 5, 30), + 364, + "Sandusky", + "OH", + "USA", + 41.4814, + -82.6838, + "America/New_York", + ), + ( + "Six Flags Magic Mountain", + "Thrill capital of the world", + "THEME_PARK", + "OPERATING", + date(1971, 5, 29), + 262, + "Valencia", + "CA", + "USA", + 34.4244, + -118.5969, + "America/Los_Angeles", + ), + ( + "Knott's Berry Farm", + "America's first theme park", + "THEME_PARK", + "OPERATING", + date(1920, 6, 1), + 57, + "Buena Park", + "CA", + "USA", + 33.8442, + -117.9981, + "America/Los_Angeles", + ), + ( + "Busch Gardens Tampa", + "African-themed adventure park", + "THEME_PARK", + "OPERATING", + date(1959, 3, 31), + 335, + "Tampa", + "FL", + "USA", + 28.0373, + -82.4194, + "America/New_York", + ), + ( + "SeaWorld Orlando", + "Marine life theme park", + "THEME_PARK", + "OPERATING", + date(1973, 12, 15), + 200, + "Orlando", + "FL", + "USA", + 28.4110, + -81.4610, + "America/New_York", + ), ] - for park_name, description, park_type, status, opening_date, size_acres, city, state, country, lat, lng, timezone_str in famous_parks: + for ( + park_name, + description, + park_type, + status, + opening_date, + size_acres, + city, + state, + country, + lat, + lng, + timezone_str, + ) in famous_parks: # Find appropriate operator operator = None - if 'Disney' in park_name: - operator = next((c for c in operators if 'Disney' in c.name), None) - elif 'Universal' in park_name: - operator = next((c for c in operators if 'Universal' in c.name), None) - elif 'Cedar Point' in park_name: - operator = next((c for c in operators if 'Cedar Fair' in c.name), None) - elif 'Six Flags' in park_name: - operator = next((c for c in operators if 'Six Flags' in c.name), None) - elif 'Knott' in park_name: - operator = next((c for c in operators if 'Knott' in c.name), None) - elif 'Busch' in park_name: - operator = next((c for c in operators if 'Busch' in c.name), None) - elif 'SeaWorld' in park_name: - operator = next((c for c in operators if 'SeaWorld' in c.name), None) + if "Disney" in park_name: + operator = next((c for c in operators if "Disney" in c.name), None) + elif "Universal" in park_name: + operator = next((c for c in operators if "Universal" in c.name), None) + elif "Cedar Point" in park_name: + operator = next((c for c in operators if "Cedar Fair" in c.name), None) + elif "Six Flags" in park_name: + operator = next((c for c in operators if "Six Flags" in c.name), None) + elif "Knott" in park_name: + operator = next((c for c in operators if "Knott" in c.name), None) + elif "Busch" in park_name: + operator = next((c for c in operators if "Busch" in c.name), None) + elif "SeaWorld" in park_name: + operator = next((c for c in operators if "SeaWorld" in c.name), None) if not operator: operator = random.choice(operators) @@ -671,26 +850,26 @@ class Command(BaseCommand): park, created = Park.objects.get_or_create( name=park_name, defaults={ - 'description': description, - 'park_type': park_type, - 'status': status, - 'opening_date': opening_date, - 'size_acres': Decimal(str(size_acres)), - 'operator': operator, - 'property_owner': property_owner, - 'average_rating': Decimal(str(random.uniform(7.5, 9.5))), - 'ride_count': random.randint(20, 60), - 'coaster_count': random.randint(5, 20), - 'timezone': timezone_str, - } + "description": description, + "park_type": park_type, + "status": status, + "opening_date": opening_date, + "size_acres": Decimal(str(size_acres)), + "operator": operator, + "property_owner": property_owner, + "average_rating": Decimal(str(random.uniform(7.5, 9.5))), + "ride_count": random.randint(20, 60), + "coaster_count": random.randint(5, 20), + "timezone": timezone_str, + }, ) if not created: - self.stdout.write(f' ℹ️ Using existing park: {park_name}') + self.stdout.write(f" ℹ️ Using existing park: {park_name}") # Create park location only if it doesn't exist location_exists = False try: - location_exists = hasattr(park, 'location') and park.location is not None + location_exists = hasattr(park, "location") and park.location is not None except Exception: location_exists = False @@ -698,40 +877,40 @@ class Command(BaseCommand): ParkLocation.objects.get_or_create( park=park, defaults={ - 'point': Point(lng, lat), - 'street_address': f"{random.randint(100, 9999)} {park_name} Dr", - 'city': city, - 'state': state, - 'country': country, - 'postal_code': f"{random.randint(10000, 99999)}" if country == 'USA' else '', - } + "point": Point(lng, lat), + "street_address": f"{random.randint(100, 9999)} {park_name} Dr", + "city": city, + "state": state, + "country": country, + "postal_code": f"{random.randint(10000, 99999)}" if country == "USA" else "", + }, ) # Create park areas only if park was created if created: - area_names = ['Main Street', 'Fantasyland', 'Tomorrowland', 'Adventureland', 'Frontierland'] + area_names = ["Main Street", "Fantasyland", "Tomorrowland", "Adventureland", "Frontierland"] for area_name in random.sample(area_names, random.randint(2, 4)): ParkArea.objects.get_or_create( park=park, name=area_name, defaults={ - 'description': f"Themed area within {park_name}", - } + "description": f"Themed area within {park_name}", + }, ) parks.append(park) # Create additional random parks - park_types = ['THEME_PARK', 'AMUSEMENT_PARK', 'WATER_PARK', 'FAMILY_ENTERTAINMENT_CENTER'] + park_types = ["THEME_PARK", "AMUSEMENT_PARK", "WATER_PARK", "FAMILY_ENTERTAINMENT_CENTER"] cities_data = [ - ('Los Angeles', 'CA', 'USA', 34.0522, -118.2437), - ('New York', 'NY', 'USA', 40.7128, -74.0060), - ('Chicago', 'IL', 'USA', 41.8781, -87.6298), - ('Houston', 'TX', 'USA', 29.7604, -95.3698), - ('Phoenix', 'AZ', 'USA', 33.4484, -112.0740), - ('Philadelphia', 'PA', 'USA', 39.9526, -75.1652), - ('San Antonio', 'TX', 'USA', 29.4241, -98.4936), - ('San Diego', 'CA', 'USA', 32.7157, -117.1611), + ("Los Angeles", "CA", "USA", 34.0522, -118.2437), + ("New York", "NY", "USA", 40.7128, -74.0060), + ("Chicago", "IL", "USA", 41.8781, -87.6298), + ("Houston", "TX", "USA", 29.7604, -95.3698), + ("Phoenix", "AZ", "USA", 33.4484, -112.0740), + ("Philadelphia", "PA", "USA", 39.9526, -75.1652), + ("San Antonio", "TX", "USA", 29.4241, -98.4936), + ("San Diego", "CA", "USA", 32.7157, -117.1611), ] for i in range(len(famous_parks), count): @@ -746,20 +925,20 @@ class Command(BaseCommand): # Determine timezone based on state timezone_map = { - 'CA': 'America/Los_Angeles', - 'NY': 'America/New_York', - 'IL': 'America/Chicago', - 'TX': 'America/Chicago', - 'AZ': 'America/Phoenix', - 'PA': 'America/New_York', + "CA": "America/Los_Angeles", + "NY": "America/New_York", + "IL": "America/Chicago", + "TX": "America/Chicago", + "AZ": "America/Phoenix", + "PA": "America/New_York", } - park_timezone = timezone_map.get(state, 'America/New_York') + park_timezone = timezone_map.get(state, "America/New_York") park = Park.objects.create( name=park_name, description=f"Exciting {park_type.lower().replace('_', ' ')} featuring thrilling rides and family entertainment", park_type=park_type, - status=random.choice(['OPERATING', 'OPERATING', 'OPERATING', 'CLOSED_TEMP']), + status=random.choice(["OPERATING", "OPERATING", "OPERATING", "CLOSED_TEMP"]), opening_date=date(random.randint(1950, 2020), random.randint(1, 12), random.randint(1, 28)), size_acres=Decimal(str(random.randint(50, 500))), operator=operator, @@ -785,7 +964,7 @@ class Command(BaseCommand): ) # Create park areas - area_names = ['Main Plaza', 'Adventure Zone', 'Family Area', 'Thrill Section', 'Water World', 'Kids Corner'] + area_names = ["Main Plaza", "Adventure Zone", "Family Area", "Thrill Section", "Water World", "Kids Corner"] for area_name in random.sample(area_names, random.randint(2, 4)): ParkArea.objects.create( park=park, @@ -795,34 +974,34 @@ class Command(BaseCommand): parks.append(park) - self.stdout.write(f' ✅ Created {len(parks)} parks') + self.stdout.write(f" ✅ Created {len(parks)} parks") return parks def create_rides(self, count: int, parks: list[Park], companies: list, ride_models: list[RideModel]) -> list[Ride]: """Create rides with comprehensive details""" - self.stdout.write(f'🎠 Creating {count} rides...') + self.stdout.write(f"🎠 Creating {count} rides...") if not parks: - self.stdout.write(' ⚠️ No parks found, skipping rides') + self.stdout.write(" ⚠️ No parks found, skipping rides") return [] - manufacturers = [c for c in companies if 'MANUFACTURER' in c.roles] - designers = [c for c in companies if 'DESIGNER' in c.roles] + manufacturers = [c for c in companies if "MANUFACTURER" in c.roles] + designers = [c for c in companies if "DESIGNER" in c.roles] rides = [] # Famous roller coasters famous_coasters = [ - ('Steel Vengeance', 'RC', 'Hybrid steel-wood roller coaster', 'Rocky Mountain Construction'), - ('Millennium Force', 'RC', 'Giga coaster with 300-foot drop', 'Intamin'), - ('The Beast', 'RC', 'Legendary wooden roller coaster', None), - ('Fury 325', 'RC', 'Giga coaster with 325-foot height', 'Bolliger & Mabillard'), - ('Lightning Rod', 'RC', 'Launched wooden roller coaster', 'Rocky Mountain Construction'), - ('Maverick', 'RC', 'Multi-launch roller coaster', 'Intamin'), - ('El Toro', 'RC', 'Wooden roller coaster with steep drops', 'Intamin'), - ('Intimidator 305', 'RC', 'Giga coaster with intense elements', 'Intamin'), - ('Twisted Timbers', 'RC', 'RMC conversion of wooden coaster', 'Rocky Mountain Construction'), - ('Goliath', 'RC', 'Hyper coaster with massive drops', 'Bolliger & Mabillard'), + ("Steel Vengeance", "RC", "Hybrid steel-wood roller coaster", "Rocky Mountain Construction"), + ("Millennium Force", "RC", "Giga coaster with 300-foot drop", "Intamin"), + ("The Beast", "RC", "Legendary wooden roller coaster", None), + ("Fury 325", "RC", "Giga coaster with 325-foot height", "Bolliger & Mabillard"), + ("Lightning Rod", "RC", "Launched wooden roller coaster", "Rocky Mountain Construction"), + ("Maverick", "RC", "Multi-launch roller coaster", "Intamin"), + ("El Toro", "RC", "Wooden roller coaster with steep drops", "Intamin"), + ("Intimidator 305", "RC", "Giga coaster with intense elements", "Intamin"), + ("Twisted Timbers", "RC", "RMC conversion of wooden coaster", "Rocky Mountain Construction"), + ("Goliath", "RC", "Hyper coaster with massive drops", "Bolliger & Mabillard"), ] # Create famous coasters @@ -850,7 +1029,7 @@ class Command(BaseCommand): manufacturer=manufacturer, designer=designer, ride_model=ride_model, - status=random.choice(['OPERATING'] * 8 + ['CLOSED_TEMP', 'SBNO']), + status=random.choice(["OPERATING"] * 8 + ["CLOSED_TEMP", "SBNO"]), opening_date=date(random.randint(1990, 2023), random.randint(1, 12), random.randint(1, 28)), min_height_in=random.choice([48, 52, 54, 60]), capacity_per_hour=random.randint(800, 1800), @@ -859,7 +1038,7 @@ class Command(BaseCommand): ) # Create roller coaster stats - if category == 'RC': + if category == "RC": RollerCoasterStats.objects.create( ride=ride, height_ft=Decimal(str(random.randint(100, 350))), @@ -867,12 +1046,12 @@ class Command(BaseCommand): speed_mph=Decimal(str(random.randint(45, 120))), inversions=random.randint(0, 8), ride_time_seconds=random.randint(90, 240), - track_type=random.choice(['Steel', 'Wood', 'Hybrid']), - track_material=random.choice(['STEEL', 'WOOD', 'HYBRID']), - roller_coaster_type=random.choice(['SITDOWN', 'INVERTED', 'WING', 'DIVE', 'FLYING']), + track_type=random.choice(["Steel", "Wood", "Hybrid"]), + track_material=random.choice(["STEEL", "WOOD", "HYBRID"]), + roller_coaster_type=random.choice(["SITDOWN", "INVERTED", "WING", "DIVE", "FLYING"]), max_drop_height_ft=Decimal(str(random.randint(80, 300))), - propulsion_system=random.choice(['CHAIN', 'LSM', 'HYDRAULIC']), - train_style=random.choice(['Traditional', 'Floorless', 'Wing', 'Flying']), + propulsion_system=random.choice(["CHAIN", "LSM", "HYDRAULIC"]), + train_style=random.choice(["Traditional", "Floorless", "Wing", "Flying"]), trains_count=random.randint(2, 4), cars_per_train=random.randint(6, 8), seats_per_car=random.randint(2, 4), @@ -882,13 +1061,31 @@ class Command(BaseCommand): # Create additional random rides ride_names = [ - 'Thunder Mountain', 'Space Coaster', 'Wild Eagle', 'Dragon Fire', 'Phoenix Rising', - 'Storm Runner', 'Lightning Strike', 'Tornado Alley', 'Hurricane Force', 'Cyclone', - 'Viper', 'Cobra', 'Rattlesnake', 'Sidewinder', 'Diamondback', 'Copperhead', - 'Banshee', 'Valkyrie', 'Griffon', 'Falcon', 'Eagle\'s Flight', 'Soaring Heights' + "Thunder Mountain", + "Space Coaster", + "Wild Eagle", + "Dragon Fire", + "Phoenix Rising", + "Storm Runner", + "Lightning Strike", + "Tornado Alley", + "Hurricane Force", + "Cyclone", + "Viper", + "Cobra", + "Rattlesnake", + "Sidewinder", + "Diamondback", + "Copperhead", + "Banshee", + "Valkyrie", + "Griffon", + "Falcon", + "Eagle's Flight", + "Soaring Heights", ] - categories = ['RC', 'DR', 'FR', 'WR', 'TR', 'OT'] + categories = ["RC", "DR", "FR", "WR", "TR", "OT"] for _i in range(len(famous_coasters), count): park = random.choice(parks) @@ -911,16 +1108,16 @@ class Command(BaseCommand): manufacturer=manufacturer, designer=designer, ride_model=ride_model, - status=random.choice(['OPERATING'] * 9 + ['CLOSED_TEMP']), + status=random.choice(["OPERATING"] * 9 + ["CLOSED_TEMP"]), opening_date=date(random.randint(1980, 2023), random.randint(1, 12), random.randint(1, 28)), - min_height_in=random.choice([36, 42, 48, 52, 54]) if category != 'DR' else None, + min_height_in=random.choice([36, 42, 48, 52, 54]) if category != "DR" else None, capacity_per_hour=random.randint(400, 2000), ride_duration_seconds=random.randint(60, 300), average_rating=Decimal(str(random.uniform(6.0, 9.0))), ) # Create roller coaster stats for RC category - if category == 'RC': + if category == "RC": RollerCoasterStats.objects.create( ride=ride, height_ft=Decimal(str(random.randint(50, 300))), @@ -928,12 +1125,12 @@ class Command(BaseCommand): speed_mph=Decimal(str(random.randint(25, 100))), inversions=random.randint(0, 6), ride_time_seconds=random.randint(90, 180), - track_type=random.choice(['Steel', 'Wood']), - track_material=random.choice(['STEEL', 'WOOD', 'HYBRID']), - roller_coaster_type=random.choice(['SITDOWN', 'INVERTED', 'FAMILY', 'WILD_MOUSE']), + track_type=random.choice(["Steel", "Wood"]), + track_material=random.choice(["STEEL", "WOOD", "HYBRID"]), + roller_coaster_type=random.choice(["SITDOWN", "INVERTED", "FAMILY", "WILD_MOUSE"]), max_drop_height_ft=Decimal(str(random.randint(40, 250))), - propulsion_system=random.choice(['CHAIN', 'LSM', 'GRAVITY']), - train_style=random.choice(['Traditional', 'Family', 'Compact']), + propulsion_system=random.choice(["CHAIN", "LSM", "GRAVITY"]), + train_style=random.choice(["Traditional", "Family", "Compact"]), trains_count=random.randint(1, 3), cars_per_train=random.randint(4, 8), seats_per_car=random.randint(2, 4), @@ -941,15 +1138,15 @@ class Command(BaseCommand): rides.append(ride) - self.stdout.write(f' ✅ Created {len(rides)} rides') + self.stdout.write(f" ✅ Created {len(rides)} rides") return rides def create_reviews(self, count: int, users: list[User], parks: list[Park], rides: list[Ride]) -> None: """Create park and ride reviews""" - self.stdout.write(f'📝 Creating {count} reviews...') + self.stdout.write(f"📝 Creating {count} reviews...") if not users or (not parks and not rides): - self.stdout.write(' ⚠️ No users or content found, skipping reviews') + self.stdout.write(" ⚠️ No users or content found, skipping reviews") return review_texts = [ @@ -984,16 +1181,12 @@ class Command(BaseCommand): user=user, park=park, defaults={ - 'rating': random.randint(6, 10), - 'title': f"Great visit to {park.name}", - 'content': random.choice(review_texts), - 'is_published': random.choice([True] * 9 + [False]), - 'visit_date': date( - random.randint(2020, 2024), - random.randint(1, 12), - random.randint(1, 28) - ), - } + "rating": random.randint(6, 10), + "title": f"Great visit to {park.name}", + "content": random.choice(review_texts), + "is_published": random.choice([True] * 9 + [False]), + "visit_date": date(random.randint(2020, 2024), random.randint(1, 12), random.randint(1, 28)), + }, ) if created: @@ -1018,40 +1211,50 @@ class Command(BaseCommand): user=user, ride=ride, defaults={ - 'rating': random.randint(6, 10), - 'title': f"Awesome ride - {ride.name}", - 'content': random.choice(review_texts), - 'is_published': random.choice([True] * 9 + [False]), - 'visit_date': date( - random.randint(2020, 2024), - random.randint(1, 12), - random.randint(1, 28) - ), - } + "rating": random.randint(6, 10), + "title": f"Awesome ride - {ride.name}", + "content": random.choice(review_texts), + "is_published": random.choice([True] * 9 + [False]), + "visit_date": date(random.randint(2020, 2024), random.randint(1, 12), random.randint(1, 28)), + }, ) if created: created_ride_reviews += 1 - self.stdout.write(f' ✅ Created {count} reviews') - - + self.stdout.write(f" ✅ Created {count} reviews") def create_notifications(self, users: list[User]) -> None: """Create sample notifications for users""" - self.stdout.write('🔔 Creating notifications...') + self.stdout.write("🔔 Creating notifications...") if not users: - self.stdout.write(' ⚠️ No users found, skipping notifications') + self.stdout.write(" ⚠️ No users found, skipping notifications") return notification_count = 0 notification_types = [ - ("submission_approved", "Your park submission has been approved!", "Great news! Your submission for Adventure Park has been approved and is now live."), - ("review_helpful", "Someone found your review helpful", "Your review of Steel Vengeance was marked as helpful by another user."), - ("system_announcement", "New features available", "Check out our new ride comparison tool and enhanced search filters."), - ("achievement_unlocked", "Achievement unlocked!", "Congratulations! You've unlocked the 'Coaster Enthusiast' achievement."), + ( + "submission_approved", + "Your park submission has been approved!", + "Great news! Your submission for Adventure Park has been approved and is now live.", + ), + ( + "review_helpful", + "Someone found your review helpful", + "Your review of Steel Vengeance was marked as helpful by another user.", + ), + ( + "system_announcement", + "New features available", + "Check out our new ride comparison tool and enhanced search filters.", + ), + ( + "achievement_unlocked", + "Achievement unlocked!", + "Congratulations! You've unlocked the 'Coaster Enthusiast' achievement.", + ), ] # Create notifications for random users @@ -1064,83 +1267,82 @@ class Command(BaseCommand): notification_type=notification_type, title=title, message=message, - priority=random.choice(['normal'] * 3 + ['high']), + priority=random.choice(["normal"] * 3 + ["high"]), is_read=random.choice([True, False]), email_sent=random.choice([True, False]), push_sent=random.choice([True, False]), ) notification_count += 1 - self.stdout.write(f' ✅ Created {notification_count} notifications') + self.stdout.write(f" ✅ Created {notification_count} notifications") def create_moderation_data(self, users: list[User], parks: list[Park], rides: list[Ride]) -> None: """Create moderation queue and actions""" - self.stdout.write('🛡️ Creating moderation data...') + self.stdout.write("🛡️ Creating moderation data...") if not ModerationQueue or not ModerationAction: - self.stdout.write(' ⚠️ Moderation models not available, skipping') + self.stdout.write(" ⚠️ Moderation models not available, skipping") return if not users or (not parks and not rides): - self.stdout.write(' ⚠️ No users or content found, skipping moderation data') + self.stdout.write(" ⚠️ No users or content found, skipping moderation data") return # This would create sample moderation queue items and actions # Implementation depends on the actual moderation models structure - self.stdout.write(' ✅ Moderation data creation skipped (models not fully defined)') + self.stdout.write(" ✅ Moderation data creation skipped (models not fully defined)") def create_photos(self, parks: list[Park], rides: list[Ride], ride_models: list[RideModel]) -> None: """Create sample photo records""" - self.stdout.write('📸 Creating photo records...') + self.stdout.write("📸 Creating photo records...") if not CloudflareImage: - self.stdout.write(' ⚠️ CloudflareImage model not available, skipping photo creation') + self.stdout.write(" ⚠️ CloudflareImage model not available, skipping photo creation") return # Since we don't have actual Cloudflare images, we'll skip photo creation # In a real scenario, you would need actual CloudflareImage instances - self.stdout.write(' ⚠️ Photo creation skipped (requires actual CloudflareImage instances)') - self.stdout.write(' ℹ️ To create photos, you need to upload actual images to Cloudflare first') + self.stdout.write(" ⚠️ Photo creation skipped (requires actual CloudflareImage instances)") + self.stdout.write(" ℹ️ To create photos, you need to upload actual images to Cloudflare first") def create_rankings(self, rides: list[Ride]) -> None: """Create ride rankings if model exists""" - self.stdout.write('🏆 Creating ride rankings...') + self.stdout.write("🏆 Creating ride rankings...") if not RideRanking: - self.stdout.write(' ⚠️ RideRanking model not available, skipping') + self.stdout.write(" ⚠️ RideRanking model not available, skipping") return if not rides: - self.stdout.write(' ⚠️ No rides found, skipping rankings') + self.stdout.write(" ⚠️ No rides found, skipping rankings") return # This would create sample ride rankings # Implementation depends on the actual RideRanking model structure - self.stdout.write(' ✅ Ride rankings creation skipped (model structure not fully defined)') + self.stdout.write(" ✅ Ride rankings creation skipped (model structure not fully defined)") def print_summary(self) -> None: """Print a summary of created data""" - self.stdout.write('\n📊 Data Seeding Summary:') - self.stdout.write('=' * 50) + self.stdout.write("\n📊 Data Seeding Summary:") + self.stdout.write("=" * 50) # Count all created objects counts = { - 'Users': User.objects.count(), - 'Park Companies': ParkCompany.objects.count(), - 'Ride Companies': RideCompany.objects.count(), - 'Parks': Park.objects.count(), - 'Rides': Ride.objects.count(), - 'Ride Models': RideModel.objects.count(), - 'Park Reviews': ParkReview.objects.count(), - 'Ride Reviews': RideReview.objects.count(), - - 'Notifications': UserNotification.objects.count(), - 'Park Photos': ParkPhoto.objects.count(), - 'Ride Photos': RidePhoto.objects.count(), + "Users": User.objects.count(), + "Park Companies": ParkCompany.objects.count(), + "Ride Companies": RideCompany.objects.count(), + "Parks": Park.objects.count(), + "Rides": Ride.objects.count(), + "Ride Models": RideModel.objects.count(), + "Park Reviews": ParkReview.objects.count(), + "Ride Reviews": RideReview.objects.count(), + "Notifications": UserNotification.objects.count(), + "Park Photos": ParkPhoto.objects.count(), + "Ride Photos": RidePhoto.objects.count(), } for model_name, count in counts.items(): - self.stdout.write(f' {model_name}: {count}') + self.stdout.write(f" {model_name}: {count}") - self.stdout.write('=' * 50) - self.stdout.write('🎉 Seeding completed! Your ThrillWiki database is ready for testing.') + self.stdout.write("=" * 50) + self.stdout.write("🎉 Seeding completed! Your ThrillWiki database is ready for testing.") diff --git a/backend/apps/api/v1/accounts/serializers.py b/backend/apps/api/v1/accounts/serializers.py index 8cdff197..4787c68f 100644 --- a/backend/apps/api/v1/accounts/serializers.py +++ b/backend/apps/api/v1/accounts/serializers.py @@ -23,6 +23,7 @@ class UserProfileUpdateInputSerializer(serializers.ModelSerializer): cloudflare_id = validated_data.pop("cloudflare_image_id", None) if cloudflare_id: from django_cloudflareimages_toolkit.models import CloudflareImage + image, _ = CloudflareImage.objects.get_or_create(cloudflare_id=cloudflare_id) instance.avatar = image diff --git a/backend/apps/api/v1/accounts/urls.py b/backend/apps/api/v1/accounts/urls.py index 34eddbe0..f5aa579e 100644 --- a/backend/apps/api/v1/accounts/urls.py +++ b/backend/apps/api/v1/accounts/urls.py @@ -76,9 +76,7 @@ urlpatterns = [ name="update_privacy_settings", ), # Security settings endpoints - path( - "settings/security/", views.get_security_settings, name="get_security_settings" - ), + path("settings/security/", views.get_security_settings, name="get_security_settings"), path( "settings/security/update/", views.update_security_settings, @@ -90,9 +88,7 @@ urlpatterns = [ path("top-lists/", views.get_user_top_lists, name="get_user_top_lists"), path("top-lists/create/", views.create_top_list, name="create_top_list"), path("top-lists//", views.update_top_list, name="update_top_list"), - path( - "top-lists//delete/", views.delete_top_list, name="delete_top_list" - ), + path("top-lists//delete/", views.delete_top_list, name="delete_top_list"), # Notification endpoints path("notifications/", views.get_user_notifications, name="get_user_notifications"), path( @@ -114,18 +110,13 @@ urlpatterns = [ path("profile/avatar/upload/", views.upload_avatar, name="upload_avatar"), path("profile/avatar/save/", views.save_avatar_image, name="save_avatar_image"), path("profile/avatar/delete/", views.delete_avatar, name="delete_avatar"), - # Login history endpoint path("login-history/", views.get_login_history, name="get_login_history"), - # Magic Link (Login by Code) endpoints path("magic-link/request/", views_magic_link.request_magic_link, name="request_magic_link"), path("magic-link/verify/", views_magic_link.verify_magic_link, name="verify_magic_link"), - # Public Profile path("profiles//", views.get_public_user_profile, name="get_public_user_profile"), - # ViewSet routes path("", include(router.urls)), ] - diff --git a/backend/apps/api/v1/accounts/views.py b/backend/apps/api/v1/accounts/views.py index d4c9ba82..468f0871 100644 --- a/backend/apps/api/v1/accounts/views.py +++ b/backend/apps/api/v1/accounts/views.py @@ -69,8 +69,7 @@ logger = logging.getLogger(__name__) 200: { "description": "User successfully deleted with submissions preserved", "example": { - "success": True, - "message": "User successfully deleted with submissions preserved", + "detail": "User successfully deleted with submissions preserved", "deleted_user": { "username": "john_doe", "user_id": "1234", @@ -92,17 +91,16 @@ logger = logging.getLogger(__name__) 400: { "description": "Bad request - user cannot be deleted", "example": { - "success": False, - "error": "Cannot delete user: Cannot delete superuser accounts", + "detail": "Cannot delete user: Cannot delete superuser accounts", }, }, 404: { "description": "User not found", - "example": {"success": False, "error": "User not found"}, + "example": {"detail": "User not found"}, }, 403: { "description": "Permission denied - admin access required", - "example": {"success": False, "error": "Admin access required"}, + "example": {"detail": "Admin access required"}, }, }, tags=["User Management"], @@ -137,7 +135,7 @@ def delete_user_preserve_submissions(request, user_id): "is_superuser": user.is_superuser, "user_role": user.role, "rejection_reason": reason, - } + }, ) # Determine error code based on reason @@ -151,8 +149,7 @@ def delete_user_preserve_submissions(request, user_id): return Response( { - "success": False, - "error": f"Cannot delete user: {reason}", + "detail": f"Cannot delete user: {reason}", "error_code": error_code, "user_info": { "username": user.username, @@ -174,7 +171,7 @@ def delete_user_preserve_submissions(request, user_id): "target_user": user.username, "target_user_id": user_id, "action": "user_deletion", - } + }, ) # Perform the deletion @@ -185,17 +182,16 @@ def delete_user_preserve_submissions(request, user_id): f"Successfully deleted user {result['deleted_user']['username']} (ID: {user_id}) by admin {request.user.username}", extra={ "admin_user": request.user.username, - "deleted_user": result['deleted_user']['username'], + "deleted_user": result["deleted_user"]["username"], "deleted_user_id": user_id, - "preserved_submissions": result['preserved_submissions'], + "preserved_submissions": result["preserved_submissions"], "action": "user_deletion_completed", - } + }, ) return Response( { - "success": True, - "message": "User successfully deleted with submissions preserved", + "detail": "User successfully deleted with submissions preserved", **result, }, status=status.HTTP_200_OK, @@ -208,16 +204,15 @@ def delete_user_preserve_submissions(request, user_id): extra={ "admin_user": request.user.username, "target_user_id": user_id, - "error": str(e), + "detail": str(e), "action": "user_deletion_error", }, - exc_info=True + exc_info=True, ) return Response( { - "success": False, - "error": f"Error deleting user: {str(e)}", + "detail": f"Error deleting user: {str(e)}", "error_code": "DELETION_ERROR", "help_text": "Please try again or contact system administrator if the problem persists.", }, @@ -259,8 +254,7 @@ def delete_user_preserve_submissions(request, user_id): }, }, "example": { - "success": True, - "message": "Avatar saved successfully", + "detail": "Avatar saved successfully", "avatar_url": "https://imagedelivery.net/account-hash/image-id/avatar", "avatar_variants": { "thumbnail": "https://imagedelivery.net/account-hash/image-id/thumbnail", @@ -285,7 +279,7 @@ def save_avatar_image(request): if not cloudflare_image_id: return Response( - {"success": False, "error": "cloudflare_image_id is required"}, + {"detail": "cloudflare_image_id is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -299,26 +293,25 @@ def save_avatar_image(request): if not image_data: return Response( - {"success": False, "error": "Image not found in Cloudflare"}, + {"detail": "Image not found in Cloudflare"}, status=status.HTTP_400_BAD_REQUEST, ) # Try to find existing CloudflareImage record by cloudflare_id cloudflare_image = None try: - cloudflare_image = CloudflareImage.objects.get( - cloudflare_id=cloudflare_image_id) + cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id) # Update existing record with latest data from Cloudflare - cloudflare_image.status = 'uploaded' + cloudflare_image.status = "uploaded" cloudflare_image.uploaded_at = timezone.now() - cloudflare_image.metadata = image_data.get('meta', {}) + cloudflare_image.metadata = image_data.get("meta", {}) # Extract variants from nested result structure - cloudflare_image.variants = image_data.get('result', {}).get('variants', []) + cloudflare_image.variants = image_data.get("result", {}).get("variants", []) cloudflare_image.cloudflare_metadata = image_data - cloudflare_image.width = image_data.get('width') - cloudflare_image.height = image_data.get('height') - cloudflare_image.format = image_data.get('format', '') + cloudflare_image.width = image_data.get("width") + cloudflare_image.height = image_data.get("height") + cloudflare_image.format = image_data.get("format", "") cloudflare_image.save() except CloudflareImage.DoesNotExist: @@ -326,25 +319,23 @@ def save_avatar_image(request): cloudflare_image = CloudflareImage.objects.create( cloudflare_id=cloudflare_image_id, user=user, - status='uploaded', - upload_url='', # Not needed for uploaded images + status="uploaded", + upload_url="", # Not needed for uploaded images expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry uploaded_at=timezone.now(), - metadata=image_data.get('meta', {}), + metadata=image_data.get("meta", {}), # Extract variants from nested result structure - variants=image_data.get('result', {}).get('variants', []), + variants=image_data.get("result", {}).get("variants", []), cloudflare_metadata=image_data, - width=image_data.get('width'), - height=image_data.get('height'), - format=image_data.get('format', ''), + width=image_data.get("width"), + height=image_data.get("height"), + format=image_data.get("format", ""), ) except Exception as api_error: - logger.error( - f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True) + logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True) return Response( - {"success": False, - "error": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, + {"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -391,8 +382,7 @@ def save_avatar_image(request): return Response( { - "success": True, - "message": "Avatar saved successfully", + "detail": "Avatar saved successfully", "avatar_url": avatar_url, "avatar_variants": avatar_variants, }, @@ -402,7 +392,7 @@ def save_avatar_image(request): except Exception as e: logger.error(f"Error saving avatar image: {str(e)}", exc_info=True) return Response( - {"success": False, "error": f"Failed to save avatar: {str(e)}"}, + {"detail": f"Failed to save avatar: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -420,8 +410,7 @@ def save_avatar_image(request): "avatar_url": {"type": "string"}, }, "example": { - "success": True, - "message": "Avatar deleted successfully", + "detail": "Avatar deleted successfully", "avatar_url": "https://ui-avatars.com/api/?name=J&size=200&background=random&color=fff&bold=true", }, }, @@ -447,6 +436,7 @@ def delete_avatar(request): # Delete from Cloudflare first, then from database try: from django_cloudflareimages_toolkit.services import CloudflareImagesService + service = CloudflareImagesService() service.delete_image(avatar_to_delete) logger.info(f"Successfully deleted avatar from Cloudflare: {avatar_to_delete.cloudflare_id}") @@ -461,8 +451,7 @@ def delete_avatar(request): return Response( { - "success": True, - "message": "Avatar deleted successfully", + "detail": "Avatar deleted successfully", "avatar_url": avatar_url, }, status=status.HTTP_200_OK, @@ -471,8 +460,7 @@ def delete_avatar(request): except UserProfile.DoesNotExist: return Response( { - "success": True, - "message": "No avatar to delete", + "detail": "No avatar to delete", "avatar_url": f"https://ui-avatars.com/api/?name={user.username[0].upper()}&size=200&background=random&color=fff&bold=true", }, status=status.HTTP_200_OK, @@ -480,7 +468,7 @@ def delete_avatar(request): except Exception as e: return Response( - {"success": False, "error": f"Failed to delete avatar: {str(e)}"}, + {"detail": f"Failed to delete avatar: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -506,7 +494,7 @@ def request_account_deletion(request): can_delete, reason = UserDeletionService.can_delete_user(user) if not can_delete: return Response( - {"success": False, "error": reason}, + {"detail": reason}, status=status.HTTP_400_BAD_REQUEST, ) @@ -515,8 +503,7 @@ def request_account_deletion(request): return Response( { - "success": True, - "message": "Verification code sent to your email", + "detail": "Verification code sent to your email", "expires_at": deletion_request.expires_at, "email": user.email, }, @@ -534,7 +521,7 @@ def request_account_deletion(request): "user_role": request.user.role, "rejection_reason": str(e), "action": "self_deletion_rejected", - } + }, ) # Determine error code based on reason @@ -549,8 +536,7 @@ def request_account_deletion(request): return Response( { - "success": False, - "error": error_message, + "detail": error_message, "error_code": error_code, "user_info": { "username": request.user.username, @@ -570,16 +556,15 @@ def request_account_deletion(request): extra={ "user": request.user.username, "user_id": request.user.user_id, - "error": str(e), + "detail": str(e), "action": "self_deletion_error", }, - exc_info=True + exc_info=True, ) return Response( { - "success": False, - "error": f"Error creating deletion request: {str(e)}", + "detail": f"Error creating deletion request: {str(e)}", "error_code": "DELETION_REQUEST_ERROR", "help_text": "Please try again or contact support if the problem persists.", }, @@ -611,8 +596,7 @@ def request_account_deletion(request): 200: { "description": "Account successfully deleted", "example": { - "success": True, - "message": "Account successfully deleted with submissions preserved", + "detail": "Account successfully deleted with submissions preserved", "deleted_user": { "username": "john_doe", "user_id": "1234", @@ -637,7 +621,7 @@ def request_account_deletion(request): }, 400: { "description": "Invalid or expired verification code", - "example": {"success": False, "error": "Verification code has expired"}, + "example": {"detail": "Verification code has expired"}, }, }, tags=["Self-Service Account Management"], @@ -663,7 +647,7 @@ def verify_account_deletion(request): if not verification_code: return Response( - {"success": False, "error": "Verification code is required"}, + {"detail": "Verification code is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -672,20 +656,17 @@ def verify_account_deletion(request): return Response( { - "success": True, - "message": "Account successfully deleted with submissions preserved", + "detail": "Account successfully deleted with submissions preserved", **result, }, status=status.HTTP_200_OK, ) except ValueError as e: - return Response( - {"success": False, "error": str(e)}, status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) except Exception as e: return Response( - {"success": False, "error": f"Error verifying deletion: {str(e)}"}, + {"detail": f"Error verifying deletion: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -701,14 +682,13 @@ def verify_account_deletion(request): 200: { "description": "Deletion request cancelled or no request found", "example": { - "success": True, - "message": "Deletion request cancelled", + "detail": "Deletion request cancelled", "had_pending_request": True, }, }, 401: { "description": "Authentication required", - "example": {"success": False, "error": "Authentication required"}, + "example": {"detail": "Authentication required"}, }, }, tags=["Self-Service Account Management"], @@ -732,12 +712,7 @@ def cancel_account_deletion(request): return Response( { - "success": True, - "message": ( - "Deletion request cancelled" - if had_request - else "No pending deletion request found" - ), + "detail": ("Deletion request cancelled" if had_request else "No pending deletion request found"), "had_pending_request": had_request, }, status=status.HTTP_200_OK, @@ -745,7 +720,7 @@ def cancel_account_deletion(request): except Exception as e: return Response( - {"success": False, "error": f"Error cancelling deletion request: {str(e)}"}, + {"detail": f"Error cancelling deletion request: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -753,10 +728,7 @@ def cancel_account_deletion(request): @extend_schema( operation_id="check_user_deletion_eligibility", summary="Check if user can be deleted", - description=( - "Check if a user can be safely deleted and get a preview of " - "what submissions would be preserved." - ), + description=("Check if a user can be safely deleted and get a preview of " "what submissions would be preserved."), parameters=[ OpenApiParameter( name="user_id", @@ -792,11 +764,11 @@ def cancel_account_deletion(request): }, 404: { "description": "User not found", - "example": {"success": False, "error": "User not found"}, + "example": {"detail": "User not found"}, }, 403: { "description": "Permission denied - admin access required", - "example": {"success": False, "error": "Admin access required"}, + "example": {"detail": "Admin access required"}, }, }, tags=["User Management"], @@ -821,27 +793,13 @@ def check_user_deletion_eligibility(request, user_id): # Count submissions submission_counts = { - "park_reviews": getattr( - user, "park_reviews", user.__class__.objects.none() - ).count(), - "ride_reviews": getattr( - user, "ride_reviews", user.__class__.objects.none() - ).count(), - "uploaded_park_photos": getattr( - user, "uploaded_park_photos", user.__class__.objects.none() - ).count(), - "uploaded_ride_photos": getattr( - user, "uploaded_ride_photos", user.__class__.objects.none() - ).count(), - "top_lists": getattr( - user, "user_lists", user.__class__.objects.none() - ).count(), - "edit_submissions": getattr( - user, "edit_submissions", user.__class__.objects.none() - ).count(), - "photo_submissions": getattr( - user, "photo_submissions", user.__class__.objects.none() - ).count(), + "park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(), + "ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(), + "uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(), + "uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(), + "top_lists": getattr(user, "user_lists", user.__class__.objects.none()).count(), + "edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(), + "photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(), } total_submissions = sum(submission_counts.values()) @@ -865,7 +823,7 @@ def check_user_deletion_eligibility(request, user_id): except Exception as e: return Response( - {"success": False, "error": f"Error checking user: {str(e)}"}, + {"detail": f"Error checking user: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -912,9 +870,7 @@ def get_user_profile(request): @permission_classes([IsAuthenticated]) def update_user_account(request): """Update basic account information.""" - serializer = AccountUpdateSerializer( - request.user, data=request.data, partial=True, context={"request": request} - ) + serializer = AccountUpdateSerializer(request.user, data=request.data, partial=True, context={"request": request}) if serializer.is_valid(): serializer.save() @@ -944,9 +900,7 @@ def update_user_profile(request): """Update user profile information.""" profile, created = UserProfile.objects.get_or_create(user=request.user) - serializer = ProfileUpdateSerializer( - profile, data=request.data, partial=True, context={"request": request} - ) + serializer = ProfileUpdateSerializer(profile, data=request.data, partial=True, context={"request": request}) if serializer.is_valid(): serializer.save() @@ -1046,9 +1000,7 @@ def update_user_preferences(request): @permission_classes([IsAuthenticated]) def update_theme_preference(request): """Update theme preference.""" - serializer = ThemePreferenceSerializer( - request.user, data=request.data, partial=True - ) + serializer = ThemePreferenceSerializer(request.user, data=request.data, partial=True) if serializer.is_valid(): serializer.save() @@ -1395,14 +1347,9 @@ def update_top_list(request, list_id): try: top_list = UserList.objects.get(id=list_id, user=request.user) except UserList.DoesNotExist: - return Response( - {"error": "Top list not found"}, - status=status.HTTP_404_NOT_FOUND - ) + return Response({"detail": "Top list not found"}, status=status.HTTP_404_NOT_FOUND) - serializer = UserListSerializer( - top_list, data=request.data, partial=True, context={"request": request} - ) + serializer = UserListSerializer(top_list, data=request.data, partial=True, context={"request": request}) if serializer.is_valid(): serializer.save() @@ -1430,10 +1377,7 @@ def delete_top_list(request, list_id): top_list.delete() return Response(status=status.HTTP_204_NO_CONTENT) except UserList.DoesNotExist: - return Response( - {"error": "Top list not found"}, - status=status.HTTP_404_NOT_FOUND - ) + return Response({"detail": "Top list not found"}, status=status.HTTP_404_NOT_FOUND) # === NOTIFICATION ENDPOINTS === @@ -1453,9 +1397,9 @@ def delete_top_list(request, list_id): @permission_classes([IsAuthenticated]) def get_user_notifications(request): """Get user notifications.""" - notifications = UserNotification.objects.filter( - user=request.user - ).order_by("-created_at")[:50] # Limit to 50 most recent + notifications = UserNotification.objects.filter(user=request.user).order_by("-created_at")[ + :50 + ] # Limit to 50 most recent serializer = UserNotificationSerializer(notifications, many=True) return Response(serializer.data, status=status.HTTP_200_OK) @@ -1483,19 +1427,16 @@ def mark_notifications_read(request): mark_all = serializer.validated_data.get("mark_all", False) if mark_all: - UserNotification.objects.filter( - user=request.user, is_read=False - ).update(is_read=True, read_at=timezone.now()) + UserNotification.objects.filter(user=request.user, is_read=False).update( + is_read=True, read_at=timezone.now() + ) count = UserNotification.objects.filter(user=request.user).count() else: - count = UserNotification.objects.filter( - id__in=notification_ids, user=request.user, is_read=False - ).update(is_read=True, read_at=timezone.now()) + count = UserNotification.objects.filter(id__in=notification_ids, user=request.user, is_read=False).update( + is_read=True, read_at=timezone.now() + ) - return Response( - {"message": f"Marked {count} notifications as read"}, - status=status.HTTP_200_OK - ) + return Response({"detail": f"Marked {count} notifications as read"}, status=status.HTTP_200_OK) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -1544,9 +1485,7 @@ def update_notification_preferences(request): except NotificationPreference.DoesNotExist: preferences = NotificationPreference.objects.create(user=request.user) - serializer = NotificationPreferenceSerializer( - preferences, data=request.data, partial=True - ) + serializer = NotificationPreferenceSerializer(preferences, data=request.data, partial=True) if serializer.is_valid(): serializer.save() @@ -1578,10 +1517,7 @@ def upload_avatar(request): if serializer.is_valid(): # Handle avatar upload logic here # This would typically involve saving the file and updating the user profile - return Response( - {"message": "Avatar uploaded successfully"}, - status=status.HTTP_200_OK - ) + return Response({"detail": "Avatar uploaded successfully"}, status=status.HTTP_200_OK) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -1596,8 +1532,8 @@ def upload_avatar(request): "example": { "account": {"username": "user", "email": "user@example.com"}, "profile": {"display_name": "User"}, - "content": {"park_reviews": [], "lists": []} - } + "content": {"park_reviews": [], "lists": []}, + }, }, 401: {"description": "Authentication required"}, }, @@ -1612,10 +1548,7 @@ def export_user_data(request): return Response(export_data, status=status.HTTP_200_OK) except Exception as e: logger.error(f"Error exporting data for user {request.user.id}: {e}", exc_info=True) - return Response( - {"error": "Failed to generate data export"}, - status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + return Response({"detail": "Failed to generate data export"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @extend_schema( @@ -1690,20 +1623,25 @@ def get_login_history(request): # Serialize results = [] for entry in entries: - results.append({ - "id": entry.id, - "ip_address": entry.ip_address, - "user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents - "login_method": entry.login_method, - "login_method_display": dict(LoginHistory._meta.get_field('login_method').choices).get(entry.login_method, entry.login_method), - "login_timestamp": entry.login_timestamp.isoformat(), - "country": entry.country, - "city": entry.city, - "success": entry.success, - }) - - return Response({ - "results": results, - "count": len(results), - }) + results.append( + { + "id": entry.id, + "ip_address": entry.ip_address, + "user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents + "login_method": entry.login_method, + "login_method_display": dict(LoginHistory._meta.get_field("login_method").choices).get( + entry.login_method, entry.login_method + ), + "login_timestamp": entry.login_timestamp.isoformat(), + "country": entry.country, + "city": entry.city, + "success": entry.success, + } + ) + return Response( + { + "results": results, + "count": len(results), + } + ) diff --git a/backend/apps/api/v1/accounts/views_credits.py b/backend/apps/api/v1/accounts/views_credits.py index b19faf7c..2a315567 100644 --- a/backend/apps/api/v1/accounts/views_credits.py +++ b/backend/apps/api/v1/accounts/views_credits.py @@ -15,22 +15,23 @@ class RideCreditViewSet(viewsets.ModelViewSet): ViewSet for managing Ride Credits. Allows users to track rides they have ridden. """ + serializer_class = RideCreditSerializer permission_classes = [permissions.IsAuthenticatedOrReadOnly] filter_backends = [DjangoFilterBackend, filters.OrderingFilter] - filterset_fields = ['user__username', 'ride__park__slug', 'ride__manufacturer__slug'] - ordering_fields = ['first_ridden_at', 'last_ridden_at', 'created_at', 'count', 'rating', 'display_order'] - ordering = ['display_order', '-last_ridden_at'] + filterset_fields = ["user__username", "ride__park__slug", "ride__manufacturer__slug"] + ordering_fields = ["first_ridden_at", "last_ridden_at", "created_at", "count", "rating", "display_order"] + ordering = ["display_order", "-last_ridden_at"] def get_queryset(self): """ Return ride credits. Optionally filter by user via query param ?user=username """ - queryset = RideCredit.objects.all().select_related('ride', 'ride__park', 'user') + queryset = RideCredit.objects.all().select_related("ride", "ride__park", "user") # Filter by user if provided - username = self.request.query_params.get('user') + username = self.request.query_params.get("user") if username: queryset = queryset.filter(user__username=username) @@ -40,64 +41,49 @@ class RideCreditViewSet(viewsets.ModelViewSet): """Associate the current user with the ride credit.""" serializer.save(user=self.request.user) - @action(detail=False, methods=['post'], permission_classes=[permissions.IsAuthenticated]) + @action(detail=False, methods=["post"], permission_classes=[permissions.IsAuthenticated]) @extend_schema( summary="Reorder ride credits", description="Bulk update the display order of ride credits. Send a list of {id, order} objects.", request={ - 'application/json': { - 'type': 'object', - 'properties': { - 'order': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'id': {'type': 'integer'}, - 'order': {'type': 'integer'} - }, - 'required': ['id', 'order'] - } + "application/json": { + "type": "object", + "properties": { + "order": { + "type": "array", + "items": { + "type": "object", + "properties": {"id": {"type": "integer"}, "order": {"type": "integer"}}, + "required": ["id", "order"], + }, } - } + }, } - } + }, ) def reorder(self, request): """ Bulk update display_order for multiple credits. Expects: {"order": [{"id": 1, "order": 0}, {"id": 2, "order": 1}, ...]} """ - order_data = request.data.get('order', []) + order_data = request.data.get("order", []) if not order_data: - return Response( - {'error': 'No order data provided'}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": "No order data provided"}, status=status.HTTP_400_BAD_REQUEST) # Validate that all credits belong to the current user - credit_ids = [item['id'] for item in order_data] - user_credits = RideCredit.objects.filter( - id__in=credit_ids, - user=request.user - ).values_list('id', flat=True) + credit_ids = [item["id"] for item in order_data] + user_credits = RideCredit.objects.filter(id__in=credit_ids, user=request.user).values_list("id", flat=True) if set(credit_ids) != set(user_credits): - return Response( - {'error': 'You can only reorder your own credits'}, - status=status.HTTP_403_FORBIDDEN - ) + return Response({"detail": "You can only reorder your own credits"}, status=status.HTTP_403_FORBIDDEN) # Bulk update in a transaction with transaction.atomic(): for item in order_data: - RideCredit.objects.filter( - id=item['id'], - user=request.user - ).update(display_order=item['order']) + RideCredit.objects.filter(id=item["id"], user=request.user).update(display_order=item["order"]) - return Response({'status': 'reordered', 'count': len(order_data)}) + return Response({"status": "reordered", "count": len(order_data)}) @extend_schema( summary="List ride credits", @@ -109,8 +95,7 @@ class RideCreditViewSet(viewsets.ModelViewSet): type=OpenApiTypes.STR, description="Filter by username", ), - ] + ], ) def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) - diff --git a/backend/apps/api/v1/accounts/views_magic_link.py b/backend/apps/api/v1/accounts/views_magic_link.py index 806bf67e..65fdad4e 100644 --- a/backend/apps/api/v1/accounts/views_magic_link.py +++ b/backend/apps/api/v1/accounts/views_magic_link.py @@ -4,6 +4,7 @@ Magic Link (Login by Code) API views. Provides API endpoints for passwordless login via email code. Uses django-allauth's built-in login-by-code functionality. """ + from django.conf import settings from drf_spectacular.utils import OpenApiExample, extend_schema from rest_framework import status @@ -15,6 +16,7 @@ try: from allauth.account.internal.flows.login_by_code import perform_login_by_code, request_login_code from allauth.account.models import EmailAddress from allauth.account.utils import user_email # noqa: F401 - imported to verify availability + HAS_LOGIN_BY_CODE = True except ImportError: HAS_LOGIN_BY_CODE = False @@ -24,27 +26,19 @@ except ImportError: summary="Request magic link login code", description="Send a one-time login code to the user's email address.", request={ - 'application/json': { - 'type': 'object', - 'properties': { - 'email': {'type': 'string', 'format': 'email'} - }, - 'required': ['email'] + "application/json": { + "type": "object", + "properties": {"email": {"type": "string", "format": "email"}}, + "required": ["email"], } }, responses={ - 200: {'description': 'Login code sent successfully'}, - 400: {'description': 'Invalid email or feature disabled'}, + 200: {"description": "Login code sent successfully"}, + 400: {"description": "Invalid email or feature disabled"}, }, - examples=[ - OpenApiExample( - 'Request login code', - value={'email': 'user@example.com'}, - request_only=True - ) - ] + examples=[OpenApiExample("Request login code", value={"email": "user@example.com"}, request_only=True)], ) -@api_view(['POST']) +@api_view(["POST"]) @permission_classes([AllowAny]) def request_magic_link(request): """ @@ -55,25 +49,18 @@ def request_magic_link(request): 2. If the email exists, a code is sent 3. User enters the code to complete login """ - if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False): - return Response( - {'error': 'Magic link login is not enabled'}, - status=status.HTTP_400_BAD_REQUEST - ) + if not getattr(settings, "ACCOUNT_LOGIN_BY_CODE_ENABLED", False): + return Response({"detail": "Magic link login is not enabled"}, status=status.HTTP_400_BAD_REQUEST) if not HAS_LOGIN_BY_CODE: return Response( - {'error': 'Login by code is not available in this version of allauth'}, - status=status.HTTP_400_BAD_REQUEST + {"detail": "Login by code is not available in this version of allauth"}, status=status.HTTP_400_BAD_REQUEST ) - email = request.data.get('email', '').lower().strip() + email = request.data.get("email", "").lower().strip() if not email: - return Response( - {'error': 'Email is required'}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": "Email is required"}, status=status.HTTP_400_BAD_REQUEST) # Check if email exists (don't reveal if it doesn't for security) try: @@ -83,40 +70,39 @@ def request_magic_link(request): # Request the login code request_login_code(request._request, user) - return Response({ - 'success': True, - 'message': 'If an account exists with this email, a login code has been sent.', - 'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300) - }) + return Response( + { + "detail": "If an account exists with this email, a login code has been sent.", + "timeout": getattr(settings, "ACCOUNT_LOGIN_BY_CODE_TIMEOUT", 300), + } + ) except EmailAddress.DoesNotExist: # Don't reveal that the email doesn't exist - return Response({ - 'success': True, - 'message': 'If an account exists with this email, a login code has been sent.', - 'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300) - }) + return Response( + { + "detail": "If an account exists with this email, a login code has been sent.", + "timeout": getattr(settings, "ACCOUNT_LOGIN_BY_CODE_TIMEOUT", 300), + } + ) @extend_schema( summary="Verify magic link code", description="Verify the login code and complete the login process.", request={ - 'application/json': { - 'type': 'object', - 'properties': { - 'email': {'type': 'string', 'format': 'email'}, - 'code': {'type': 'string'} - }, - 'required': ['email', 'code'] + "application/json": { + "type": "object", + "properties": {"email": {"type": "string", "format": "email"}, "code": {"type": "string"}}, + "required": ["email", "code"], } }, responses={ - 200: {'description': 'Login successful'}, - 400: {'description': 'Invalid or expired code'}, - } + 200: {"description": "Login successful"}, + 400: {"description": "Invalid or expired code"}, + }, ) -@api_view(['POST']) +@api_view(["POST"]) @permission_classes([AllowAny]) def verify_magic_link(request): """ @@ -124,26 +110,17 @@ def verify_magic_link(request): This is the second step of the magic link flow. """ - if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False): - return Response( - {'error': 'Magic link login is not enabled'}, - status=status.HTTP_400_BAD_REQUEST - ) + if not getattr(settings, "ACCOUNT_LOGIN_BY_CODE_ENABLED", False): + return Response({"detail": "Magic link login is not enabled"}, status=status.HTTP_400_BAD_REQUEST) if not HAS_LOGIN_BY_CODE: - return Response( - {'error': 'Login by code is not available'}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": "Login by code is not available"}, status=status.HTTP_400_BAD_REQUEST) - email = request.data.get('email', '').lower().strip() - code = request.data.get('code', '').strip() + email = request.data.get("email", "").lower().strip() + code = request.data.get("code", "").strip() if not email or not code: - return Response( - {'error': 'Email and code are required'}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": "Email and code are required"}, status=status.HTTP_400_BAD_REQUEST) try: email_address = EmailAddress.objects.get(email__iexact=email, verified=True) @@ -153,28 +130,20 @@ def verify_magic_link(request): success = perform_login_by_code(request._request, user, code) if success: - return Response({ - 'success': True, - 'message': 'Login successful', - 'user': { - 'id': user.id, - 'username': user.username, - 'email': user.email + return Response( + { + "detail": "Login successful", + "user": {"id": user.id, "username": user.username, "email": user.email}, } - }) + ) else: return Response( - {'error': 'Invalid or expired code. Please request a new one.'}, - status=status.HTTP_400_BAD_REQUEST + {"detail": "Invalid or expired code. Please request a new one."}, status=status.HTTP_400_BAD_REQUEST ) except EmailAddress.DoesNotExist: - return Response( - {'error': 'Invalid email or code'}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": "Invalid email or code"}, status=status.HTTP_400_BAD_REQUEST) except Exception: return Response( - {'error': 'Invalid or expired code. Please request a new one.'}, - status=status.HTTP_400_BAD_REQUEST + {"detail": "Invalid or expired code. Please request a new one."}, status=status.HTTP_400_BAD_REQUEST ) diff --git a/backend/apps/api/v1/auth/mfa.py b/backend/apps/api/v1/auth/mfa.py index 9e046630..ca2c88f3 100644 --- a/backend/apps/api/v1/auth/mfa.py +++ b/backend/apps/api/v1/auth/mfa.py @@ -17,6 +17,7 @@ from rest_framework.response import Response try: import qrcode + HAS_QRCODE = True except ImportError: HAS_QRCODE = False @@ -59,12 +60,14 @@ def get_mfa_status(request): except Authenticator.DoesNotExist: pass - return Response({ - "mfa_enabled": totp_enabled, - "totp_enabled": totp_enabled, - "recovery_codes_enabled": recovery_enabled, - "recovery_codes_count": recovery_count, - }) + return Response( + { + "mfa_enabled": totp_enabled, + "totp_enabled": totp_enabled, + "recovery_codes_enabled": recovery_enabled, + "recovery_codes_count": recovery_count, + } + ) @extend_schema( @@ -110,11 +113,13 @@ def setup_totp(request): # Store secret in session for later verification request.session["pending_totp_secret"] = secret - return Response({ - "secret": secret, - "provisioning_uri": uri, - "qr_code_base64": qr_code_base64, - }) + return Response( + { + "secret": secret, + "provisioning_uri": uri, + "qr_code_base64": qr_code_base64, + } + ) @extend_schema( @@ -138,8 +143,7 @@ def setup_totp(request): 200: { "description": "TOTP activated successfully", "example": { - "success": True, - "message": "Two-factor authentication enabled", + "detail": "Two-factor authentication enabled", "recovery_codes": ["ABCD1234", "EFGH5678"], }, }, @@ -160,7 +164,7 @@ def activate_totp(request): if not code: return Response( - {"success": False, "error": "Verification code is required"}, + {"detail": "Verification code is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -168,21 +172,21 @@ def activate_totp(request): secret = request.session.get("pending_totp_secret") if not secret: return Response( - {"success": False, "error": "No pending TOTP setup. Please start setup again."}, + {"detail": "No pending TOTP setup. Please start setup again."}, status=status.HTTP_400_BAD_REQUEST, ) # Verify the code if not totp_auth.validate_totp_code(secret, code): return Response( - {"success": False, "error": "Invalid verification code"}, + {"detail": "Invalid verification code"}, status=status.HTTP_400_BAD_REQUEST, ) # Check if already has TOTP if Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists(): return Response( - {"success": False, "error": "TOTP is already enabled"}, + {"detail": "TOTP is already enabled"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -204,11 +208,12 @@ def activate_totp(request): # Clear session del request.session["pending_totp_secret"] - return Response({ - "success": True, - "message": "Two-factor authentication enabled", - "recovery_codes": codes, - }) + return Response( + { + "detail": "Two-factor authentication enabled", + "recovery_codes": codes, + } + ) @extend_schema( @@ -230,7 +235,7 @@ def activate_totp(request): responses={ 200: { "description": "TOTP disabled", - "example": {"success": True, "message": "Two-factor authentication disabled"}, + "example": {"detail": "Two-factor authentication disabled"}, }, 400: {"description": "Invalid password or MFA not enabled"}, }, @@ -248,26 +253,26 @@ def deactivate_totp(request): # Verify password if not user.check_password(password): return Response( - {"success": False, "error": "Invalid password"}, + {"detail": "Invalid password"}, status=status.HTTP_400_BAD_REQUEST, ) # Remove TOTP and recovery codes deleted_count, _ = Authenticator.objects.filter( - user=user, - type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES] + user=user, type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES] ).delete() if deleted_count == 0: return Response( - {"success": False, "error": "Two-factor authentication is not enabled"}, + {"detail": "Two-factor authentication is not enabled"}, status=status.HTTP_400_BAD_REQUEST, ) - return Response({ - "success": True, - "message": "Two-factor authentication disabled", - }) + return Response( + { + "detail": "Two-factor authentication disabled", + } + ) @extend_schema( @@ -277,9 +282,7 @@ def deactivate_totp(request): request={ "application/json": { "type": "object", - "properties": { - "code": {"type": "string", "description": "6-digit TOTP code"} - }, + "properties": {"code": {"type": "string", "description": "6-digit TOTP code"}}, "required": ["code"], } }, @@ -301,7 +304,7 @@ def verify_totp(request): if not code: return Response( - {"success": False, "error": "Verification code is required"}, + {"detail": "Verification code is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -313,12 +316,12 @@ def verify_totp(request): return Response({"success": True}) else: return Response( - {"success": False, "error": "Invalid verification code"}, + {"detail": "Invalid verification code"}, status=status.HTTP_400_BAD_REQUEST, ) except Authenticator.DoesNotExist: return Response( - {"success": False, "error": "TOTP is not enabled"}, + {"detail": "TOTP is not enabled"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -330,9 +333,7 @@ def verify_totp(request): request={ "application/json": { "type": "object", - "properties": { - "password": {"type": "string", "description": "Current password"} - }, + "properties": {"password": {"type": "string", "description": "Current password"}}, "required": ["password"], } }, @@ -358,14 +359,14 @@ def regenerate_recovery_codes(request): # Verify password if not user.check_password(password): return Response( - {"success": False, "error": "Invalid password"}, + {"detail": "Invalid password"}, status=status.HTTP_400_BAD_REQUEST, ) # Check if TOTP is enabled if not Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists(): return Response( - {"success": False, "error": "Two-factor authentication is not enabled"}, + {"detail": "Two-factor authentication is not enabled"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -379,7 +380,9 @@ def regenerate_recovery_codes(request): defaults={"data": {"codes": codes}}, ) - return Response({ - "success": True, - "recovery_codes": codes, - }) + return Response( + { + "success": True, + "recovery_codes": codes, + } + ) diff --git a/backend/apps/api/v1/auth/serializers.py b/backend/apps/api/v1/auth/serializers.py index ffd913bd..3f99577d 100644 --- a/backend/apps/api/v1/auth/serializers.py +++ b/backend/apps/api/v1/auth/serializers.py @@ -38,8 +38,6 @@ class ModelChoices: """Model choices utility class.""" - - # === AUTHENTICATION SERIALIZERS === @@ -95,12 +93,8 @@ class UserOutputSerializer(serializers.ModelSerializer): class LoginInputSerializer(serializers.Serializer): """Input serializer for user login.""" - username = serializers.CharField( - max_length=254, help_text="Username or email address" - ) - password = serializers.CharField( - max_length=128, style={"input_type": "password"}, trim_whitespace=False - ) + username = serializers.CharField(max_length=254, help_text="Username or email address") + password = serializers.CharField(max_length=128, style={"input_type": "password"}, trim_whitespace=False) def validate(self, attrs): username = attrs.get("username") @@ -129,9 +123,7 @@ class SignupInputSerializer(serializers.ModelSerializer): validators=[validate_password], style={"input_type": "password"}, ) - password_confirm = serializers.CharField( - write_only=True, style={"input_type": "password"} - ) + password_confirm = serializers.CharField(write_only=True, style={"input_type": "password"}) class Meta: model = UserModel @@ -158,9 +150,7 @@ class SignupInputSerializer(serializers.ModelSerializer): def validate_username(self, value): """Validate username is unique.""" if UserModel.objects.filter(username=value).exists(): - raise serializers.ValidationError( - "A user with this username already exists." - ) + raise serializers.ValidationError("A user with this username already exists.") return value def validate(self, attrs): @@ -169,9 +159,7 @@ class SignupInputSerializer(serializers.ModelSerializer): password_confirm = attrs.get("password_confirm") if password != password_confirm: - raise serializers.ValidationError( - {"password_confirm": "Passwords do not match."} - ) + raise serializers.ValidationError({"password_confirm": "Passwords do not match."}) return attrs @@ -204,8 +192,7 @@ class SignupInputSerializer(serializers.ModelSerializer): # Create or update email verification record verification, created = EmailVerification.objects.get_or_create( - user=user, - defaults={'token': get_random_string(64)} + user=user, defaults={"token": get_random_string(64)} ) if not created: @@ -214,14 +201,12 @@ class SignupInputSerializer(serializers.ModelSerializer): verification.save() # Get current site from request context - request = self.context.get('request') + request = self.context.get("request") if request: site = get_current_site(request._request) # Build verification URL - verification_url = request.build_absolute_uri( - f"/api/v1/auth/verify-email/{verification.token}/" - ) + verification_url = request.build_absolute_uri(f"/api/v1/auth/verify-email/{verification.token}/") # Send verification email try: @@ -243,13 +228,11 @@ The ThrillWiki Team ) # Log the ForwardEmail email ID from the response - email_id = response.get('id') if response else None + email_id = response.get("id") if response else None if email_id: - logger.info( - f"Verification email sent successfully to {user.email}. ForwardEmail ID: {email_id}") + logger.info(f"Verification email sent successfully to {user.email}. ForwardEmail ID: {email_id}") else: - logger.info( - f"Verification email sent successfully to {user.email}. No email ID in response.") + logger.info(f"Verification email sent successfully to {user.email}. No email ID in response.") except Exception as e: # Log the error but don't fail registration @@ -312,17 +295,13 @@ class PasswordResetOutputSerializer(serializers.Serializer): class PasswordChangeInputSerializer(serializers.Serializer): """Input serializer for password change.""" - old_password = serializers.CharField( - max_length=128, style={"input_type": "password"} - ) + old_password = serializers.CharField(max_length=128, style={"input_type": "password"}) new_password = serializers.CharField( max_length=128, validators=[validate_password], style={"input_type": "password"}, ) - new_password_confirm = serializers.CharField( - max_length=128, style={"input_type": "password"} - ) + new_password_confirm = serializers.CharField(max_length=128, style={"input_type": "password"}) def validate_old_password(self, value): """Validate old password is correct.""" @@ -337,9 +316,7 @@ class PasswordChangeInputSerializer(serializers.Serializer): new_password_confirm = attrs.get("new_password_confirm") if new_password != new_password_confirm: - raise serializers.ValidationError( - {"new_password_confirm": "New passwords do not match."} - ) + raise serializers.ValidationError({"new_password_confirm": "New passwords do not match."}) return attrs @@ -471,6 +448,3 @@ class UserProfileUpdateInputSerializer(serializers.Serializer): dark_ride_credits = serializers.IntegerField(required=False) flat_ride_credits = serializers.IntegerField(required=False) water_ride_credits = serializers.IntegerField(required=False) - - - diff --git a/backend/apps/api/v1/auth/serializers_package/__init__.py b/backend/apps/api/v1/auth/serializers_package/__init__.py index caf8e0d9..799ef060 100644 --- a/backend/apps/api/v1/auth/serializers_package/__init__.py +++ b/backend/apps/api/v1/auth/serializers_package/__init__.py @@ -19,13 +19,13 @@ from .social import ( __all__ = [ # Social authentication serializers - 'ConnectedProviderSerializer', - 'AvailableProviderSerializer', - 'SocialAuthStatusSerializer', - 'ConnectProviderInputSerializer', - 'ConnectProviderOutputSerializer', - 'DisconnectProviderOutputSerializer', - 'SocialProviderListOutputSerializer', - 'ConnectedProvidersListOutputSerializer', - 'SocialProviderErrorSerializer', + "ConnectedProviderSerializer", + "AvailableProviderSerializer", + "SocialAuthStatusSerializer", + "ConnectProviderInputSerializer", + "ConnectProviderOutputSerializer", + "DisconnectProviderOutputSerializer", + "SocialProviderListOutputSerializer", + "ConnectedProvidersListOutputSerializer", + "SocialProviderErrorSerializer", ] diff --git a/backend/apps/api/v1/auth/serializers_package/social.py b/backend/apps/api/v1/auth/serializers_package/social.py index dfe855b9..4a90d485 100644 --- a/backend/apps/api/v1/auth/serializers_package/social.py +++ b/backend/apps/api/v1/auth/serializers_package/social.py @@ -14,74 +14,36 @@ User = get_user_model() class ConnectedProviderSerializer(serializers.Serializer): """Serializer for connected social provider information.""" - provider = serializers.CharField( - help_text="Provider ID (e.g., 'google', 'discord')" - ) - provider_name = serializers.CharField( - help_text="Human-readable provider name" - ) - uid = serializers.CharField( - help_text="User ID on the social provider" - ) - date_joined = serializers.DateTimeField( - help_text="When this provider was connected" - ) - can_disconnect = serializers.BooleanField( - help_text="Whether this provider can be safely disconnected" - ) + provider = serializers.CharField(help_text="Provider ID (e.g., 'google', 'discord')") + provider_name = serializers.CharField(help_text="Human-readable provider name") + uid = serializers.CharField(help_text="User ID on the social provider") + date_joined = serializers.DateTimeField(help_text="When this provider was connected") + can_disconnect = serializers.BooleanField(help_text="Whether this provider can be safely disconnected") disconnect_reason = serializers.CharField( - allow_null=True, - required=False, - help_text="Reason why provider cannot be disconnected (if applicable)" - ) - extra_data = serializers.JSONField( - required=False, - help_text="Additional data from the social provider" + allow_null=True, required=False, help_text="Reason why provider cannot be disconnected (if applicable)" ) + extra_data = serializers.JSONField(required=False, help_text="Additional data from the social provider") class AvailableProviderSerializer(serializers.Serializer): """Serializer for available social provider information.""" - id = serializers.CharField( - help_text="Provider ID (e.g., 'google', 'discord')" - ) - name = serializers.CharField( - help_text="Human-readable provider name" - ) - auth_url = serializers.URLField( - help_text="URL to initiate authentication with this provider" - ) - connect_url = serializers.URLField( - help_text="API URL to connect this provider" - ) + id = serializers.CharField(help_text="Provider ID (e.g., 'google', 'discord')") + name = serializers.CharField(help_text="Human-readable provider name") + auth_url = serializers.URLField(help_text="URL to initiate authentication with this provider") + connect_url = serializers.URLField(help_text="API URL to connect this provider") class SocialAuthStatusSerializer(serializers.Serializer): """Serializer for comprehensive social authentication status.""" - user_id = serializers.IntegerField( - help_text="User's ID" - ) - username = serializers.CharField( - help_text="User's username" - ) - email = serializers.EmailField( - help_text="User's email address" - ) - has_password_auth = serializers.BooleanField( - help_text="Whether user has email/password authentication set up" - ) - connected_providers = ConnectedProviderSerializer( - many=True, - help_text="List of connected social providers" - ) - total_auth_methods = serializers.IntegerField( - help_text="Total number of authentication methods available" - ) - can_disconnect_any = serializers.BooleanField( - help_text="Whether user can safely disconnect any provider" - ) + user_id = serializers.IntegerField(help_text="User's ID") + username = serializers.CharField(help_text="User's username") + email = serializers.EmailField(help_text="User's email address") + has_password_auth = serializers.BooleanField(help_text="Whether user has email/password authentication set up") + connected_providers = ConnectedProviderSerializer(many=True, help_text="List of connected social providers") + total_auth_methods = serializers.IntegerField(help_text="Total number of authentication methods available") + can_disconnect_any = serializers.BooleanField(help_text="Whether user can safely disconnect any provider") requires_password_setup = serializers.BooleanField( help_text="Whether user needs to set up password before disconnecting" ) @@ -90,9 +52,7 @@ class SocialAuthStatusSerializer(serializers.Serializer): class ConnectProviderInputSerializer(serializers.Serializer): """Serializer for social provider connection requests.""" - provider = serializers.CharField( - help_text="Provider ID to connect (e.g., 'google', 'discord')" - ) + provider = serializers.CharField(help_text="Provider ID to connect (e.g., 'google', 'discord')") def validate_provider(self, value): """Validate that the provider is supported and configured.""" @@ -108,93 +68,51 @@ class ConnectProviderInputSerializer(serializers.Serializer): class ConnectProviderOutputSerializer(serializers.Serializer): """Serializer for social provider connection responses.""" - success = serializers.BooleanField( - help_text="Whether the connection was successful" - ) - message = serializers.CharField( - help_text="Success or error message" - ) - provider = serializers.CharField( - help_text="Provider that was connected" - ) - auth_url = serializers.URLField( - required=False, - help_text="URL to complete the connection process" - ) + success = serializers.BooleanField(help_text="Whether the connection was successful") + message = serializers.CharField(help_text="Success or error message") + provider = serializers.CharField(help_text="Provider that was connected") + auth_url = serializers.URLField(required=False, help_text="URL to complete the connection process") class DisconnectProviderOutputSerializer(serializers.Serializer): """Serializer for social provider disconnection responses.""" - success = serializers.BooleanField( - help_text="Whether the disconnection was successful" - ) - message = serializers.CharField( - help_text="Success or error message" - ) - provider = serializers.CharField( - help_text="Provider that was disconnected" - ) + success = serializers.BooleanField(help_text="Whether the disconnection was successful") + message = serializers.CharField(help_text="Success or error message") + provider = serializers.CharField(help_text="Provider that was disconnected") remaining_providers = serializers.ListField( - child=serializers.CharField(), - help_text="List of remaining connected providers" - ) - has_password_auth = serializers.BooleanField( - help_text="Whether user still has password authentication" + child=serializers.CharField(), help_text="List of remaining connected providers" ) + has_password_auth = serializers.BooleanField(help_text="Whether user still has password authentication") suggestions = serializers.ListField( child=serializers.CharField(), required=False, - help_text="Suggestions for maintaining account access (if applicable)" + help_text="Suggestions for maintaining account access (if applicable)", ) class SocialProviderListOutputSerializer(serializers.Serializer): """Serializer for listing available social providers.""" - available_providers = AvailableProviderSerializer( - many=True, - help_text="List of available social providers" - ) - count = serializers.IntegerField( - help_text="Number of available providers" - ) + available_providers = AvailableProviderSerializer(many=True, help_text="List of available social providers") + count = serializers.IntegerField(help_text="Number of available providers") class ConnectedProvidersListOutputSerializer(serializers.Serializer): """Serializer for listing connected social providers.""" - connected_providers = ConnectedProviderSerializer( - many=True, - help_text="List of connected social providers" - ) - count = serializers.IntegerField( - help_text="Number of connected providers" - ) - has_password_auth = serializers.BooleanField( - help_text="Whether user has password authentication" - ) - can_disconnect_any = serializers.BooleanField( - help_text="Whether user can safely disconnect any provider" - ) + connected_providers = ConnectedProviderSerializer(many=True, help_text="List of connected social providers") + count = serializers.IntegerField(help_text="Number of connected providers") + has_password_auth = serializers.BooleanField(help_text="Whether user has password authentication") + can_disconnect_any = serializers.BooleanField(help_text="Whether user can safely disconnect any provider") class SocialProviderErrorSerializer(serializers.Serializer): """Serializer for social provider error responses.""" - error = serializers.CharField( - help_text="Error message" - ) - code = serializers.CharField( - required=False, - help_text="Error code for programmatic handling" - ) + error = serializers.CharField(help_text="Error message") + code = serializers.CharField(required=False, help_text="Error code for programmatic handling") suggestions = serializers.ListField( - child=serializers.CharField(), - required=False, - help_text="Suggestions for resolving the error" - ) - provider = serializers.CharField( - required=False, - help_text="Provider related to the error (if applicable)" + child=serializers.CharField(), required=False, help_text="Suggestions for resolving the error" ) + provider = serializers.CharField(required=False, help_text="Provider related to the error (if applicable)") diff --git a/backend/apps/api/v1/auth/urls.py b/backend/apps/api/v1/auth/urls.py index 4fc5faf0..e73b5e99 100644 --- a/backend/apps/api/v1/auth/urls.py +++ b/backend/apps/api/v1/auth/urls.py @@ -36,13 +36,10 @@ urlpatterns = [ path("signup/", SignupAPIView.as_view(), name="auth-signup"), path("logout/", LogoutAPIView.as_view(), name="auth-logout"), path("user/", CurrentUserAPIView.as_view(), name="auth-current-user"), - # JWT token management path("token/refresh/", TokenRefreshView.as_view(), name="auth-token-refresh"), - # Social authentication endpoints (dj-rest-auth) path("social/", include("dj_rest_auth.registration.urls")), - path( "password/reset/", PasswordResetAPIView.as_view(), @@ -58,7 +55,6 @@ urlpatterns = [ SocialProvidersAPIView.as_view(), name="auth-social-providers", ), - # Social provider management endpoints path( "social/providers/available/", @@ -85,9 +81,7 @@ urlpatterns = [ SocialAuthStatusAPIView.as_view(), name="auth-social-status", ), - path("status/", AuthStatusAPIView.as_view(), name="auth-status"), - # Email verification endpoints path( "verify-email//", @@ -99,7 +93,6 @@ urlpatterns = [ ResendVerificationAPIView.as_view(), name="auth-resend-verification", ), - # MFA (Multi-Factor Authentication) endpoints path("mfa/status/", mfa_views.get_mfa_status, name="auth-mfa-status"), path("mfa/totp/setup/", mfa_views.setup_totp, name="auth-mfa-totp-setup"), diff --git a/backend/apps/api/v1/auth/views.py b/backend/apps/api/v1/auth/views.py index 028cf70c..33327944 100644 --- a/backend/apps/api/v1/auth/views.py +++ b/backend/apps/api/v1/auth/views.py @@ -85,9 +85,7 @@ def _get_underlying_request(request: Request) -> HttpRequest: # Helper: encapsulate user lookup + authenticate to reduce complexity in view -def _authenticate_user_by_lookup( - email_or_username: str, password: str, request: Request -) -> UserModel | None: +def _authenticate_user_by_lookup(email_or_username: str, password: str, request: Request) -> UserModel | None: """ Try a single optimized query to find a user by email OR username then authenticate. Returns authenticated user or None. @@ -154,7 +152,7 @@ class LoginAPIView(APIView): # instantiate mixin before calling to avoid type-mismatch in static analysis TurnstileMixin().validate_turnstile(request) except ValidationError as e: - return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) except Exception: # If mixin doesn't do anything, continue pass @@ -168,7 +166,7 @@ class LoginAPIView(APIView): if not email_or_username or not password: return Response( - {"error": "username and password are required"}, + {"detail": "username and password are required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -177,8 +175,7 @@ class LoginAPIView(APIView): if user: if getattr(user, "is_active", False): # pass a real HttpRequest to Django login with backend specified - login(_get_underlying_request(request), user, - backend='django.contrib.auth.backends.ModelBackend') + login(_get_underlying_request(request), user, backend="django.contrib.auth.backends.ModelBackend") # Generate JWT tokens from rest_framework_simplejwt.tokens import RefreshToken @@ -191,22 +188,22 @@ class LoginAPIView(APIView): "access": str(access_token), "refresh": str(refresh), "user": user, - "message": "Login successful", + "detail": "Login successful", } ) return Response(response_serializer.data) else: return Response( { - "error": "Email verification required", - "message": "Please verify your email address before logging in. Check your email for a verification link.", - "email_verification_required": True + "detail": "Please verify your email address before logging in. Check your email for a verification link.", + "code": "EMAIL_VERIFICATION_REQUIRED", + "email_verification_required": True, }, status=status.HTTP_400_BAD_REQUEST, ) else: return Response( - {"error": "Invalid credentials"}, + {"detail": "Invalid credentials"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -237,7 +234,7 @@ class SignupAPIView(APIView): # instantiate mixin before calling to avoid type-mismatch in static analysis TurnstileMixin().validate_turnstile(request) except ValidationError as e: - return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) except Exception: # If mixin doesn't do anything, continue pass @@ -252,7 +249,7 @@ class SignupAPIView(APIView): "access": None, "refresh": None, "user": user, - "message": "Registration successful. Please check your email to verify your account.", + "detail": "Registration successful. Please check your email to verify your account.", "email_verification_required": True, } ) @@ -282,18 +279,18 @@ class LogoutAPIView(APIView): try: # Get refresh token from request data with proper type handling refresh_token = None - if hasattr(request, 'data') and request.data is not None: - data = getattr(request, 'data', {}) - if hasattr(data, 'get'): + if hasattr(request, "data") and request.data is not None: + data = getattr(request, "data", {}) + if hasattr(data, "get"): refresh_token = data.get("refresh") if refresh_token and isinstance(refresh_token, str): # Blacklist the refresh token from rest_framework_simplejwt.tokens import RefreshToken + try: # Create RefreshToken from string and blacklist it - refresh_token_obj = RefreshToken( - refresh_token) # type: ignore[arg-type] + refresh_token_obj = RefreshToken(refresh_token) # type: ignore[arg-type] refresh_token_obj.blacklist() except Exception: # Token might be invalid or already blacklisted @@ -306,14 +303,10 @@ class LogoutAPIView(APIView): # Logout from session using the underlying HttpRequest logout(_get_underlying_request(request)) - response_serializer = LogoutOutputSerializer( - {"message": "Logout successful"} - ) + response_serializer = LogoutOutputSerializer({"detail": "Logout successful"}) return Response(response_serializer.data) except Exception: - return Response( - {"error": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + return Response({"detail": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @extend_schema_view( @@ -357,15 +350,11 @@ class PasswordResetAPIView(APIView): serializer_class = PasswordResetInputSerializer def post(self, request: Request) -> Response: - serializer = PasswordResetInputSerializer( - data=request.data, context={"request": request} - ) + serializer = PasswordResetInputSerializer(data=request.data, context={"request": request}) if serializer.is_valid(): serializer.save() - response_serializer = PasswordResetOutputSerializer( - {"detail": "Password reset email sent"} - ) + response_serializer = PasswordResetOutputSerializer({"detail": "Password reset email sent"}) return Response(response_serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -391,15 +380,11 @@ class PasswordChangeAPIView(APIView): serializer_class = PasswordChangeInputSerializer def post(self, request: Request) -> Response: - serializer = PasswordChangeInputSerializer( - data=request.data, context={"request": request} - ) + serializer = PasswordChangeInputSerializer(data=request.data, context={"request": request}) if serializer.is_valid(): serializer.save() - response_serializer = PasswordChangeOutputSerializer( - {"detail": "Password changed successfully"} - ) + response_serializer = PasswordChangeOutputSerializer({"detail": "Password changed successfully"}) return Response(response_serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -443,13 +428,9 @@ class SocialProvidersAPIView(APIView): for social_app in social_apps: try: - provider_name = ( - social_app.name or getattr(social_app, "provider", "").title() - ) + provider_name = social_app.name or getattr(social_app, "provider", "").title() - auth_url = request.build_absolute_uri( - f"/accounts/{social_app.provider}/login/" - ) + auth_url = request.build_absolute_uri(f"/accounts/{social_app.provider}/login/") providers_list.append( { @@ -532,7 +513,7 @@ class AvailableProvidersAPIView(APIView): "name": "Discord", "login_url": "/auth/social/discord/", "connect_url": "/auth/social/connect/discord/", - } + }, ] serializer = AvailableProviderSerializer(providers, many=True) @@ -585,31 +566,29 @@ class ConnectProviderAPIView(APIView): def post(self, request: Request, provider: str) -> Response: # Validate provider - if provider not in ['google', 'discord']: + if provider not in ["google", "discord"]: return Response( { - "success": False, - "error": "INVALID_PROVIDER", - "message": f"Provider '{provider}' is not supported", - "suggestions": ["Use 'google' or 'discord'"] + "detail": f"Provider '{provider}' is not supported", + "code": "INVALID_PROVIDER", + "suggestions": ["Use 'google' or 'discord'"], }, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) serializer = ConnectProviderInputSerializer(data=request.data) if not serializer.is_valid(): return Response( { - "success": False, - "error": "VALIDATION_ERROR", - "message": "Invalid request data", + "detail": "Invalid request data", + "code": "VALIDATION_ERROR", "details": serializer.errors, - "suggestions": ["Provide a valid access_token"] + "suggestions": ["Provide a valid access_token"], }, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) - access_token = serializer.validated_data['access_token'] + access_token = serializer.validated_data["access_token"] try: service = SocialProviderService() @@ -622,14 +601,14 @@ class ConnectProviderAPIView(APIView): return Response( { "success": False, - "error": "CONNECTION_FAILED", + "detail": "CONNECTION_FAILED", "message": str(e), "suggestions": [ "Verify the access token is valid", - "Ensure the provider account is not already connected to another user" - ] + "Ensure the provider account is not already connected to another user", + ], }, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) @@ -653,35 +632,33 @@ class DisconnectProviderAPIView(APIView): def post(self, request: Request, provider: str) -> Response: # Validate provider - if provider not in ['google', 'discord']: + if provider not in ["google", "discord"]: return Response( { - "success": False, - "error": "INVALID_PROVIDER", - "message": f"Provider '{provider}' is not supported", - "suggestions": ["Use 'google' or 'discord'"] + "detail": f"Provider '{provider}' is not supported", + "code": "INVALID_PROVIDER", + "suggestions": ["Use 'google' or 'discord'"], }, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) try: service = SocialProviderService() # Check if disconnection is safe - can_disconnect, reason = service.can_disconnect_provider( - request.user, provider) + can_disconnect, reason = service.can_disconnect_provider(request.user, provider) if not can_disconnect: return Response( { "success": False, - "error": "UNSAFE_DISCONNECTION", + "detail": "UNSAFE_DISCONNECTION", "message": reason, "suggestions": [ "Set up email/password authentication before disconnecting", - "Connect another social provider before disconnecting this one" - ] + "Connect another social provider before disconnecting this one", + ], }, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) # Perform disconnection @@ -694,14 +671,14 @@ class DisconnectProviderAPIView(APIView): return Response( { "success": False, - "error": "DISCONNECTION_FAILED", + "detail": "DISCONNECTION_FAILED", "message": str(e), "suggestions": [ "Verify the provider is currently connected", - "Ensure you have alternative authentication methods" - ] + "Ensure you have alternative authentication methods", + ], }, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) @@ -755,7 +732,7 @@ class EmailVerificationAPIView(APIView): from apps.accounts.models import EmailVerification try: - verification = EmailVerification.objects.select_related('user').get(token=token) + verification = EmailVerification.objects.select_related("user").get(token=token) user = verification.user # Activate the user @@ -765,16 +742,10 @@ class EmailVerificationAPIView(APIView): # Delete the verification record verification.delete() - return Response({ - "message": "Email verified successfully. You can now log in.", - "success": True - }) + return Response({"detail": "Email verified successfully. You can now log in.", "success": True}) except EmailVerification.DoesNotExist: - return Response( - {"error": "Invalid or expired verification token"}, - status=status.HTTP_404_NOT_FOUND - ) + return Response({"detail": "Invalid or expired verification token"}, status=status.HTTP_404_NOT_FOUND) @extend_schema_view( @@ -803,27 +774,20 @@ class ResendVerificationAPIView(APIView): from apps.accounts.models import EmailVerification - email = request.data.get('email') + email = request.data.get("email") if not email: - return Response( - {"error": "Email address is required"}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": "Email address is required"}, status=status.HTTP_400_BAD_REQUEST) try: user = UserModel.objects.get(email__iexact=email.strip().lower()) # Don't resend if user is already active if user.is_active: - return Response( - {"error": "Email is already verified"}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"detail": "Email is already verified"}, status=status.HTTP_400_BAD_REQUEST) # Create or update verification record verification, created = EmailVerification.objects.get_or_create( - user=user, - defaults={'token': get_random_string(64)} + user=user, defaults={"token": get_random_string(64)} ) if not created: @@ -833,9 +797,7 @@ class ResendVerificationAPIView(APIView): # Send verification email site = get_current_site(_get_underlying_request(request)) - verification_url = request.build_absolute_uri( - f"/api/v1/auth/verify-email/{verification.token}/" - ) + verification_url = request.build_absolute_uri(f"/api/v1/auth/verify-email/{verification.token}/") try: EmailService.send_email( @@ -855,27 +817,21 @@ The ThrillWiki Team site=site, ) - return Response({ - "message": "Verification email sent successfully", - "success": True - }) + return Response({"detail": "Verification email sent successfully", "success": True}) except Exception as e: import logging + logger = logging.getLogger(__name__) logger.error(f"Failed to send verification email to {user.email}: {e}") return Response( - {"error": "Failed to send verification email"}, - status=status.HTTP_500_INTERNAL_SERVER_ERROR + {"detail": "Failed to send verification email"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR ) except UserModel.DoesNotExist: # Don't reveal whether email exists - return Response({ - "message": "If the email exists, a verification email has been sent", - "success": True - }) + return Response({"detail": "If the email exists, a verification email has been sent", "success": True}) # Note: User Profile, Top List, and Top List Item ViewSets are now handled diff --git a/backend/apps/api/v1/core/views.py b/backend/apps/api/v1/core/views.py index e0871bf9..96b6d7a7 100644 --- a/backend/apps/api/v1/core/views.py +++ b/backend/apps/api/v1/core/views.py @@ -8,7 +8,6 @@ Caching Strategy: - EntityNotFoundView: No caching - POST requests with context-specific data """ - import contextlib from drf_spectacular.utils import extend_schema @@ -82,9 +81,7 @@ class EntityFuzzySearchView(APIView): try: # Parse request data query = request.data.get("query", "").strip() - entity_types_raw = request.data.get( - "entity_types", ["park", "ride", "company"] - ) + entity_types_raw = request.data.get("entity_types", ["park", "ride", "company"]) include_suggestions = request.data.get("include_suggestions", True) # Validate query @@ -92,7 +89,7 @@ class EntityFuzzySearchView(APIView): return Response( { "success": False, - "error": "Query must be at least 2 characters long", + "detail": "Query must be at least 2 characters long", "code": "INVALID_QUERY", }, status=status.HTTP_400_BAD_REQUEST, @@ -120,9 +117,7 @@ class EntityFuzzySearchView(APIView): "query": query, "matches": [match.to_dict() for match in matches], "user_authenticated": ( - request.user.is_authenticated - if hasattr(request.user, "is_authenticated") - else False + request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False ), } @@ -143,7 +138,7 @@ class EntityFuzzySearchView(APIView): return Response( { "success": False, - "error": f"Internal server error: {str(e)}", + "detail": f"Internal server error: {str(e)}", "code": "INTERNAL_ERROR", }, status=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -192,7 +187,7 @@ class EntityNotFoundView(APIView): return Response( { "success": False, - "error": "original_query is required", + "detail": "original_query is required", "code": "MISSING_QUERY", }, status=status.HTTP_400_BAD_REQUEST, @@ -233,9 +228,7 @@ class EntityNotFoundView(APIView): "context": context, "matches": [match.to_dict() for match in matches], "user_authenticated": ( - request.user.is_authenticated - if hasattr(request.user, "is_authenticated") - else False + request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False ), "has_matches": len(matches) > 0, } @@ -257,7 +250,7 @@ class EntityNotFoundView(APIView): return Response( { "success": False, - "error": f"Internal server error: {str(e)}", + "detail": f"Internal server error: {str(e)}", "code": "INTERNAL_ERROR", }, status=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -297,9 +290,7 @@ class QuickEntitySuggestionView(APIView): limit = min(int(request.GET.get("limit", 5)), 10) # Cap at 10 if not query or len(query) < 2: - return Response( - {"suggestions": [], "query": query}, status=status.HTTP_200_OK - ) + return Response({"suggestions": [], "query": query}, status=status.HTTP_200_OK) # Parse entity types entity_types = [] @@ -312,9 +303,7 @@ class QuickEntitySuggestionView(APIView): entity_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY] # Get fuzzy matches - matches, _ = entity_fuzzy_matcher.find_entity( - query=query, entity_types=entity_types, user=request.user - ) + matches, _ = entity_fuzzy_matcher.find_entity(query=query, entity_types=entity_types, user=request.user) # Format as simple suggestions suggestions = [] @@ -337,15 +326,13 @@ class QuickEntitySuggestionView(APIView): except Exception as e: return Response( - {"suggestions": [], "query": request.GET.get("q", ""), "error": str(e)}, + {"suggestions": [], "query": request.GET.get("q", ""), "detail": str(e)}, status=status.HTTP_200_OK, ) # Return 200 even on errors for autocomplete # Utility function for other views to use -def get_entity_suggestions( - query: str, entity_types: list[str] | None = None, user=None -): +def get_entity_suggestions(query: str, entity_types: list[str] | None = None, user=None): """ Utility function for other Django views to get entity suggestions. @@ -370,8 +357,6 @@ def get_entity_suggestions( if not parsed_types: parsed_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY] - return entity_fuzzy_matcher.find_entity( - query=query, entity_types=parsed_types, user=user - ) + return entity_fuzzy_matcher.find_entity(query=query, entity_types=parsed_types, user=user) except Exception: return [], None diff --git a/backend/apps/api/v1/email/views.py b/backend/apps/api/v1/email/views.py index 4aa1ab6a..044405c5 100644 --- a/backend/apps/api/v1/email/views.py +++ b/backend/apps/api/v1/email/views.py @@ -76,7 +76,7 @@ class SendEmailView(APIView): if not all([to, subject, text]): return Response( { - "error": "Missing required fields", + "detail": "Missing required fields", "required_fields": ["to", "subject", "text"], }, status=status.HTTP_400_BAD_REQUEST, @@ -96,11 +96,9 @@ class SendEmailView(APIView): ) return Response( - {"message": "Email sent successfully", "response": response}, + {"detail": "Email sent successfully", "response": response}, status=status.HTTP_200_OK, ) except Exception as e: - return Response( - {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + return Response({"detail": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/backend/apps/api/v1/history/views.py b/backend/apps/api/v1/history/views.py index 2ec260fa..eb1213ec 100644 --- a/backend/apps/api/v1/history/views.py +++ b/backend/apps/api/v1/history/views.py @@ -37,21 +37,11 @@ class _FallbackSerializer(drf_serializers.Serializer): return {} -ParkHistoryEventSerializer = getattr( - history_serializers, "ParkHistoryEventSerializer", _FallbackSerializer -) -RideHistoryEventSerializer = getattr( - history_serializers, "RideHistoryEventSerializer", _FallbackSerializer -) -ParkHistoryOutputSerializer = getattr( - history_serializers, "ParkHistoryOutputSerializer", _FallbackSerializer -) -RideHistoryOutputSerializer = getattr( - history_serializers, "RideHistoryOutputSerializer", _FallbackSerializer -) -UnifiedHistoryTimelineSerializer = getattr( - history_serializers, "UnifiedHistoryTimelineSerializer", _FallbackSerializer -) +ParkHistoryEventSerializer = getattr(history_serializers, "ParkHistoryEventSerializer", _FallbackSerializer) +RideHistoryEventSerializer = getattr(history_serializers, "RideHistoryEventSerializer", _FallbackSerializer) +ParkHistoryOutputSerializer = getattr(history_serializers, "ParkHistoryOutputSerializer", _FallbackSerializer) +RideHistoryOutputSerializer = getattr(history_serializers, "RideHistoryOutputSerializer", _FallbackSerializer) +UnifiedHistoryTimelineSerializer = getattr(history_serializers, "UnifiedHistoryTimelineSerializer", _FallbackSerializer) # --- Constants for model strings to avoid duplication --- PARK_MODEL = "parks.park" @@ -201,18 +191,14 @@ class ParkHistoryViewSet(ReadOnlyModelViewSet): # Base queryset for park events queryset = ( - pghistory.models.Events.objects.filter( - pgh_model__in=[PARK_MODEL], pgh_obj_id=getattr(park, "id", None) - ) + pghistory.models.Events.objects.filter(pgh_model__in=[PARK_MODEL], pgh_obj_id=getattr(park, "id", None)) .select_related() .order_by("-pgh_created_at") ) # Apply list filters via helper to reduce complexity if self.action == "list": - queryset = _apply_list_filters( - queryset, cast(Request, self.request), default_limit=50, max_limit=500 - ) + queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=50, max_limit=500) return queryset @@ -322,18 +308,14 @@ class RideHistoryViewSet(ReadOnlyModelViewSet): # Base queryset for ride events queryset = ( - pghistory.models.Events.objects.filter( - pgh_model__in=RIDE_MODELS, pgh_obj_id=getattr(ride, "id", None) - ) + pghistory.models.Events.objects.filter(pgh_model__in=RIDE_MODELS, pgh_obj_id=getattr(ride, "id", None)) .select_related() .order_by("-pgh_created_at") ) # Apply list filters via helper if self.action == "list": - queryset = _apply_list_filters( - queryset, cast(Request, self.request), default_limit=50, max_limit=500 - ) + queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=50, max_limit=500) return queryset @@ -462,9 +444,7 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet): # Apply shared list filters when serving the list action if self.action == "list": - queryset = _apply_list_filters( - queryset, cast(Request, self.request), default_limit=100, max_limit=1000 - ) + queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=100, max_limit=1000) return queryset @@ -477,9 +457,7 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet): events = list(self.get_queryset()) # evaluate for counts / earliest/latest use # Summary statistics across all tracked models - total_events = pghistory.models.Events.objects.filter( - pgh_model__in=ALL_TRACKED_MODELS - ).count() + total_events = pghistory.models.Events.objects.filter(pgh_model__in=ALL_TRACKED_MODELS).count() event_type_counts = ( pghistory.models.Events.objects.filter(pgh_model__in=ALL_TRACKED_MODELS) @@ -497,12 +475,8 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet): "summary": { "total_events": total_events, "events_returned": len(events), - "event_type_breakdown": { - item["pgh_label"]: item["count"] for item in event_type_counts - }, - "model_type_breakdown": { - item["pgh_model"]: item["count"] for item in model_type_counts - }, + "event_type_breakdown": {item["pgh_label"]: item["count"] for item in event_type_counts}, + "model_type_breakdown": {item["pgh_model"]: item["count"] for item in model_type_counts}, "time_range": { "earliest": events[-1].pgh_created_at if events else None, "latest": events[0].pgh_created_at if events else None, diff --git a/backend/apps/api/v1/images/views.py b/backend/apps/api/v1/images/views.py index 3e343a4e..84ce3419 100644 --- a/backend/apps/api/v1/images/views.py +++ b/backend/apps/api/v1/images/views.py @@ -11,6 +11,7 @@ from apps.core.utils.cloudflare import get_direct_upload_url logger = logging.getLogger(__name__) + class GenerateUploadURLView(APIView): permission_classes = [IsAuthenticated] @@ -21,19 +22,10 @@ class GenerateUploadURLView(APIView): return Response(result, status=status.HTTP_200_OK) except ImproperlyConfigured as e: logger.error(f"Configuration Error: {e}") - return Response( - {"detail": "Server configuration error."}, - status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + return Response({"detail": "Server configuration error."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) except requests.RequestException as e: logger.error(f"Cloudflare API Error: {e}") - return Response( - {"detail": "Failed to generate upload URL."}, - status=status.HTTP_502_BAD_GATEWAY - ) + return Response({"detail": "Failed to generate upload URL."}, status=status.HTTP_502_BAD_GATEWAY) except Exception: logger.exception("Unexpected error generating upload URL") - return Response( - {"detail": "An unexpected error occurred."}, - status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + return Response({"detail": "An unexpected error occurred."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/backend/apps/api/v1/maps/views.py b/backend/apps/api/v1/maps/views.py index 0ec2a04e..5a44597e 100644 --- a/backend/apps/api/v1/maps/views.py +++ b/backend/apps/api/v1/maps/views.py @@ -162,16 +162,13 @@ class MapLocationsAPIView(APIView): if not all([north, south, east, west]): return None try: - return Polygon.from_bbox( - (float(west), float(south), float(east), float(north)) - ) + return Polygon.from_bbox((float(west), float(south), float(east), float(north))) except (ValueError, TypeError): return None def _serialize_park_location(self, park) -> dict: """Serialize park location data.""" - location = park.location if hasattr( - park, "location") and park.location else None + location = park.location if hasattr(park, "location") and park.location else None return { "city": location.city if location else "", "state": location.state if location else "", @@ -181,8 +178,7 @@ class MapLocationsAPIView(APIView): def _serialize_park_data(self, park) -> dict: """Serialize park data for map response.""" - location = park.location if hasattr( - park, "location") and park.location else None + location = park.location if hasattr(park, "location") and park.location else None return { "id": park.id, "type": "park", @@ -195,9 +191,7 @@ class MapLocationsAPIView(APIView): "stats": { "coaster_count": park.coaster_count or 0, "ride_count": park.ride_count or 0, - "average_rating": ( - float(park.average_rating) if park.average_rating else None - ), + "average_rating": (float(park.average_rating) if park.average_rating else None), }, } @@ -206,14 +200,10 @@ class MapLocationsAPIView(APIView): if "park" not in params["types"]: return [] - parks_query = Park.objects.select_related( - "location", "operator" - ).filter(location__point__isnull=False) + parks_query = Park.objects.select_related("location", "operator").filter(location__point__isnull=False) # Apply bounds filtering - bounds_polygon = self._create_bounds_polygon( - params["north"], params["south"], params["east"], params["west"] - ) + bounds_polygon = self._create_bounds_polygon(params["north"], params["south"], params["east"], params["west"]) if bounds_polygon: parks_query = parks_query.filter(location__point__within=bounds_polygon) @@ -229,11 +219,7 @@ class MapLocationsAPIView(APIView): def _serialize_ride_location(self, ride) -> dict: """Serialize ride location data.""" - location = ( - ride.park.location - if hasattr(ride.park, "location") and ride.park.location - else None - ) + location = ride.park.location if hasattr(ride.park, "location") and ride.park.location else None return { "city": location.city if location else "", "state": location.state if location else "", @@ -243,11 +229,7 @@ class MapLocationsAPIView(APIView): def _serialize_ride_data(self, ride) -> dict: """Serialize ride data for map response.""" - location = ( - ride.park.location - if hasattr(ride.park, "location") and ride.park.location - else None - ) + location = ride.park.location if hasattr(ride.park, "location") and ride.park.location else None return { "id": ride.id, "type": "ride", @@ -259,9 +241,7 @@ class MapLocationsAPIView(APIView): "location": self._serialize_ride_location(ride), "stats": { "category": ride.get_category_display() if ride.category else None, - "average_rating": ( - float(ride.average_rating) if ride.average_rating else None - ), + "average_rating": (float(ride.average_rating) if ride.average_rating else None), "park_name": ride.park.name, }, } @@ -271,17 +251,14 @@ class MapLocationsAPIView(APIView): if "ride" not in params["types"]: return [] - rides_query = Ride.objects.select_related( - "park__location", "manufacturer" - ).filter(park__location__point__isnull=False) + rides_query = Ride.objects.select_related("park__location", "manufacturer").filter( + park__location__point__isnull=False + ) # Apply bounds filtering - bounds_polygon = self._create_bounds_polygon( - params["north"], params["south"], params["east"], params["west"] - ) + bounds_polygon = self._create_bounds_polygon(params["north"], params["south"], params["east"], params["west"]) if bounds_polygon: - rides_query = rides_query.filter( - park__location__point__within=bounds_polygon) + rides_query = rides_query.filter(park__location__point__within=bounds_polygon) # Apply text search if params["query"]: @@ -335,7 +312,7 @@ class MapLocationsAPIView(APIView): # Use EnhancedCacheService for improved caching with monitoring cache_service = EnhancedCacheService() - cached_result = cache_service.get_cached_api_response('map_locations', params) + cached_result = cache_service.get_cached_api_response("map_locations", params) if cached_result: logger.debug(f"Cache hit for map_locations with key: {cache_key}") return Response(cached_result) @@ -349,7 +326,7 @@ class MapLocationsAPIView(APIView): result = self._build_response(locations, params) # Cache result for 5 minutes using EnhancedCacheService - cache_service.cache_api_response('map_locations', params, result, timeout=300) + cache_service.cache_api_response("map_locations", params, result, timeout=300) logger.debug(f"Cached map_locations result for key: {cache_key}") return Response(result) @@ -357,7 +334,7 @@ class MapLocationsAPIView(APIView): except Exception as e: logger.error(f"Error in MapLocationsAPIView: {str(e)}", exc_info=True) return Response( - {"status": "error", "message": "Failed to retrieve map locations"}, + {"status": "error", "detail": "Failed to retrieve map locations"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -401,34 +378,28 @@ class MapLocationDetailAPIView(APIView): permission_classes = [AllowAny] @cache_api_response(timeout=1800, key_prefix="map_detail") - def get( - self, request: HttpRequest, location_type: str, location_id: int - ) -> Response: + def get(self, request: HttpRequest, location_type: str, location_id: int) -> Response: """Get detailed information for a specific location.""" try: if location_type == "park": try: - obj = Park.objects.select_related("location", "operator").get( - id=location_id - ) + obj = Park.objects.select_related("location", "operator").get(id=location_id) except Park.DoesNotExist: return Response( - {"status": "error", "message": "Park not found"}, + {"status": "error", "detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND, ) elif location_type == "ride": try: - obj = Ride.objects.select_related( - "park__location", "manufacturer" - ).get(id=location_id) + obj = Ride.objects.select_related("park__location", "manufacturer").get(id=location_id) except Ride.DoesNotExist: return Response( - {"status": "error", "message": "Ride not found"}, + {"status": "error", "detail": "Ride not found"}, status=status.HTTP_404_NOT_FOUND, ) else: return Response( - {"status": "error", "message": "Invalid location type"}, + {"status": "error", "detail": "Invalid location type"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -440,59 +411,27 @@ class MapLocationDetailAPIView(APIView): "name": obj.name, "slug": obj.slug, "description": obj.description, - "latitude": ( - obj.location.latitude - if hasattr(obj, "location") and obj.location - else None - ), - "longitude": ( - obj.location.longitude - if hasattr(obj, "location") and obj.location - else None - ), + "latitude": (obj.location.latitude if hasattr(obj, "location") and obj.location else None), + "longitude": (obj.location.longitude if hasattr(obj, "location") and obj.location else None), "status": obj.status, "location": { "street_address": ( - obj.location.street_address - if hasattr(obj, "location") and obj.location - else "" - ), - "city": ( - obj.location.city - if hasattr(obj, "location") and obj.location - else "" - ), - "state": ( - obj.location.state - if hasattr(obj, "location") and obj.location - else "" - ), - "country": ( - obj.location.country - if hasattr(obj, "location") and obj.location - else "" - ), - "postal_code": ( - obj.location.postal_code - if hasattr(obj, "location") and obj.location - else "" + obj.location.street_address if hasattr(obj, "location") and obj.location else "" ), + "city": (obj.location.city if hasattr(obj, "location") and obj.location else ""), + "state": (obj.location.state if hasattr(obj, "location") and obj.location else ""), + "country": (obj.location.country if hasattr(obj, "location") and obj.location else ""), + "postal_code": (obj.location.postal_code if hasattr(obj, "location") and obj.location else ""), "formatted_address": ( - obj.location.formatted_address - if hasattr(obj, "location") and obj.location - else "" + obj.location.formatted_address if hasattr(obj, "location") and obj.location else "" ), }, "stats": { "coaster_count": obj.coaster_count or 0, "ride_count": obj.ride_count or 0, - "average_rating": ( - float(obj.average_rating) if obj.average_rating else None - ), + "average_rating": (float(obj.average_rating) if obj.average_rating else None), "size_acres": float(obj.size_acres) if obj.size_acres else None, - "opening_date": ( - obj.opening_date.isoformat() if obj.opening_date else None - ), + "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None), }, "nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107 } @@ -504,14 +443,10 @@ class MapLocationDetailAPIView(APIView): "slug": obj.slug, "description": obj.description, "latitude": ( - obj.park.location.latitude - if hasattr(obj.park, "location") and obj.park.location - else None + obj.park.location.latitude if hasattr(obj.park, "location") and obj.park.location else None ), "longitude": ( - obj.park.location.longitude - if hasattr(obj.park, "location") and obj.park.location - else None + obj.park.location.longitude if hasattr(obj.park, "location") and obj.park.location else None ), "status": obj.status, "location": { @@ -520,25 +455,15 @@ class MapLocationDetailAPIView(APIView): if hasattr(obj.park, "location") and obj.park.location else "" ), - "city": ( - obj.park.location.city - if hasattr(obj.park, "location") and obj.park.location - else "" - ), + "city": (obj.park.location.city if hasattr(obj.park, "location") and obj.park.location else ""), "state": ( - obj.park.location.state - if hasattr(obj.park, "location") and obj.park.location - else "" + obj.park.location.state if hasattr(obj.park, "location") and obj.park.location else "" ), "country": ( - obj.park.location.country - if hasattr(obj.park, "location") and obj.park.location - else "" + obj.park.location.country if hasattr(obj.park, "location") and obj.park.location else "" ), "postal_code": ( - obj.park.location.postal_code - if hasattr(obj.park, "location") and obj.park.location - else "" + obj.park.location.postal_code if hasattr(obj.park, "location") and obj.park.location else "" ), "formatted_address": ( obj.park.location.formatted_address @@ -547,19 +472,11 @@ class MapLocationDetailAPIView(APIView): ), }, "stats": { - "category": ( - obj.get_category_display() if obj.category else None - ), - "average_rating": ( - float(obj.average_rating) if obj.average_rating else None - ), + "category": (obj.get_category_display() if obj.category else None), + "average_rating": (float(obj.average_rating) if obj.average_rating else None), "park_name": obj.park.name, - "opening_date": ( - obj.opening_date.isoformat() if obj.opening_date else None - ), - "manufacturer": ( - obj.manufacturer.name if obj.manufacturer else None - ), + "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None), + "manufacturer": (obj.manufacturer.name if obj.manufacturer else None), }, "nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107 } @@ -574,7 +491,7 @@ class MapLocationDetailAPIView(APIView): except Exception as e: logger.error(f"Error in MapLocationDetailAPIView: {str(e)}", exc_info=True) return Response( - {"status": "error", "message": "Failed to retrieve location details"}, + {"status": "error", "detail": "Failed to retrieve location details"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -640,7 +557,7 @@ class MapSearchAPIView(APIView): return Response( { "status": "error", - "message": "Search query 'q' parameter is required", + "detail": "Search query 'q' parameter is required", }, status=status.HTTP_400_BAD_REQUEST, ) @@ -672,30 +589,16 @@ class MapSearchAPIView(APIView): "name": park.name, "slug": park.slug, "latitude": ( - park.location.latitude - if hasattr(park, "location") and park.location - else None + park.location.latitude if hasattr(park, "location") and park.location else None ), "longitude": ( - park.location.longitude - if hasattr(park, "location") and park.location - else None + park.location.longitude if hasattr(park, "location") and park.location else None ), "location": { - "city": ( - park.location.city - if hasattr(park, "location") and park.location - else "" - ), - "state": ( - park.location.state - if hasattr(park, "location") and park.location - else "" - ), + "city": (park.location.city if hasattr(park, "location") and park.location else ""), + "state": (park.location.state if hasattr(park, "location") and park.location else ""), "country": ( - park.location.country - if hasattr(park, "location") and park.location - else "" + park.location.country if hasattr(park, "location") and park.location else "" ), }, "relevance_score": 1.0, # See FUTURE_WORK.md - THRILLWIKI-108 @@ -734,20 +637,17 @@ class MapSearchAPIView(APIView): "location": { "city": ( ride.park.location.city - if hasattr(ride.park, "location") - and ride.park.location + if hasattr(ride.park, "location") and ride.park.location else "" ), "state": ( ride.park.location.state - if hasattr(ride.park, "location") - and ride.park.location + if hasattr(ride.park, "location") and ride.park.location else "" ), "country": ( ride.park.location.country - if hasattr(ride.park, "location") - and ride.park.location + if hasattr(ride.park, "location") and ride.park.location else "" ), }, @@ -776,7 +676,7 @@ class MapSearchAPIView(APIView): except Exception as e: logger.error(f"Error in MapSearchAPIView: {str(e)}", exc_info=True) return Response( - {"status": "error", "message": "Search failed due to internal error"}, + {"status": "error", "detail": "Search failed due to internal error"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -848,8 +748,7 @@ class MapBoundsAPIView(APIView): if not all([north_str, south_str, east_str, west_str]): return Response( - {"status": "error", - "message": "All bounds parameters (north, south, east, west) are required"}, + {"status": "error", "detail": "All bounds parameters (north, south, east, west) are required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -860,7 +759,7 @@ class MapBoundsAPIView(APIView): west = float(west_str) if west_str else 0.0 except (TypeError, ValueError): return Response( - {"status": "error", "message": "Invalid bounds parameters"}, + {"status": "error", "detail": "Invalid bounds parameters"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -869,7 +768,7 @@ class MapBoundsAPIView(APIView): return Response( { "status": "error", - "message": "North bound must be greater than south bound", + "detail": "North bound must be greater than south bound", }, status=status.HTTP_400_BAD_REQUEST, ) @@ -878,7 +777,7 @@ class MapBoundsAPIView(APIView): return Response( { "status": "error", - "message": "West bound must be less than east bound", + "detail": "West bound must be less than east bound", }, status=status.HTTP_400_BAD_REQUEST, ) @@ -891,9 +790,7 @@ class MapBoundsAPIView(APIView): # Get parks within bounds if "park" in types: - parks_query = Park.objects.select_related("location").filter( - location__point__within=bounds_polygon - ) + parks_query = Park.objects.select_related("location").filter(location__point__within=bounds_polygon) for park in parks_query[:100]: # Limit results locations.append( @@ -903,14 +800,10 @@ class MapBoundsAPIView(APIView): "name": park.name, "slug": park.slug, "latitude": ( - park.location.latitude - if hasattr(park, "location") and park.location - else None + park.location.latitude if hasattr(park, "location") and park.location else None ), "longitude": ( - park.location.longitude - if hasattr(park, "location") and park.location - else None + park.location.longitude if hasattr(park, "location") and park.location else None ), "status": park.status, } @@ -960,7 +853,7 @@ class MapBoundsAPIView(APIView): except Exception as e: logger.error(f"Error in MapBoundsAPIView: {str(e)}", exc_info=True) return Response( - {"status": "error", "message": "Failed to retrieve locations within bounds"}, + {"status": "error", "detail": "Failed to retrieve locations within bounds"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -987,18 +880,15 @@ class MapStatsAPIView(APIView): """Get map service statistics and performance metrics.""" try: # Count locations with coordinates - parks_with_location = Park.objects.filter( - location__point__isnull=False - ).count() - rides_with_location = Ride.objects.filter( - park__location__point__isnull=False - ).count() + parks_with_location = Park.objects.filter(location__point__isnull=False).count() + rides_with_location = Ride.objects.filter(park__location__point__isnull=False).count() total_locations = parks_with_location + rides_with_location # Get cache statistics from apps.core.services.enhanced_cache_service import CacheMonitor + cache_monitor = CacheMonitor() - cache_stats = cache_monitor.get_cache_statistics('map_locations') + cache_stats = cache_monitor.get_cache_statistics("map_locations") return Response( { @@ -1006,17 +896,17 @@ class MapStatsAPIView(APIView): "total_locations": total_locations, "parks_with_location": parks_with_location, "rides_with_location": rides_with_location, - "cache_hits": cache_stats.get('hits', 0), - "cache_misses": cache_stats.get('misses', 0), - "cache_hit_rate": cache_stats.get('hit_rate', 0.0), - "cache_size": cache_stats.get('size', 0), + "cache_hits": cache_stats.get("hits", 0), + "cache_misses": cache_stats.get("misses", 0), + "cache_hit_rate": cache_stats.get("hit_rate", 0.0), + "cache_size": cache_stats.get("size", 0), } ) except Exception as e: logger.error(f"Error in MapStatsAPIView: {str(e)}", exc_info=True) return Response( - {"status": "error", "message": "Failed to retrieve map statistics"}, + {"status": "error", "detail": "Failed to retrieve map statistics"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -1060,7 +950,7 @@ class MapCacheAPIView(APIView): return Response( { "status": "success", - "message": f"Map cache cleared successfully. Cleared {cleared_count} entries.", + "detail": f"Map cache cleared successfully. Cleared {cleared_count} entries.", "cleared_count": cleared_count, } ) @@ -1068,7 +958,7 @@ class MapCacheAPIView(APIView): except Exception as e: logger.error(f"Error in MapCacheAPIView.delete: {str(e)}", exc_info=True) return Response( - {"status": "error", "message": "Failed to clear map cache"}, + {"status": "error", "detail": "Failed to clear map cache"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -1076,7 +966,7 @@ class MapCacheAPIView(APIView): """Invalidate specific cache entries.""" try: # Get cache keys to invalidate from request data - request_data = getattr(request, 'data', {}) + request_data = getattr(request, "data", {}) cache_keys = request_data.get("cache_keys", []) if request_data else [] if cache_keys: @@ -1088,7 +978,7 @@ class MapCacheAPIView(APIView): return Response( { "status": "success", - "message": f"Cache invalidated successfully. Invalidated {invalidated_count} entries.", + "detail": f"Cache invalidated successfully. Invalidated {invalidated_count} entries.", "invalidated_count": invalidated_count, } ) @@ -1096,7 +986,7 @@ class MapCacheAPIView(APIView): except Exception as e: logger.error(f"Error in MapCacheAPIView.post: {str(e)}", exc_info=True) return Response( - {"status": "error", "message": "Failed to invalidate cache"}, + {"status": "error", "detail": "Failed to invalidate cache"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) diff --git a/backend/apps/api/v1/middleware.py b/backend/apps/api/v1/middleware.py index b45dcb2b..4cc2335c 100644 --- a/backend/apps/api/v1/middleware.py +++ b/backend/apps/api/v1/middleware.py @@ -33,7 +33,7 @@ class ContractValidationMiddleware(MiddlewareMixin): def __init__(self, get_response): super().__init__(get_response) self.get_response = get_response - self.enabled = getattr(settings, 'DEBUG', False) + self.enabled = getattr(settings, "DEBUG", False) if self.enabled: logger.info("Contract validation middleware enabled (DEBUG mode)") @@ -45,11 +45,11 @@ class ContractValidationMiddleware(MiddlewareMixin): return response # Only validate API endpoints - if not request.path.startswith('/api/'): + if not request.path.startswith("/api/"): return response # Only validate JSON responses - if not isinstance(response, (JsonResponse, Response)): + if not isinstance(response, JsonResponse | Response): return response # Only validate successful responses (2xx status codes) @@ -58,7 +58,7 @@ class ContractValidationMiddleware(MiddlewareMixin): try: # Get response data - data = response.data if isinstance(response, Response) else json.loads(response.content.decode('utf-8')) + data = response.data if isinstance(response, Response) else json.loads(response.content.decode("utf-8")) # Validate the response self._validate_response_contract(request.path, data) @@ -68,11 +68,11 @@ class ContractValidationMiddleware(MiddlewareMixin): logger.warning( f"Contract validation error for {request.path}: {str(e)}", extra={ - 'path': request.path, - 'method': request.method, - 'status_code': response.status_code, - 'validation_error': str(e) - } + "path": request.path, + "method": request.method, + "status_code": response.status_code, + "validation_error": str(e), + }, ) return response @@ -81,15 +81,15 @@ class ContractValidationMiddleware(MiddlewareMixin): """Validate response data against expected contracts.""" # Check for filter metadata endpoints - if 'filter-options' in path or 'filter_options' in path: + if "filter-options" in path or "filter_options" in path: self._validate_filter_metadata(path, data) # Check for hybrid filtering endpoints - if 'hybrid' in path: + if "hybrid" in path: self._validate_hybrid_response(path, data) # Check for pagination responses - if isinstance(data, dict) and 'results' in data: + if isinstance(data, dict) and "results" in data: self._validate_pagination_response(path, data) # Check for common contract violations @@ -100,22 +100,20 @@ class ContractValidationMiddleware(MiddlewareMixin): if not isinstance(data, dict): self._log_contract_violation( - path, - "FILTER_METADATA_NOT_DICT", - f"Filter metadata should be a dictionary, got {type(data).__name__}" + path, "FILTER_METADATA_NOT_DICT", f"Filter metadata should be a dictionary, got {type(data).__name__}" ) return # Check for categorical filters - if 'categorical' in data: - categorical = data['categorical'] + if "categorical" in data: + categorical = data["categorical"] if isinstance(categorical, dict): for filter_name, filter_options in categorical.items(): self._validate_categorical_filter(path, filter_name, filter_options) # Check for ranges - if 'ranges' in data: - ranges = data['ranges'] + if "ranges" in data: + ranges = data["ranges"] if isinstance(ranges, dict): for range_name, range_data in ranges.items(): self._validate_range_filter(path, range_name, range_data) @@ -127,7 +125,7 @@ class ContractValidationMiddleware(MiddlewareMixin): self._log_contract_violation( path, "CATEGORICAL_FILTER_NOT_ARRAY", - f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}" + f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}", ) return @@ -138,28 +136,28 @@ class ContractValidationMiddleware(MiddlewareMixin): path, "CATEGORICAL_OPTION_IS_STRING", f"Categorical filter '{filter_name}' option {i} is a string '{option}' but should be an object with value/label/count properties", - severity="ERROR" + severity="ERROR", ) elif isinstance(option, dict): # Validate object structure - if 'value' not in option: + if "value" not in option: self._log_contract_violation( path, "MISSING_VALUE_PROPERTY", - f"Categorical filter '{filter_name}' option {i} missing 'value' property" + f"Categorical filter '{filter_name}' option {i} missing 'value' property", ) - if 'label' not in option: + if "label" not in option: self._log_contract_violation( path, "MISSING_LABEL_PROPERTY", - f"Categorical filter '{filter_name}' option {i} missing 'label' property" + f"Categorical filter '{filter_name}' option {i} missing 'label' property", ) # Count is optional but should be number if present - if 'count' in option and option['count'] is not None and not isinstance(option['count'], (int, float)): + if "count" in option and option["count"] is not None and not isinstance(option["count"], int | float): self._log_contract_violation( path, "INVALID_COUNT_TYPE", - f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}" + f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}", ) def _validate_range_filter(self, path: str, range_name: str, range_data: Any) -> None: @@ -169,26 +167,24 @@ class ContractValidationMiddleware(MiddlewareMixin): self._log_contract_violation( path, "RANGE_FILTER_NOT_OBJECT", - f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}" + f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}", ) return # Check required properties - required_props = ['min', 'max'] + required_props = ["min", "max"] for prop in required_props: if prop not in range_data: self._log_contract_violation( - path, - "MISSING_RANGE_PROPERTY", - f"Range filter '{range_name}' missing required property '{prop}'" + path, "MISSING_RANGE_PROPERTY", f"Range filter '{range_name}' missing required property '{prop}'" ) # Check step property - if 'step' in range_data and not isinstance(range_data['step'], (int, float)): + if "step" in range_data and not isinstance(range_data["step"], int | float): self._log_contract_violation( path, "INVALID_STEP_TYPE", - f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}" + f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}", ) def _validate_hybrid_response(self, path: str, data: Any) -> None: @@ -198,38 +194,36 @@ class ContractValidationMiddleware(MiddlewareMixin): return # Check for strategy field - if 'strategy' in data: - strategy = data['strategy'] - if strategy not in ['client_side', 'server_side']: + if "strategy" in data: + strategy = data["strategy"] + if strategy not in ["client_side", "server_side"]: self._log_contract_violation( path, "INVALID_STRATEGY_VALUE", - f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'" + f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'", ) # Check filter_metadata structure - if 'filter_metadata' in data: - self._validate_filter_metadata(path, data['filter_metadata']) + if "filter_metadata" in data: + self._validate_filter_metadata(path, data["filter_metadata"]) def _validate_pagination_response(self, path: str, data: dict[str, Any]) -> None: """Validate pagination response structure.""" # Check for required pagination fields - required_fields = ['count', 'results'] + required_fields = ["count", "results"] for field in required_fields: if field not in data: self._log_contract_violation( - path, - "MISSING_PAGINATION_FIELD", - f"Pagination response missing required field '{field}'" + path, "MISSING_PAGINATION_FIELD", f"Pagination response missing required field '{field}'" ) # Check results is array - if 'results' in data and not isinstance(data['results'], list): + if "results" in data and not isinstance(data["results"], list): self._log_contract_violation( path, "RESULTS_NOT_ARRAY", - f"Pagination 'results' should be an array, got {type(data['results']).__name__}" + f"Pagination 'results' should be an array, got {type(data['results']).__name__}", ) def _validate_common_patterns(self, path: str, data: Any) -> None: @@ -238,38 +232,32 @@ class ContractValidationMiddleware(MiddlewareMixin): if isinstance(data, dict): # Check for null vs undefined issues for key, value in data.items(): - if value is None and key.endswith('_id'): + if value is None and key.endswith("_id"): # ID fields should probably be null, not undefined continue # Check for numeric fields that might be strings - if key.endswith('_count') and isinstance(value, str): + if key.endswith("_count") and isinstance(value, str): try: int(value) self._log_contract_violation( path, "NUMERIC_FIELD_AS_STRING", - f"Field '{key}' appears to be numeric but is a string: '{value}'" + f"Field '{key}' appears to be numeric but is a string: '{value}'", ) except ValueError: pass - def _log_contract_violation( - self, - path: str, - violation_type: str, - message: str, - severity: str = "WARNING" - ) -> None: + def _log_contract_violation(self, path: str, violation_type: str, message: str, severity: str = "WARNING") -> None: """Log a contract violation with structured data.""" log_data = { - 'contract_violation': True, - 'violation_type': violation_type, - 'api_path': path, - 'severity': severity, - 'message': message, - 'suggestion': self._get_violation_suggestion(violation_type) + "contract_violation": True, + "violation_type": violation_type, + "api_path": path, + "severity": severity, + "message": message, + "suggestion": self._get_violation_suggestion(violation_type), } if severity == "ERROR": @@ -302,9 +290,8 @@ class ContractValidationMiddleware(MiddlewareMixin): "Check serializer field types and database field types." ), "RESULTS_NOT_ARRAY": ( - "Ensure pagination 'results' field is always an array. " - "Check serializer implementation." - ) + "Ensure pagination 'results' field is always an array. " "Check serializer implementation." + ), } return suggestions.get(violation_type, "Check the API response format against frontend TypeScript interfaces.") @@ -326,9 +313,9 @@ class ContractValidationSettings: # Paths to exclude from validation EXCLUDED_PATHS = [ - '/api/docs/', - '/api/schema/', - '/api/v1/auth/', # Auth endpoints might have different structures + "/api/docs/", + "/api/schema/", + "/api/v1/auth/", # Auth endpoints might have different structures ] @classmethod diff --git a/backend/apps/api/v1/parks/history_views.py b/backend/apps/api/v1/parks/history_views.py index 186c9ad7..98faf7de 100644 --- a/backend/apps/api/v1/parks/history_views.py +++ b/backend/apps/api/v1/parks/history_views.py @@ -17,6 +17,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet): """ ViewSet for retrieving park history. """ + permission_classes = [AllowAny] lookup_field = "slug" lookup_url_kwarg = "park_slug" @@ -40,12 +41,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet): "last_modified": events.first().pgh_created_at if len(events) else None, } - data = { - "park": park, - "current_state": park, - "summary": summary, - "events": events - } + data = {"park": park, "current_state": park, "summary": summary, "events": events} serializer = ParkHistoryOutputSerializer(data) return Response(serializer.data) @@ -55,6 +51,7 @@ class RideHistoryViewSet(viewsets.GenericViewSet): """ ViewSet for retrieving ride history. """ + permission_classes = [AllowAny] lookup_field = "slug" lookup_url_kwarg = "ride_slug" @@ -79,12 +76,7 @@ class RideHistoryViewSet(viewsets.GenericViewSet): "last_modified": events.first().pgh_created_at if len(events) else None, } - data = { - "ride": ride, - "current_state": ride, - "summary": summary, - "events": events - } + data = {"ride": ride, "current_state": ride, "summary": summary, "events": events} serializer = RideHistoryOutputSerializer(data) return Response(serializer.data) diff --git a/backend/apps/api/v1/parks/park_reviews_views.py b/backend/apps/api/v1/parks/park_reviews_views.py index bc4f99bf..6cd90ff3 100644 --- a/backend/apps/api/v1/parks/park_reviews_views.py +++ b/backend/apps/api/v1/parks/park_reviews_views.py @@ -65,14 +65,12 @@ class ParkReviewViewSet(ModelViewSet): def get_permissions(self): """Set permissions based on action.""" - permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated] + permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated] return [permission() for permission in permission_classes] def get_queryset(self): """Get reviews for the current park.""" - queryset = ParkReview.objects.select_related( - "park", "user", "user__profile" - ) + queryset = ParkReview.objects.select_related("park", "user", "user__profile") park_slug = self.kwargs.get("park_slug") if park_slug: @@ -82,7 +80,7 @@ class ParkReviewViewSet(ModelViewSet): except Park.DoesNotExist: return queryset.none() - if not (hasattr(self.request, 'user') and getattr(self.request.user, 'is_staff', False)): + if not (hasattr(self.request, "user") and getattr(self.request.user, "is_staff", False)): queryset = queryset.filter(is_published=True) return queryset.order_by("-created_at") @@ -102,16 +100,12 @@ class ParkReviewViewSet(ModelViewSet): try: park, _ = Park.get_by_slug(park_slug) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None if ParkReview.objects.filter(park=park, user=self.request.user).exists(): raise ValidationError("You have already reviewed this park") - serializer.save( - park=park, - user=self.request.user, - is_published=True - ) + serializer.save(park=park, user=self.request.user, is_published=True) def perform_update(self, serializer): instance = self.get_object() @@ -134,17 +128,18 @@ class ParkReviewViewSet(ModelViewSet): try: park, _ = Park.get_by_slug(park_slug) except Park.DoesNotExist: - return Response({"error": "Park not found"}, status=status.HTTP_404_NOT_FOUND) + return Response({"detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND) reviews = ParkReview.objects.filter(park=park, is_published=True) total_reviews = reviews.count() - avg_rating = reviews.aggregate(avg=Avg('rating'))['avg'] + avg_rating = reviews.aggregate(avg=Avg("rating"))["avg"] rating_distribution = {} for i in range(1, 11): rating_distribution[str(i)] = reviews.filter(rating=i).count() from datetime import timedelta + recent_reviews = reviews.filter(created_at__gte=timezone.now() - timedelta(days=30)).count() stats = { diff --git a/backend/apps/api/v1/parks/park_rides_views.py b/backend/apps/api/v1/parks/park_rides_views.py index fecd9f27..bd6585a8 100644 --- a/backend/apps/api/v1/parks/park_rides_views.py +++ b/backend/apps/api/v1/parks/park_rides_views.py @@ -21,6 +21,7 @@ from rest_framework.views import APIView try: from apps.parks.models import Park from apps.rides.models import Ride + MODELS_AVAILABLE = True except Exception: Park = None # type: ignore @@ -31,6 +32,7 @@ except Exception: try: from apps.api.v1.serializers.parks import ParkDetailOutputSerializer from apps.api.v1.serializers.rides import RideDetailOutputSerializer, RideListOutputSerializer + SERIALIZERS_AVAILABLE = True except Exception: SERIALIZERS_AVAILABLE = False @@ -52,22 +54,41 @@ class ParkRidesListAPIView(APIView): description="Get paginated list of rides at a specific park with filtering options", parameters=[ # Pagination - OpenApiParameter(name="page", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Page number"), - OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Number of results per page (max 100)"), - + OpenApiParameter( + name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, description="Page number" + ), + OpenApiParameter( + name="page_size", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Number of results per page (max 100)", + ), # Filtering - OpenApiParameter(name="category", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by ride category"), - OpenApiParameter(name="status", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by operational status"), - OpenApiParameter(name="search", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Search rides by name"), - + OpenApiParameter( + name="category", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by ride category", + ), + OpenApiParameter( + name="status", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by operational status", + ), + OpenApiParameter( + name="search", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Search rides by name", + ), # Ordering - OpenApiParameter(name="ordering", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Order results by field"), + OpenApiParameter( + name="ordering", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Order results by field", + ), ], responses={ 200: OpenApiTypes.OBJECT, @@ -87,12 +108,14 @@ class ParkRidesListAPIView(APIView): try: park, is_historical = Park.get_by_slug(park_slug) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None # Get rides for this park - qs = Ride.objects.filter(park=park).select_related( - "manufacturer", "designer", "ride_model", "park_area" - ).prefetch_related("photos") + qs = ( + Ride.objects.filter(park=park) + .select_related("manufacturer", "designer", "ride_model", "park_area") + .prefetch_related("photos") + ) # Apply filtering qs = self._apply_filters(qs, request.query_params) @@ -107,9 +130,7 @@ class ParkRidesListAPIView(APIView): page = paginator.paginate_queryset(qs, request) if SERIALIZERS_AVAILABLE: - serializer = RideListOutputSerializer( - page, many=True, context={"request": request, "park": park} - ) + serializer = RideListOutputSerializer(page, many=True, context={"request": request, "park": park}) return paginator.get_paginated_response(serializer.data) else: # Fallback serialization @@ -145,9 +166,7 @@ class ParkRidesListAPIView(APIView): search = params.get("search") if search: qs = qs.filter( - Q(name__icontains=search) | - Q(description__icontains=search) | - Q(manufacturer__name__icontains=search) + Q(name__icontains=search) | Q(description__icontains=search) | Q(manufacturer__name__icontains=search) ) return qs @@ -179,42 +198,46 @@ class ParkRideDetailAPIView(APIView): try: park, is_historical = Park.get_by_slug(park_slug) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None # Get the ride try: ride, is_historical = Ride.get_by_slug(ride_slug, park=park) except Ride.DoesNotExist: - raise NotFound("Ride not found at this park") + raise NotFound("Ride not found at this park") from None # Ensure ride belongs to this park if ride.park_id != park.id: raise NotFound("Ride not found at this park") if SERIALIZERS_AVAILABLE: - serializer = RideDetailOutputSerializer( - ride, context={"request": request, "park": park} - ) + serializer = RideDetailOutputSerializer(ride, context={"request": request, "park": park}) return Response(serializer.data) else: # Fallback serialization - return Response({ - "id": ride.id, - "name": ride.name, - "slug": ride.slug, - "description": getattr(ride, "description", ""), - "category": getattr(ride, "category", ""), - "status": getattr(ride, "status", ""), - "park": { - "id": park.id, - "name": park.name, - "slug": park.slug, - }, - "manufacturer": { - "name": ride.manufacturer.name if ride.manufacturer else "", - "slug": getattr(ride.manufacturer, "slug", "") if ride.manufacturer else "", - } if ride.manufacturer else None, - }) + return Response( + { + "id": ride.id, + "name": ride.name, + "slug": ride.slug, + "description": getattr(ride, "description", ""), + "category": getattr(ride, "category", ""), + "status": getattr(ride, "status", ""), + "park": { + "id": park.id, + "name": park.name, + "slug": park.slug, + }, + "manufacturer": ( + { + "name": ride.manufacturer.name if ride.manufacturer else "", + "slug": getattr(ride.manufacturer, "slug", "") if ride.manufacturer else "", + } + if ride.manufacturer + else None + ), + } + ) class ParkComprehensiveDetailAPIView(APIView): @@ -243,25 +266,21 @@ class ParkComprehensiveDetailAPIView(APIView): try: park, is_historical = Park.get_by_slug(park_slug) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None # Get park with full related data - park = Park.objects.select_related( - "operator", "property_owner", "location" - ).prefetch_related( - "areas", "rides", "photos" - ).get(pk=park.pk) + park = ( + Park.objects.select_related("operator", "property_owner", "location") + .prefetch_related("areas", "rides", "photos") + .get(pk=park.pk) + ) # Get a sample of rides (first 10) for preview - rides_sample = Ride.objects.filter(park=park).select_related( - "manufacturer", "designer", "ride_model" - )[:10] + rides_sample = Ride.objects.filter(park=park).select_related("manufacturer", "designer", "ride_model")[:10] if SERIALIZERS_AVAILABLE: # Get full park details - park_serializer = ParkDetailOutputSerializer( - park, context={"request": request} - ) + park_serializer = ParkDetailOutputSerializer(park, context={"request": request}) park_data = park_serializer.data # Add rides summary @@ -279,25 +298,27 @@ class ParkComprehensiveDetailAPIView(APIView): return Response(park_data) else: # Fallback serialization - return Response({ - "id": park.id, - "name": park.name, - "slug": park.slug, - "description": getattr(park, "description", ""), - "location": str(getattr(park, "location", "")), - "operator": getattr(park.operator, "name", "") if hasattr(park, "operator") else "", - "ride_count": getattr(park, "ride_count", 0), - "rides_summary": { - "total_count": getattr(park, "ride_count", 0), - "sample": [ - { - "id": ride.id, - "name": ride.name, - "slug": ride.slug, - "category": getattr(ride, "category", ""), - } - for ride in rides_sample - ], - "full_list_url": f"/api/v1/parks/{park_slug}/rides/", - }, - }) + return Response( + { + "id": park.id, + "name": park.name, + "slug": park.slug, + "description": getattr(park, "description", ""), + "location": str(getattr(park, "location", "")), + "operator": getattr(park.operator, "name", "") if hasattr(park, "operator") else "", + "ride_count": getattr(park, "ride_count", 0), + "rides_summary": { + "total_count": getattr(park, "ride_count", 0), + "sample": [ + { + "id": ride.id, + "name": ride.name, + "slug": ride.slug, + "category": getattr(ride, "category", ""), + } + for ride in rides_sample + ], + "full_list_url": f"/api/v1/parks/{park_slug}/rides/", + }, + } + ) diff --git a/backend/apps/api/v1/parks/park_views.py b/backend/apps/api/v1/parks/park_views.py index baf11a04..d9c8affc 100644 --- a/backend/apps/api/v1/parks/park_views.py +++ b/backend/apps/api/v1/parks/park_views.py @@ -29,6 +29,7 @@ from rest_framework.views import APIView # Import models try: from apps.parks.models import Company, Park + MODELS_AVAILABLE = True except Exception: Park = None # type: ignore @@ -38,6 +39,7 @@ except Exception: # Import ModelChoices for filter options try: from apps.api.v1.serializers.shared import ModelChoices + HAVE_MODELCHOICES = True except Exception: ModelChoices = None # type: ignore @@ -52,6 +54,7 @@ try: ParkListOutputSerializer, ParkUpdateInputSerializer, ) + SERIALIZERS_AVAILABLE = True except Exception: SERIALIZERS_AVAILABLE = False @@ -72,80 +75,152 @@ class ParkListCreateAPIView(APIView): description="List parks with comprehensive filtering matching frontend API documentation. Supports all 24 filtering parameters including continent, rating ranges, ride counts, and more.", parameters=[ # Pagination - OpenApiParameter(name="page", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Page number"), - OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Number of results per page"), - + OpenApiParameter( + name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, description="Page number" + ), + OpenApiParameter( + name="page_size", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Number of results per page", + ), # Search - OpenApiParameter(name="search", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Search parks by name"), - + OpenApiParameter( + name="search", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Search parks by name", + ), # Location filters - OpenApiParameter(name="continent", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by continent"), - OpenApiParameter(name="country", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by country"), - OpenApiParameter(name="state", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by state/province"), - OpenApiParameter(name="city", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by city"), - + OpenApiParameter( + name="continent", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by continent", + ), + OpenApiParameter( + name="country", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, description="Filter by country" + ), + OpenApiParameter( + name="state", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by state/province", + ), + OpenApiParameter( + name="city", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, description="Filter by city" + ), # Park attributes - OpenApiParameter(name="park_type", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by park type"), - OpenApiParameter(name="status", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by operational status"), - + OpenApiParameter( + name="park_type", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by park type", + ), + OpenApiParameter( + name="status", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by operational status", + ), # Company filters - OpenApiParameter(name="operator_id", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Filter by operator company ID"), - OpenApiParameter(name="operator_slug", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by operator company slug"), - OpenApiParameter(name="property_owner_id", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Filter by property owner company ID"), - OpenApiParameter(name="property_owner_slug", location=OpenApiParameter.QUERY, - type=OpenApiTypes.STR, description="Filter by property owner company slug"), - + OpenApiParameter( + name="operator_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Filter by operator company ID", + ), + OpenApiParameter( + name="operator_slug", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by operator company slug", + ), + OpenApiParameter( + name="property_owner_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Filter by property owner company ID", + ), + OpenApiParameter( + name="property_owner_slug", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Filter by property owner company slug", + ), # Rating filters - OpenApiParameter(name="min_rating", location=OpenApiParameter.QUERY, - type=OpenApiTypes.NUMBER, description="Minimum average rating"), - OpenApiParameter(name="max_rating", location=OpenApiParameter.QUERY, - type=OpenApiTypes.NUMBER, description="Maximum average rating"), - + OpenApiParameter( + name="min_rating", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.NUMBER, + description="Minimum average rating", + ), + OpenApiParameter( + name="max_rating", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.NUMBER, + description="Maximum average rating", + ), # Ride count filters - OpenApiParameter(name="min_ride_count", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Minimum total ride count"), - OpenApiParameter(name="max_ride_count", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Maximum total ride count"), - + OpenApiParameter( + name="min_ride_count", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Minimum total ride count", + ), + OpenApiParameter( + name="max_ride_count", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Maximum total ride count", + ), # Opening year filters - OpenApiParameter(name="opening_year", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Filter by specific opening year"), - OpenApiParameter(name="min_opening_year", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Minimum opening year"), - OpenApiParameter(name="max_opening_year", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Maximum opening year"), - + OpenApiParameter( + name="opening_year", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Filter by specific opening year", + ), + OpenApiParameter( + name="min_opening_year", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Minimum opening year", + ), + OpenApiParameter( + name="max_opening_year", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Maximum opening year", + ), # Roller coaster filters - OpenApiParameter(name="has_roller_coasters", location=OpenApiParameter.QUERY, - type=OpenApiTypes.BOOL, description="Filter parks that have roller coasters"), - OpenApiParameter(name="min_roller_coaster_count", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Minimum roller coaster count"), - OpenApiParameter(name="max_roller_coaster_count", location=OpenApiParameter.QUERY, - type=OpenApiTypes.INT, description="Maximum roller coaster count"), - + OpenApiParameter( + name="has_roller_coasters", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.BOOL, + description="Filter parks that have roller coasters", + ), + OpenApiParameter( + name="min_roller_coaster_count", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Minimum roller coaster count", + ), + OpenApiParameter( + name="max_roller_coaster_count", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + description="Maximum roller coaster count", + ), # Ordering - OpenApiParameter(name="ordering", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, - description="Order results by field (prefix with - for descending)"), + OpenApiParameter( + name="ordering", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Order results by field (prefix with - for descending)", + ), ], - responses={ - 200: ( - "ParkListOutputSerializer(many=True)" - if SERIALIZERS_AVAILABLE - else OpenApiTypes.OBJECT - ) - }, + responses={200: ("ParkListOutputSerializer(many=True)" if SERIALIZERS_AVAILABLE else OpenApiTypes.OBJECT)}, tags=["Parks"], ) def get(self, request: Request) -> Response: @@ -163,13 +238,15 @@ class ParkListCreateAPIView(APIView): ) # Start with base queryset - qs = Park.objects.all().select_related( - "operator", "property_owner", "location" - ).prefetch_related("rides").annotate( - ride_count_calculated=Count('rides'), - roller_coaster_count_calculated=Count( - 'rides', filter=Q(rides__category='RC')), - average_rating_calculated=Avg('reviews__rating') + qs = ( + Park.objects.all() + .select_related("operator", "property_owner", "location") + .prefetch_related("rides") + .annotate( + ride_count_calculated=Count("rides"), + roller_coaster_count_calculated=Count("rides", filter=Q(rides__category="RC")), + average_rating_calculated=Avg("reviews__rating"), + ) ) # Apply comprehensive filtering @@ -185,9 +262,7 @@ class ParkListCreateAPIView(APIView): page = paginator.paginate_queryset(qs, request) if SERIALIZERS_AVAILABLE: - serializer = ParkListOutputSerializer( - page, many=True, context={"request": request} - ) + serializer = ParkListOutputSerializer(page, many=True, context={"request": request}) return paginator.get_paginated_response(serializer.data) else: # Fallback serialization @@ -232,21 +307,21 @@ class ParkListCreateAPIView(APIView): search = params.get("search") if search: qs = qs.filter( - Q(name__icontains=search) | - Q(description__icontains=search) | - Q(location__city__icontains=search) | - Q(location__state__icontains=search) | - Q(location__country__icontains=search) + Q(name__icontains=search) + | Q(description__icontains=search) + | Q(location__city__icontains=search) + | Q(location__state__icontains=search) + | Q(location__country__icontains=search) ) return qs def _apply_location_filters(self, qs: QuerySet, params: dict) -> QuerySet: """Apply location-based filtering to the queryset.""" location_filters = { - 'country': 'location__country__iexact', - 'state': 'location__state__iexact', - 'city': 'location__city__iexact', - 'continent': 'location__continent__iexact' + "country": "location__country__iexact", + "state": "location__state__iexact", + "city": "location__city__iexact", + "continent": "location__continent__iexact", } for param_name, filter_field in location_filters.items(): @@ -271,10 +346,10 @@ class ParkListCreateAPIView(APIView): def _apply_company_filters(self, qs: QuerySet, params: dict) -> QuerySet: """Apply company-related filtering to the queryset.""" company_filters = { - 'operator_id': 'operator_id', - 'operator_slug': 'operator__slug', - 'property_owner_id': 'property_owner_id', - 'property_owner_slug': 'property_owner__slug' + "operator_id": "operator_id", + "operator_slug": "operator__slug", + "property_owner_id": "property_owner_id", + "property_owner_slug": "property_owner__slug", } for param_name, filter_field in company_filters.items(): @@ -335,9 +410,9 @@ class ParkListCreateAPIView(APIView): """Apply roller coaster filtering to the queryset.""" has_roller_coasters = params.get("has_roller_coasters") if has_roller_coasters is not None: - if has_roller_coasters.lower() in ['true', '1', 'yes']: + if has_roller_coasters.lower() in ["true", "1", "yes"]: qs = qs.filter(coaster_count__gt=0) - elif has_roller_coasters.lower() in ['false', '0', 'no']: + elif has_roller_coasters.lower() in ["false", "0", "no"]: qs = qs.filter(coaster_count=0) min_roller_coaster_count = params.get("min_roller_coaster_count") @@ -355,13 +430,7 @@ class ParkListCreateAPIView(APIView): @extend_schema( summary="Create a new park", description="Create a new park.", - responses={ - 201: ( - "ParkDetailOutputSerializer()" - if SERIALIZERS_AVAILABLE - else OpenApiTypes.OBJECT - ) - }, + responses={201: ("ParkDetailOutputSerializer()" if SERIALIZERS_AVAILABLE else OpenApiTypes.OBJECT)}, tags=["Parks"], ) def post(self, request: Request) -> Response: @@ -408,11 +477,7 @@ class ParkListCreateAPIView(APIView): summary="Retrieve, update or delete a park by ID or slug", description="Retrieve full park details including location, photos, areas, rides, and company information. Supports both ID and slug-based lookup with historical slug support.", responses={ - 200: ( - "ParkDetailOutputSerializer()" - if SERIALIZERS_AVAILABLE - else OpenApiTypes.OBJECT - ), + 200: ("ParkDetailOutputSerializer()" if SERIALIZERS_AVAILABLE else OpenApiTypes.OBJECT), 404: OpenApiTypes.OBJECT, }, tags=["Parks"], @@ -423,36 +488,34 @@ class ParkDetailAPIView(APIView): def _get_park_or_404(self, identifier: str) -> Any: if not MODELS_AVAILABLE: raise NotFound( - - "Park detail is not available because domain models " - "are not imported. Implement apps.parks.models.Park " - "to enable detail endpoints." - + "Park detail is not available because domain models " + "are not imported. Implement apps.parks.models.Park " + "to enable detail endpoints." ) # Try to parse as integer ID first try: pk = int(identifier) try: - return Park.objects.select_related( - "operator", "property_owner", "location" - ).prefetch_related( - "areas", "rides", "photos" - ).get(pk=pk) + return ( + Park.objects.select_related("operator", "property_owner", "location") + .prefetch_related("areas", "rides", "photos") + .get(pk=pk) + ) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None except ValueError: # Not an integer, try slug lookup try: park, is_historical = Park.get_by_slug(identifier) # Ensure we have the full related data - return Park.objects.select_related( - "operator", "property_owner", "location" - ).prefetch_related( - "areas", "rides", "photos" - ).get(pk=park.pk) + return ( + Park.objects.select_related("operator", "property_owner", "location") + .prefetch_related("areas", "rides", "photos") + .get(pk=park.pk) + ) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None @extend_schema( summary="Get park full details", @@ -491,11 +554,7 @@ class ParkDetailAPIView(APIView): **No Query Parameters Required** - This endpoint returns full details by default. """, responses={ - 200: ( - "ParkDetailOutputSerializer()" - if SERIALIZERS_AVAILABLE - else OpenApiTypes.OBJECT - ), + 200: ("ParkDetailOutputSerializer()" if SERIALIZERS_AVAILABLE else OpenApiTypes.OBJECT), 404: OpenApiTypes.OBJECT, }, ) @@ -513,11 +572,7 @@ class ParkDetailAPIView(APIView): "slug": getattr(park, "slug", ""), "description": getattr(park, "description", ""), "location": str(getattr(park, "location", "")), - "operator": ( - getattr(park.operator, "name", "") - if hasattr(park, "operator") - else "" - ), + "operator": (getattr(park.operator, "name", "") if hasattr(park, "operator") else ""), } ) @@ -534,12 +589,7 @@ class ParkDetailAPIView(APIView): if not MODELS_AVAILABLE: return Response( - { - "detail": ( - "Park update is not available because domain models " - "are not imported." - ) - }, + {"detail": ("Park update is not available because domain models " "are not imported.")}, status=status.HTTP_501_NOT_IMPLEMENTED, ) for key, value in serializer_in.validated_data.items(): @@ -555,12 +605,7 @@ class ParkDetailAPIView(APIView): def delete(self, request: Request, pk: str) -> Response: if not MODELS_AVAILABLE: return Response( - { - "detail": ( - "Park delete is not available because domain models " - "are not imported." - ) - }, + {"detail": ("Park delete is not available because domain models " "are not imported.")}, status=status.HTTP_501_NOT_IMPLEMENTED, ) park = self._get_park_or_404(pk) @@ -583,8 +628,8 @@ class FilterOptionsAPIView(APIView): from apps.core.choices.registry import get_choices # Always get static choice definitions from Rich Choice Objects (primary source) - park_types = get_choices('types', 'parks') - statuses = get_choices('statuses', 'parks') + park_types = get_choices("types", "parks") + statuses = get_choices("statuses", "parks") # Convert Rich Choice Objects to frontend format with metadata park_types_data = [ @@ -592,10 +637,10 @@ class FilterOptionsAPIView(APIView): "value": choice.value, "label": choice.label, "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), } for choice in park_types ] @@ -605,10 +650,10 @@ class FilterOptionsAPIView(APIView): "value": choice.value, "label": choice.label, "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), } for choice in statuses ] @@ -618,18 +663,11 @@ class FilterOptionsAPIView(APIView): # Add any dynamic data queries here pass - return Response({ + return Response( + { "park_types": park_types_data, "statuses": statuses_data, - "continents": [ - "North America", - "South America", - "Europe", - "Asia", - "Africa", - "Australia", - "Antarctica" - ], + "continents": ["North America", "South America", "Europe", "Asia", "Africa", "Australia", "Antarctica"], "countries": [ "United States", "Canada", @@ -638,22 +676,10 @@ class FilterOptionsAPIView(APIView): "France", "Japan", "Australia", - "Brazil" - ], - "states": [ - "California", - "Florida", - "Ohio", - "Pennsylvania", - "Texas", - "New York" - ], - "cities": [ - "Orlando", - "Los Angeles", - "Cedar Point", - "Sandusky" + "Brazil", ], + "states": ["California", "Florida", "Ohio", "Pennsylvania", "Texas", "New York"], + "cities": ["Orlando", "Los Angeles", "Cedar Point", "Sandusky"], "operators": [], "property_owners": [], "ranges": { @@ -676,20 +702,19 @@ class FilterOptionsAPIView(APIView): {"value": "-average_rating", "label": "Rating (High to Low)"}, {"value": "size_acres", "label": "Size (Small to Large)"}, {"value": "-size_acres", "label": "Size (Large to Small)"}, - {"value": "created_at", - "label": "Added to Database (Oldest First)"}, - {"value": "-created_at", - "label": "Added to Database (Newest First)"}, + {"value": "created_at", "label": "Added to Database (Oldest First)"}, + {"value": "-created_at", "label": "Added to Database (Newest First)"}, {"value": "updated_at", "label": "Last Updated (Oldest First)"}, {"value": "-updated_at", "label": "Last Updated (Newest First)"}, ], - }) + } + ) # Try to get dynamic options from database using Rich Choice Objects try: # Get rich choice objects from registry - park_types = get_choices('types', 'parks') - statuses = get_choices('statuses', 'parks') + park_types = get_choices("types", "parks") + statuses = get_choices("statuses", "parks") # Convert Rich Choice Objects to frontend format with metadata park_types_data = [ @@ -697,10 +722,10 @@ class FilterOptionsAPIView(APIView): "value": choice.value, "label": choice.label, "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), } for choice in park_types ] @@ -710,213 +735,191 @@ class FilterOptionsAPIView(APIView): "value": choice.value, "label": choice.label, "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), } for choice in statuses ] # Get location data from database - continents = list(Park.objects.exclude( - location__continent__isnull=True - ).exclude( - location__continent__exact='' - ).values_list('location__continent', flat=True).distinct().order_by('location__continent')) + continents = list( + Park.objects.exclude(location__continent__isnull=True) + .exclude(location__continent__exact="") + .values_list("location__continent", flat=True) + .distinct() + .order_by("location__continent") + ) # Fallback to static list if no continents in database if not continents: - continents = [ - "North America", - "South America", - "Europe", - "Asia", - "Africa", - "Australia", - "Antarctica" - ] + continents = ["North America", "South America", "Europe", "Asia", "Africa", "Australia", "Antarctica"] - countries = list(Park.objects.exclude( - location__country__isnull=True - ).exclude( - location__country__exact='' - ).values_list('location__country', flat=True).distinct().order_by('location__country')) + countries = list( + Park.objects.exclude(location__country__isnull=True) + .exclude(location__country__exact="") + .values_list("location__country", flat=True) + .distinct() + .order_by("location__country") + ) - states = list(Park.objects.exclude( - location__state__isnull=True - ).exclude( - location__state__exact='' - ).values_list('location__state', flat=True).distinct().order_by('location__state')) + states = list( + Park.objects.exclude(location__state__isnull=True) + .exclude(location__state__exact="") + .values_list("location__state", flat=True) + .distinct() + .order_by("location__state") + ) - cities = list(Park.objects.exclude( - location__city__isnull=True - ).exclude( - location__city__exact='' - ).values_list('location__city', flat=True).distinct().order_by('location__city')) + cities = list( + Park.objects.exclude(location__city__isnull=True) + .exclude(location__city__exact="") + .values_list("location__city", flat=True) + .distinct() + .order_by("location__city") + ) # Get operators and property owners - operators = list(Company.objects.filter( - roles__contains=['OPERATOR'] - ).values('id', 'name', 'slug').order_by('name')) + operators = list( + Company.objects.filter(roles__contains=["OPERATOR"]).values("id", "name", "slug").order_by("name") + ) - property_owners = list(Company.objects.filter( - roles__contains=['PROPERTY_OWNER'] - ).values('id', 'name', 'slug').order_by('name')) + property_owners = list( + Company.objects.filter(roles__contains=["PROPERTY_OWNER"]).values("id", "name", "slug").order_by("name") + ) # Calculate ranges from actual data park_stats = Park.objects.aggregate( - min_rating=models.Min('average_rating'), - max_rating=models.Max('average_rating'), - min_ride_count=models.Min('ride_count'), - max_ride_count=models.Max('ride_count'), - min_coaster_count=models.Min('coaster_count'), - max_coaster_count=models.Max('coaster_count'), - min_size=models.Min('size_acres'), - max_size=models.Max('size_acres'), - min_year=models.Min('opening_date__year'), - max_year=models.Max('opening_date__year'), + min_rating=models.Min("average_rating"), + max_rating=models.Max("average_rating"), + min_ride_count=models.Min("ride_count"), + max_ride_count=models.Max("ride_count"), + min_coaster_count=models.Min("coaster_count"), + max_coaster_count=models.Max("coaster_count"), + min_size=models.Min("size_acres"), + max_size=models.Max("size_acres"), + min_year=models.Min("opening_date__year"), + max_year=models.Max("opening_date__year"), ) ranges = { "rating": { - "min": float(park_stats['min_rating'] or 1), - "max": float(park_stats['max_rating'] or 10), + "min": float(park_stats["min_rating"] or 1), + "max": float(park_stats["max_rating"] or 10), "step": 0.1, - "unit": "stars" + "unit": "stars", }, "ride_count": { - "min": park_stats['min_ride_count'] or 0, - "max": park_stats['max_ride_count'] or 100, + "min": park_stats["min_ride_count"] or 0, + "max": park_stats["max_ride_count"] or 100, "step": 1, - "unit": "rides" + "unit": "rides", }, "coaster_count": { - "min": park_stats['min_coaster_count'] or 0, - "max": park_stats['max_coaster_count'] or 50, + "min": park_stats["min_coaster_count"] or 0, + "max": park_stats["max_coaster_count"] or 50, "step": 1, - "unit": "coasters" + "unit": "coasters", }, "size_acres": { - "min": float(park_stats['min_size'] or 0), - "max": float(park_stats['max_size'] or 10000), + "min": float(park_stats["min_size"] or 0), + "max": float(park_stats["max_size"] or 10000), "step": 1, - "unit": "acres" + "unit": "acres", }, "opening_year": { - "min": park_stats['min_year'] or 1800, - "max": park_stats['max_year'] or 2030, + "min": park_stats["min_year"] or 1800, + "max": park_stats["max_year"] or 2030, "step": 1, - "unit": "year" + "unit": "year", }, } - return Response({ - "park_types": park_types_data, - "statuses": statuses_data, - "continents": continents, - "countries": countries, - "states": states, - "cities": cities, - "operators": operators, - "property_owners": property_owners, - "ranges": ranges, - "ordering_options": [ - {"value": "name", "label": "Name (A-Z)"}, - {"value": "-name", "label": "Name (Z-A)"}, - {"value": "opening_date", "label": "Opening Date (Oldest First)"}, - {"value": "-opening_date", "label": "Opening Date (Newest First)"}, - {"value": "ride_count", "label": "Ride Count (Low to High)"}, - {"value": "-ride_count", "label": "Ride Count (High to Low)"}, - {"value": "coaster_count", "label": "Coaster Count (Low to High)"}, - {"value": "-coaster_count", "label": "Coaster Count (High to Low)"}, - {"value": "average_rating", "label": "Rating (Low to High)"}, - {"value": "-average_rating", "label": "Rating (High to Low)"}, - {"value": "size_acres", "label": "Size (Small to Large)"}, - {"value": "-size_acres", "label": "Size (Large to Small)"}, - {"value": "created_at", - "label": "Added to Database (Oldest First)"}, - {"value": "-created_at", - "label": "Added to Database (Newest First)"}, - {"value": "updated_at", "label": "Last Updated (Oldest First)"}, - {"value": "-updated_at", "label": "Last Updated (Newest First)"}, - ], - }) + return Response( + { + "park_types": park_types_data, + "statuses": statuses_data, + "continents": continents, + "countries": countries, + "states": states, + "cities": cities, + "operators": operators, + "property_owners": property_owners, + "ranges": ranges, + "ordering_options": [ + {"value": "name", "label": "Name (A-Z)"}, + {"value": "-name", "label": "Name (Z-A)"}, + {"value": "opening_date", "label": "Opening Date (Oldest First)"}, + {"value": "-opening_date", "label": "Opening Date (Newest First)"}, + {"value": "ride_count", "label": "Ride Count (Low to High)"}, + {"value": "-ride_count", "label": "Ride Count (High to Low)"}, + {"value": "coaster_count", "label": "Coaster Count (Low to High)"}, + {"value": "-coaster_count", "label": "Coaster Count (High to Low)"}, + {"value": "average_rating", "label": "Rating (Low to High)"}, + {"value": "-average_rating", "label": "Rating (High to Low)"}, + {"value": "size_acres", "label": "Size (Small to Large)"}, + {"value": "-size_acres", "label": "Size (Large to Small)"}, + {"value": "created_at", "label": "Added to Database (Oldest First)"}, + {"value": "-created_at", "label": "Added to Database (Newest First)"}, + {"value": "updated_at", "label": "Last Updated (Oldest First)"}, + {"value": "-updated_at", "label": "Last Updated (Newest First)"}, + ], + } + ) except Exception: # Fallback to static options if database query fails - return Response({ - "park_types": [ - {"value": "THEME_PARK", "label": "Theme Park"}, - {"value": "AMUSEMENT_PARK", "label": "Amusement Park"}, - {"value": "WATER_PARK", "label": "Water Park"}, - {"value": "FAMILY_ENTERTAINMENT_CENTER", - "label": "Family Entertainment Center"}, - {"value": "CARNIVAL", "label": "Carnival"}, - {"value": "FAIR", "label": "Fair"}, - {"value": "PIER", "label": "Pier"}, - {"value": "BOARDWALK", "label": "Boardwalk"}, - {"value": "SAFARI_PARK", "label": "Safari Park"}, - {"value": "ZOO", "label": "Zoo"}, - {"value": "OTHER", "label": "Other"}, - ], - "statuses": [ - {"value": "OPERATING", "label": "Operating"}, - {"value": "CLOSED_TEMP", "label": "Temporarily Closed"}, - {"value": "CLOSED_PERM", "label": "Permanently Closed"}, - {"value": "UNDER_CONSTRUCTION", "label": "Under Construction"}, - {"value": "DEMOLISHED", "label": "Demolished"}, - {"value": "RELOCATED", "label": "Relocated"}, - ], - "continents": [ - "North America", - "South America", - "Europe", - "Asia", - "Africa", - "Australia" - ], - "countries": [ - "United States", - "Canada", - "United Kingdom", - "Germany", - "France", - "Japan" - ], - "states": [ - "California", - "Florida", - "Ohio", - "Pennsylvania" - ], - "cities": [ - "Orlando", - "Los Angeles", - "Cedar Point" - ], - "operators": [], - "property_owners": [], - "ranges": { - "rating": {"min": 1, "max": 10, "step": 0.1, "unit": "stars"}, - "ride_count": {"min": 0, "max": 100, "step": 1, "unit": "rides"}, - "coaster_count": {"min": 0, "max": 50, "step": 1, "unit": "coasters"}, - "size_acres": {"min": 0, "max": 10000, "step": 1, "unit": "acres"}, - "opening_year": {"min": 1800, "max": 2030, "step": 1, "unit": "year"}, - }, - "ordering_options": [ - {"value": "name", "label": "Name (A-Z)"}, - {"value": "-name", "label": "Name (Z-A)"}, - {"value": "opening_date", "label": "Opening Date (Oldest First)"}, - {"value": "-opening_date", "label": "Opening Date (Newest First)"}, - {"value": "ride_count", "label": "Ride Count (Low to High)"}, - {"value": "-ride_count", "label": "Ride Count (High to Low)"}, - {"value": "coaster_count", "label": "Coaster Count (Low to High)"}, - {"value": "-coaster_count", "label": "Coaster Count (High to Low)"}, - {"value": "average_rating", "label": "Rating (Low to High)"}, - {"value": "-average_rating", "label": "Rating (High to Low)"}, - ], - }) + return Response( + { + "park_types": [ + {"value": "THEME_PARK", "label": "Theme Park"}, + {"value": "AMUSEMENT_PARK", "label": "Amusement Park"}, + {"value": "WATER_PARK", "label": "Water Park"}, + {"value": "FAMILY_ENTERTAINMENT_CENTER", "label": "Family Entertainment Center"}, + {"value": "CARNIVAL", "label": "Carnival"}, + {"value": "FAIR", "label": "Fair"}, + {"value": "PIER", "label": "Pier"}, + {"value": "BOARDWALK", "label": "Boardwalk"}, + {"value": "SAFARI_PARK", "label": "Safari Park"}, + {"value": "ZOO", "label": "Zoo"}, + {"value": "OTHER", "label": "Other"}, + ], + "statuses": [ + {"value": "OPERATING", "label": "Operating"}, + {"value": "CLOSED_TEMP", "label": "Temporarily Closed"}, + {"value": "CLOSED_PERM", "label": "Permanently Closed"}, + {"value": "UNDER_CONSTRUCTION", "label": "Under Construction"}, + {"value": "DEMOLISHED", "label": "Demolished"}, + {"value": "RELOCATED", "label": "Relocated"}, + ], + "continents": ["North America", "South America", "Europe", "Asia", "Africa", "Australia"], + "countries": ["United States", "Canada", "United Kingdom", "Germany", "France", "Japan"], + "states": ["California", "Florida", "Ohio", "Pennsylvania"], + "cities": ["Orlando", "Los Angeles", "Cedar Point"], + "operators": [], + "property_owners": [], + "ranges": { + "rating": {"min": 1, "max": 10, "step": 0.1, "unit": "stars"}, + "ride_count": {"min": 0, "max": 100, "step": 1, "unit": "rides"}, + "coaster_count": {"min": 0, "max": 50, "step": 1, "unit": "coasters"}, + "size_acres": {"min": 0, "max": 10000, "step": 1, "unit": "acres"}, + "opening_year": {"min": 1800, "max": 2030, "step": 1, "unit": "year"}, + }, + "ordering_options": [ + {"value": "name", "label": "Name (A-Z)"}, + {"value": "-name", "label": "Name (Z-A)"}, + {"value": "opening_date", "label": "Opening Date (Oldest First)"}, + {"value": "-opening_date", "label": "Opening Date (Newest First)"}, + {"value": "ride_count", "label": "Ride Count (Low to High)"}, + {"value": "-ride_count", "label": "Ride Count (High to Low)"}, + {"value": "coaster_count", "label": "Coaster Count (Low to High)"}, + {"value": "-coaster_count", "label": "Coaster Count (High to Low)"}, + {"value": "average_rating", "label": "Rating (Low to High)"}, + {"value": "-average_rating", "label": "Rating (High to Low)"}, + ], + } + ) # --- Company search (autocomplete) ----------------------------------------- @@ -924,7 +927,10 @@ class FilterOptionsAPIView(APIView): summary="Search companies (operators/property owners) for autocomplete", parameters=[ OpenApiParameter( - name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, description="Search query for company names" + name="q", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="Search query for company names", ) ], responses={200: OpenApiTypes.OBJECT}, @@ -940,49 +946,42 @@ class CompanySearchAPIView(APIView): if not MODELS_AVAILABLE or Company is None: # Provide helpful placeholder structure - return Response([ - {"id": 1, "name": "Six Flags Entertainment", "slug": "six-flags"}, - {"id": 2, "name": "Cedar Fair", "slug": "cedar-fair"}, - {"id": 3, "name": "Disney Parks", "slug": "disney"}, - {"id": 4, "name": "Universal Parks & Resorts", "slug": "universal"}, - {"id": 5, "name": "SeaWorld Parks & Entertainment", "slug": "seaworld"}, - ]) + return Response( + [ + {"id": 1, "name": "Six Flags Entertainment", "slug": "six-flags"}, + {"id": 2, "name": "Cedar Fair", "slug": "cedar-fair"}, + {"id": 3, "name": "Disney Parks", "slug": "disney"}, + {"id": 4, "name": "Universal Parks & Resorts", "slug": "universal"}, + {"id": 5, "name": "SeaWorld Parks & Entertainment", "slug": "seaworld"}, + ] + ) try: # Search companies that can be operators or property owners qs = Company.objects.filter( - Q(name__icontains=q) & - (Q(roles__contains=['OPERATOR']) | Q( - roles__contains=['PROPERTY_OWNER'])) + Q(name__icontains=q) & (Q(roles__contains=["OPERATOR"]) | Q(roles__contains=["PROPERTY_OWNER"])) ).distinct()[:20] results = [ - { - "id": c.id, - "name": c.name, - "slug": getattr(c, "slug", ""), - "roles": getattr(c, "roles", []) - } + {"id": c.id, "name": c.name, "slug": getattr(c, "slug", ""), "roles": getattr(c, "roles", [])} for c in qs ] return Response(results) except Exception: # Fallback to placeholder data - return Response([ - {"id": 1, "name": "Six Flags Entertainment", "slug": "six-flags"}, - {"id": 2, "name": "Cedar Fair", "slug": "cedar-fair"}, - {"id": 3, "name": "Disney Parks", "slug": "disney"}, - ]) + return Response( + [ + {"id": 1, "name": "Six Flags Entertainment", "slug": "six-flags"}, + {"id": 2, "name": "Cedar Fair", "slug": "cedar-fair"}, + {"id": 3, "name": "Disney Parks", "slug": "disney"}, + ] + ) # --- Search suggestions ----------------------------------------------------- @extend_schema( summary="Search suggestions for park search box", - parameters=[ - OpenApiParameter( - name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR - ) - ], + parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)], tags=["Parks"], ) class ParkSearchSuggestionsAPIView(APIView): @@ -995,9 +994,7 @@ class ParkSearchSuggestionsAPIView(APIView): # Very small suggestion implementation: look in park names if available if MODELS_AVAILABLE and Park is not None: - qs = Park.objects.filter(name__icontains=q).values_list("name", flat=True)[ - :10 - ] # type: ignore + qs = Park.objects.filter(name__icontains=q).values_list("name", flat=True)[:10] # type: ignore return Response([{"suggestion": name} for name in qs]) # Fallback suggestions @@ -1013,17 +1010,9 @@ class ParkSearchSuggestionsAPIView(APIView): @extend_schema( summary="Set park banner and card images", description="Set banner_image and card_image for a park from existing park photos", - request=( - "ParkImageSettingsInputSerializer" - if SERIALIZERS_AVAILABLE - else OpenApiTypes.OBJECT - ), + request=("ParkImageSettingsInputSerializer" if SERIALIZERS_AVAILABLE else OpenApiTypes.OBJECT), responses={ - 200: ( - "ParkDetailOutputSerializer" - if SERIALIZERS_AVAILABLE - else OpenApiTypes.OBJECT - ), + 200: ("ParkDetailOutputSerializer" if SERIALIZERS_AVAILABLE else OpenApiTypes.OBJECT), 400: OpenApiTypes.OBJECT, 404: OpenApiTypes.OBJECT, }, @@ -1038,7 +1027,7 @@ class ParkImageSettingsAPIView(APIView): try: return Park.objects.get(pk=pk) # type: ignore except Park.DoesNotExist: # type: ignore - raise NotFound("Park not found") + raise NotFound("Park not found") from None def patch(self, request: Request, pk: int) -> Response: """Set banner and card images for the park.""" @@ -1060,10 +1049,10 @@ class ParkImageSettingsAPIView(APIView): park.save() # Return updated park data - output_serializer = ParkDetailOutputSerializer( - park, context={"request": request} - ) + output_serializer = ParkDetailOutputSerializer(park, context={"request": request}) return Response(output_serializer.data) + + # --- Operator list ---------------------------------------------------------- @extend_schema( summary="List park operators", @@ -1078,10 +1067,7 @@ class OperatorListAPIView(APIView): def get(self, request: Request) -> Response: if not MODELS_AVAILABLE: - return Response( - {"detail": "Models not available"}, - status=status.HTTP_501_NOT_IMPLEMENTED - ) + return Response({"detail": "Models not available"}, status=status.HTTP_501_NOT_IMPLEMENTED) operators = ( Company.objects.filter(roles__contains=["OPERATOR"]) @@ -1102,7 +1088,4 @@ class OperatorListAPIView(APIView): for op in operators ] - return Response({ - "results": data, - "count": len(data) - }) + return Response({"results": data, "count": len(data)}) diff --git a/backend/apps/api/v1/parks/ride_photos_views.py b/backend/apps/api/v1/parks/ride_photos_views.py index 8d5d8ad6..f2c19c7e 100644 --- a/backend/apps/api/v1/parks/ride_photos_views.py +++ b/backend/apps/api/v1/parks/ride_photos_views.py @@ -116,14 +116,12 @@ class RidePhotoViewSet(ModelViewSet): def get_permissions(self): """Set permissions based on action.""" - permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated] + permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated] return [permission() for permission in permission_classes] def get_queryset(self): """Get photos for the current ride with optimized queries.""" - queryset = RidePhoto.objects.select_related( - "ride", "ride__park", "ride__park__operator", "uploaded_by" - ) + queryset = RidePhoto.objects.select_related("ride", "ride__park", "ride__park__operator", "uploaded_by") # Filter by park and ride from URL kwargs park_slug = self.kwargs.get("park_slug") @@ -163,9 +161,9 @@ class RidePhotoViewSet(ModelViewSet): park, _ = Park.get_by_slug(park_slug) ride, _ = Ride.get_by_slug(ride_slug, park=park) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None except Ride.DoesNotExist: - raise NotFound("Ride not found at this park") + raise NotFound("Ride not found at this park") from None try: # Use the service to create the photo with proper business logic @@ -187,17 +185,14 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error creating ride photo: {e}") - raise ValidationError(f"Failed to create photo: {str(e)}") + raise ValidationError(f"Failed to create photo: {str(e)}") from None def perform_update(self, serializer): """Update ride photo with permission checking.""" instance = self.get_object() # Check permissions - allow owner or staff - if not ( - self.request.user == instance.uploaded_by - or getattr(self.request.user, "is_staff", False) - ): + if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)): raise PermissionDenied("You can only edit your own photos or be an admin.") # Handle primary photo logic using service @@ -209,48 +204,40 @@ class RidePhotoViewSet(ModelViewSet): del serializer.validated_data["is_primary"] except Exception as e: logger.error(f"Error setting primary photo: {e}") - raise ValidationError(f"Failed to set primary photo: {str(e)}") + raise ValidationError(f"Failed to set primary photo: {str(e)}") from None try: serializer.save() logger.info(f"Updated ride photo {instance.id} by user {self.request.user.username}") except Exception as e: logger.error(f"Error updating ride photo: {e}") - raise ValidationError(f"Failed to update photo: {str(e)}") + raise ValidationError(f"Failed to update photo: {str(e)}") from None def perform_destroy(self, instance): """Delete ride photo with permission checking.""" # Check permissions - allow owner or staff - if not ( - self.request.user == instance.uploaded_by - or getattr(self.request.user, "is_staff", False) - ): - raise PermissionDenied( - "You can only delete your own photos or be an admin." - ) + if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)): + raise PermissionDenied("You can only delete your own photos or be an admin.") try: # Delete from Cloudflare first if image exists if instance.image: try: from django_cloudflareimages_toolkit.services import CloudflareImagesService + service = CloudflareImagesService() service.delete_image(instance.image) - logger.info( - f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}") + logger.info(f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}") except Exception as e: - logger.error( - f"Failed to delete ride photo from Cloudflare: {str(e)}") + logger.error(f"Failed to delete ride photo from Cloudflare: {str(e)}") # Continue with database deletion even if Cloudflare deletion fails - RideMediaService.delete_photo( - instance, deleted_by=self.request.user - ) + RideMediaService.delete_photo(instance, deleted_by=self.request.user) logger.info(f"Deleted ride photo {instance.id} by user {self.request.user.username}") except Exception as e: logger.error(f"Error deleting ride photo: {e}") - raise ValidationError(f"Failed to delete photo: {str(e)}") + raise ValidationError(f"Failed to delete photo: {str(e)}") from None @extend_schema( summary="Set photo as primary", @@ -269,13 +256,8 @@ class RidePhotoViewSet(ModelViewSet): photo = self.get_object() # Check permissions - allow owner or staff - if not ( - request.user == photo.uploaded_by - or getattr(request.user, "is_staff", False) - ): - raise PermissionDenied( - "You can only modify your own photos or be an admin." - ) + if not (request.user == photo.uploaded_by or getattr(request.user, "is_staff", False)): + raise PermissionDenied("You can only modify your own photos or be an admin.") try: success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo) @@ -287,21 +269,21 @@ class RidePhotoViewSet(ModelViewSet): return Response( { - "message": "Photo set as primary successfully", + "detail": "Photo set as primary successfully", "photo": serializer.data, }, status=status.HTTP_200_OK, ) else: return Response( - {"error": "Failed to set primary photo"}, + {"detail": "Failed to set primary photo"}, status=status.HTTP_400_BAD_REQUEST, ) except Exception as e: logger.error(f"Error setting primary photo: {e}") return Response( - {"error": f"Failed to set primary photo: {str(e)}"}, + {"detail": f"Failed to set primary photo: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -334,7 +316,7 @@ class RidePhotoViewSet(ModelViewSet): if photo_ids is None or approve is None: return Response( - {"error": "Missing required fields: photo_ids and/or approve."}, + {"detail": "Missing required fields: photo_ids and/or approve."}, status=status.HTTP_400_BAD_REQUEST, ) @@ -350,7 +332,7 @@ class RidePhotoViewSet(ModelViewSet): return Response( { - "message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", + "detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", "updated_count": updated_count, }, status=status.HTTP_200_OK, @@ -359,7 +341,7 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error in bulk photo approval: {e}") return Response( - {"error": f"Failed to update photos: {str(e)}"}, + {"detail": f"Failed to update photos: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -381,7 +363,7 @@ class RidePhotoViewSet(ModelViewSet): if not park_slug or not ride_slug: return Response( - {"error": "Park and ride slugs are required"}, + {"detail": "Park and ride slugs are required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -390,12 +372,12 @@ class RidePhotoViewSet(ModelViewSet): ride, _ = Ride.get_by_slug(ride_slug, park=park) except Park.DoesNotExist: return Response( - {"error": "Park not found"}, + {"detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND, ) except Ride.DoesNotExist: return Response( - {"error": "Ride not found at this park"}, + {"detail": "Ride not found at this park"}, status=status.HTTP_404_NOT_FOUND, ) @@ -407,7 +389,7 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error getting ride photo stats: {e}") return Response( - {"error": f"Failed to get photo statistics: {str(e)}"}, + {"detail": f"Failed to get photo statistics: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -431,7 +413,7 @@ class RidePhotoViewSet(ModelViewSet): if not park_slug or not ride_slug: return Response( - {"error": "Park and ride slugs are required"}, + {"detail": "Park and ride slugs are required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -440,19 +422,19 @@ class RidePhotoViewSet(ModelViewSet): ride, _ = Ride.get_by_slug(ride_slug, park=park) except Park.DoesNotExist: return Response( - {"error": "Park not found"}, + {"detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND, ) except Ride.DoesNotExist: return Response( - {"error": "Ride not found at this park"}, + {"detail": "Ride not found at this park"}, status=status.HTTP_404_NOT_FOUND, ) cloudflare_image_id = request.data.get("cloudflare_image_id") if not cloudflare_image_id: return Response( - {"error": "cloudflare_image_id is required"}, + {"detail": "cloudflare_image_id is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -469,27 +451,25 @@ class RidePhotoViewSet(ModelViewSet): if not image_data: return Response( - {"error": "Image not found in Cloudflare"}, + {"detail": "Image not found in Cloudflare"}, status=status.HTTP_400_BAD_REQUEST, ) # Try to find existing CloudflareImage record by cloudflare_id cloudflare_image = None try: - cloudflare_image = CloudflareImage.objects.get( - cloudflare_id=cloudflare_image_id) + cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id) # Update existing record with latest data from Cloudflare - cloudflare_image.status = 'uploaded' + cloudflare_image.status = "uploaded" cloudflare_image.uploaded_at = timezone.now() - cloudflare_image.metadata = image_data.get('meta', {}) + cloudflare_image.metadata = image_data.get("meta", {}) # Extract variants from nested result structure - cloudflare_image.variants = image_data.get( - 'result', {}).get('variants', []) + cloudflare_image.variants = image_data.get("result", {}).get("variants", []) cloudflare_image.cloudflare_metadata = image_data - cloudflare_image.width = image_data.get('width') - cloudflare_image.height = image_data.get('height') - cloudflare_image.format = image_data.get('format', '') + cloudflare_image.width = image_data.get("width") + cloudflare_image.height = image_data.get("height") + cloudflare_image.format = image_data.get("format", "") cloudflare_image.save() except CloudflareImage.DoesNotExist: @@ -497,24 +477,23 @@ class RidePhotoViewSet(ModelViewSet): cloudflare_image = CloudflareImage.objects.create( cloudflare_id=cloudflare_image_id, user=request.user, - status='uploaded', - upload_url='', # Not needed for uploaded images + status="uploaded", + upload_url="", # Not needed for uploaded images expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry uploaded_at=timezone.now(), - metadata=image_data.get('meta', {}), + metadata=image_data.get("meta", {}), # Extract variants from nested result structure - variants=image_data.get('result', {}).get('variants', []), + variants=image_data.get("result", {}).get("variants", []), cloudflare_metadata=image_data, - width=image_data.get('width'), - height=image_data.get('height'), - format=image_data.get('format', ''), + width=image_data.get("width"), + height=image_data.get("height"), + format=image_data.get("format", ""), ) except Exception as api_error: - logger.error( - f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True) + logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True) return Response( - {"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, + {"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -544,6 +523,6 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error saving ride photo: {e}") return Response( - {"error": f"Failed to save photo: {str(e)}"}, + {"detail": f"Failed to save photo: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) diff --git a/backend/apps/api/v1/parks/ride_reviews_views.py b/backend/apps/api/v1/parks/ride_reviews_views.py index 80018071..4e83d738 100644 --- a/backend/apps/api/v1/parks/ride_reviews_views.py +++ b/backend/apps/api/v1/parks/ride_reviews_views.py @@ -115,14 +115,12 @@ class RideReviewViewSet(ModelViewSet): def get_permissions(self): """Set permissions based on action.""" - permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated] + permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated] return [permission() for permission in permission_classes] def get_queryset(self): """Get reviews for the current ride with optimized queries.""" - queryset = RideReview.objects.select_related( - "ride", "ride__park", "user", "user__profile" - ) + queryset = RideReview.objects.select_related("ride", "ride__park", "user", "user__profile") # Filter by park and ride from URL kwargs park_slug = self.kwargs.get("park_slug") @@ -138,8 +136,7 @@ class RideReviewViewSet(ModelViewSet): return queryset.none() # Filter published reviews for non-staff users - if not (hasattr(self.request, 'user') and - getattr(self.request.user, 'is_staff', False)): + if not (hasattr(self.request, "user") and getattr(self.request.user, "is_staff", False)): queryset = queryset.filter(is_published=True) return queryset.order_by("-created_at") @@ -167,9 +164,9 @@ class RideReviewViewSet(ModelViewSet): park, _ = Park.get_by_slug(park_slug) ride, _ = Ride.get_by_slug(ride_slug, park=park) except Park.DoesNotExist: - raise NotFound("Park not found") + raise NotFound("Park not found") from None except Ride.DoesNotExist: - raise NotFound("Ride not found at this park") + raise NotFound("Ride not found at this park") from None # Check if user already has a review for this ride if RideReview.objects.filter(ride=ride, user=self.request.user).exists(): @@ -178,26 +175,21 @@ class RideReviewViewSet(ModelViewSet): try: # Save the review review = serializer.save( - ride=ride, - user=self.request.user, - is_published=True # Auto-publish for now, can add moderation later + ride=ride, user=self.request.user, is_published=True # Auto-publish for now, can add moderation later ) logger.info(f"Created ride review {review.id} for ride {ride.name} by user {self.request.user.username}") except Exception as e: logger.error(f"Error creating ride review: {e}") - raise ValidationError(f"Failed to create review: {str(e)}") + raise ValidationError(f"Failed to create review: {str(e)}") from None def perform_update(self, serializer): """Update ride review with permission checking.""" instance = self.get_object() # Check permissions - allow owner or staff - if not ( - self.request.user == instance.user - or getattr(self.request.user, "is_staff", False) - ): + if not (self.request.user == instance.user or getattr(self.request.user, "is_staff", False)): raise PermissionDenied("You can only edit your own reviews or be an admin.") try: @@ -205,15 +197,12 @@ class RideReviewViewSet(ModelViewSet): logger.info(f"Updated ride review {instance.id} by user {self.request.user.username}") except Exception as e: logger.error(f"Error updating ride review: {e}") - raise ValidationError(f"Failed to update review: {str(e)}") + raise ValidationError(f"Failed to update review: {str(e)}") from None def perform_destroy(self, instance): """Delete ride review with permission checking.""" # Check permissions - allow owner or staff - if not ( - self.request.user == instance.user - or getattr(self.request.user, "is_staff", False) - ): + if not (self.request.user == instance.user or getattr(self.request.user, "is_staff", False)): raise PermissionDenied("You can only delete your own reviews or be an admin.") try: @@ -221,7 +210,7 @@ class RideReviewViewSet(ModelViewSet): instance.delete() except Exception as e: logger.error(f"Error deleting ride review: {e}") - raise ValidationError(f"Failed to delete review: {str(e)}") + raise ValidationError(f"Failed to delete review: {str(e)}") from None @extend_schema( summary="Get ride review statistics", @@ -241,7 +230,7 @@ class RideReviewViewSet(ModelViewSet): if not park_slug or not ride_slug: return Response( - {"error": "Park and ride slugs are required"}, + {"detail": "Park and ride slugs are required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -250,12 +239,12 @@ class RideReviewViewSet(ModelViewSet): ride, _ = Ride.get_by_slug(ride_slug, park=park) except Park.DoesNotExist: return Response( - {"error": "Park not found"}, + {"detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND, ) except Ride.DoesNotExist: return Response( - {"error": "Ride not found at this park"}, + {"detail": "Ride not found at this park"}, status=status.HTTP_404_NOT_FOUND, ) @@ -268,7 +257,7 @@ class RideReviewViewSet(ModelViewSet): pending_reviews = RideReview.objects.filter(ride=ride, is_published=False).count() # Calculate average rating - avg_rating = reviews.aggregate(avg_rating=Avg('rating'))['avg_rating'] + avg_rating = reviews.aggregate(avg_rating=Avg("rating"))["avg_rating"] # Get rating distribution rating_distribution = {} @@ -277,6 +266,7 @@ class RideReviewViewSet(ModelViewSet): # Get recent reviews count (last 30 days) from datetime import timedelta + thirty_days_ago = timezone.now() - timedelta(days=30) recent_reviews = reviews.filter(created_at__gte=thirty_days_ago).count() @@ -295,7 +285,7 @@ class RideReviewViewSet(ModelViewSet): except Exception as e: logger.error(f"Error getting ride review stats: {e}") return Response( - {"error": f"Failed to get review statistics: {str(e)}"}, + {"detail": f"Failed to get review statistics: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -340,7 +330,7 @@ class RideReviewViewSet(ModelViewSet): is_published=True, moderated_by=request.user, moderated_at=timezone.now(), - moderation_notes=moderation_notes + moderation_notes=moderation_notes, ) message = f"Successfully published {updated_count} reviews" elif action_type == "unpublish": @@ -348,7 +338,7 @@ class RideReviewViewSet(ModelViewSet): is_published=False, moderated_by=request.user, moderated_at=timezone.now(), - moderation_notes=moderation_notes + moderation_notes=moderation_notes, ) message = f"Successfully unpublished {updated_count} reviews" elif action_type == "delete": @@ -357,13 +347,13 @@ class RideReviewViewSet(ModelViewSet): message = f"Successfully deleted {updated_count} reviews" else: return Response( - {"error": "Invalid action type"}, + {"detail": "Invalid action type"}, status=status.HTTP_400_BAD_REQUEST, ) return Response( { - "message": message, + "detail": message, "updated_count": updated_count, }, status=status.HTTP_200_OK, @@ -372,6 +362,6 @@ class RideReviewViewSet(ModelViewSet): except Exception as e: logger.error(f"Error in bulk review moderation: {e}") return Response( - {"error": f"Failed to moderate reviews: {str(e)}"}, + {"detail": f"Failed to moderate reviews: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) diff --git a/backend/apps/api/v1/parks/serializers.py b/backend/apps/api/v1/parks/serializers.py index e223578e..5eb914ae 100644 --- a/backend/apps/api/v1/parks/serializers.py +++ b/backend/apps/api/v1/parks/serializers.py @@ -50,18 +50,14 @@ from apps.parks.models import Park, ParkPhoto class ParkPhotoOutputSerializer(serializers.ModelSerializer): """Enhanced output serializer for park photos with Cloudflare Images support.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) file_size = serializers.SerializerMethodField() dimensions = serializers.SerializerMethodField() image_url = serializers.SerializerMethodField() image_variants = serializers.SerializerMethodField() - @extend_schema_field( - serializers.IntegerField(allow_null=True, help_text="File size in bytes") - ) + @extend_schema_field(serializers.IntegerField(allow_null=True, help_text="File size in bytes")) def get_file_size(self, obj): """Get file size in bytes.""" return obj.file_size @@ -79,11 +75,7 @@ class ParkPhotoOutputSerializer(serializers.ModelSerializer): """Get image dimensions as [width, height].""" return obj.dimensions - @extend_schema_field( - serializers.URLField( - help_text="Full URL to the Cloudflare Images asset", allow_null=True - ) - ) + @extend_schema_field(serializers.URLField(help_text="Full URL to the Cloudflare Images asset", allow_null=True)) def get_image_url(self, obj): """Get the full Cloudflare Images URL.""" if obj.image: @@ -175,9 +167,7 @@ class ParkPhotoUpdateInputSerializer(serializers.ModelSerializer): class ParkPhotoListOutputSerializer(serializers.ModelSerializer): """Optimized output serializer for park photo lists.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) class Meta: model = ParkPhoto @@ -196,12 +186,8 @@ class ParkPhotoListOutputSerializer(serializers.ModelSerializer): class ParkPhotoApprovalInputSerializer(serializers.Serializer): """Input serializer for bulk photo approval operations.""" - photo_ids = serializers.ListField( - child=serializers.IntegerField(), help_text="List of photo IDs to approve" - ) - approve = serializers.BooleanField( - default=True, help_text="Whether to approve (True) or reject (False) the photos" - ) + photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve") + approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos") class ParkPhotoStatsOutputSerializer(serializers.Serializer): @@ -261,7 +247,7 @@ class HybridParkSerializer(serializers.ModelSerializer): def get_city(self, obj): """Get city from related location.""" try: - return obj.location.city if hasattr(obj, 'location') and obj.location else None + return obj.location.city if hasattr(obj, "location") and obj.location else None except AttributeError: return None @@ -269,7 +255,7 @@ class HybridParkSerializer(serializers.ModelSerializer): def get_state(self, obj): """Get state from related location.""" try: - return obj.location.state if hasattr(obj, 'location') and obj.location else None + return obj.location.state if hasattr(obj, "location") and obj.location else None except AttributeError: return None @@ -277,7 +263,7 @@ class HybridParkSerializer(serializers.ModelSerializer): def get_country(self, obj): """Get country from related location.""" try: - return obj.location.country if hasattr(obj, 'location') and obj.location else None + return obj.location.country if hasattr(obj, "location") and obj.location else None except AttributeError: return None @@ -285,7 +271,7 @@ class HybridParkSerializer(serializers.ModelSerializer): def get_continent(self, obj): """Get continent from related location.""" try: - return obj.location.continent if hasattr(obj, 'location') and obj.location else None + return obj.location.continent if hasattr(obj, "location") and obj.location else None except AttributeError: return None @@ -293,7 +279,7 @@ class HybridParkSerializer(serializers.ModelSerializer): def get_latitude(self, obj): """Get latitude from related location.""" try: - if hasattr(obj, 'location') and obj.location and obj.location.coordinates: + if hasattr(obj, "location") and obj.location and obj.location.coordinates: return obj.location.coordinates[1] # PostGIS returns [lon, lat] return None except (AttributeError, IndexError, TypeError): @@ -303,7 +289,7 @@ class HybridParkSerializer(serializers.ModelSerializer): def get_longitude(self, obj): """Get longitude from related location.""" try: - if hasattr(obj, 'location') and obj.location and obj.location.coordinates: + if hasattr(obj, "location") and obj.location and obj.location.coordinates: return obj.location.coordinates[0] # PostGIS returns [lon, lat] return None except (AttributeError, IndexError, TypeError): @@ -333,13 +319,11 @@ class HybridParkSerializer(serializers.ModelSerializer): "description", "status", "park_type", - # Dates and computed fields "opening_date", "closing_date", "opening_year", "operating_season", - # Location fields "city", "state", @@ -347,28 +331,22 @@ class HybridParkSerializer(serializers.ModelSerializer): "continent", "latitude", "longitude", - # Company relationships "operator_name", "property_owner_name", - # Statistics "size_acres", "average_rating", "ride_count", "coaster_count", - # Images "banner_image_url", "card_image_url", - # URLs "website", "url", - # Computed fields for filtering "search_text", - # Metadata "created_at", "updated_at", diff --git a/backend/apps/api/v1/parks/urls.py b/backend/apps/api/v1/parks/urls.py index 69e916c0..13c2317d 100644 --- a/backend/apps/api/v1/parks/urls.py +++ b/backend/apps/api/v1/parks/urls.py @@ -46,8 +46,8 @@ ride_photos_router.register(r"", RidePhotoViewSet, basename="ride-photo") ride_reviews_router = DefaultRouter() ride_reviews_router.register(r"", RideReviewViewSet, basename="ride-review") -from .history_views import ParkHistoryViewSet, RideHistoryViewSet -from .park_reviews_views import ParkReviewViewSet +from .history_views import ParkHistoryViewSet, RideHistoryViewSet # noqa: E402 +from .park_reviews_views import ParkReviewViewSet # noqa: E402 # Create routers for nested park endpoints reviews_router = DefaultRouter() @@ -59,11 +59,9 @@ app_name = "api_v1_parks" urlpatterns = [ # Core list/create endpoints path("", ParkListCreateAPIView.as_view(), name="park-list-create"), - # Hybrid filtering endpoints path("hybrid/", HybridParkAPIView.as_view(), name="park-hybrid-list"), path("hybrid/filter-metadata/", ParkFilterMetadataAPIView.as_view(), name="park-hybrid-filter-metadata"), - # Filter options path("filter-options/", FilterOptionsAPIView.as_view(), name="park-filter-options"), # Autocomplete / suggestion endpoints @@ -79,14 +77,11 @@ urlpatterns = [ ), # Detail and action endpoints - supports both ID and slug path("/", ParkDetailAPIView.as_view(), name="park-detail"), - # Park rides endpoints path("/rides/", ParkRidesListAPIView.as_view(), name="park-rides-list"), path("/rides//", ParkRideDetailAPIView.as_view(), name="park-ride-detail"), - # Comprehensive park detail endpoint with rides summary path("/detail/", ParkComprehensiveDetailAPIView.as_view(), name="park-comprehensive-detail"), - # Park image settings endpoint path( "/image-settings/", @@ -95,33 +90,29 @@ urlpatterns = [ ), # Park photo endpoints - domain-specific photo management path("/photos/", include(router.urls)), - # Nested ride photo endpoints - photos for specific rides within parks path("/rides//photos/", include(ride_photos_router.urls)), - # Nested ride review endpoints - reviews for specific rides within parks path("/rides//reviews/", include(ride_reviews_router.urls)), # Nested ride review endpoints - reviews for specific rides within parks path("/rides//reviews/", include(ride_reviews_router.urls)), - # Ride History - path("/rides//history/", RideHistoryViewSet.as_view({'get': 'list'}), name="ride-history"), - + path( + "/rides//history/", + RideHistoryViewSet.as_view({"get": "list"}), + name="ride-history", + ), # Park Reviews path("/reviews/", include(reviews_router.urls)), - # Park History - path("/history/", ParkHistoryViewSet.as_view({'get': 'list'}), name="park-history"), - + path("/history/", ParkHistoryViewSet.as_view({"get": "list"}), name="park-history"), # Roadtrip API endpoints path("roadtrip/create/", CreateTripView.as_view(), name="roadtrip-create"), path("roadtrip/find-along-route/", FindParksAlongRouteView.as_view(), name="roadtrip-find"), path("roadtrip/geocode/", GeocodeAddressView.as_view(), name="roadtrip-geocode"), path("roadtrip/distance/", ParkDistanceCalculatorView.as_view(), name="roadtrip-distance"), - # Operator endpoints path("operators/", OperatorListAPIView.as_view(), name="operator-list"), - # Location search endpoints path("search/location/", location_search, name="location-search"), path("search/reverse-geocode/", reverse_geocode, name="reverse-geocode"), diff --git a/backend/apps/api/v1/parks/views.py b/backend/apps/api/v1/parks/views.py index 0d3a32dc..1754d04e 100644 --- a/backend/apps/api/v1/parks/views.py +++ b/backend/apps/api/v1/parks/views.py @@ -134,9 +134,7 @@ class ParkPhotoViewSet(ModelViewSet): def get_queryset(self): # type: ignore[override] """Get photos for the current park with optimized queries.""" - queryset = ParkPhoto.objects.select_related( - "park", "park__operator", "uploaded_by" - ) + queryset = ParkPhoto.objects.select_related("park", "park__operator", "uploaded_by") # If park_pk is provided in URL kwargs, filter by park # If park_pk is provided in URL kwargs, filter by park @@ -172,7 +170,7 @@ class ParkPhotoViewSet(ModelViewSet): # Use real park ID park_id = park.id except Park.DoesNotExist: - raise ValidationError("Park not found") + raise ValidationError("Park not found") from None try: # Use the service to create the photo with proper business logic @@ -188,48 +186,38 @@ class ParkPhotoViewSet(ModelViewSet): except (ValidationException, ValidationError) as e: logger.warning(f"Validation error creating park photo: {e}") - raise ValidationError(str(e)) + raise ValidationError(str(e)) from None except ServiceError as e: logger.error(f"Service error creating park photo: {e}") - raise ValidationError(f"Failed to create photo: {str(e)}") + raise ValidationError(f"Failed to create photo: {str(e)}") from None def perform_update(self, serializer): """Update park photo with permission checking.""" instance = self.get_object() # Check permissions - allow owner or staff - if not ( - self.request.user == instance.uploaded_by - or cast(Any, self.request.user).is_staff - ): + if not (self.request.user == instance.uploaded_by or cast(Any, self.request.user).is_staff): raise PermissionDenied("You can only edit your own photos or be an admin.") # Handle primary photo logic using service if serializer.validated_data.get("is_primary", False): try: - ParkMediaService().set_primary_photo( - park_id=instance.park_id, photo_id=instance.id - ) + ParkMediaService().set_primary_photo(park_id=instance.park_id, photo_id=instance.id) # Remove is_primary from validated_data since service handles it if "is_primary" in serializer.validated_data: del serializer.validated_data["is_primary"] except (ValidationException, ValidationError) as e: logger.warning(f"Validation error setting primary photo: {e}") - raise ValidationError(str(e)) + raise ValidationError(str(e)) from None except ServiceError as e: logger.error(f"Service error setting primary photo: {e}") - raise ValidationError(f"Failed to set primary photo: {str(e)}") + raise ValidationError(f"Failed to set primary photo: {str(e)}") from None def perform_destroy(self, instance): """Delete park photo with permission checking.""" # Check permissions - allow owner or staff - if not ( - self.request.user == instance.uploaded_by - or cast(Any, self.request.user).is_staff - ): - raise PermissionDenied( - "You can only delete your own photos or be an admin." - ) + if not (self.request.user == instance.uploaded_by or cast(Any, self.request.user).is_staff): + raise PermissionDenied("You can only delete your own photos or be an admin.") # Delete from Cloudflare first if image exists if instance.image: @@ -240,9 +228,7 @@ class ParkPhotoViewSet(ModelViewSet): service = CloudflareImagesService() service.delete_image(instance.image) - logger.info( - f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}" - ) + logger.info(f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}") except ImportError: logger.warning("CloudflareImagesService not available") except ServiceError as e: @@ -250,12 +236,10 @@ class ParkPhotoViewSet(ModelViewSet): # Continue with database deletion even if Cloudflare deletion fails try: - ParkMediaService().delete_photo( - instance.id, deleted_by=cast(UserModel, self.request.user) - ) + ParkMediaService().delete_photo(instance.id, deleted_by=cast(UserModel, self.request.user)) except ServiceError as e: logger.error(f"Service error deleting park photo: {e}") - raise ValidationError(f"Failed to delete photo: {str(e)}") + raise ValidationError(f"Failed to delete photo: {str(e)}") from None @extend_schema( summary="Set photo as primary", @@ -275,14 +259,10 @@ class ParkPhotoViewSet(ModelViewSet): # Check permissions - allow owner or staff if not (request.user == photo.uploaded_by or cast(Any, request.user).is_staff): - raise PermissionDenied( - "You can only modify your own photos or be an admin." - ) + raise PermissionDenied("You can only modify your own photos or be an admin.") try: - ParkMediaService().set_primary_photo( - park_id=photo.park_id, photo_id=photo.id - ) + ParkMediaService().set_primary_photo(park_id=photo.park_id, photo_id=photo.id) # Refresh the photo instance photo.refresh_from_db() @@ -290,7 +270,7 @@ class ParkPhotoViewSet(ModelViewSet): return Response( { - "message": "Photo set as primary successfully", + "detail": "Photo set as primary successfully", "photo": serializer.data, }, status=status.HTTP_200_OK, @@ -337,7 +317,7 @@ class ParkPhotoViewSet(ModelViewSet): if photo_ids is None or approve is None: return Response( - {"error": "Missing required fields: photo_ids and/or approve."}, + {"detail": "Missing required fields: photo_ids and/or approve."}, status=status.HTTP_400_BAD_REQUEST, ) @@ -354,7 +334,7 @@ class ParkPhotoViewSet(ModelViewSet): return Response( { - "message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", + "detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", "updated_count": updated_count, }, status=status.HTTP_200_OK, @@ -430,19 +410,14 @@ class ParkPhotoViewSet(ModelViewSet): def set_primary_legacy(self, request, id=None): """Legacy set primary action for backwards compatibility.""" photo = self.get_object() - if not ( - request.user == photo.uploaded_by - or request.user.has_perm("parks.change_parkphoto") - ): + if not (request.user == photo.uploaded_by or request.user.has_perm("parks.change_parkphoto")): return Response( - {"error": "You do not have permission to edit photos for this park."}, + {"detail": "You do not have permission to edit photos for this park."}, status=status.HTTP_403_FORBIDDEN, ) try: - ParkMediaService().set_primary_photo( - park_id=photo.park_id, photo_id=photo.id - ) - return Response({"message": "Photo set as primary successfully."}) + ParkMediaService().set_primary_photo(park_id=photo.park_id, photo_id=photo.id) + return Response({"detail": "Photo set as primary successfully."}) except (ValidationException, ValidationError) as e: logger.warning(f"Validation error in set_primary_photo: {str(e)}") return ErrorHandler.handle_api_error( @@ -475,7 +450,7 @@ class ParkPhotoViewSet(ModelViewSet): park_pk = self.kwargs.get("park_pk") if not park_pk: return Response( - {"error": "Park ID is required"}, + {"detail": "Park ID is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -483,14 +458,14 @@ class ParkPhotoViewSet(ModelViewSet): park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk) except Park.DoesNotExist: return Response( - {"error": "Park not found"}, + {"detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND, ) cloudflare_image_id = request.data.get("cloudflare_image_id") if not cloudflare_image_id: return Response( - {"error": "cloudflare_image_id is required"}, + {"detail": "cloudflare_image_id is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -515,18 +490,14 @@ class ParkPhotoViewSet(ModelViewSet): # Try to find existing CloudflareImage record by cloudflare_id cloudflare_image = None try: - cloudflare_image = CloudflareImage.objects.get( - cloudflare_id=cloudflare_image_id - ) + cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id) # Update existing record with latest data from Cloudflare cloudflare_image.status = "uploaded" cloudflare_image.uploaded_at = timezone.now() cloudflare_image.metadata = image_data.get("meta", {}) # Extract variants from nested result structure - cloudflare_image.variants = image_data.get("result", {}).get( - "variants", [] - ) + cloudflare_image.variants = image_data.get("result", {}).get("variants", []) cloudflare_image.cloudflare_metadata = image_data cloudflare_image.width = image_data.get("width") cloudflare_image.height = image_data.get("height") @@ -540,8 +511,7 @@ class ParkPhotoViewSet(ModelViewSet): user=request.user, status="uploaded", upload_url="", # Not needed for uploaded images - expires_at=timezone.now() - + timezone.timedelta(days=365), # Set far future expiry + expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry uploaded_at=timezone.now(), metadata=image_data.get("meta", {}), # Extract variants from nested result structure @@ -567,9 +537,7 @@ class ParkPhotoViewSet(ModelViewSet): # Handle primary photo logic if requested if request.data.get("is_primary", False): try: - ParkMediaService().set_primary_photo( - park_id=park.id, photo_id=photo.id - ) + ParkMediaService().set_primary_photo(park_id=park.id, photo_id=photo.id) except ServiceError as e: logger.error(f"Error setting primary photo: {e}") # Don't fail the entire operation, just log the error @@ -624,12 +592,8 @@ class ParkPhotoViewSet(ModelViewSet): OpenApiTypes.STR, description="Filter by state (comma-separated for multiple)", ), - OpenApiParameter( - "opening_year_min", OpenApiTypes.INT, description="Minimum opening year" - ), - OpenApiParameter( - "opening_year_max", OpenApiTypes.INT, description="Maximum opening year" - ), + OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"), + OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"), OpenApiParameter( "size_min", OpenApiTypes.NUMBER, @@ -640,18 +604,10 @@ class ParkPhotoViewSet(ModelViewSet): OpenApiTypes.NUMBER, description="Maximum park size in acres", ), - OpenApiParameter( - "rating_min", OpenApiTypes.NUMBER, description="Minimum average rating" - ), - OpenApiParameter( - "rating_max", OpenApiTypes.NUMBER, description="Maximum average rating" - ), - OpenApiParameter( - "ride_count_min", OpenApiTypes.INT, description="Minimum ride count" - ), - OpenApiParameter( - "ride_count_max", OpenApiTypes.INT, description="Maximum ride count" - ), + OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"), + OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"), + OpenApiParameter("ride_count_min", OpenApiTypes.INT, description="Minimum ride count"), + OpenApiParameter("ride_count_max", OpenApiTypes.INT, description="Maximum ride count"), OpenApiParameter( "coaster_count_min", OpenApiTypes.INT, @@ -688,9 +644,7 @@ class ParkPhotoViewSet(ModelViewSet): "properties": { "parks": { "type": "array", - "items": { - "$ref": "#/components/schemas/HybridParkSerializer" - }, + "items": {"$ref": "#/components/schemas/HybridParkSerializer"}, }, "total_count": {"type": "integer"}, "strategy": { @@ -808,7 +762,7 @@ class HybridParkAPIView(APIView): for param in int_params: value = query_params.get(param) if value: - try: + try: # noqa: SIM105 filters[param] = int(value) except ValueError: pass # Skip invalid integer values @@ -818,7 +772,7 @@ class HybridParkAPIView(APIView): for param in float_params: value = query_params.get(param) if value: - try: + try: # noqa: SIM105 filters[param] = float(value) except ValueError: pass # Skip invalid float values diff --git a/backend/apps/api/v1/responses.py b/backend/apps/api/v1/responses.py new file mode 100644 index 00000000..a17ca22d --- /dev/null +++ b/backend/apps/api/v1/responses.py @@ -0,0 +1,167 @@ +""" +Standardized API response helpers for ThrillWiki. + +This module provides consistent response formatting across all API endpoints: + +Success responses: +- Action completed: {"detail": "Success message"} +- With data: {"detail": "...", "data": {...}} + +Error responses: +- Validation: {"field": ["error"]} (DRF default) +- Application: {"detail": "Error message", "code": "ERROR_CODE"} + +Usage: + from apps.api.v1.responses import success_response, error_response + + # Success + return success_response("Avatar saved successfully") + + # Error + return error_response("User not found", code="NOT_FOUND", status_code=404) +""" + +from rest_framework import status +from rest_framework.response import Response + + +# Standard error codes for machine-readable error handling +class ErrorCodes: + """Standard error codes for API responses.""" + + # Authentication / Authorization + UNAUTHORIZED = "UNAUTHORIZED" + FORBIDDEN = "FORBIDDEN" + INVALID_CREDENTIALS = "INVALID_CREDENTIALS" + TOKEN_EXPIRED = "TOKEN_EXPIRED" + TOKEN_INVALID = "TOKEN_INVALID" + + # Resource errors + NOT_FOUND = "NOT_FOUND" + ALREADY_EXISTS = "ALREADY_EXISTS" + CONFLICT = "CONFLICT" + + # Validation errors + VALIDATION_ERROR = "VALIDATION_ERROR" + INVALID_INPUT = "INVALID_INPUT" + MISSING_FIELD = "MISSING_FIELD" + + # Operation errors + OPERATION_FAILED = "OPERATION_FAILED" + PERMISSION_DENIED = "PERMISSION_DENIED" + RATE_LIMITED = "RATE_LIMITED" + + # User-specific errors + USER_NOT_FOUND = "USER_NOT_FOUND" + USER_INACTIVE = "USER_INACTIVE" + USER_BANNED = "USER_BANNED" + CANNOT_DELETE_SUPERUSER = "CANNOT_DELETE_SUPERUSER" + CANNOT_DELETE_SELF = "CANNOT_DELETE_SELF" + + # Verification errors + VERIFICATION_EXPIRED = "VERIFICATION_EXPIRED" + VERIFICATION_INVALID = "VERIFICATION_INVALID" + ALREADY_VERIFIED = "ALREADY_VERIFIED" + + # External service errors + EXTERNAL_SERVICE_ERROR = "EXTERNAL_SERVICE_ERROR" + CLOUDFLARE_ERROR = "CLOUDFLARE_ERROR" + + +def success_response( + detail: str, + data: dict | None = None, + status_code: int = status.HTTP_200_OK, +) -> Response: + """ + Create a standardized success response. + + Args: + detail: Human-readable success message + data: Optional additional data to include + status_code: HTTP status code (default 200) + + Returns: + DRF Response object + + Example: + return success_response("Avatar saved successfully") + return success_response("User created", data={"id": user.id}, status_code=201) + """ + response_data = {"detail": detail} + if data: + response_data.update(data) + return Response(response_data, status=status_code) + + +def error_response( + detail: str, + code: str | None = None, + status_code: int = status.HTTP_400_BAD_REQUEST, + extra: dict | None = None, +) -> Response: + """ + Create a standardized error response. + + Args: + detail: Human-readable error message + code: Machine-readable error code from ErrorCodes + status_code: HTTP status code (default 400) + extra: Optional additional data to include + + Returns: + DRF Response object + + Example: + return error_response("User not found", code=ErrorCodes.NOT_FOUND, status_code=404) + return error_response("Invalid input", code=ErrorCodes.VALIDATION_ERROR) + """ + response_data = {"detail": detail} + if code: + response_data["code"] = code + if extra: + response_data.update(extra) + return Response(response_data, status=status_code) + + +def created_response(detail: str, data: dict | None = None) -> Response: + """Convenience wrapper for 201 Created responses.""" + return success_response(detail, data=data, status_code=status.HTTP_201_CREATED) + + +def not_found_response(detail: str = "Resource not found") -> Response: + """Convenience wrapper for 404 Not Found responses.""" + return error_response( + detail, + code=ErrorCodes.NOT_FOUND, + status_code=status.HTTP_404_NOT_FOUND, + ) + + +def forbidden_response(detail: str = "Permission denied") -> Response: + """Convenience wrapper for 403 Forbidden responses.""" + return error_response( + detail, + code=ErrorCodes.FORBIDDEN, + status_code=status.HTTP_403_FORBIDDEN, + ) + + +def unauthorized_response(detail: str = "Authentication required") -> Response: + """Convenience wrapper for 401 Unauthorized responses.""" + return error_response( + detail, + code=ErrorCodes.UNAUTHORIZED, + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + +__all__ = [ + "ErrorCodes", + "success_response", + "error_response", + "created_response", + "not_found_response", + "forbidden_response", + "unauthorized_response", +] diff --git a/backend/apps/api/v1/rides/company_views.py b/backend/apps/api/v1/rides/company_views.py index d0c3b010..6b819720 100644 --- a/backend/apps/api/v1/rides/company_views.py +++ b/backend/apps/api/v1/rides/company_views.py @@ -24,6 +24,7 @@ from apps.api.v1.serializers.companies import ( try: from apps.rides.models.company import Company + MODELS_AVAILABLE = True except ImportError: Company = None @@ -65,9 +66,7 @@ class CompanyListCreateAPIView(APIView): # Search filter search = request.query_params.get("search", "") if search: - qs = qs.filter( - Q(name__icontains=search) | Q(description__icontains=search) - ) + qs = qs.filter(Q(name__icontains=search) | Q(description__icontains=search)) # Role filter role = request.query_params.get("role", "") @@ -120,7 +119,7 @@ class CompanyDetailAPIView(APIView): try: return Company.objects.get(pk=pk) except Company.DoesNotExist: - raise NotFound("Company not found") + raise NotFound("Company not found") from None @extend_schema( summary="Retrieve a company", diff --git a/backend/apps/api/v1/rides/manufacturers/views.py b/backend/apps/api/v1/rides/manufacturers/views.py index e06317f3..ff443f06 100644 --- a/backend/apps/api/v1/rides/manufacturers/views.py +++ b/backend/apps/api/v1/rides/manufacturers/views.py @@ -93,18 +93,10 @@ class RideModelListCreateAPIView(APIView): type=OpenApiTypes.STR, required=True, ), - OpenApiParameter( - name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT - ), - OpenApiParameter( - name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT - ), - OpenApiParameter( - name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR - ), - OpenApiParameter( - name="category", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR - ), + OpenApiParameter(name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT), + OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT), + OpenApiParameter(name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + OpenApiParameter(name="category", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), OpenApiParameter( name="target_market", location=OpenApiParameter.QUERY, @@ -134,7 +126,7 @@ class RideModelListCreateAPIView(APIView): try: manufacturer = Company.objects.get(slug=manufacturer_slug) except Company.DoesNotExist: - raise NotFound("Manufacturer not found") + raise NotFound("Manufacturer not found") from None qs = ( RideModel.objects.filter(manufacturer=manufacturer) @@ -176,13 +168,9 @@ class RideModelListCreateAPIView(APIView): # Year filters if filters.get("first_installation_year_min"): - qs = qs.filter( - first_installation_year__gte=filters["first_installation_year_min"] - ) + qs = qs.filter(first_installation_year__gte=filters["first_installation_year_min"]) if filters.get("first_installation_year_max"): - qs = qs.filter( - first_installation_year__lte=filters["first_installation_year_max"] - ) + qs = qs.filter(first_installation_year__lte=filters["first_installation_year_max"]) # Installation count filter if filters.get("min_installations"): @@ -190,23 +178,15 @@ class RideModelListCreateAPIView(APIView): # Height filters if filters.get("min_height_ft"): - qs = qs.filter( - typical_height_range_max_ft__gte=filters["min_height_ft"] - ) + qs = qs.filter(typical_height_range_max_ft__gte=filters["min_height_ft"]) if filters.get("max_height_ft"): - qs = qs.filter( - typical_height_range_min_ft__lte=filters["max_height_ft"] - ) + qs = qs.filter(typical_height_range_min_ft__lte=filters["max_height_ft"]) # Speed filters if filters.get("min_speed_mph"): - qs = qs.filter( - typical_speed_range_max_mph__gte=filters["min_speed_mph"] - ) + qs = qs.filter(typical_speed_range_max_mph__gte=filters["min_speed_mph"]) if filters.get("max_speed_mph"): - qs = qs.filter( - typical_speed_range_min_mph__lte=filters["max_speed_mph"] - ) + qs = qs.filter(typical_speed_range_min_mph__lte=filters["max_speed_mph"]) # Ordering ordering = filters.get("ordering", "manufacturer__name,name") @@ -216,9 +196,7 @@ class RideModelListCreateAPIView(APIView): paginator = StandardResultsSetPagination() page = paginator.paginate_queryset(qs, request) - serializer = RideModelListOutputSerializer( - page, many=True, context={"request": request} - ) + serializer = RideModelListOutputSerializer(page, many=True, context={"request": request}) return paginator.get_paginated_response(serializer.data) @extend_schema( @@ -240,9 +218,7 @@ class RideModelListCreateAPIView(APIView): """Create a new ride model for a specific manufacturer.""" if not MODELS_AVAILABLE: return Response( - { - "detail": "Ride model creation is not available because domain models are not imported." - }, + {"detail": "Ride model creation is not available because domain models are not imported."}, status=status.HTTP_501_NOT_IMPLEMENTED, ) @@ -250,7 +226,7 @@ class RideModelListCreateAPIView(APIView): try: manufacturer = Company.objects.get(slug=manufacturer_slug) except Company.DoesNotExist: - raise NotFound("Manufacturer not found") + raise NotFound("Manufacturer not found") from None serializer_in = RideModelCreateInputSerializer(data=request.data) serializer_in.is_valid(raise_exception=True) @@ -279,18 +255,14 @@ class RideModelListCreateAPIView(APIView): target_market=validated.get("target_market", ""), ) - out_serializer = RideModelDetailOutputSerializer( - ride_model, context={"request": request} - ) + out_serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request}) return Response(out_serializer.data, status=status.HTTP_201_CREATED) class RideModelDetailAPIView(APIView): permission_classes = [permissions.AllowAny] - def _get_ride_model_or_404( - self, manufacturer_slug: str, ride_model_slug: str - ) -> Any: + def _get_ride_model_or_404(self, manufacturer_slug: str, ride_model_slug: str) -> Any: if not MODELS_AVAILABLE: raise NotFound("Ride model models not available") try: @@ -300,7 +272,7 @@ class RideModelDetailAPIView(APIView): .get(manufacturer__slug=manufacturer_slug, slug=ride_model_slug) ) except RideModel.DoesNotExist: - raise NotFound("Ride model not found") + raise NotFound("Ride model not found") from None @extend_schema( summary="Retrieve a ride model", @@ -322,13 +294,9 @@ class RideModelDetailAPIView(APIView): responses={200: RideModelDetailOutputSerializer()}, tags=["Ride Models"], ) - def get( - self, request: Request, manufacturer_slug: str, ride_model_slug: str - ) -> Response: + def get(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response: ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug) - serializer = RideModelDetailOutputSerializer( - ride_model, context={"request": request} - ) + serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request}) return Response(serializer.data) @extend_schema( @@ -352,9 +320,7 @@ class RideModelDetailAPIView(APIView): responses={200: RideModelDetailOutputSerializer()}, tags=["Ride Models"], ) - def patch( - self, request: Request, manufacturer_slug: str, ride_model_slug: str - ) -> Response: + def patch(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response: ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug) serializer_in = RideModelUpdateInputSerializer(data=request.data, partial=True) serializer_in.is_valid(raise_exception=True) @@ -366,20 +332,16 @@ class RideModelDetailAPIView(APIView): manufacturer = Company.objects.get(id=value) ride_model.manufacturer = manufacturer except Company.DoesNotExist: - raise ValidationError({"manufacturer_id": "Manufacturer not found"}) + raise ValidationError({"manufacturer_id": "Manufacturer not found"}) from None else: setattr(ride_model, field, value) ride_model.save() - serializer = RideModelDetailOutputSerializer( - ride_model, context={"request": request} - ) + serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request}) return Response(serializer.data) - def put( - self, request: Request, manufacturer_slug: str, ride_model_slug: str - ) -> Response: + def put(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response: # Full replace - reuse patch behavior for simplicity return self.patch(request, manufacturer_slug, ride_model_slug) @@ -403,9 +365,7 @@ class RideModelDetailAPIView(APIView): responses={204: None}, tags=["Ride Models"], ) - def delete( - self, request: Request, manufacturer_slug: str, ride_model_slug: str - ) -> Response: + def delete(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response: ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug) ride_model.delete() return Response(status=status.HTTP_204_NO_CONTENT) @@ -449,9 +409,7 @@ class RideModelSearchAPIView(APIView): ) qs = RideModel.objects.filter( - Q(name__icontains=q) - | Q(description__icontains=q) - | Q(manufacturer__name__icontains=q) + Q(name__icontains=q) | Q(description__icontains=q) | Q(manufacturer__name__icontains=q) ).select_related("manufacturer")[:20] results = [ @@ -491,8 +449,8 @@ class RideModelFilterOptionsAPIView(APIView): # Use Rich Choice Objects for fallback options try: # Get rich choice objects from registry - categories = get_choices('categories', 'rides') - target_markets = get_choices('target_markets', 'rides') + categories = get_choices("categories", "rides") + target_markets = get_choices("target_markets", "rides") # Convert Rich Choice Objects to frontend format with metadata categories_data = [ @@ -500,10 +458,10 @@ class RideModelFilterOptionsAPIView(APIView): "value": choice.value, "label": choice.label, "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), } for choice in categories ] @@ -513,10 +471,10 @@ class RideModelFilterOptionsAPIView(APIView): "value": choice.value, "label": choice.label, "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), } for choice in target_markets ] @@ -524,25 +482,173 @@ class RideModelFilterOptionsAPIView(APIView): except Exception: # Ultimate fallback with basic structure categories_data = [ - {"value": "RC", "label": "Roller Coaster", "description": "High-speed thrill rides with tracks", "color": "red", "icon": "roller-coaster", "css_class": "bg-red-100 text-red-800", "sort_order": 1}, - {"value": "DR", "label": "Dark Ride", "description": "Indoor themed experiences", "color": "purple", "icon": "dark-ride", "css_class": "bg-purple-100 text-purple-800", "sort_order": 2}, - {"value": "FR", "label": "Flat Ride", "description": "Spinning and rotating attractions", "color": "blue", "icon": "flat-ride", "css_class": "bg-blue-100 text-blue-800", "sort_order": 3}, - {"value": "WR", "label": "Water Ride", "description": "Water-based attractions and slides", "color": "cyan", "icon": "water-ride", "css_class": "bg-cyan-100 text-cyan-800", "sort_order": 4}, - {"value": "TR", "label": "Transport", "description": "Transportation systems within parks", "color": "green", "icon": "transport", "css_class": "bg-green-100 text-green-800", "sort_order": 5}, - {"value": "OT", "label": "Other", "description": "Miscellaneous attractions", "color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 6}, + { + "value": "RC", + "label": "Roller Coaster", + "description": "High-speed thrill rides with tracks", + "color": "red", + "icon": "roller-coaster", + "css_class": "bg-red-100 text-red-800", + "sort_order": 1, + }, + { + "value": "DR", + "label": "Dark Ride", + "description": "Indoor themed experiences", + "color": "purple", + "icon": "dark-ride", + "css_class": "bg-purple-100 text-purple-800", + "sort_order": 2, + }, + { + "value": "FR", + "label": "Flat Ride", + "description": "Spinning and rotating attractions", + "color": "blue", + "icon": "flat-ride", + "css_class": "bg-blue-100 text-blue-800", + "sort_order": 3, + }, + { + "value": "WR", + "label": "Water Ride", + "description": "Water-based attractions and slides", + "color": "cyan", + "icon": "water-ride", + "css_class": "bg-cyan-100 text-cyan-800", + "sort_order": 4, + }, + { + "value": "TR", + "label": "Transport", + "description": "Transportation systems within parks", + "color": "green", + "icon": "transport", + "css_class": "bg-green-100 text-green-800", + "sort_order": 5, + }, + { + "value": "OT", + "label": "Other", + "description": "Miscellaneous attractions", + "color": "gray", + "icon": "other", + "css_class": "bg-gray-100 text-gray-800", + "sort_order": 6, + }, ] target_markets_data = [ - {"value": "FAMILY", "label": "Family", "description": "Suitable for all family members", "color": "green", "icon": "family", "css_class": "bg-green-100 text-green-800", "sort_order": 1}, - {"value": "THRILL", "label": "Thrill", "description": "High-intensity thrill experience", "color": "orange", "icon": "thrill", "css_class": "bg-orange-100 text-orange-800", "sort_order": 2}, - {"value": "EXTREME", "label": "Extreme", "description": "Maximum intensity experience", "color": "red", "icon": "extreme", "css_class": "bg-red-100 text-red-800", "sort_order": 3}, - {"value": "KIDDIE", "label": "Kiddie", "description": "Designed for young children", "color": "pink", "icon": "kiddie", "css_class": "bg-pink-100 text-pink-800", "sort_order": 4}, - {"value": "ALL_AGES", "label": "All Ages", "description": "Enjoyable for all age groups", "color": "blue", "icon": "all-ages", "css_class": "bg-blue-100 text-blue-800", "sort_order": 5}, + { + "value": "FAMILY", + "label": "Family", + "description": "Suitable for all family members", + "color": "green", + "icon": "family", + "css_class": "bg-green-100 text-green-800", + "sort_order": 1, + }, + { + "value": "THRILL", + "label": "Thrill", + "description": "High-intensity thrill experience", + "color": "orange", + "icon": "thrill", + "css_class": "bg-orange-100 text-orange-800", + "sort_order": 2, + }, + { + "value": "EXTREME", + "label": "Extreme", + "description": "Maximum intensity experience", + "color": "red", + "icon": "extreme", + "css_class": "bg-red-100 text-red-800", + "sort_order": 3, + }, + { + "value": "KIDDIE", + "label": "Kiddie", + "description": "Designed for young children", + "color": "pink", + "icon": "kiddie", + "css_class": "bg-pink-100 text-pink-800", + "sort_order": 4, + }, + { + "value": "ALL_AGES", + "label": "All Ages", + "description": "Enjoyable for all age groups", + "color": "blue", + "icon": "all-ages", + "css_class": "bg-blue-100 text-blue-800", + "sort_order": 5, + }, ] - return Response({ + return Response( + { + "categories": categories_data, + "target_markets": target_markets_data, + "manufacturers": [{"id": 1, "name": "Bolliger & Mabillard", "slug": "bolliger-mabillard"}], + "ordering_options": [ + {"value": "name", "label": "Name A-Z"}, + {"value": "-name", "label": "Name Z-A"}, + {"value": "manufacturer__name", "label": "Manufacturer A-Z"}, + {"value": "-manufacturer__name", "label": "Manufacturer Z-A"}, + {"value": "first_installation_year", "label": "Oldest First"}, + {"value": "-first_installation_year", "label": "Newest First"}, + {"value": "total_installations", "label": "Fewest Installations"}, + {"value": "-total_installations", "label": "Most Installations"}, + ], + } + ) + + # Get static choice definitions from Rich Choice Objects (primary source) + # Get dynamic data from database queries + + # Get rich choice objects from registry + categories = get_choices("categories", "rides") + target_markets = get_choices("target_markets", "rides") + + # Convert Rich Choice Objects to frontend format with metadata + categories_data = [ + { + "value": choice.value, + "label": choice.label, + "description": choice.description, + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), + } + for choice in categories + ] + + target_markets_data = [ + { + "value": choice.value, + "label": choice.label, + "description": choice.description, + "color": choice.metadata.get("color"), + "icon": choice.metadata.get("icon"), + "css_class": choice.metadata.get("css_class"), + "sort_order": choice.metadata.get("sort_order", 0), + } + for choice in target_markets + ] + + # Get actual data from database + manufacturers = ( + Company.objects.filter(roles__contains=["MANUFACTURER"], ride_models__isnull=False) + .distinct() + .values("id", "name", "slug") + ) + + return Response( + { "categories": categories_data, "target_markets": target_markets_data, - "manufacturers": [{"id": 1, "name": "Bolliger & Mabillard", "slug": "bolliger-mabillard"}], + "manufacturers": list(manufacturers), "ordering_options": [ {"value": "name", "label": "Name A-Z"}, {"value": "-name", "label": "Name Z-A"}, @@ -553,68 +659,9 @@ class RideModelFilterOptionsAPIView(APIView): {"value": "total_installations", "label": "Fewest Installations"}, {"value": "-total_installations", "label": "Most Installations"}, ], - }) - - # Get static choice definitions from Rich Choice Objects (primary source) - # Get dynamic data from database queries - - # Get rich choice objects from registry - categories = get_choices('categories', 'rides') - target_markets = get_choices('target_markets', 'rides') - - # Convert Rich Choice Objects to frontend format with metadata - categories_data = [ - { - "value": choice.value, - "label": choice.label, - "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) } - for choice in categories - ] - - target_markets_data = [ - { - "value": choice.value, - "label": choice.label, - "description": choice.description, - "color": choice.metadata.get('color'), - "icon": choice.metadata.get('icon'), - "css_class": choice.metadata.get('css_class'), - "sort_order": choice.metadata.get('sort_order', 0) - } - for choice in target_markets - ] - - # Get actual data from database - manufacturers = ( - Company.objects.filter( - roles__contains=["MANUFACTURER"], ride_models__isnull=False - ) - .distinct() - .values("id", "name", "slug") ) - return Response({ - "categories": categories_data, - "target_markets": target_markets_data, - "manufacturers": list(manufacturers), - "ordering_options": [ - {"value": "name", "label": "Name A-Z"}, - {"value": "-name", "label": "Name Z-A"}, - {"value": "manufacturer__name", "label": "Manufacturer A-Z"}, - {"value": "-manufacturer__name", "label": "Manufacturer Z-A"}, - {"value": "first_installation_year", "label": "Oldest First"}, - {"value": "-first_installation_year", "label": "Newest First"}, - {"value": "total_installations", "label": "Fewest Installations"}, - {"value": "-total_installations", "label": "Most Installations"}, - ], - }) - - # === RIDE MODEL STATISTICS === @@ -646,37 +693,23 @@ class RideModelStatsAPIView(APIView): # Calculate statistics total_models = RideModel.objects.count() - total_installations = ( - RideModel.objects.aggregate(total=Count("rides"))["total"] or 0 - ) + total_installations = RideModel.objects.aggregate(total=Count("rides"))["total"] or 0 active_manufacturers = ( - Company.objects.filter( - roles__contains=["MANUFACTURER"], ride_models__isnull=False - ) - .distinct() - .count() + Company.objects.filter(roles__contains=["MANUFACTURER"], ride_models__isnull=False).distinct().count() ) discontinued_models = RideModel.objects.filter(is_discontinued=True).count() # Category breakdown by_category = {} - category_counts = ( - RideModel.objects.exclude(category="") - .values("category") - .annotate(count=Count("id")) - ) + category_counts = RideModel.objects.exclude(category="").values("category").annotate(count=Count("id")) for item in category_counts: by_category[item["category"]] = item["count"] # Target market breakdown by_target_market = {} - market_counts = ( - RideModel.objects.exclude(target_market="") - .values("target_market") - .annotate(count=Count("id")) - ) + market_counts = RideModel.objects.exclude(target_market="").values("target_market").annotate(count=Count("id")) for item in market_counts: by_target_market[item["target_market"]] = item["count"] @@ -693,9 +726,7 @@ class RideModelStatsAPIView(APIView): # Recent models (last 30 days) thirty_days_ago = timezone.now() - timedelta(days=30) - recent_models = RideModel.objects.filter( - created_at__gte=thirty_days_ago - ).count() + recent_models = RideModel.objects.filter(created_at__gte=thirty_days_ago).count() return Response( { @@ -730,7 +761,7 @@ class RideModelVariantListCreateAPIView(APIView): try: ride_model = RideModel.objects.get(pk=ride_model_pk) except RideModel.DoesNotExist: - raise NotFound("Ride model not found") + raise NotFound("Ride model not found") from None variants = RideModelVariant.objects.filter(ride_model=ride_model) serializer = RideModelVariantOutputSerializer(variants, many=True) @@ -753,7 +784,7 @@ class RideModelVariantListCreateAPIView(APIView): try: ride_model = RideModel.objects.get(pk=ride_model_pk) except RideModel.DoesNotExist: - raise NotFound("Ride model not found") + raise NotFound("Ride model not found") from None # Override ride_model_id in the data data = request.data.copy() @@ -787,7 +818,7 @@ class RideModelVariantDetailAPIView(APIView): try: return RideModelVariant.objects.get(ride_model_id=ride_model_pk, pk=pk) except RideModelVariant.DoesNotExist: - raise NotFound("Variant not found") + raise NotFound("Variant not found") from None @extend_schema( summary="Get a ride model variant", @@ -807,9 +838,7 @@ class RideModelVariantDetailAPIView(APIView): ) def patch(self, request: Request, ride_model_pk: int, pk: int) -> Response: variant = self._get_variant_or_404(ride_model_pk, pk) - serializer_in = RideModelVariantUpdateInputSerializer( - data=request.data, partial=True - ) + serializer_in = RideModelVariantUpdateInputSerializer(data=request.data, partial=True) serializer_in.is_valid(raise_exception=True) for field, value in serializer_in.validated_data.items(): diff --git a/backend/apps/api/v1/rides/photo_views.py b/backend/apps/api/v1/rides/photo_views.py index eb400826..7e11bf22 100644 --- a/backend/apps/api/v1/rides/photo_views.py +++ b/backend/apps/api/v1/rides/photo_views.py @@ -118,9 +118,7 @@ class RidePhotoViewSet(ModelViewSet): def get_queryset(self): # type: ignore[override] """Get photos for the current ride with optimized queries.""" - queryset = RidePhoto.objects.select_related( - "ride", "ride__park", "ride__park__operator", "uploaded_by" - ) + queryset = RidePhoto.objects.select_related("ride", "ride__park", "ride__park__operator", "uploaded_by") # If ride_pk is provided in URL kwargs, filter by ride ride_pk = self.kwargs.get("ride_pk") @@ -149,7 +147,7 @@ class RidePhotoViewSet(ModelViewSet): try: ride = Ride.objects.get(pk=ride_id) except Ride.DoesNotExist: - raise ValidationError("Ride not found") + raise ValidationError("Ride not found") from None try: # Use the service to create the photo with proper business logic @@ -169,17 +167,14 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error creating ride photo: {e}") - raise ValidationError(f"Failed to create photo: {str(e)}") + raise ValidationError(f"Failed to create photo: {str(e)}") from None def perform_update(self, serializer): """Update ride photo with permission checking.""" instance = self.get_object() # Check permissions - allow owner or staff - if not ( - self.request.user == instance.uploaded_by - or getattr(self.request.user, "is_staff", False) - ): + if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)): raise PermissionDenied("You can only edit your own photos or be an admin.") # Handle primary photo logic using service @@ -191,39 +186,31 @@ class RidePhotoViewSet(ModelViewSet): del serializer.validated_data["is_primary"] except Exception as e: logger.error(f"Error setting primary photo: {e}") - raise ValidationError(f"Failed to set primary photo: {str(e)}") + raise ValidationError(f"Failed to set primary photo: {str(e)}") from None def perform_destroy(self, instance): """Delete ride photo with permission checking.""" # Check permissions - allow owner or staff - if not ( - self.request.user == instance.uploaded_by - or getattr(self.request.user, "is_staff", False) - ): - raise PermissionDenied( - "You can only delete your own photos or be an admin." - ) + if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)): + raise PermissionDenied("You can only delete your own photos or be an admin.") try: # Delete from Cloudflare first if image exists if instance.image: try: from django_cloudflareimages_toolkit.services import CloudflareImagesService + service = CloudflareImagesService() service.delete_image(instance.image) - logger.info( - f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}") + logger.info(f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}") except Exception as e: - logger.error( - f"Failed to delete ride photo from Cloudflare: {str(e)}") + logger.error(f"Failed to delete ride photo from Cloudflare: {str(e)}") # Continue with database deletion even if Cloudflare deletion fails - RideMediaService.delete_photo( - instance, deleted_by=self.request.user # type: ignore - ) + RideMediaService.delete_photo(instance, deleted_by=self.request.user) # type: ignore except Exception as e: logger.error(f"Error deleting ride photo: {e}") - raise ValidationError(f"Failed to delete photo: {str(e)}") + raise ValidationError(f"Failed to delete photo: {str(e)}") from None @extend_schema( summary="Set photo as primary", @@ -242,13 +229,8 @@ class RidePhotoViewSet(ModelViewSet): photo = self.get_object() # Check permissions - allow owner or staff - if not ( - request.user == photo.uploaded_by - or getattr(request.user, "is_staff", False) - ): - raise PermissionDenied( - "You can only modify your own photos or be an admin." - ) + if not (request.user == photo.uploaded_by or getattr(request.user, "is_staff", False)): + raise PermissionDenied("You can only modify your own photos or be an admin.") try: success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo) @@ -260,21 +242,21 @@ class RidePhotoViewSet(ModelViewSet): return Response( { - "message": "Photo set as primary successfully", + "detail": "Photo set as primary successfully", "photo": serializer.data, }, status=status.HTTP_200_OK, ) else: return Response( - {"error": "Failed to set primary photo"}, + {"detail": "Failed to set primary photo"}, status=status.HTTP_400_BAD_REQUEST, ) except Exception as e: logger.error(f"Error setting primary photo: {e}") return Response( - {"error": f"Failed to set primary photo: {str(e)}"}, + {"detail": f"Failed to set primary photo: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -305,7 +287,7 @@ class RidePhotoViewSet(ModelViewSet): if photo_ids is None or approve is None: return Response( - {"error": "Missing required fields: photo_ids and/or approve."}, + {"detail": "Missing required fields: photo_ids and/or approve."}, status=status.HTTP_400_BAD_REQUEST, ) @@ -319,7 +301,7 @@ class RidePhotoViewSet(ModelViewSet): return Response( { - "message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", + "detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos", "updated_count": updated_count, }, status=status.HTTP_200_OK, @@ -328,7 +310,7 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error in bulk photo approval: {e}") return Response( - {"error": f"Failed to update photos: {str(e)}"}, + {"detail": f"Failed to update photos: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -352,7 +334,7 @@ class RidePhotoViewSet(ModelViewSet): ride = Ride.objects.get(pk=ride_pk) except Ride.DoesNotExist: return Response( - {"error": "Ride not found."}, + {"detail": "Ride not found."}, status=status.HTTP_404_NOT_FOUND, ) @@ -363,16 +345,10 @@ class RidePhotoViewSet(ModelViewSet): # Global stats across all rides stats = { "total_photos": RidePhoto.objects.count(), - "approved_photos": RidePhoto.objects.filter( - is_approved=True - ).count(), - "pending_photos": RidePhoto.objects.filter( - is_approved=False - ).count(), + "approved_photos": RidePhoto.objects.filter(is_approved=True).count(), + "pending_photos": RidePhoto.objects.filter(is_approved=False).count(), "has_primary": False, # Not applicable for global stats - "recent_uploads": RidePhoto.objects.order_by("-created_at")[ - :5 - ].count(), + "recent_uploads": RidePhoto.objects.order_by("-created_at")[:5].count(), "by_type": {}, } @@ -382,7 +358,7 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error getting ride photo stats: {e}") return Response( - {"error": f"Failed to get photo statistics: {str(e)}"}, + {"detail": f"Failed to get photo statistics: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -401,26 +377,23 @@ class RidePhotoViewSet(ModelViewSet): def set_primary_legacy(self, request, id=None): """Legacy set primary action for backwards compatibility.""" photo = self.get_object() - if not ( - request.user == photo.uploaded_by - or request.user.has_perm("rides.change_ridephoto") - ): + if not (request.user == photo.uploaded_by or request.user.has_perm("rides.change_ridephoto")): return Response( - {"error": "You do not have permission to edit photos for this ride."}, + {"detail": "You do not have permission to edit photos for this ride."}, status=status.HTTP_403_FORBIDDEN, ) try: success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo) if success: - return Response({"message": "Photo set as primary successfully."}) + return Response({"detail": "Photo set as primary successfully."}) else: return Response( - {"error": "Failed to set primary photo"}, + {"detail": "Failed to set primary photo"}, status=status.HTTP_400_BAD_REQUEST, ) except Exception as e: logger.error(f"Error in set_primary_photo: {str(e)}", exc_info=True) - return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) @extend_schema( summary="Save Cloudflare image as ride photo", @@ -440,7 +413,7 @@ class RidePhotoViewSet(ModelViewSet): ride_pk = self.kwargs.get("ride_pk") if not ride_pk: return Response( - {"error": "Ride ID is required"}, + {"detail": "Ride ID is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -448,14 +421,14 @@ class RidePhotoViewSet(ModelViewSet): ride = Ride.objects.get(pk=ride_pk) except Ride.DoesNotExist: return Response( - {"error": "Ride not found"}, + {"detail": "Ride not found"}, status=status.HTTP_404_NOT_FOUND, ) cloudflare_image_id = request.data.get("cloudflare_image_id") if not cloudflare_image_id: return Response( - {"error": "cloudflare_image_id is required"}, + {"detail": "cloudflare_image_id is required"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -473,27 +446,25 @@ class RidePhotoViewSet(ModelViewSet): if not image_data: return Response( - {"error": "Image not found in Cloudflare"}, + {"detail": "Image not found in Cloudflare"}, status=status.HTTP_400_BAD_REQUEST, ) # Try to find existing CloudflareImage record by cloudflare_id cloudflare_image = None try: - cloudflare_image = CloudflareImage.objects.get( - cloudflare_id=cloudflare_image_id) + cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id) # Update existing record with latest data from Cloudflare - cloudflare_image.status = 'uploaded' + cloudflare_image.status = "uploaded" cloudflare_image.uploaded_at = timezone.now() - cloudflare_image.metadata = image_data.get('meta', {}) + cloudflare_image.metadata = image_data.get("meta", {}) # Extract variants from nested result structure - cloudflare_image.variants = image_data.get( - 'result', {}).get('variants', []) + cloudflare_image.variants = image_data.get("result", {}).get("variants", []) cloudflare_image.cloudflare_metadata = image_data - cloudflare_image.width = image_data.get('width') - cloudflare_image.height = image_data.get('height') - cloudflare_image.format = image_data.get('format', '') + cloudflare_image.width = image_data.get("width") + cloudflare_image.height = image_data.get("height") + cloudflare_image.format = image_data.get("format", "") cloudflare_image.save() except CloudflareImage.DoesNotExist: @@ -501,24 +472,23 @@ class RidePhotoViewSet(ModelViewSet): cloudflare_image = CloudflareImage.objects.create( cloudflare_id=cloudflare_image_id, user=request.user, - status='uploaded', - upload_url='', # Not needed for uploaded images + status="uploaded", + upload_url="", # Not needed for uploaded images expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry uploaded_at=timezone.now(), - metadata=image_data.get('meta', {}), + metadata=image_data.get("meta", {}), # Extract variants from nested result structure - variants=image_data.get('result', {}).get('variants', []), + variants=image_data.get("result", {}).get("variants", []), cloudflare_metadata=image_data, - width=image_data.get('width'), - height=image_data.get('height'), - format=image_data.get('format', ''), + width=image_data.get("width"), + height=image_data.get("height"), + format=image_data.get("format", ""), ) except Exception as api_error: - logger.error( - f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True) + logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True) return Response( - {"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, + {"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"}, status=status.HTTP_400_BAD_REQUEST, ) @@ -548,6 +518,6 @@ class RidePhotoViewSet(ModelViewSet): except Exception as e: logger.error(f"Error saving ride photo: {e}") return Response( - {"error": f"Failed to save photo: {str(e)}"}, + {"detail": f"Failed to save photo: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST, ) diff --git a/backend/apps/api/v1/rides/serializers.py b/backend/apps/api/v1/rides/serializers.py index f2700152..76e32744 100644 --- a/backend/apps/api/v1/rides/serializers.py +++ b/backend/apps/api/v1/rides/serializers.py @@ -52,18 +52,14 @@ from apps.rides.models import Ride, RidePhoto class RidePhotoOutputSerializer(serializers.ModelSerializer): """Output serializer for ride photos with Cloudflare Images support.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) file_size = serializers.SerializerMethodField() dimensions = serializers.SerializerMethodField() image_url = serializers.SerializerMethodField() image_variants = serializers.SerializerMethodField() - @extend_schema_field( - serializers.IntegerField(allow_null=True, help_text="File size in bytes") - ) + @extend_schema_field(serializers.IntegerField(allow_null=True, help_text="File size in bytes")) def get_file_size(self, obj): """Get file size in bytes.""" return obj.file_size @@ -81,11 +77,7 @@ class RidePhotoOutputSerializer(serializers.ModelSerializer): """Get image dimensions as [width, height].""" return obj.dimensions - @extend_schema_field( - serializers.URLField( - help_text="Full URL to the Cloudflare Images asset", allow_null=True - ) - ) + @extend_schema_field(serializers.URLField(help_text="Full URL to the Cloudflare Images asset", allow_null=True)) def get_image_url(self, obj): """Get the full Cloudflare Images URL.""" if obj.image: @@ -186,9 +178,7 @@ class RidePhotoUpdateInputSerializer(serializers.ModelSerializer): class RidePhotoListOutputSerializer(serializers.ModelSerializer): """Simplified output serializer for ride photo lists.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) class Meta: model = RidePhoto @@ -208,12 +198,8 @@ class RidePhotoListOutputSerializer(serializers.ModelSerializer): class RidePhotoApprovalInputSerializer(serializers.Serializer): """Input serializer for photo approval operations.""" - photo_ids = serializers.ListField( - child=serializers.IntegerField(), help_text="List of photo IDs to approve" - ) - approve = serializers.BooleanField( - default=True, help_text="Whether to approve (True) or reject (False) the photos" - ) + photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve") + approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos") class RidePhotoStatsOutputSerializer(serializers.Serializer): @@ -224,9 +210,7 @@ class RidePhotoStatsOutputSerializer(serializers.Serializer): pending_photos = serializers.IntegerField() has_primary = serializers.BooleanField() recent_uploads = serializers.IntegerField() - by_type = serializers.DictField( - child=serializers.IntegerField(), help_text="Photo counts by type" - ) + by_type = serializers.DictField(child=serializers.IntegerField(), help_text="Photo counts by type") class RidePhotoTypeFilterSerializer(serializers.Serializer): @@ -292,8 +276,12 @@ class HybridRideSerializer(serializers.ModelSerializer): ride_model_name = serializers.CharField(source="ride_model.name", read_only=True, allow_null=True) ride_model_slug = serializers.CharField(source="ride_model.slug", read_only=True, allow_null=True) ride_model_category = serializers.CharField(source="ride_model.category", read_only=True, allow_null=True) - ride_model_manufacturer_name = serializers.CharField(source="ride_model.manufacturer.name", read_only=True, allow_null=True) - ride_model_manufacturer_slug = serializers.CharField(source="ride_model.manufacturer.slug", read_only=True, allow_null=True) + ride_model_manufacturer_name = serializers.CharField( + source="ride_model.manufacturer.name", read_only=True, allow_null=True + ) + ride_model_manufacturer_slug = serializers.CharField( + source="ride_model.manufacturer.slug", read_only=True, allow_null=True + ) # Roller coaster stats fields coaster_height_ft = serializers.SerializerMethodField() @@ -323,7 +311,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_park_city(self, obj): """Get city from park location.""" try: - if obj.park and hasattr(obj.park, 'location') and obj.park.location: + if obj.park and hasattr(obj.park, "location") and obj.park.location: return obj.park.location.city return None except AttributeError: @@ -333,7 +321,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_park_state(self, obj): """Get state from park location.""" try: - if obj.park and hasattr(obj.park, 'location') and obj.park.location: + if obj.park and hasattr(obj.park, "location") and obj.park.location: return obj.park.location.state return None except AttributeError: @@ -343,7 +331,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_park_country(self, obj): """Get country from park location.""" try: - if obj.park and hasattr(obj.park, 'location') and obj.park.location: + if obj.park and hasattr(obj.park, "location") and obj.park.location: return obj.park.location.country return None except AttributeError: @@ -353,7 +341,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_height_ft(self, obj): """Get roller coaster height.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return float(obj.coaster_stats.height_ft) if obj.coaster_stats.height_ft else None return None except (AttributeError, TypeError): @@ -363,7 +351,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_length_ft(self, obj): """Get roller coaster length.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return float(obj.coaster_stats.length_ft) if obj.coaster_stats.length_ft else None return None except (AttributeError, TypeError): @@ -373,7 +361,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_speed_mph(self, obj): """Get roller coaster speed.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return float(obj.coaster_stats.speed_mph) if obj.coaster_stats.speed_mph else None return None except (AttributeError, TypeError): @@ -383,7 +371,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_inversions(self, obj): """Get roller coaster inversions.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.inversions return None except AttributeError: @@ -393,7 +381,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_ride_time_seconds(self, obj): """Get roller coaster ride time.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.ride_time_seconds return None except AttributeError: @@ -403,7 +391,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_track_type(self, obj): """Get roller coaster track type.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.track_type return None except AttributeError: @@ -413,7 +401,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_track_material(self, obj): """Get roller coaster track material.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.track_material return None except AttributeError: @@ -423,7 +411,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_roller_coaster_type(self, obj): """Get roller coaster type.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.roller_coaster_type return None except AttributeError: @@ -433,7 +421,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_max_drop_height_ft(self, obj): """Get roller coaster max drop height.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return float(obj.coaster_stats.max_drop_height_ft) if obj.coaster_stats.max_drop_height_ft else None return None except (AttributeError, TypeError): @@ -443,7 +431,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_propulsion_system(self, obj): """Get roller coaster propulsion system.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.propulsion_system return None except AttributeError: @@ -453,7 +441,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_train_style(self, obj): """Get roller coaster train style.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.train_style return None except AttributeError: @@ -463,7 +451,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_trains_count(self, obj): """Get roller coaster trains count.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.trains_count return None except AttributeError: @@ -473,7 +461,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_cars_per_train(self, obj): """Get roller coaster cars per train.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.cars_per_train return None except AttributeError: @@ -483,7 +471,7 @@ class HybridRideSerializer(serializers.ModelSerializer): def get_coaster_seats_per_car(self, obj): """Get roller coaster seats per car.""" try: - if hasattr(obj, 'coaster_stats') and obj.coaster_stats: + if hasattr(obj, "coaster_stats") and obj.coaster_stats: return obj.coaster_stats.seats_per_car return None except AttributeError: @@ -514,44 +502,37 @@ class HybridRideSerializer(serializers.ModelSerializer): "category", "status", "post_closing_status", - # Dates and computed fields "opening_date", "closing_date", "status_since", "opening_year", - # Park fields "park_name", "park_slug", "park_city", "park_state", "park_country", - # Park area fields "park_area_name", "park_area_slug", - # Company fields "manufacturer_name", "manufacturer_slug", "designer_name", "designer_slug", - # Ride model fields "ride_model_name", "ride_model_slug", "ride_model_category", "ride_model_manufacturer_name", "ride_model_manufacturer_slug", - # Ride specifications "min_height_in", "max_height_in", "capacity_per_hour", "ride_duration_seconds", "average_rating", - # Roller coaster stats "coaster_height_ft", "coaster_length_ft", @@ -567,18 +548,14 @@ class HybridRideSerializer(serializers.ModelSerializer): "coaster_trains_count", "coaster_cars_per_train", "coaster_seats_per_car", - # Images "banner_image_url", "card_image_url", - # URLs "url", "park_url", - # Computed fields for filtering "search_text", - # Metadata "created_at", "updated_at", diff --git a/backend/apps/api/v1/rides/urls.py b/backend/apps/api/v1/rides/urls.py index 5ccdd456..d8a516bc 100644 --- a/backend/apps/api/v1/rides/urls.py +++ b/backend/apps/api/v1/rides/urls.py @@ -35,11 +35,9 @@ app_name = "api_v1_rides" urlpatterns = [ # Core list/create endpoints path("", RideListCreateAPIView.as_view(), name="ride-list-create"), - # Hybrid filtering endpoints path("hybrid/", HybridRideAPIView.as_view(), name="ride-hybrid-filtering"), path("hybrid/filter-metadata/", RideFilterMetadataAPIView.as_view(), name="ride-hybrid-filter-metadata"), - # Filter options path("filter-options/", FilterOptionsAPIView.as_view(), name="ride-filter-options"), # Autocomplete / suggestion endpoints @@ -61,7 +59,6 @@ urlpatterns = [ # Manufacturer and Designer endpoints path("manufacturers/", ManufacturerListAPIView.as_view(), name="manufacturer-list"), path("designers/", DesignerListAPIView.as_view(), name="designer-list"), - # Ride model management endpoints - nested under rides/manufacturers path( "manufacturers//", diff --git a/backend/apps/api/v1/rides/views.py b/backend/apps/api/v1/rides/views.py index 183a532c..5f858ad6 100644 --- a/backend/apps/api/v1/rides/views.py +++ b/backend/apps/api/v1/rides/views.py @@ -28,6 +28,7 @@ import logging from typing import Any from django.db import models +from django.db.models import Count from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view from rest_framework import permissions, status @@ -333,9 +334,7 @@ class RideListCreateAPIView(APIView): paginator = StandardResultsSetPagination() page = paginator.paginate_queryset(qs, request) - serializer = RideListOutputSerializer( - page, many=True, context={"request": request} - ) + serializer = RideListOutputSerializer(page, many=True, context={"request": request}) return paginator.get_paginated_response(serializer.data) def _apply_filters(self, qs, params): @@ -567,9 +566,9 @@ class RideListCreateAPIView(APIView): if ordering in valid_orderings: if ordering in ["height_ft", "-height_ft", "speed_mph", "-speed_mph"]: # For coaster stats ordering, we need to join and order by the stats - ordering_field = ordering.replace( - "height_ft", "coaster_stats__height_ft" - ).replace("speed_mph", "coaster_stats__speed_mph") + ordering_field = ordering.replace("height_ft", "coaster_stats__height_ft").replace( + "speed_mph", "coaster_stats__speed_mph" + ) qs = qs.order_by(ordering_field) else: qs = qs.order_by(ordering) @@ -602,7 +601,7 @@ class RideListCreateAPIView(APIView): try: park = Park.objects.get(id=validated["park_id"]) # type: ignore except Park.DoesNotExist: # type: ignore - raise NotFound("Park not found") + raise NotFound("Park not found") from None ride = Ride.objects.create( # type: ignore name=validated["name"], @@ -658,7 +657,7 @@ class RideDetailAPIView(APIView): try: return Ride.objects.select_related("park").get(pk=pk) # type: ignore except Ride.DoesNotExist: # type: ignore - raise NotFound("Ride not found") + raise NotFound("Ride not found") from None @cache_api_response(timeout=1800, key_prefix="ride_detail") def get(self, request: Request, pk: int) -> Response: @@ -672,9 +671,7 @@ class RideDetailAPIView(APIView): serializer_in.is_valid(raise_exception=True) if not MODELS_AVAILABLE: return Response( - { - "detail": "Ride update is not available because domain models are not imported." - }, + {"detail": "Ride update is not available because domain models are not imported."}, status=status.HTTP_501_NOT_IMPLEMENTED, ) @@ -690,7 +687,7 @@ class RideDetailAPIView(APIView): # Use the move_to_park method for proper handling park_change_info = ride.move_to_park(new_park) except Park.DoesNotExist: # type: ignore - raise NotFound("Target park not found") + raise NotFound("Target park not found") from None # Apply other field updates for key, value in validated_data.items(): @@ -715,9 +712,7 @@ class RideDetailAPIView(APIView): def delete(self, request: Request, pk: int) -> Response: if not MODELS_AVAILABLE: return Response( - { - "detail": "Ride delete is not available because domain models are not imported." - }, + {"detail": "Ride delete is not available because domain models are not imported."}, status=status.HTTP_501_NOT_IMPLEMENTED, ) ride = self._get_ride_or_404(pk) @@ -1491,16 +1486,12 @@ class FilterOptionsAPIView(APIView): # Get manufacturers (companies with MANUFACTURER role) manufacturers = list( - Company.objects.filter(roles__contains=["MANUFACTURER"]) - .values("id", "name", "slug") - .order_by("name") + Company.objects.filter(roles__contains=["MANUFACTURER"]).values("id", "name", "slug").order_by("name") ) # Get designers (companies with DESIGNER role) designers = list( - Company.objects.filter(roles__contains=["DESIGNER"]) - .values("id", "name", "slug") - .order_by("name") + Company.objects.filter(roles__contains=["DESIGNER"]).values("id", "name", "slug").order_by("name") ) # Get ride models data from database @@ -1722,11 +1713,7 @@ class FilterOptionsAPIView(APIView): # --- Company search (autocomplete) ----------------------------------------- @extend_schema( summary="Search companies (manufacturers/designers) for autocomplete", - parameters=[ - OpenApiParameter( - name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR - ) - ], + parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)], responses={200: OpenApiTypes.OBJECT}, tags=["Rides"], ) @@ -1753,20 +1740,14 @@ class CompanySearchAPIView(APIView): ) qs = Company.objects.filter(name__icontains=q)[:20] # type: ignore - results = [ - {"id": c.id, "name": c.name, "slug": getattr(c, "slug", "")} for c in qs - ] + results = [{"id": c.id, "name": c.name, "slug": getattr(c, "slug", "")} for c in qs] return Response(results) # --- Ride model search (autocomplete) -------------------------------------- @extend_schema( summary="Search ride models for autocomplete", - parameters=[ - OpenApiParameter( - name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR - ) - ], + parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)], tags=["Rides"], ) class RideModelSearchAPIView(APIView): @@ -1795,21 +1776,14 @@ class RideModelSearchAPIView(APIView): ) qs = RideModel.objects.filter(name__icontains=q)[:20] # type: ignore - results = [ - {"id": m.id, "name": m.name, "category": getattr(m, "category", "")} - for m in qs - ] + results = [{"id": m.id, "name": m.name, "category": getattr(m, "category", "")} for m in qs] return Response(results) # --- Search suggestions ----------------------------------------------------- @extend_schema( summary="Search suggestions for ride search box", - parameters=[ - OpenApiParameter( - name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR - ) - ], + parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)], tags=["Rides"], ) class RideSearchSuggestionsAPIView(APIView): @@ -1827,9 +1801,7 @@ class RideSearchSuggestionsAPIView(APIView): # Very small suggestion implementation: look in ride names if available if MODELS_AVAILABLE and Ride is not None: - qs = Ride.objects.filter(name__icontains=q).values_list("name", flat=True)[ - :10 - ] # type: ignore + qs = Ride.objects.filter(name__icontains=q).values_list("name", flat=True)[:10] # type: ignore return Response([{"suggestion": name} for name in qs]) # Fallback suggestions @@ -1862,7 +1834,7 @@ class RideImageSettingsAPIView(APIView): try: return Ride.objects.get(pk=pk) # type: ignore except Ride.DoesNotExist: # type: ignore - raise NotFound("Ride not found") + raise NotFound("Ride not found") from None def patch(self, request: Request, pk: int) -> Response: """Set banner and card images for the ride.""" @@ -1878,9 +1850,7 @@ class RideImageSettingsAPIView(APIView): ride.save() # Return updated ride data - output_serializer = RideDetailOutputSerializer( - ride, context={"request": request} - ) + output_serializer = RideDetailOutputSerializer(ride, context={"request": request}) return Response(output_serializer.data) @@ -1902,12 +1872,8 @@ class RideImageSettingsAPIView(APIView): OpenApiTypes.STR, description="Filter by ride status (comma-separated for multiple)", ), - OpenApiParameter( - "park_slug", OpenApiTypes.STR, description="Filter by park slug" - ), - OpenApiParameter( - "park_id", OpenApiTypes.INT, description="Filter by park ID" - ), + OpenApiParameter("park_slug", OpenApiTypes.STR, description="Filter by park slug"), + OpenApiParameter("park_id", OpenApiTypes.INT, description="Filter by park ID"), OpenApiParameter( "manufacturer", OpenApiTypes.STR, @@ -1923,18 +1889,10 @@ class RideImageSettingsAPIView(APIView): OpenApiTypes.STR, description="Filter by ride model slug (comma-separated for multiple)", ), - OpenApiParameter( - "opening_year_min", OpenApiTypes.INT, description="Minimum opening year" - ), - OpenApiParameter( - "opening_year_max", OpenApiTypes.INT, description="Maximum opening year" - ), - OpenApiParameter( - "rating_min", OpenApiTypes.NUMBER, description="Minimum average rating" - ), - OpenApiParameter( - "rating_max", OpenApiTypes.NUMBER, description="Maximum average rating" - ), + OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"), + OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"), + OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"), + OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"), OpenApiParameter( "height_requirement_min", OpenApiTypes.INT, @@ -1945,12 +1903,8 @@ class RideImageSettingsAPIView(APIView): OpenApiTypes.INT, description="Maximum height requirement in inches", ), - OpenApiParameter( - "capacity_min", OpenApiTypes.INT, description="Minimum hourly capacity" - ), - OpenApiParameter( - "capacity_max", OpenApiTypes.INT, description="Maximum hourly capacity" - ), + OpenApiParameter("capacity_min", OpenApiTypes.INT, description="Minimum hourly capacity"), + OpenApiParameter("capacity_max", OpenApiTypes.INT, description="Maximum hourly capacity"), OpenApiParameter( "roller_coaster_type", OpenApiTypes.STR, @@ -2022,9 +1976,7 @@ class RideImageSettingsAPIView(APIView): "properties": { "rides": { "type": "array", - "items": { - "$ref": "#/components/schemas/HybridRideSerializer" - }, + "items": {"$ref": "#/components/schemas/HybridRideSerializer"}, }, "total_count": {"type": "integer"}, "strategy": { @@ -2084,7 +2036,7 @@ class HybridRideAPIView(APIView): data = smart_ride_loader.get_progressive_load(offset, filters) except ValueError: return Response( - {"error": "Invalid offset parameter"}, + {"detail": "Invalid offset parameter"}, status=status.HTTP_400_BAD_REQUEST, ) else: @@ -2109,7 +2061,7 @@ class HybridRideAPIView(APIView): except Exception as e: logger.error(f"Error in HybridRideAPIView: {e}") return Response( - {"error": "Internal server error"}, + {"detail": "Internal server error"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -2158,7 +2110,7 @@ class HybridRideAPIView(APIView): for param in int_params: value = query_params.get(param) if value: - try: + try: # noqa: SIM105 filters[param] = int(value) except ValueError: pass # Skip invalid integer values @@ -2175,7 +2127,7 @@ class HybridRideAPIView(APIView): for param in float_params: value = query_params.get(param) if value: - try: + try: # noqa: SIM105 filters[param] = float(value) except ValueError: pass # Skip invalid float values @@ -2408,7 +2360,7 @@ class RideFilterMetadataAPIView(APIView): except Exception as e: logger.error(f"Error in RideFilterMetadataAPIView: {e}") return Response( - {"error": "Internal server error"}, + {"detail": "Internal server error"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) @@ -2417,18 +2369,18 @@ class RideFilterMetadataAPIView(APIView): # Reuse the same filter extraction logic view = HybridRideAPIView() return view._extract_filters(query_params) + + # === MANUFACTURER & DESIGNER LISTS === + class BaseCompanyListAPIView(APIView): permission_classes = [permissions.AllowAny] role = None def get(self, request: Request) -> Response: if not MODELS_AVAILABLE: - return Response( - {"detail": "Models not available"}, - status=status.HTTP_501_NOT_IMPLEMENTED - ) + return Response({"detail": "Models not available"}, status=status.HTTP_501_NOT_IMPLEMENTED) companies = ( Company.objects.filter(roles__contains=[self.role]) @@ -2448,10 +2400,8 @@ class BaseCompanyListAPIView(APIView): for c in companies ] - return Response({ - "results": data, - "count": len(data) - }) + return Response({"results": data, "count": len(data)}) + @extend_schema( summary="List manufacturers", @@ -2462,6 +2412,7 @@ class BaseCompanyListAPIView(APIView): class ManufacturerListAPIView(BaseCompanyListAPIView): role = "MANUFACTURER" + @extend_schema( summary="List designers", description="List all companies with DESIGNER role.", diff --git a/backend/apps/api/v1/serializers.py b/backend/apps/api/v1/serializers.py index 5f5a2b89..310fbe9f 100644 --- a/backend/apps/api/v1/serializers.py +++ b/backend/apps/api/v1/serializers.py @@ -49,5 +49,4 @@ __all__ = ( "UserProfileCreateInputSerializer", "UserProfileUpdateInputSerializer", "UserProfileOutputSerializer", - ) diff --git a/backend/apps/api/v1/serializers/__init__.py b/backend/apps/api/v1/serializers/__init__.py index 9b763ae5..2c94b220 100644 --- a/backend/apps/api/v1/serializers/__init__.py +++ b/backend/apps/api/v1/serializers/__init__.py @@ -90,7 +90,6 @@ _ACCOUNTS_SYMBOLS: list[str] = [ "UserProfileOutputSerializer", "UserProfileCreateInputSerializer", "UserProfileUpdateInputSerializer", - "UserOutputSerializer", "LoginInputSerializer", "LoginOutputSerializer", diff --git a/backend/apps/api/v1/serializers/accounts.py b/backend/apps/api/v1/serializers/accounts.py index 1af81654..03de38e6 100644 --- a/backend/apps/api/v1/serializers/accounts.py +++ b/backend/apps/api/v1/serializers/accounts.py @@ -187,6 +187,7 @@ class PublicUserSerializer(serializers.ModelSerializer): Public user serializer for viewing other users' profiles. Only exposes public information. """ + profile = UserProfileSerializer(read_only=True) class Meta: @@ -228,37 +229,21 @@ class UserPreferencesSerializer(serializers.Serializer): """Serializer for user preferences and settings.""" theme_preference = RichChoiceFieldSerializer( - choice_group="theme_preferences", - domain="accounts", - help_text="User's theme preference" - ) - email_notifications = serializers.BooleanField( - default=True, help_text="Whether to receive email notifications" - ) - push_notifications = serializers.BooleanField( - default=False, help_text="Whether to receive push notifications" + choice_group="theme_preferences", domain="accounts", help_text="User's theme preference" ) + email_notifications = serializers.BooleanField(default=True, help_text="Whether to receive email notifications") + push_notifications = serializers.BooleanField(default=False, help_text="Whether to receive push notifications") privacy_level = RichChoiceFieldSerializer( choice_group="privacy_levels", domain="accounts", default="public", help_text="Profile visibility level", ) - show_email = serializers.BooleanField( - default=False, help_text="Whether to show email on profile" - ) - show_real_name = serializers.BooleanField( - default=True, help_text="Whether to show real name on profile" - ) - show_statistics = serializers.BooleanField( - default=True, help_text="Whether to show ride statistics on profile" - ) - allow_friend_requests = serializers.BooleanField( - default=True, help_text="Whether to allow friend requests" - ) - allow_messages = serializers.BooleanField( - default=True, help_text="Whether to allow direct messages" - ) + show_email = serializers.BooleanField(default=False, help_text="Whether to show email on profile") + show_real_name = serializers.BooleanField(default=True, help_text="Whether to show real name on profile") + show_statistics = serializers.BooleanField(default=True, help_text="Whether to show ride statistics on profile") + allow_friend_requests = serializers.BooleanField(default=True, help_text="Whether to allow friend requests") + allow_messages = serializers.BooleanField(default=True, help_text="Whether to allow direct messages") # === NOTIFICATION SETTINGS SERIALIZERS === @@ -363,39 +348,17 @@ class PrivacySettingsSerializer(serializers.Serializer): default="public", help_text="Overall profile visibility", ) - show_email = serializers.BooleanField( - default=False, help_text="Show email address on profile" - ) - show_real_name = serializers.BooleanField( - default=True, help_text="Show real name on profile" - ) - show_join_date = serializers.BooleanField( - default=True, help_text="Show join date on profile" - ) - show_statistics = serializers.BooleanField( - default=True, help_text="Show ride statistics on profile" - ) - show_reviews = serializers.BooleanField( - default=True, help_text="Show reviews on profile" - ) - show_photos = serializers.BooleanField( - default=True, help_text="Show uploaded photos on profile" - ) - show_top_lists = serializers.BooleanField( - default=True, help_text="Show top lists on profile" - ) - allow_friend_requests = serializers.BooleanField( - default=True, help_text="Allow others to send friend requests" - ) - allow_messages = serializers.BooleanField( - default=True, help_text="Allow others to send direct messages" - ) - allow_profile_comments = serializers.BooleanField( - default=False, help_text="Allow others to comment on profile" - ) - search_visibility = serializers.BooleanField( - default=True, help_text="Allow profile to appear in search results" - ) + show_email = serializers.BooleanField(default=False, help_text="Show email address on profile") + show_real_name = serializers.BooleanField(default=True, help_text="Show real name on profile") + show_join_date = serializers.BooleanField(default=True, help_text="Show join date on profile") + show_statistics = serializers.BooleanField(default=True, help_text="Show ride statistics on profile") + show_reviews = serializers.BooleanField(default=True, help_text="Show reviews on profile") + show_photos = serializers.BooleanField(default=True, help_text="Show uploaded photos on profile") + show_top_lists = serializers.BooleanField(default=True, help_text="Show top lists on profile") + allow_friend_requests = serializers.BooleanField(default=True, help_text="Allow others to send friend requests") + allow_messages = serializers.BooleanField(default=True, help_text="Allow others to send direct messages") + allow_profile_comments = serializers.BooleanField(default=False, help_text="Allow others to comment on profile") + search_visibility = serializers.BooleanField(default=True, help_text="Allow profile to appear in search results") activity_visibility = RichChoiceFieldSerializer( choice_group="privacy_levels", domain="accounts", @@ -431,21 +394,13 @@ class SecuritySettingsSerializer(serializers.Serializer): two_factor_enabled = serializers.BooleanField( default=False, help_text="Whether two-factor authentication is enabled" ) - login_notifications = serializers.BooleanField( - default=True, help_text="Send notifications for new logins" - ) + login_notifications = serializers.BooleanField(default=True, help_text="Send notifications for new logins") session_timeout = serializers.IntegerField( default=30, min_value=5, max_value=180, help_text="Session timeout in days" ) - require_password_change = serializers.BooleanField( - default=False, help_text="Whether password change is required" - ) - last_password_change = serializers.DateTimeField( - read_only=True, help_text="When password was last changed" - ) - active_sessions = serializers.IntegerField( - read_only=True, help_text="Number of active sessions" - ) + require_password_change = serializers.BooleanField(default=False, help_text="Whether password change is required") + last_password_change = serializers.DateTimeField(read_only=True, help_text="When password was last changed") + active_sessions = serializers.IntegerField(read_only=True, help_text="Number of active sessions") login_history_retention = serializers.IntegerField( default=90, min_value=30, @@ -699,7 +654,7 @@ class ThemePreferenceSerializer(serializers.ModelSerializer): "id": 1, "notification_type": "submission_approved", "title": "Your submission has been approved!", - "message": "Your photo submission for Cedar Point has been approved and is now live on the site.", + "detail": "Your photo submission for Cedar Point has been approved and is now live on the site.", "priority": "normal", "is_read": False, "read_at": None, @@ -866,15 +821,11 @@ class MarkNotificationsReadSerializer(serializers.Serializer): def validate_notification_ids(self, value): """Validate that all notification IDs belong to the requesting user.""" user = self.context["request"].user - valid_ids = UserNotification.objects.filter( - id__in=value, user=user - ).values_list("id", flat=True) + valid_ids = UserNotification.objects.filter(id__in=value, user=user).values_list("id", flat=True) invalid_ids = set(value) - set(valid_ids) if invalid_ids: - raise serializers.ValidationError( - f"Invalid notification IDs: {list(invalid_ids)}" - ) + raise serializers.ValidationError(f"Invalid notification IDs: {list(invalid_ids)}") return value @@ -901,9 +852,8 @@ class AvatarUploadSerializer(serializers.Serializer): raise serializers.ValidationError("No file provided") # Check file size constraints (max 10MB for Cloudflare Images) - if hasattr(value, 'size') and value.size > 10 * 1024 * 1024: - raise serializers.ValidationError( - "Image file too large. Maximum size is 10MB.") + if hasattr(value, "size") and value.size > 10 * 1024 * 1024: + raise serializers.ValidationError("Image file too large. Maximum size is 10MB.") # Try to validate with PIL try: @@ -926,13 +876,13 @@ class AvatarUploadSerializer(serializers.Serializer): # Check image dimensions (max 12,000x12,000 for Cloudflare Images) if image.size[0] > 12000 or image.size[1] > 12000: - raise serializers.ValidationError( - "Image dimensions too large. Maximum is 12,000x12,000 pixels.") + raise serializers.ValidationError("Image dimensions too large. Maximum is 12,000x12,000 pixels.") # Check if it's a supported format - if image.format not in ['JPEG', 'PNG', 'GIF', 'WEBP']: + if image.format not in ["JPEG", "PNG", "GIF", "WEBP"]: raise serializers.ValidationError( - f"Unsupported image format: {image.format}. Supported formats: JPEG, PNG, GIF, WebP.") + f"Unsupported image format: {image.format}. Supported formats: JPEG, PNG, GIF, WebP." + ) except serializers.ValidationError: raise # Re-raise validation errors diff --git a/backend/apps/api/v1/serializers/auth.py b/backend/apps/api/v1/serializers/auth.py index 4bf1a650..fd61e2b8 100644 --- a/backend/apps/api/v1/serializers/auth.py +++ b/backend/apps/api/v1/serializers/auth.py @@ -97,7 +97,7 @@ class LoginInputSerializer(serializers.Serializer): password=password, ) - if not user: + if not user: # noqa: SIM102 # Try email-based authentication if username failed if "@" in username: try: @@ -138,7 +138,7 @@ class LoginInputSerializer(serializers.Serializer): "first_name": "John", "last_name": "Doe", }, - "message": "Login successful", + "detail": "Login successful", }, ) ] @@ -213,7 +213,7 @@ class SignupInputSerializer(serializers.ModelSerializer): try: validate_password(value) except DjangoValidationError as e: - raise serializers.ValidationError(list(e.messages)) + raise serializers.ValidationError(list(e.messages)) from None return value def validate(self, attrs): @@ -253,7 +253,7 @@ class SignupInputSerializer(serializers.ModelSerializer): "first_name": "Jane", "last_name": "Smith", }, - "message": "Registration successful", + "detail": "Registration successful", }, ) ] @@ -276,7 +276,7 @@ class SignupOutputSerializer(serializers.Serializer): summary="Example logout response", description="Successful logout response", value={ - "message": "Logout successful", + "detail": "Logout successful", }, ) ] @@ -318,9 +318,9 @@ class PasswordResetInputSerializer(serializers.Serializer): """Send password reset email.""" email = self.validated_data["email"] # type: ignore[index] try: - _user = UserModel.objects.get(email=email) + # Check if email exists (but don't reveal the result for security) + UserModel.objects.get(email=email) # Here you would typically send a password reset email - # For now, we'll just pass pass except UserModel.DoesNotExist: # Don't reveal if email exists for security @@ -393,7 +393,7 @@ class PasswordChangeInputSerializer(serializers.Serializer): try: validate_password(value, user=self.context["request"].user) except DjangoValidationError as e: - raise serializers.ValidationError(list(e.messages)) + raise serializers.ValidationError(list(e.messages)) from None return value def validate(self, attrs): @@ -492,6 +492,4 @@ class AuthStatusOutputSerializer(serializers.Serializer): """Output serializer for authentication status.""" authenticated = serializers.BooleanField(help_text="Whether user is authenticated") - user = UserOutputSerializer( - allow_null=True, help_text="User information if authenticated" - ) + user = UserOutputSerializer(allow_null=True, help_text="User information if authenticated") diff --git a/backend/apps/api/v1/serializers/companies.py b/backend/apps/api/v1/serializers/companies.py index 5e8257a2..df4935e1 100644 --- a/backend/apps/api/v1/serializers/companies.py +++ b/backend/apps/api/v1/serializers/companies.py @@ -112,10 +112,7 @@ class RideModelDetailOutputSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() description = serializers.CharField() - category = RichChoiceFieldSerializer( - choice_group="categories", - domain="rides" - ) + category = RichChoiceFieldSerializer(choice_group="categories", domain="rides") # Manufacturer info manufacturer = serializers.SerializerMethodField() diff --git a/backend/apps/api/v1/serializers/history.py b/backend/apps/api/v1/serializers/history.py index 9f35dcb1..6f9c0ceb 100644 --- a/backend/apps/api/v1/serializers/history.py +++ b/backend/apps/api/v1/serializers/history.py @@ -99,9 +99,7 @@ class ParkHistoryOutputSerializer(serializers.Serializer): "slug": park.slug, "status": park.status, "opening_date": ( - park.opening_date.isoformat() - if hasattr(park, "opening_date") and park.opening_date - else None + park.opening_date.isoformat() if hasattr(park, "opening_date") and park.opening_date else None ), "coaster_count": getattr(park, "coaster_count", 0), "ride_count": getattr(park, "ride_count", 0), @@ -143,9 +141,7 @@ class RideHistoryOutputSerializer(serializers.Serializer): "park_name": ride.park.name if hasattr(ride, "park") else None, "status": getattr(ride, "status", "UNKNOWN"), "opening_date": ( - ride.opening_date.isoformat() - if hasattr(ride, "opening_date") and ride.opening_date - else None + ride.opening_date.isoformat() if hasattr(ride, "opening_date") and ride.opening_date else None ), "ride_type": getattr(ride, "ride_type", "Unknown"), } diff --git a/backend/apps/api/v1/serializers/maps.py b/backend/apps/api/v1/serializers/maps.py index 4938ffde..de5b31d3 100644 --- a/backend/apps/api/v1/serializers/maps.py +++ b/backend/apps/api/v1/serializers/maps.py @@ -79,16 +79,12 @@ class MapLocationSerializer(serializers.Serializer): return { "coaster_count": obj.coaster_count or 0, "ride_count": obj.ride_count or 0, - "average_rating": ( - float(obj.average_rating) if obj.average_rating else None - ), + "average_rating": (float(obj.average_rating) if obj.average_rating else None), } elif obj._meta.model_name == "ride": return { "category": obj.get_category_display() if obj.category else None, - "average_rating": ( - float(obj.average_rating) if obj.average_rating else None - ), + "average_rating": (float(obj.average_rating) if obj.average_rating else None), "park_name": obj.park.name if obj.park else None, } return {} @@ -339,24 +335,16 @@ class MapLocationDetailSerializer(serializers.Serializer): return { "coaster_count": obj.coaster_count or 0, "ride_count": obj.ride_count or 0, - "average_rating": ( - float(obj.average_rating) if obj.average_rating else None - ), + "average_rating": (float(obj.average_rating) if obj.average_rating else None), "size_acres": float(obj.size_acres) if obj.size_acres else None, - "opening_date": ( - obj.opening_date.isoformat() if obj.opening_date else None - ), + "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None), } elif obj._meta.model_name == "ride": return { "category": obj.get_category_display() if obj.category else None, - "average_rating": ( - float(obj.average_rating) if obj.average_rating else None - ), + "average_rating": (float(obj.average_rating) if obj.average_rating else None), "park_name": obj.park.name if obj.park else None, - "opening_date": ( - obj.opening_date.isoformat() if obj.opening_date else None - ), + "opening_date": (obj.opening_date.isoformat() if obj.opening_date else None), "manufacturer": obj.manufacturer.name if obj.manufacturer else None, } return {} @@ -382,9 +370,7 @@ class MapBoundsInputSerializer(serializers.Serializer): def validate(self, attrs): """Validate that bounds make geographic sense.""" if attrs["north"] <= attrs["south"]: - raise serializers.ValidationError( - "North bound must be greater than south bound" - ) + raise serializers.ValidationError("North bound must be greater than south bound") # Handle longitude wraparound (e.g., crossing the international date line) # For now, we'll require west < east for simplicity diff --git a/backend/apps/api/v1/serializers/media.py b/backend/apps/api/v1/serializers/media.py index 7a8a654e..11e4a51d 100644 --- a/backend/apps/api/v1/serializers/media.py +++ b/backend/apps/api/v1/serializers/media.py @@ -31,9 +31,7 @@ class PhotoUploadInputSerializer(serializers.Serializer): allow_blank=True, help_text="Alt text for accessibility", ) - is_primary = serializers.BooleanField( - default=False, help_text="Whether this should be the primary photo" - ) + is_primary = serializers.BooleanField(default=False, help_text="Whether this should be the primary photo") @extend_schema_serializer( @@ -89,9 +87,7 @@ class PhotoDetailOutputSerializer(serializers.Serializer): return { "id": obj.uploaded_by.id, "username": obj.uploaded_by.username, - "display_name": getattr( - obj.uploaded_by, "get_display_name", lambda: obj.uploaded_by.username - )(), + "display_name": getattr(obj.uploaded_by, "get_display_name", lambda: obj.uploaded_by.username)(), } diff --git a/backend/apps/api/v1/serializers/other.py b/backend/apps/api/v1/serializers/other.py index f1ac4ec6..7bc8c0ab 100644 --- a/backend/apps/api/v1/serializers/other.py +++ b/backend/apps/api/v1/serializers/other.py @@ -24,12 +24,8 @@ class ParkStatsOutputSerializer(serializers.Serializer): under_construction = serializers.IntegerField() # Averages - average_rating = serializers.DecimalField( - max_digits=3, decimal_places=2, allow_null=True - ) - average_coaster_count = serializers.DecimalField( - max_digits=5, decimal_places=2, allow_null=True - ) + average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True) + average_coaster_count = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True) # Top countries top_countries = serializers.ListField(child=serializers.DictField()) @@ -50,12 +46,8 @@ class RideStatsOutputSerializer(serializers.Serializer): rides_by_category = serializers.DictField() # Averages - average_rating = serializers.DecimalField( - max_digits=3, decimal_places=2, allow_null=True - ) - average_capacity = serializers.DecimalField( - max_digits=8, decimal_places=2, allow_null=True - ) + average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True) + average_capacity = serializers.DecimalField(max_digits=8, decimal_places=2, allow_null=True) # Top manufacturers top_manufacturers = serializers.ListField(child=serializers.DictField()) @@ -91,10 +83,7 @@ class ParkReviewOutputSerializer(serializers.Serializer): class HealthCheckOutputSerializer(serializers.Serializer): """Output serializer for health check responses.""" - status = RichChoiceFieldSerializer( - choice_group="health_statuses", - domain="core" - ) + status = RichChoiceFieldSerializer(choice_group="health_statuses", domain="core") timestamp = serializers.DateTimeField() version = serializers.CharField() environment = serializers.CharField() @@ -115,9 +104,6 @@ class PerformanceMetricsOutputSerializer(serializers.Serializer): class SimpleHealthOutputSerializer(serializers.Serializer): """Output serializer for simple health check.""" - status = RichChoiceFieldSerializer( - choice_group="simple_health_statuses", - domain="core" - ) + status = RichChoiceFieldSerializer(choice_group="simple_health_statuses", domain="core") timestamp = serializers.DateTimeField() error = serializers.CharField(required=False) diff --git a/backend/apps/api/v1/serializers/park_reviews.py b/backend/apps/api/v1/serializers/park_reviews.py index f1265213..610d77c3 100644 --- a/backend/apps/api/v1/serializers/park_reviews.py +++ b/backend/apps/api/v1/serializers/park_reviews.py @@ -29,14 +29,10 @@ from apps.parks.models.reviews import ParkReview "user": { "username": "park_fan", "display_name": "Park Fan", - "avatar_url": "https://example.com/avatar.jpg" + "avatar_url": "https://example.com/avatar.jpg", }, - "park": { - "id": 101, - "name": "Cedar Point", - "slug": "cedar-point" - } - } + "park": {"id": 101, "name": "Cedar Point", "slug": "cedar-point"}, + }, ) ] ) @@ -145,8 +141,7 @@ class ParkReviewStatsOutputSerializer(serializers.Serializer): pending_reviews = serializers.IntegerField() average_rating = serializers.FloatField(allow_null=True) rating_distribution = serializers.DictField( - child=serializers.IntegerField(), - help_text="Count of reviews by rating (1-10)" + child=serializers.IntegerField(), help_text="Count of reviews by rating (1-10)" ) recent_reviews = serializers.IntegerField() @@ -154,20 +149,15 @@ class ParkReviewStatsOutputSerializer(serializers.Serializer): class ParkReviewModerationInputSerializer(serializers.Serializer): """Input serializer for review moderation operations.""" - review_ids = serializers.ListField( - child=serializers.IntegerField(), - help_text="List of review IDs to moderate" - ) + review_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of review IDs to moderate") action = serializers.ChoiceField( choices=[ ("publish", "Publish"), ("unpublish", "Unpublish"), ("delete", "Delete"), ], - help_text="Moderation action to perform" + help_text="Moderation action to perform", ) moderation_notes = serializers.CharField( - required=False, - allow_blank=True, - help_text="Optional notes about the moderation action" + required=False, allow_blank=True, help_text="Optional notes about the moderation action" ) diff --git a/backend/apps/api/v1/serializers/parks.py b/backend/apps/api/v1/serializers/parks.py index d72552ca..6dbce6de 100644 --- a/backend/apps/api/v1/serializers/parks.py +++ b/backend/apps/api/v1/serializers/parks.py @@ -52,16 +52,11 @@ class ParkListOutputSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() slug = serializers.CharField() - status = RichChoiceFieldSerializer( - choice_group="statuses", - domain="parks" - ) + status = RichChoiceFieldSerializer(choice_group="statuses", domain="parks") description = serializers.CharField() # Statistics - average_rating = serializers.DecimalField( - max_digits=3, decimal_places=2, allow_null=True - ) + average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True) coaster_count = serializers.IntegerField(allow_null=True) ride_count = serializers.IntegerField(allow_null=True) @@ -145,25 +140,18 @@ class ParkDetailOutputSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() slug = serializers.CharField() - status = RichChoiceFieldSerializer( - choice_group="statuses", - domain="parks" - ) + status = RichChoiceFieldSerializer(choice_group="statuses", domain="parks") description = serializers.CharField() # Details opening_date = serializers.DateField(allow_null=True) closing_date = serializers.DateField(allow_null=True) operating_season = serializers.CharField() - size_acres = serializers.DecimalField( - max_digits=10, decimal_places=2, allow_null=True - ) + size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, allow_null=True) website = serializers.URLField() # Statistics - average_rating = serializers.DecimalField( - max_digits=3, decimal_places=2, allow_null=True - ) + average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True) coaster_count = serializers.IntegerField(allow_null=True) ride_count = serializers.IntegerField(allow_null=True) @@ -211,9 +199,7 @@ class ParkDetailOutputSerializer(serializers.Serializer): """Get all approved photos for this park.""" from apps.parks.models import ParkPhoto - photos = ParkPhoto.objects.filter(park=obj, is_approved=True).order_by( - "-is_primary", "-created_at" - )[ + photos = ParkPhoto.objects.filter(park=obj, is_approved=True).order_by("-is_primary", "-created_at")[ :10 ] # Limit to 10 photos @@ -228,7 +214,9 @@ class ParkDetailOutputSerializer(serializers.Serializer): "public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"), }, "friendly_urls": { - "thumbnail": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "thumbnail"), + "thumbnail": MediaURLService.generate_park_photo_url( + obj.slug, photo.caption, photo.pk, "thumbnail" + ), "medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"), "large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"), "public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"), @@ -246,9 +234,7 @@ class ParkDetailOutputSerializer(serializers.Serializer): from apps.parks.models import ParkPhoto try: - photo = ParkPhoto.objects.filter( - park=obj, is_primary=True, is_approved=True - ).first() + photo = ParkPhoto.objects.filter(park=obj, is_primary=True, is_approved=True).first() if photo and photo.image: return { @@ -261,7 +247,9 @@ class ParkDetailOutputSerializer(serializers.Serializer): "public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"), }, "friendly_urls": { - "thumbnail": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "thumbnail"), + "thumbnail": MediaURLService.generate_park_photo_url( + obj.slug, photo.caption, photo.pk, "thumbnail" + ), "medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"), "large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"), "public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"), @@ -289,10 +277,18 @@ class ParkDetailOutputSerializer(serializers.Serializer): "public": MediaURLService.get_cloudflare_url_with_fallback(obj.banner_image.image, "public"), }, "friendly_urls": { - "thumbnail": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "thumbnail"), - "medium": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "medium"), - "large": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "large"), - "public": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "public"), + "thumbnail": MediaURLService.generate_park_photo_url( + obj.slug, obj.banner_image.caption, obj.banner_image.pk, "thumbnail" + ), + "medium": MediaURLService.generate_park_photo_url( + obj.slug, obj.banner_image.caption, obj.banner_image.pk, "medium" + ), + "large": MediaURLService.generate_park_photo_url( + obj.slug, obj.banner_image.caption, obj.banner_image.pk, "large" + ), + "public": MediaURLService.generate_park_photo_url( + obj.slug, obj.banner_image.caption, obj.banner_image.pk, "public" + ), }, "caption": obj.banner_image.caption, "alt_text": obj.banner_image.alt_text, @@ -303,9 +299,7 @@ class ParkDetailOutputSerializer(serializers.Serializer): try: latest_photo = ( - ParkPhoto.objects.filter( - park=obj, is_approved=True, image__isnull=False - ) + ParkPhoto.objects.filter(park=obj, is_approved=True, image__isnull=False) .order_by("-created_at") .first() ) @@ -321,10 +315,18 @@ class ParkDetailOutputSerializer(serializers.Serializer): "public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"), }, "friendly_urls": { - "thumbnail": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"), - "medium": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "medium"), - "large": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "large"), - "public": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "public"), + "thumbnail": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail" + ), + "medium": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "medium" + ), + "large": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "large" + ), + "public": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "public" + ), }, "caption": latest_photo.caption, "alt_text": latest_photo.alt_text, @@ -350,10 +352,18 @@ class ParkDetailOutputSerializer(serializers.Serializer): "public": MediaURLService.get_cloudflare_url_with_fallback(obj.card_image.image, "public"), }, "friendly_urls": { - "thumbnail": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "thumbnail"), - "medium": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "medium"), - "large": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "large"), - "public": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "public"), + "thumbnail": MediaURLService.generate_park_photo_url( + obj.slug, obj.card_image.caption, obj.card_image.pk, "thumbnail" + ), + "medium": MediaURLService.generate_park_photo_url( + obj.slug, obj.card_image.caption, obj.card_image.pk, "medium" + ), + "large": MediaURLService.generate_park_photo_url( + obj.slug, obj.card_image.caption, obj.card_image.pk, "large" + ), + "public": MediaURLService.generate_park_photo_url( + obj.slug, obj.card_image.caption, obj.card_image.pk, "public" + ), }, "caption": obj.card_image.caption, "alt_text": obj.card_image.alt_text, @@ -364,9 +374,7 @@ class ParkDetailOutputSerializer(serializers.Serializer): try: latest_photo = ( - ParkPhoto.objects.filter( - park=obj, is_approved=True, image__isnull=False - ) + ParkPhoto.objects.filter(park=obj, is_approved=True, image__isnull=False) .order_by("-created_at") .first() ) @@ -382,10 +390,18 @@ class ParkDetailOutputSerializer(serializers.Serializer): "public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"), }, "friendly_urls": { - "thumbnail": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"), - "medium": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "medium"), - "large": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "large"), - "public": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "public"), + "thumbnail": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail" + ), + "medium": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "medium" + ), + "large": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "large" + ), + "public": MediaURLService.generate_park_photo_url( + obj.slug, latest_photo.caption, latest_photo.pk, "public" + ), }, "caption": latest_photo.caption, "alt_text": latest_photo.alt_text, @@ -417,7 +433,7 @@ class ParkImageSettingsInputSerializer(serializers.Serializer): # The park will be validated in the view return value except ParkPhoto.DoesNotExist: - raise serializers.ValidationError("Photo not found") + raise serializers.ValidationError("Photo not found") from None return value def validate_card_image_id(self, value): @@ -430,7 +446,7 @@ class ParkImageSettingsInputSerializer(serializers.Serializer): # The park will be validated in the view return value except ParkPhoto.DoesNotExist: - raise serializers.ValidationError("Photo not found") + raise serializers.ValidationError("Photo not found") from None return value @@ -439,19 +455,13 @@ class ParkCreateInputSerializer(serializers.Serializer): name = serializers.CharField(max_length=255) description = serializers.CharField(allow_blank=True, default="") - status = serializers.ChoiceField( - choices=ModelChoices.get_park_status_choices(), default="OPERATING" - ) + status = serializers.ChoiceField(choices=ModelChoices.get_park_status_choices(), default="OPERATING") # Optional details opening_date = serializers.DateField(required=False, allow_null=True) closing_date = serializers.DateField(required=False, allow_null=True) - operating_season = serializers.CharField( - max_length=255, required=False, allow_blank=True - ) - size_acres = serializers.DecimalField( - max_digits=10, decimal_places=2, required=False, allow_null=True - ) + operating_season = serializers.CharField(max_length=255, required=False, allow_blank=True) + size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, allow_null=True) website = serializers.URLField(required=False, allow_blank=True) # Required operator @@ -466,9 +476,7 @@ class ParkCreateInputSerializer(serializers.Serializer): closing_date = attrs.get("closing_date") if opening_date and closing_date and closing_date < opening_date: - raise serializers.ValidationError( - "Closing date cannot be before opening date" - ) + raise serializers.ValidationError("Closing date cannot be before opening date") return attrs @@ -478,19 +486,13 @@ class ParkUpdateInputSerializer(serializers.Serializer): name = serializers.CharField(max_length=255, required=False) description = serializers.CharField(allow_blank=True, required=False) - status = serializers.ChoiceField( - choices=ModelChoices.get_park_status_choices(), required=False - ) + status = serializers.ChoiceField(choices=ModelChoices.get_park_status_choices(), required=False) # Optional details opening_date = serializers.DateField(required=False, allow_null=True) closing_date = serializers.DateField(required=False, allow_null=True) - operating_season = serializers.CharField( - max_length=255, required=False, allow_blank=True - ) - size_acres = serializers.DecimalField( - max_digits=10, decimal_places=2, required=False, allow_null=True - ) + operating_season = serializers.CharField(max_length=255, required=False, allow_blank=True) + size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, allow_null=True) website = serializers.URLField(required=False, allow_blank=True) # Companies @@ -503,9 +505,7 @@ class ParkUpdateInputSerializer(serializers.Serializer): closing_date = attrs.get("closing_date") if opening_date and closing_date and closing_date < opening_date: - raise serializers.ValidationError( - "Closing date cannot be before opening date" - ) + raise serializers.ValidationError("Closing date cannot be before opening date") return attrs @@ -537,12 +537,8 @@ class ParkFilterInputSerializer(serializers.Serializer): ) # Size filter - min_size_acres = serializers.DecimalField( - max_digits=10, decimal_places=2, required=False, min_value=0 - ) - max_size_acres = serializers.DecimalField( - max_digits=10, decimal_places=2, required=False, min_value=0 - ) + min_size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, min_value=0) + max_size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, min_value=0) # Company filters operator_id = serializers.IntegerField(required=False) @@ -625,9 +621,7 @@ class ParkAreaCreateInputSerializer(serializers.Serializer): closing_date = attrs.get("closing_date") if opening_date and closing_date and closing_date < opening_date: - raise serializers.ValidationError( - "Closing date cannot be before opening date" - ) + raise serializers.ValidationError("Closing date cannot be before opening date") return attrs @@ -646,9 +640,7 @@ class ParkAreaUpdateInputSerializer(serializers.Serializer): closing_date = attrs.get("closing_date") if opening_date and closing_date and closing_date < opening_date: - raise serializers.ValidationError( - "Closing date cannot be before opening date" - ) + raise serializers.ValidationError("Closing date cannot be before opening date") return attrs diff --git a/backend/apps/api/v1/serializers/parks_media.py b/backend/apps/api/v1/serializers/parks_media.py index 2b86f988..611facd4 100644 --- a/backend/apps/api/v1/serializers/parks_media.py +++ b/backend/apps/api/v1/serializers/parks_media.py @@ -12,9 +12,7 @@ from apps.parks.models import ParkPhoto class ParkPhotoOutputSerializer(serializers.ModelSerializer): """Output serializer for park photos.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) file_size = serializers.ReadOnlyField() dimensions = serializers.ReadOnlyField() park_slug = serializers.CharField(source="park.slug", read_only=True) @@ -78,9 +76,7 @@ class ParkPhotoUpdateInputSerializer(serializers.ModelSerializer): class ParkPhotoListOutputSerializer(serializers.ModelSerializer): """Simplified output serializer for park photo lists.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) class Meta: model = ParkPhoto @@ -99,12 +95,8 @@ class ParkPhotoListOutputSerializer(serializers.ModelSerializer): class ParkPhotoApprovalInputSerializer(serializers.Serializer): """Input serializer for photo approval operations.""" - photo_ids = serializers.ListField( - child=serializers.IntegerField(), help_text="List of photo IDs to approve" - ) - approve = serializers.BooleanField( - default=True, help_text="Whether to approve (True) or reject (False) the photos" - ) + photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve") + approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos") class ParkPhotoStatsOutputSerializer(serializers.Serializer): diff --git a/backend/apps/api/v1/serializers/ride_credits.py b/backend/apps/api/v1/serializers/ride_credits.py index b1d282ef..a28dbbf7 100644 --- a/backend/apps/api/v1/serializers/ride_credits.py +++ b/backend/apps/api/v1/serializers/ride_credits.py @@ -8,35 +8,33 @@ from apps.rides.models.credits import RideCredit class RideCreditSerializer(serializers.ModelSerializer): """Serializer for user ride credits.""" - ride_id = serializers.PrimaryKeyRelatedField( - queryset=Ride.objects.all(), source='ride', write_only=True - ) + ride_id = serializers.PrimaryKeyRelatedField(queryset=Ride.objects.all(), source="ride", write_only=True) ride = RideListOutputSerializer(read_only=True) class Meta: model = RideCredit fields = [ - 'id', - 'ride', - 'ride_id', - 'count', - 'rating', - 'first_ridden_at', - 'last_ridden_at', - 'notes', - 'display_order', - 'created_at', - 'updated_at', + "id", + "ride", + "ride_id", + "count", + "rating", + "first_ridden_at", + "last_ridden_at", + "notes", + "display_order", + "created_at", + "updated_at", ] - read_only_fields = ['id', 'created_at', 'updated_at'] + read_only_fields = ["id", "created_at", "updated_at"] def validate(self, attrs): """ Validate data. """ # Ensure dates make sense - first = attrs.get('first_ridden_at') - last = attrs.get('last_ridden_at') + first = attrs.get("first_ridden_at") + last = attrs.get("last_ridden_at") if first and last and last < first: raise serializers.ValidationError("Last ridden date cannot be before first ridden date.") @@ -44,6 +42,6 @@ class RideCreditSerializer(serializers.ModelSerializer): def create(self, validated_data): """Create a new ride credit.""" - user = self.context['request'].user - validated_data['user'] = user + user = self.context["request"].user + validated_data["user"] = user return super().create(validated_data) diff --git a/backend/apps/api/v1/serializers/ride_models.py b/backend/apps/api/v1/serializers/ride_models.py index feeede88..3c1c267e 100644 --- a/backend/apps/api/v1/serializers/ride_models.py +++ b/backend/apps/api/v1/serializers/ride_models.py @@ -80,18 +80,10 @@ class RideModelVariantOutputSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() description = serializers.CharField() - min_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, allow_null=True - ) - max_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, allow_null=True - ) - min_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, allow_null=True - ) - max_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, allow_null=True - ) + min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True) + max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True) + min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True) + max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True) distinguishing_features = serializers.CharField() @@ -134,20 +126,14 @@ class RideModelListOutputSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() slug = serializers.CharField() - category = RichChoiceFieldSerializer( - choice_group="categories", - domain="rides" - ) + category = RichChoiceFieldSerializer(choice_group="categories", domain="rides") description = serializers.CharField() # Manufacturer info manufacturer = RideModelManufacturerOutputSerializer(allow_null=True) # Market info - target_market = RichChoiceFieldSerializer( - choice_group="target_markets", - domain="rides" - ) + target_market = RichChoiceFieldSerializer(choice_group="target_markets", domain="rides") is_discontinued = serializers.BooleanField() total_installations = serializers.IntegerField() first_installation_year = serializers.IntegerField(allow_null=True) @@ -258,18 +244,10 @@ class RideModelDetailOutputSerializer(serializers.Serializer): manufacturer = RideModelManufacturerOutputSerializer(allow_null=True) # Technical specifications - typical_height_range_min_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, allow_null=True - ) - typical_height_range_max_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, allow_null=True - ) - typical_speed_range_min_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, allow_null=True - ) - typical_speed_range_max_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, allow_null=True - ) + typical_height_range_min_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True) + typical_height_range_max_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True) + typical_speed_range_min_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True) + typical_speed_range_max_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True) typical_capacity_range_min = serializers.IntegerField(allow_null=True) typical_capacity_range_max = serializers.IntegerField(allow_null=True) @@ -343,9 +321,7 @@ class RideModelCreateInputSerializer(serializers.Serializer): name = serializers.CharField(max_length=255) description = serializers.CharField(allow_blank=True, default="") - category = serializers.ChoiceField( - choices=ModelChoices.get_ride_category_choices(), allow_blank=True, default="" - ) + category = serializers.ChoiceField(choices=ModelChoices.get_ride_category_choices(), allow_blank=True, default="") # Required manufacturer manufacturer_id = serializers.IntegerField() @@ -363,32 +339,18 @@ class RideModelCreateInputSerializer(serializers.Serializer): typical_speed_range_max_mph = serializers.DecimalField( max_digits=5, decimal_places=2, required=False, allow_null=True ) - typical_capacity_range_min = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) - typical_capacity_range_max = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) + typical_capacity_range_min = serializers.IntegerField(required=False, allow_null=True, min_value=1) + typical_capacity_range_max = serializers.IntegerField(required=False, allow_null=True, min_value=1) # Design characteristics track_type = serializers.CharField(max_length=100, allow_blank=True, default="") - support_structure = serializers.CharField( - max_length=100, allow_blank=True, default="" - ) - train_configuration = serializers.CharField( - max_length=200, allow_blank=True, default="" - ) - restraint_system = serializers.CharField( - max_length=100, allow_blank=True, default="" - ) + support_structure = serializers.CharField(max_length=100, allow_blank=True, default="") + train_configuration = serializers.CharField(max_length=200, allow_blank=True, default="") + restraint_system = serializers.CharField(max_length=100, allow_blank=True, default="") # Market information - first_installation_year = serializers.IntegerField( - required=False, allow_null=True, min_value=1800, max_value=2100 - ) - last_installation_year = serializers.IntegerField( - required=False, allow_null=True, min_value=1800, max_value=2100 - ) + first_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100) + last_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100) is_discontinued = serializers.BooleanField(default=False) # Design features @@ -406,36 +368,28 @@ class RideModelCreateInputSerializer(serializers.Serializer): max_height = attrs.get("typical_height_range_max_ft") if min_height and max_height and min_height > max_height: - raise serializers.ValidationError( - "Minimum height cannot be greater than maximum height" - ) + raise serializers.ValidationError("Minimum height cannot be greater than maximum height") # Speed range validation min_speed = attrs.get("typical_speed_range_min_mph") max_speed = attrs.get("typical_speed_range_max_mph") if min_speed and max_speed and min_speed > max_speed: - raise serializers.ValidationError( - "Minimum speed cannot be greater than maximum speed" - ) + raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed") # Capacity range validation min_capacity = attrs.get("typical_capacity_range_min") max_capacity = attrs.get("typical_capacity_range_max") if min_capacity and max_capacity and min_capacity > max_capacity: - raise serializers.ValidationError( - "Minimum capacity cannot be greater than maximum capacity" - ) + raise serializers.ValidationError("Minimum capacity cannot be greater than maximum capacity") # Installation years validation first_year = attrs.get("first_installation_year") last_year = attrs.get("last_installation_year") if first_year and last_year and first_year > last_year: - raise serializers.ValidationError( - "First installation year cannot be after last installation year" - ) + raise serializers.ValidationError("First installation year cannot be after last installation year") return attrs @@ -467,32 +421,18 @@ class RideModelUpdateInputSerializer(serializers.Serializer): typical_speed_range_max_mph = serializers.DecimalField( max_digits=5, decimal_places=2, required=False, allow_null=True ) - typical_capacity_range_min = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) - typical_capacity_range_max = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) + typical_capacity_range_min = serializers.IntegerField(required=False, allow_null=True, min_value=1) + typical_capacity_range_max = serializers.IntegerField(required=False, allow_null=True, min_value=1) # Design characteristics track_type = serializers.CharField(max_length=100, allow_blank=True, required=False) - support_structure = serializers.CharField( - max_length=100, allow_blank=True, required=False - ) - train_configuration = serializers.CharField( - max_length=200, allow_blank=True, required=False - ) - restraint_system = serializers.CharField( - max_length=100, allow_blank=True, required=False - ) + support_structure = serializers.CharField(max_length=100, allow_blank=True, required=False) + train_configuration = serializers.CharField(max_length=200, allow_blank=True, required=False) + restraint_system = serializers.CharField(max_length=100, allow_blank=True, required=False) # Market information - first_installation_year = serializers.IntegerField( - required=False, allow_null=True, min_value=1800, max_value=2100 - ) - last_installation_year = serializers.IntegerField( - required=False, allow_null=True, min_value=1800, max_value=2100 - ) + first_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100) + last_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100) is_discontinued = serializers.BooleanField(required=False) # Design features @@ -510,36 +450,28 @@ class RideModelUpdateInputSerializer(serializers.Serializer): max_height = attrs.get("typical_height_range_max_ft") if min_height and max_height and min_height > max_height: - raise serializers.ValidationError( - "Minimum height cannot be greater than maximum height" - ) + raise serializers.ValidationError("Minimum height cannot be greater than maximum height") # Speed range validation min_speed = attrs.get("typical_speed_range_min_mph") max_speed = attrs.get("typical_speed_range_max_mph") if min_speed and max_speed and min_speed > max_speed: - raise serializers.ValidationError( - "Minimum speed cannot be greater than maximum speed" - ) + raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed") # Capacity range validation min_capacity = attrs.get("typical_capacity_range_min") max_capacity = attrs.get("typical_capacity_range_max") if min_capacity and max_capacity and min_capacity > max_capacity: - raise serializers.ValidationError( - "Minimum capacity cannot be greater than maximum capacity" - ) + raise serializers.ValidationError("Minimum capacity cannot be greater than maximum capacity") # Installation years validation first_year = attrs.get("first_installation_year") last_year = attrs.get("last_installation_year") if first_year and last_year and first_year > last_year: - raise serializers.ValidationError( - "First installation year cannot be after last installation year" - ) + raise serializers.ValidationError("First installation year cannot be after last installation year") return attrs @@ -551,9 +483,7 @@ class RideModelFilterInputSerializer(serializers.Serializer): search = serializers.CharField(required=False, allow_blank=True) # Category filter - category = serializers.MultipleChoiceField( - choices=ModelChoices.get_ride_category_choices(), required=False - ) + category = serializers.MultipleChoiceField(choices=ModelChoices.get_ride_category_choices(), required=False) # Manufacturer filter manufacturer_id = serializers.IntegerField(required=False) @@ -576,20 +506,12 @@ class RideModelFilterInputSerializer(serializers.Serializer): min_installations = serializers.IntegerField(required=False, min_value=0) # Height filters - min_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False - ) - max_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False - ) + min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False) + max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False) # Speed filters - min_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False - ) - max_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False - ) + min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False) + max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False) # Ordering ordering = serializers.ChoiceField( @@ -621,18 +543,10 @@ class RideModelVariantCreateInputSerializer(serializers.Serializer): description = serializers.CharField(allow_blank=True, default="") # Variant-specific specifications - min_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - max_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - min_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False, allow_null=True - ) - max_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False, allow_null=True - ) + min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True) + max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True) # Distinguishing features distinguishing_features = serializers.CharField(allow_blank=True, default="") @@ -644,18 +558,14 @@ class RideModelVariantCreateInputSerializer(serializers.Serializer): max_height = attrs.get("max_height_ft") if min_height and max_height and min_height > max_height: - raise serializers.ValidationError( - "Minimum height cannot be greater than maximum height" - ) + raise serializers.ValidationError("Minimum height cannot be greater than maximum height") # Speed range validation min_speed = attrs.get("min_speed_mph") max_speed = attrs.get("max_speed_mph") if min_speed and max_speed and min_speed > max_speed: - raise serializers.ValidationError( - "Minimum speed cannot be greater than maximum speed" - ) + raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed") return attrs @@ -667,18 +577,10 @@ class RideModelVariantUpdateInputSerializer(serializers.Serializer): description = serializers.CharField(allow_blank=True, required=False) # Variant-specific specifications - min_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - max_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - min_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False, allow_null=True - ) - max_speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False, allow_null=True - ) + min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True) + max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True) # Distinguishing features distinguishing_features = serializers.CharField(allow_blank=True, required=False) @@ -690,18 +592,14 @@ class RideModelVariantUpdateInputSerializer(serializers.Serializer): max_height = attrs.get("max_height_ft") if min_height and max_height and min_height > max_height: - raise serializers.ValidationError( - "Minimum height cannot be greater than maximum height" - ) + raise serializers.ValidationError("Minimum height cannot be greater than maximum height") # Speed range validation min_speed = attrs.get("min_speed_mph") max_speed = attrs.get("max_speed_mph") if min_speed and max_speed and min_speed > max_speed: - raise serializers.ValidationError( - "Minimum speed cannot be greater than maximum speed" - ) + raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed") return attrs @@ -713,9 +611,7 @@ class RideModelTechnicalSpecCreateInputSerializer(serializers.Serializer): """Input serializer for creating ride model technical specifications.""" ride_model_id = serializers.IntegerField() - spec_category = serializers.ChoiceField( - choices=ModelChoices.get_technical_spec_category_choices() - ) + spec_category = serializers.ChoiceField(choices=ModelChoices.get_technical_spec_category_choices()) spec_name = serializers.CharField(max_length=100) spec_value = serializers.CharField(max_length=255) spec_unit = serializers.CharField(max_length=20, allow_blank=True, default="") @@ -765,13 +661,9 @@ class RideModelPhotoUpdateInputSerializer(serializers.Serializer): required=False, ) is_primary = serializers.BooleanField(required=False) - photographer = serializers.CharField( - max_length=255, allow_blank=True, required=False - ) + photographer = serializers.CharField(max_length=255, allow_blank=True, required=False) source = serializers.CharField(max_length=255, allow_blank=True, required=False) - copyright_info = serializers.CharField( - max_length=255, allow_blank=True, required=False - ) + copyright_info = serializers.CharField(max_length=255, allow_blank=True, required=False) # === RIDE MODEL STATS SERIALIZERS === @@ -784,15 +676,9 @@ class RideModelStatsOutputSerializer(serializers.Serializer): total_installations = serializers.IntegerField() active_manufacturers = serializers.IntegerField() discontinued_models = serializers.IntegerField() - by_category = serializers.DictField( - child=serializers.IntegerField(), help_text="Model counts by category" - ) + by_category = serializers.DictField(child=serializers.IntegerField(), help_text="Model counts by category") by_target_market = serializers.DictField( child=serializers.IntegerField(), help_text="Model counts by target market" ) - by_manufacturer = serializers.DictField( - child=serializers.IntegerField(), help_text="Model counts by manufacturer" - ) - recent_models = serializers.IntegerField( - help_text="Models created in the last 30 days" - ) + by_manufacturer = serializers.DictField(child=serializers.IntegerField(), help_text="Model counts by manufacturer") + recent_models = serializers.IntegerField(help_text="Models created in the last 30 days") diff --git a/backend/apps/api/v1/serializers/ride_reviews.py b/backend/apps/api/v1/serializers/ride_reviews.py index 9ba6012b..59a6269f 100644 --- a/backend/apps/api/v1/serializers/ride_reviews.py +++ b/backend/apps/api/v1/serializers/ride_reviews.py @@ -54,19 +54,11 @@ class ReviewUserSerializer(serializers.ModelSerializer): "id": 456, "username": "coaster_fan", "display_name": "Coaster Fan", - "avatar_url": "https://example.com/avatar.jpg" + "avatar_url": "https://example.com/avatar.jpg", }, - "ride": { - "id": 789, - "name": "Steel Vengeance", - "slug": "steel-vengeance" - }, - "park": { - "id": 101, - "name": "Cedar Point", - "slug": "cedar-point" - } - } + "ride": {"id": 789, "name": "Steel Vengeance", "slug": "steel-vengeance"}, + "park": {"id": 101, "name": "Cedar Point", "slug": "cedar-point"}, + }, ) ] ) @@ -191,8 +183,7 @@ class RideReviewStatsOutputSerializer(serializers.Serializer): pending_reviews = serializers.IntegerField() average_rating = serializers.FloatField(allow_null=True) rating_distribution = serializers.DictField( - child=serializers.IntegerField(), - help_text="Count of reviews by rating (1-10)" + child=serializers.IntegerField(), help_text="Count of reviews by rating (1-10)" ) recent_reviews = serializers.IntegerField() @@ -200,20 +191,15 @@ class RideReviewStatsOutputSerializer(serializers.Serializer): class RideReviewModerationInputSerializer(serializers.Serializer): """Input serializer for review moderation operations.""" - review_ids = serializers.ListField( - child=serializers.IntegerField(), - help_text="List of review IDs to moderate" - ) + review_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of review IDs to moderate") action = serializers.ChoiceField( choices=[ ("publish", "Publish"), ("unpublish", "Unpublish"), ("delete", "Delete"), ], - help_text="Moderation action to perform" + help_text="Moderation action to perform", ) moderation_notes = serializers.CharField( - required=False, - allow_blank=True, - help_text="Optional notes about the moderation action" + required=False, allow_blank=True, help_text="Optional notes about the moderation action" ) diff --git a/backend/apps/api/v1/serializers/rides.py b/backend/apps/api/v1/serializers/rides.py index 417cf444..4e5d3840 100644 --- a/backend/apps/api/v1/serializers/rides.py +++ b/backend/apps/api/v1/serializers/rides.py @@ -81,23 +81,15 @@ class RideListOutputSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() slug = serializers.CharField() - category = RichChoiceFieldSerializer( - choice_group="categories", - domain="rides" - ) - status = RichChoiceFieldSerializer( - choice_group="statuses", - domain="rides" - ) + category = RichChoiceFieldSerializer(choice_group="categories", domain="rides") + status = RichChoiceFieldSerializer(choice_group="statuses", domain="rides") description = serializers.CharField() # Park info park = RideParkOutputSerializer() # Statistics - average_rating = serializers.DecimalField( - max_digits=3, decimal_places=2, allow_null=True - ) + average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True) capacity_per_hour = serializers.IntegerField(allow_null=True) # Dates @@ -178,18 +170,10 @@ class RideDetailOutputSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() slug = serializers.CharField() - category = RichChoiceFieldSerializer( - choice_group="categories", - domain="rides" - ) - status = RichChoiceFieldSerializer( - choice_group="statuses", - domain="rides" - ) + category = RichChoiceFieldSerializer(choice_group="categories", domain="rides") + status = RichChoiceFieldSerializer(choice_group="statuses", domain="rides") post_closing_status = RichChoiceFieldSerializer( - choice_group="post_closing_statuses", - domain="rides", - allow_null=True + choice_group="post_closing_statuses", domain="rides", allow_null=True ) description = serializers.CharField() @@ -209,9 +193,7 @@ class RideDetailOutputSerializer(serializers.Serializer): ride_duration_seconds = serializers.IntegerField(allow_null=True) # Statistics - average_rating = serializers.DecimalField( - max_digits=3, decimal_places=2, allow_null=True - ) + average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True) # Companies manufacturer = serializers.SerializerMethodField() @@ -273,9 +255,7 @@ class RideDetailOutputSerializer(serializers.Serializer): """Get all approved photos for this ride.""" from apps.rides.models import RidePhoto - photos = RidePhoto.objects.filter(ride=obj, is_approved=True).order_by( - "-is_primary", "-created_at" - )[ + photos = RidePhoto.objects.filter(ride=obj, is_approved=True).order_by("-is_primary", "-created_at")[ :10 ] # Limit to 10 photos @@ -285,9 +265,7 @@ class RideDetailOutputSerializer(serializers.Serializer): "image_url": photo.image.url if photo.image else None, "image_variants": ( { - "thumbnail": ( - f"{photo.image.url}/thumbnail" if photo.image else None - ), + "thumbnail": (f"{photo.image.url}/thumbnail" if photo.image else None), "medium": f"{photo.image.url}/medium" if photo.image else None, "large": f"{photo.image.url}/large" if photo.image else None, "public": f"{photo.image.url}/public" if photo.image else None, @@ -309,9 +287,7 @@ class RideDetailOutputSerializer(serializers.Serializer): from apps.rides.models import RidePhoto try: - photo = RidePhoto.objects.filter( - ride=obj, is_primary=True, is_approved=True - ).first() + photo = RidePhoto.objects.filter(ride=obj, is_primary=True, is_approved=True).first() if photo and photo.image: return { @@ -356,9 +332,7 @@ class RideDetailOutputSerializer(serializers.Serializer): try: latest_photo = ( - RidePhoto.objects.filter( - ride=obj, is_approved=True, image__isnull=False - ) + RidePhoto.objects.filter(ride=obj, is_approved=True, image__isnull=False) .order_by("-created_at") .first() ) @@ -407,9 +381,7 @@ class RideDetailOutputSerializer(serializers.Serializer): try: latest_photo = ( - RidePhoto.objects.filter( - ride=obj, is_approved=True, image__isnull=False - ) + RidePhoto.objects.filter(ride=obj, is_approved=True, image__isnull=False) .order_by("-created_at") .first() ) @@ -451,7 +423,7 @@ class RideImageSettingsInputSerializer(serializers.Serializer): # The ride will be validated in the view return value except RidePhoto.DoesNotExist: - raise serializers.ValidationError("Photo not found") + raise serializers.ValidationError("Photo not found") from None return value def validate_card_image_id(self, value): @@ -464,7 +436,7 @@ class RideImageSettingsInputSerializer(serializers.Serializer): # The ride will be validated in the view return value except RidePhoto.DoesNotExist: - raise serializers.ValidationError("Photo not found") + raise serializers.ValidationError("Photo not found") from None return value @@ -474,9 +446,7 @@ class RideCreateInputSerializer(serializers.Serializer): name = serializers.CharField(max_length=255) description = serializers.CharField(allow_blank=True, default="") category = serializers.ChoiceField(choices=ModelChoices.get_ride_category_choices()) - status = serializers.ChoiceField( - choices=ModelChoices.get_ride_status_choices(), default="OPERATING" - ) + status = serializers.ChoiceField(choices=ModelChoices.get_ride_status_choices(), default="OPERATING") # Required park park_id = serializers.IntegerField() @@ -490,18 +460,10 @@ class RideCreateInputSerializer(serializers.Serializer): status_since = serializers.DateField(required=False, allow_null=True) # Optional specs - min_height_in = serializers.IntegerField( - required=False, allow_null=True, min_value=30, max_value=90 - ) - max_height_in = serializers.IntegerField( - required=False, allow_null=True, min_value=30, max_value=90 - ) - capacity_per_hour = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) - ride_duration_seconds = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) + min_height_in = serializers.IntegerField(required=False, allow_null=True, min_value=30, max_value=90) + max_height_in = serializers.IntegerField(required=False, allow_null=True, min_value=30, max_value=90) + capacity_per_hour = serializers.IntegerField(required=False, allow_null=True, min_value=1) + ride_duration_seconds = serializers.IntegerField(required=False, allow_null=True, min_value=1) # Optional companies manufacturer_id = serializers.IntegerField(required=False, allow_null=True) @@ -517,18 +479,14 @@ class RideCreateInputSerializer(serializers.Serializer): closing_date = attrs.get("closing_date") if opening_date and closing_date and closing_date < opening_date: - raise serializers.ValidationError( - "Closing date cannot be before opening date" - ) + raise serializers.ValidationError("Closing date cannot be before opening date") # Height validation min_height = attrs.get("min_height_in") max_height = attrs.get("max_height_in") if min_height and max_height and min_height > max_height: - raise serializers.ValidationError( - "Minimum height cannot be greater than maximum height" - ) + raise serializers.ValidationError("Minimum height cannot be greater than maximum height") # Park area validation when park changes park_id = attrs.get("park_id") @@ -537,6 +495,7 @@ class RideCreateInputSerializer(serializers.Serializer): if park_id and park_area_id: try: from apps.parks.models import ParkArea + park_area = ParkArea.objects.get(id=park_area_id) if park_area.park_id != park_id: raise serializers.ValidationError( @@ -554,12 +513,8 @@ class RideUpdateInputSerializer(serializers.Serializer): name = serializers.CharField(max_length=255, required=False) description = serializers.CharField(allow_blank=True, required=False) - category = serializers.ChoiceField( - choices=ModelChoices.get_ride_category_choices(), required=False - ) - status = serializers.ChoiceField( - choices=ModelChoices.get_ride_status_choices(), required=False - ) + category = serializers.ChoiceField(choices=ModelChoices.get_ride_category_choices(), required=False) + status = serializers.ChoiceField(choices=ModelChoices.get_ride_status_choices(), required=False) post_closing_status = serializers.ChoiceField( choices=ModelChoices.get_ride_post_closing_choices(), required=False, @@ -576,18 +531,10 @@ class RideUpdateInputSerializer(serializers.Serializer): status_since = serializers.DateField(required=False, allow_null=True) # Specs - min_height_in = serializers.IntegerField( - required=False, allow_null=True, min_value=30, max_value=90 - ) - max_height_in = serializers.IntegerField( - required=False, allow_null=True, min_value=30, max_value=90 - ) - capacity_per_hour = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) - ride_duration_seconds = serializers.IntegerField( - required=False, allow_null=True, min_value=1 - ) + min_height_in = serializers.IntegerField(required=False, allow_null=True, min_value=30, max_value=90) + max_height_in = serializers.IntegerField(required=False, allow_null=True, min_value=30, max_value=90) + capacity_per_hour = serializers.IntegerField(required=False, allow_null=True, min_value=1) + ride_duration_seconds = serializers.IntegerField(required=False, allow_null=True, min_value=1) # Companies manufacturer_id = serializers.IntegerField(required=False, allow_null=True) @@ -603,18 +550,14 @@ class RideUpdateInputSerializer(serializers.Serializer): closing_date = attrs.get("closing_date") if opening_date and closing_date and closing_date < opening_date: - raise serializers.ValidationError( - "Closing date cannot be before opening date" - ) + raise serializers.ValidationError("Closing date cannot be before opening date") # Height validation min_height = attrs.get("min_height_in") max_height = attrs.get("max_height_in") if min_height and max_height and min_height > max_height: - raise serializers.ValidationError( - "Minimum height cannot be greater than maximum height" - ) + raise serializers.ValidationError("Minimum height cannot be greater than maximum height") return attrs @@ -626,9 +569,7 @@ class RideFilterInputSerializer(serializers.Serializer): search = serializers.CharField(required=False, allow_blank=True) # Category filter - category = serializers.MultipleChoiceField( - choices=ModelChoices.get_ride_category_choices(), required=False - ) + category = serializers.MultipleChoiceField(choices=ModelChoices.get_ride_category_choices(), required=False) # Status filter status = serializers.MultipleChoiceField( @@ -707,33 +648,16 @@ class RollerCoasterStatsOutputSerializer(serializers.Serializer): """Output serializer for roller coaster statistics.""" id = serializers.IntegerField() - height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, allow_null=True - ) - length_ft = serializers.DecimalField( - max_digits=7, decimal_places=2, allow_null=True - ) - speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, allow_null=True - ) + height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True) + length_ft = serializers.DecimalField(max_digits=7, decimal_places=2, allow_null=True) + speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True) inversions = serializers.IntegerField() ride_time_seconds = serializers.IntegerField(allow_null=True) track_type = serializers.CharField() - track_material = RichChoiceFieldSerializer( - choice_group="track_materials", - domain="rides" - ) - roller_coaster_type = RichChoiceFieldSerializer( - choice_group="coaster_types", - domain="rides" - ) - max_drop_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, allow_null=True - ) - propulsion_system = RichChoiceFieldSerializer( - choice_group="propulsion_systems", - domain="rides" - ) + track_material = RichChoiceFieldSerializer(choice_group="track_materials", domain="rides") + roller_coaster_type = RichChoiceFieldSerializer(choice_group="coaster_types", domain="rides") + max_drop_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True) + propulsion_system = RichChoiceFieldSerializer(choice_group="propulsion_systems", domain="rides") train_style = serializers.CharField() trains_count = serializers.IntegerField(allow_null=True) cars_per_train = serializers.IntegerField(allow_null=True) @@ -755,30 +679,16 @@ class RollerCoasterStatsCreateInputSerializer(serializers.Serializer): """Input serializer for creating roller coaster statistics.""" ride_id = serializers.IntegerField() - height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - length_ft = serializers.DecimalField( - max_digits=7, decimal_places=2, required=False, allow_null=True - ) - speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False, allow_null=True - ) + height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + length_ft = serializers.DecimalField(max_digits=7, decimal_places=2, required=False, allow_null=True) + speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True) inversions = serializers.IntegerField(default=0) ride_time_seconds = serializers.IntegerField(required=False, allow_null=True) track_type = serializers.CharField(max_length=255, allow_blank=True, default="") - track_material = serializers.ChoiceField( - choices=ModelChoices.get_coaster_track_choices(), default="STEEL" - ) - roller_coaster_type = serializers.ChoiceField( - choices=ModelChoices.get_coaster_type_choices(), default="SITDOWN" - ) - max_drop_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - propulsion_system = serializers.ChoiceField( - choices=ModelChoices.get_propulsion_system_choices(), default="CHAIN" - ) + track_material = serializers.ChoiceField(choices=ModelChoices.get_coaster_track_choices(), default="STEEL") + roller_coaster_type = serializers.ChoiceField(choices=ModelChoices.get_coaster_type_choices(), default="SITDOWN") + max_drop_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + propulsion_system = serializers.ChoiceField(choices=ModelChoices.get_propulsion_system_choices(), default="CHAIN") train_style = serializers.CharField(max_length=255, allow_blank=True, default="") trains_count = serializers.IntegerField(required=False, allow_null=True) cars_per_train = serializers.IntegerField(required=False, allow_null=True) @@ -788,33 +698,17 @@ class RollerCoasterStatsCreateInputSerializer(serializers.Serializer): class RollerCoasterStatsUpdateInputSerializer(serializers.Serializer): """Input serializer for updating roller coaster statistics.""" - height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - length_ft = serializers.DecimalField( - max_digits=7, decimal_places=2, required=False, allow_null=True - ) - speed_mph = serializers.DecimalField( - max_digits=5, decimal_places=2, required=False, allow_null=True - ) + height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + length_ft = serializers.DecimalField(max_digits=7, decimal_places=2, required=False, allow_null=True) + speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True) inversions = serializers.IntegerField(required=False) ride_time_seconds = serializers.IntegerField(required=False, allow_null=True) track_type = serializers.CharField(max_length=255, allow_blank=True, required=False) - track_material = serializers.ChoiceField( - choices=ModelChoices.get_coaster_track_choices(), required=False - ) - roller_coaster_type = serializers.ChoiceField( - choices=ModelChoices.get_coaster_type_choices(), required=False - ) - max_drop_height_ft = serializers.DecimalField( - max_digits=6, decimal_places=2, required=False, allow_null=True - ) - propulsion_system = serializers.ChoiceField( - choices=ModelChoices.get_propulsion_system_choices(), required=False - ) - train_style = serializers.CharField( - max_length=255, allow_blank=True, required=False - ) + track_material = serializers.ChoiceField(choices=ModelChoices.get_coaster_track_choices(), required=False) + roller_coaster_type = serializers.ChoiceField(choices=ModelChoices.get_coaster_type_choices(), required=False) + max_drop_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True) + propulsion_system = serializers.ChoiceField(choices=ModelChoices.get_propulsion_system_choices(), required=False) + train_style = serializers.CharField(max_length=255, allow_blank=True, required=False) trains_count = serializers.IntegerField(required=False, allow_null=True) cars_per_train = serializers.IntegerField(required=False, allow_null=True) seats_per_car = serializers.IntegerField(required=False, allow_null=True) diff --git a/backend/apps/api/v1/serializers/rides_media.py b/backend/apps/api/v1/serializers/rides_media.py index 59dbf49c..4e973cc4 100644 --- a/backend/apps/api/v1/serializers/rides_media.py +++ b/backend/apps/api/v1/serializers/rides_media.py @@ -12,9 +12,7 @@ from apps.rides.models import RidePhoto class RidePhotoOutputSerializer(serializers.ModelSerializer): """Output serializer for ride photos.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) file_size = serializers.ReadOnlyField() dimensions = serializers.ReadOnlyField() ride_slug = serializers.CharField(source="ride.slug", read_only=True) @@ -87,9 +85,7 @@ class RidePhotoUpdateInputSerializer(serializers.ModelSerializer): class RidePhotoListOutputSerializer(serializers.ModelSerializer): """Simplified output serializer for ride photo lists.""" - uploaded_by_username = serializers.CharField( - source="uploaded_by.username", read_only=True - ) + uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True) class Meta: model = RidePhoto @@ -109,12 +105,8 @@ class RidePhotoListOutputSerializer(serializers.ModelSerializer): class RidePhotoApprovalInputSerializer(serializers.Serializer): """Input serializer for photo approval operations.""" - photo_ids = serializers.ListField( - child=serializers.IntegerField(), help_text="List of photo IDs to approve" - ) - approve = serializers.BooleanField( - default=True, help_text="Whether to approve (True) or reject (False) the photos" - ) + photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve") + approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos") class RidePhotoStatsOutputSerializer(serializers.Serializer): @@ -125,9 +117,7 @@ class RidePhotoStatsOutputSerializer(serializers.Serializer): pending_photos = serializers.IntegerField() has_primary = serializers.BooleanField() recent_uploads = serializers.IntegerField() - by_type = serializers.DictField( - child=serializers.IntegerField(), help_text="Photo counts by type" - ) + by_type = serializers.DictField(child=serializers.IntegerField(), help_text="Photo counts by type") class RidePhotoTypeFilterSerializer(serializers.Serializer): diff --git a/backend/apps/api/v1/serializers/search.py b/backend/apps/api/v1/serializers/search.py index d08e527f..891b24bd 100644 --- a/backend/apps/api/v1/serializers/search.py +++ b/backend/apps/api/v1/serializers/search.py @@ -19,9 +19,7 @@ class EntitySearchInputSerializer(serializers.Serializer): query = serializers.CharField(max_length=255, help_text="Search query string") entity_types = serializers.ListField( - child=serializers.ChoiceField( - choices=ModelChoices.get_entity_type_choices() - ), + child=serializers.ChoiceField(choices=ModelChoices.get_entity_type_choices()), required=False, help_text="Types of entities to search for", ) @@ -39,17 +37,12 @@ class EntitySearchResultSerializer(serializers.Serializer): id = serializers.IntegerField() name = serializers.CharField() slug = serializers.CharField() - type = RichChoiceFieldSerializer( - choice_group="entity_types", - domain="core" - ) + type = RichChoiceFieldSerializer(choice_group="entity_types", domain="core") description = serializers.CharField() relevance_score = serializers.FloatField() # Context-specific info — renamed to avoid overriding Serializer.context - context_info = serializers.JSONField( - help_text="Additional context based on entity type" - ) + context_info = serializers.JSONField(help_text="Additional context based on entity type") class EntitySearchOutputSerializer(serializers.Serializer): diff --git a/backend/apps/api/v1/serializers/services.py b/backend/apps/api/v1/serializers/services.py index 25dc6052..de2fe2bb 100644 --- a/backend/apps/api/v1/serializers/services.py +++ b/backend/apps/api/v1/serializers/services.py @@ -39,9 +39,7 @@ class SimpleHealthOutputSerializer(serializers.Serializer): status = serializers.CharField(help_text="Simple health status") timestamp = serializers.DateTimeField(help_text="Timestamp of health check") - error = serializers.CharField( - required=False, help_text="Error message if unhealthy" - ) + error = serializers.CharField(required=False, help_text="Error message if unhealthy") # === EMAIL SERVICE SERIALIZERS === @@ -151,7 +149,7 @@ class ModerationSubmissionSerializer(serializers.Serializer): ("PHOTO", "Photo Submission"), ("REVIEW", "Review Submission"), ], - help_text="Type of submission" + help_text="Type of submission", ) content_type = serializers.CharField(help_text="Content type being modified") object_id = serializers.IntegerField(help_text="ID of object being modified") @@ -221,9 +219,7 @@ class RoadtripOutputSerializer(serializers.Serializer): parks = RoadtripParkSerializer(many=True) total_distance_miles = serializers.FloatField() estimated_drive_time_hours = serializers.FloatField() - route_coordinates = serializers.ListField( - child=serializers.ListField(child=serializers.FloatField()) - ) + route_coordinates = serializers.ListField(child=serializers.ListField(child=serializers.FloatField())) created_at = serializers.DateTimeField() diff --git a/backend/apps/api/v1/serializers/shared.py b/backend/apps/api/v1/serializers/shared.py index 2c8cfe84..a625e7b2 100644 --- a/backend/apps/api/v1/serializers/shared.py +++ b/backend/apps/api/v1/serializers/shared.py @@ -25,21 +25,13 @@ class FilterOptionSerializer(serializers.Serializer): selected?: boolean; } """ - value = serializers.CharField( - help_text="The actual value used for filtering" - ) - label = serializers.CharField( - help_text="Human-readable display label" - ) + + value = serializers.CharField(help_text="The actual value used for filtering") + label = serializers.CharField(help_text="Human-readable display label") count = serializers.IntegerField( - required=False, - allow_null=True, - help_text="Number of items matching this filter option" - ) - selected = serializers.BooleanField( - default=False, - help_text="Whether this option is currently selected" + required=False, allow_null=True, help_text="Number of items matching this filter option" ) + selected = serializers.BooleanField(default=False, help_text="Whether this option is currently selected") class FilterRangeSerializer(serializers.Serializer): @@ -54,22 +46,12 @@ class FilterRangeSerializer(serializers.Serializer): unit?: string; } """ - min = serializers.FloatField( - allow_null=True, - help_text="Minimum value for the range" - ) - max = serializers.FloatField( - allow_null=True, - help_text="Maximum value for the range" - ) - step = serializers.FloatField( - default=1.0, - help_text="Step size for range inputs" - ) + + min = serializers.FloatField(allow_null=True, help_text="Minimum value for the range") + max = serializers.FloatField(allow_null=True, help_text="Maximum value for the range") + step = serializers.FloatField(default=1.0, help_text="Step size for range inputs") unit = serializers.CharField( - required=False, - allow_null=True, - help_text="Unit of measurement (e.g., 'feet', 'mph', 'stars')" + required=False, allow_null=True, help_text="Unit of measurement (e.g., 'feet', 'mph', 'stars')" ) @@ -84,15 +66,10 @@ class BooleanFilterSerializer(serializers.Serializer): description: string; } """ - key = serializers.CharField( - help_text="The filter parameter key" - ) - label = serializers.CharField( - help_text="Human-readable label for the filter" - ) - description = serializers.CharField( - help_text="Description of what this filter does" - ) + + key = serializers.CharField(help_text="The filter parameter key") + label = serializers.CharField(help_text="Human-readable label for the filter") + description = serializers.CharField(help_text="Description of what this filter does") class OrderingOptionSerializer(serializers.Serializer): @@ -105,12 +82,9 @@ class OrderingOptionSerializer(serializers.Serializer): label: string; } """ - value = serializers.CharField( - help_text="The ordering parameter value" - ) - label = serializers.CharField( - help_text="Human-readable label for the ordering option" - ) + + value = serializers.CharField(help_text="The ordering parameter value") + label = serializers.CharField(help_text="Human-readable label for the ordering option") class StandardizedFilterMetadataSerializer(serializers.Serializer): @@ -120,27 +94,16 @@ class StandardizedFilterMetadataSerializer(serializers.Serializer): This serializer ensures all filter metadata responses follow the same structure that the frontend expects, preventing runtime type errors. """ + categorical = serializers.DictField( - child=FilterOptionSerializer(many=True), - help_text="Categorical filter options with value/label/count structure" + child=FilterOptionSerializer(many=True), help_text="Categorical filter options with value/label/count structure" ) ranges = serializers.DictField( - child=FilterRangeSerializer(), - help_text="Range filter metadata with min/max/step/unit" - ) - total_count = serializers.IntegerField( - help_text="Total number of items in the filtered dataset" - ) - ordering_options = FilterOptionSerializer( - many=True, - required=False, - help_text="Available ordering options" - ) - boolean_filters = BooleanFilterSerializer( - many=True, - required=False, - help_text="Available boolean filter options" + child=FilterRangeSerializer(), help_text="Range filter metadata with min/max/step/unit" ) + total_count = serializers.IntegerField(help_text="Total number of items in the filtered dataset") + ordering_options = FilterOptionSerializer(many=True, required=False, help_text="Available ordering options") + boolean_filters = BooleanFilterSerializer(many=True, required=False, help_text="Available boolean filter options") class PaginationMetadataSerializer(serializers.Serializer): @@ -157,28 +120,13 @@ class PaginationMetadataSerializer(serializers.Serializer): total_pages: number; } """ - count = serializers.IntegerField( - help_text="Total number of items across all pages" - ) - next = serializers.URLField( - allow_null=True, - required=False, - help_text="URL for the next page of results" - ) - previous = serializers.URLField( - allow_null=True, - required=False, - help_text="URL for the previous page of results" - ) - page_size = serializers.IntegerField( - help_text="Number of items per page" - ) - current_page = serializers.IntegerField( - help_text="Current page number (1-indexed)" - ) - total_pages = serializers.IntegerField( - help_text="Total number of pages" - ) + + count = serializers.IntegerField(help_text="Total number of items across all pages") + next = serializers.URLField(allow_null=True, required=False, help_text="URL for the next page of results") + previous = serializers.URLField(allow_null=True, required=False, help_text="URL for the previous page of results") + page_size = serializers.IntegerField(help_text="Number of items per page") + current_page = serializers.IntegerField(help_text="Current page number (1-indexed)") + total_pages = serializers.IntegerField(help_text="Total number of pages") class ApiResponseSerializer(serializers.Serializer): @@ -193,22 +141,14 @@ class ApiResponseSerializer(serializers.Serializer): errors?: ValidationError; } """ - success = serializers.BooleanField( - help_text="Whether the request was successful" - ) + + success = serializers.BooleanField(help_text="Whether the request was successful") response_data = serializers.JSONField( - required=False, - help_text="Response data (structure varies by endpoint)", - source='data' - ) - message = serializers.CharField( - required=False, - help_text="Human-readable message about the operation" + required=False, help_text="Response data (structure varies by endpoint)", source="data" ) + message = serializers.CharField(required=False, help_text="Human-readable message about the operation") response_errors = serializers.DictField( - required=False, - help_text="Validation errors (field -> error messages)", - source='errors' + required=False, help_text="Validation errors (field -> error messages)", source="errors" ) @@ -228,18 +168,11 @@ class ErrorResponseSerializer(serializers.Serializer): data: null; } """ - status = serializers.CharField( - default="error", - help_text="Response status indicator" - ) - error = serializers.DictField( - help_text="Error details" - ) + + status = serializers.CharField(default="error", help_text="Response status indicator") + error = serializers.DictField(help_text="Error details") response_data = serializers.JSONField( - default=None, - allow_null=True, - help_text="Always null for error responses", - source='data' + default=None, allow_null=True, help_text="Always null for error responses", source="data" ) @@ -257,32 +190,13 @@ class LocationSerializer(serializers.Serializer): longitude?: number; } """ - city = serializers.CharField( - help_text="City name" - ) - state = serializers.CharField( - required=False, - allow_null=True, - help_text="State/province name" - ) - country = serializers.CharField( - help_text="Country name" - ) - address = serializers.CharField( - required=False, - allow_null=True, - help_text="Street address" - ) - latitude = serializers.FloatField( - required=False, - allow_null=True, - help_text="Latitude coordinate" - ) - longitude = serializers.FloatField( - required=False, - allow_null=True, - help_text="Longitude coordinate" - ) + + city = serializers.CharField(help_text="City name") + state = serializers.CharField(required=False, allow_null=True, help_text="State/province name") + country = serializers.CharField(help_text="Country name") + address = serializers.CharField(required=False, allow_null=True, help_text="Street address") + latitude = serializers.FloatField(required=False, allow_null=True, help_text="Latitude coordinate") + longitude = serializers.FloatField(required=False, allow_null=True, help_text="Longitude coordinate") # Alias for backward compatibility @@ -301,24 +215,15 @@ class CompanyOutputSerializer(serializers.Serializer): roles?: string[]; } """ - id = serializers.IntegerField( - help_text="Company ID" - ) - name = serializers.CharField( - help_text="Company name" - ) - slug = serializers.SlugField( - help_text="URL-friendly identifier" - ) + + id = serializers.IntegerField(help_text="Company ID") + name = serializers.CharField(help_text="Company name") + slug = serializers.SlugField(help_text="URL-friendly identifier") roles = serializers.ListField( - child=serializers.CharField(), - required=False, - help_text="Company roles (manufacturer, operator, etc.)" + child=serializers.CharField(), required=False, help_text="Company roles (manufacturer, operator, etc.)" ) - - class ModelChoices: """ Utility class to provide model choices for serializers using Rich Choice Objects. @@ -331,6 +236,7 @@ class ModelChoices: def get_park_status_choices(): """Get park status choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("statuses", "parks") return [(choice.value, choice.label) for choice in choices] @@ -338,6 +244,7 @@ class ModelChoices: def get_ride_status_choices(): """Get ride status choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("statuses", "rides") return [(choice.value, choice.label) for choice in choices] @@ -345,6 +252,7 @@ class ModelChoices: def get_company_role_choices(): """Get company role choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + # Get rides domain company roles (MANUFACTURER, DESIGNER) rides_choices = get_choices("company_roles", "rides") # Get parks domain company roles (OPERATOR, PROPERTY_OWNER) @@ -356,6 +264,7 @@ class ModelChoices: def get_ride_category_choices(): """Get ride category choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("categories", "rides") return [(choice.value, choice.label) for choice in choices] @@ -363,6 +272,7 @@ class ModelChoices: def get_ride_post_closing_choices(): """Get ride post-closing status choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("post_closing_statuses", "rides") return [(choice.value, choice.label) for choice in choices] @@ -370,6 +280,7 @@ class ModelChoices: def get_coaster_track_choices(): """Get coaster track material choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("track_materials", "rides") return [(choice.value, choice.label) for choice in choices] @@ -377,6 +288,7 @@ class ModelChoices: def get_coaster_type_choices(): """Get coaster type choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("coaster_types", "rides") return [(choice.value, choice.label) for choice in choices] @@ -384,6 +296,7 @@ class ModelChoices: def get_launch_choices(): """Get launch system choices from Rich Choice registry (legacy method).""" from apps.core.choices.registry import get_choices + choices = get_choices("propulsion_systems", "rides") return [(choice.value, choice.label) for choice in choices] @@ -391,6 +304,7 @@ class ModelChoices: def get_propulsion_system_choices(): """Get propulsion system choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("propulsion_systems", "rides") return [(choice.value, choice.label) for choice in choices] @@ -398,6 +312,7 @@ class ModelChoices: def get_photo_type_choices(): """Get photo type choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("photo_types", "rides") return [(choice.value, choice.label) for choice in choices] @@ -405,6 +320,7 @@ class ModelChoices: def get_spec_category_choices(): """Get technical specification category choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("spec_categories", "rides") return [(choice.value, choice.label) for choice in choices] @@ -412,6 +328,7 @@ class ModelChoices: def get_technical_spec_category_choices(): """Get technical specification category choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("spec_categories", "rides") return [(choice.value, choice.label) for choice in choices] @@ -419,6 +336,7 @@ class ModelChoices: def get_target_market_choices(): """Get target market choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("target_markets", "rides") return [(choice.value, choice.label) for choice in choices] @@ -426,6 +344,7 @@ class ModelChoices: def get_entity_type_choices(): """Get entity type choices for search functionality.""" from apps.core.choices.registry import get_choices + choices = get_choices("entity_types", "core") return [(choice.value, choice.label) for choice in choices] @@ -433,6 +352,7 @@ class ModelChoices: def get_health_status_choices(): """Get health check status choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("health_statuses", "core") return [(choice.value, choice.label) for choice in choices] @@ -440,6 +360,7 @@ class ModelChoices: def get_simple_health_status_choices(): """Get simple health check status choices from Rich Choice registry.""" from apps.core.choices.registry import get_choices + choices = get_choices("simple_health_statuses", "core") return [(choice.value, choice.label) for choice in choices] @@ -455,15 +376,10 @@ class EntityReferenceSerializer(serializers.Serializer): slug: string; } """ - id = serializers.IntegerField( - help_text="Unique identifier" - ) - name = serializers.CharField( - help_text="Display name" - ) - slug = serializers.SlugField( - help_text="URL-friendly identifier" - ) + + id = serializers.IntegerField(help_text="Unique identifier") + name = serializers.CharField(help_text="Display name") + slug = serializers.SlugField(help_text="URL-friendly identifier") class ImageVariantsSerializer(serializers.Serializer): @@ -478,19 +394,11 @@ class ImageVariantsSerializer(serializers.Serializer): avatar?: string; } """ - thumbnail = serializers.URLField( - help_text="Thumbnail size image URL" - ) - medium = serializers.URLField( - help_text="Medium size image URL" - ) - large = serializers.URLField( - help_text="Large size image URL" - ) - avatar = serializers.URLField( - required=False, - help_text="Avatar size image URL (for user avatars)" - ) + + thumbnail = serializers.URLField(help_text="Thumbnail size image URL") + medium = serializers.URLField(help_text="Medium size image URL") + large = serializers.URLField(help_text="Large size image URL") + avatar = serializers.URLField(required=False, help_text="Avatar size image URL (for user avatars)") class PhotoSerializer(serializers.Serializer): @@ -509,39 +417,15 @@ class PhotoSerializer(serializers.Serializer): uploaded_at?: string; } """ - id = serializers.IntegerField( - help_text="Photo ID" - ) - image_variants = ImageVariantsSerializer( - help_text="Available image size variants" - ) - alt_text = serializers.CharField( - required=False, - allow_null=True, - help_text="Alternative text for accessibility" - ) - image_url = serializers.URLField( - required=False, - help_text="Primary image URL (for compatibility)" - ) - caption = serializers.CharField( - required=False, - allow_null=True, - help_text="Photo caption" - ) - photo_type = serializers.CharField( - required=False, - allow_null=True, - help_text="Type/category of photo" - ) - uploaded_by = EntityReferenceSerializer( - required=False, - help_text="User who uploaded the photo" - ) - uploaded_at = serializers.DateTimeField( - required=False, - help_text="When the photo was uploaded" - ) + + id = serializers.IntegerField(help_text="Photo ID") + image_variants = ImageVariantsSerializer(help_text="Available image size variants") + alt_text = serializers.CharField(required=False, allow_null=True, help_text="Alternative text for accessibility") + image_url = serializers.URLField(required=False, help_text="Primary image URL (for compatibility)") + caption = serializers.CharField(required=False, allow_null=True, help_text="Photo caption") + photo_type = serializers.CharField(required=False, allow_null=True, help_text="Type/category of photo") + uploaded_by = EntityReferenceSerializer(required=False, help_text="User who uploaded the photo") + uploaded_at = serializers.DateTimeField(required=False, help_text="When the photo was uploaded") class UserInfoSerializer(serializers.Serializer): @@ -556,20 +440,11 @@ class UserInfoSerializer(serializers.Serializer): avatar_url?: string; } """ - id = serializers.IntegerField( - help_text="User ID" - ) - username = serializers.CharField( - help_text="Username" - ) - display_name = serializers.CharField( - help_text="Display name" - ) - avatar_url = serializers.URLField( - required=False, - allow_null=True, - help_text="User avatar URL" - ) + + id = serializers.IntegerField(help_text="User ID") + username = serializers.CharField(help_text="Username") + display_name = serializers.CharField(help_text="Display name") + avatar_url = serializers.URLField(required=False, allow_null=True, help_text="User avatar URL") def validate_filter_metadata_contract(data: dict[str, Any]) -> dict[str, Any]: @@ -613,27 +488,22 @@ def ensure_filter_option_format(options: list[Any]) -> list[dict[str, Any]]: if isinstance(option, dict): # Already in correct format or close to it standardized_option = { - 'value': str(option.get('value', option.get('id', ''))), - 'label': option.get('label', option.get('name', str(option.get('value', '')))), - 'count': option.get('count'), - 'selected': option.get('selected', False) + "value": str(option.get("value", option.get("id", ""))), + "label": option.get("label", option.get("name", str(option.get("value", "")))), + "count": option.get("count"), + "selected": option.get("selected", False), } - elif hasattr(option, 'value') and hasattr(option, 'label'): + elif hasattr(option, "value") and hasattr(option, "label"): # RichChoice object format standardized_option = { - 'value': str(option.value), - 'label': str(option.label), - 'count': None, - 'selected': False + "value": str(option.value), + "label": str(option.label), + "count": None, + "selected": False, } else: # Simple value - use as both value and label - standardized_option = { - 'value': str(option), - 'label': str(option), - 'count': None, - 'selected': False - } + standardized_option = {"value": str(option), "label": str(option), "count": None, "selected": False} standardized.append(standardized_option) @@ -651,8 +521,8 @@ def ensure_range_format(range_data: dict[str, Any]) -> dict[str, Any]: Range data in standard format """ return { - 'min': range_data.get('min'), - 'max': range_data.get('max'), - 'step': range_data.get('step', 1.0), - 'unit': range_data.get('unit') + "min": range_data.get("min"), + "max": range_data.get("max"), + "step": range_data.get("step", 1.0), + "unit": range_data.get("unit"), } diff --git a/backend/apps/api/v1/serializers/stats.py b/backend/apps/api/v1/serializers/stats.py index 20bd693b..4a7bb3cd 100644 --- a/backend/apps/api/v1/serializers/stats.py +++ b/backend/apps/api/v1/serializers/stats.py @@ -16,120 +16,56 @@ class StatsSerializer(serializers.Serializer): """ # Core entity counts - total_parks = serializers.IntegerField( - help_text="Total number of parks in the database" - ) - total_rides = serializers.IntegerField( - help_text="Total number of rides in the database" - ) - total_manufacturers = serializers.IntegerField( - help_text="Total number of ride manufacturers" - ) - total_operators = serializers.IntegerField( - help_text="Total number of park operators" - ) - total_designers = serializers.IntegerField( - help_text="Total number of ride designers" - ) - total_property_owners = serializers.IntegerField( - help_text="Total number of property owners" - ) - total_roller_coasters = serializers.IntegerField( - help_text="Total number of roller coasters with detailed stats" - ) + total_parks = serializers.IntegerField(help_text="Total number of parks in the database") + total_rides = serializers.IntegerField(help_text="Total number of rides in the database") + total_manufacturers = serializers.IntegerField(help_text="Total number of ride manufacturers") + total_operators = serializers.IntegerField(help_text="Total number of park operators") + total_designers = serializers.IntegerField(help_text="Total number of ride designers") + total_property_owners = serializers.IntegerField(help_text="Total number of property owners") + total_roller_coasters = serializers.IntegerField(help_text="Total number of roller coasters with detailed stats") # Photo counts - total_photos = serializers.IntegerField( - help_text="Total number of photos (parks + rides combined)" - ) - total_park_photos = serializers.IntegerField( - help_text="Total number of park photos" - ) - total_ride_photos = serializers.IntegerField( - help_text="Total number of ride photos" - ) + total_photos = serializers.IntegerField(help_text="Total number of photos (parks + rides combined)") + total_park_photos = serializers.IntegerField(help_text="Total number of park photos") + total_ride_photos = serializers.IntegerField(help_text="Total number of ride photos") # Review counts - total_reviews = serializers.IntegerField( - help_text="Total number of reviews (parks + rides)" - ) - total_park_reviews = serializers.IntegerField( - help_text="Total number of park reviews" - ) - total_ride_reviews = serializers.IntegerField( - help_text="Total number of ride reviews" - ) + total_reviews = serializers.IntegerField(help_text="Total number of reviews (parks + rides)") + total_park_reviews = serializers.IntegerField(help_text="Total number of park reviews") + total_ride_reviews = serializers.IntegerField(help_text="Total number of ride reviews") # Ride category counts (optional fields since they depend on data) roller_coasters = serializers.IntegerField( required=False, help_text="Number of rides categorized as roller coasters" ) - dark_rides = serializers.IntegerField( - required=False, help_text="Number of rides categorized as dark rides" - ) - flat_rides = serializers.IntegerField( - required=False, help_text="Number of rides categorized as flat rides" - ) - water_rides = serializers.IntegerField( - required=False, help_text="Number of rides categorized as water rides" - ) + dark_rides = serializers.IntegerField(required=False, help_text="Number of rides categorized as dark rides") + flat_rides = serializers.IntegerField(required=False, help_text="Number of rides categorized as flat rides") + water_rides = serializers.IntegerField(required=False, help_text="Number of rides categorized as water rides") transport_rides = serializers.IntegerField( required=False, help_text="Number of rides categorized as transport rides" ) - other_rides = serializers.IntegerField( - required=False, help_text="Number of rides categorized as other" - ) + other_rides = serializers.IntegerField(required=False, help_text="Number of rides categorized as other") # Park status counts (optional fields since they depend on data) - operating_parks = serializers.IntegerField( - required=False, help_text="Number of currently operating parks" - ) - temporarily_closed_parks = serializers.IntegerField( - required=False, help_text="Number of temporarily closed parks" - ) - permanently_closed_parks = serializers.IntegerField( - required=False, help_text="Number of permanently closed parks" - ) - under_construction_parks = serializers.IntegerField( - required=False, help_text="Number of parks under construction" - ) - demolished_parks = serializers.IntegerField( - required=False, help_text="Number of demolished parks" - ) - relocated_parks = serializers.IntegerField( - required=False, help_text="Number of relocated parks" - ) + operating_parks = serializers.IntegerField(required=False, help_text="Number of currently operating parks") + temporarily_closed_parks = serializers.IntegerField(required=False, help_text="Number of temporarily closed parks") + permanently_closed_parks = serializers.IntegerField(required=False, help_text="Number of permanently closed parks") + under_construction_parks = serializers.IntegerField(required=False, help_text="Number of parks under construction") + demolished_parks = serializers.IntegerField(required=False, help_text="Number of demolished parks") + relocated_parks = serializers.IntegerField(required=False, help_text="Number of relocated parks") # Ride status counts (optional fields since they depend on data) - operating_rides = serializers.IntegerField( - required=False, help_text="Number of currently operating rides" - ) - temporarily_closed_rides = serializers.IntegerField( - required=False, help_text="Number of temporarily closed rides" - ) - sbno_rides = serializers.IntegerField( - required=False, help_text="Number of rides standing but not operating" - ) - closing_rides = serializers.IntegerField( - required=False, help_text="Number of rides in the process of closing" - ) - permanently_closed_rides = serializers.IntegerField( - required=False, help_text="Number of permanently closed rides" - ) - under_construction_rides = serializers.IntegerField( - required=False, help_text="Number of rides under construction" - ) - demolished_rides = serializers.IntegerField( - required=False, help_text="Number of demolished rides" - ) - relocated_rides = serializers.IntegerField( - required=False, help_text="Number of relocated rides" - ) + operating_rides = serializers.IntegerField(required=False, help_text="Number of currently operating rides") + temporarily_closed_rides = serializers.IntegerField(required=False, help_text="Number of temporarily closed rides") + sbno_rides = serializers.IntegerField(required=False, help_text="Number of rides standing but not operating") + closing_rides = serializers.IntegerField(required=False, help_text="Number of rides in the process of closing") + permanently_closed_rides = serializers.IntegerField(required=False, help_text="Number of permanently closed rides") + under_construction_rides = serializers.IntegerField(required=False, help_text="Number of rides under construction") + demolished_rides = serializers.IntegerField(required=False, help_text="Number of demolished rides") + relocated_rides = serializers.IntegerField(required=False, help_text="Number of relocated rides") # Metadata - last_updated = serializers.CharField( - help_text="ISO timestamp when these statistics were last calculated" - ) + last_updated = serializers.CharField(help_text="ISO timestamp when these statistics were last calculated") relative_last_updated = serializers.CharField( help_text="Human-readable relative time since last update (e.g., '2 minutes ago')" ) diff --git a/backend/apps/api/v1/serializers_rankings.py b/backend/apps/api/v1/serializers_rankings.py index ce90af96..a7d6fb7a 100644 --- a/backend/apps/api/v1/serializers_rankings.py +++ b/backend/apps/api/v1/serializers_rankings.py @@ -87,9 +87,7 @@ class RideRankingSerializer(serializers.ModelSerializer): """Calculate rank change from previous snapshot.""" from apps.rides.models import RankingSnapshot - latest_snapshots = RankingSnapshot.objects.filter(ride=obj.ride).order_by( - "-snapshot_date" - )[:2] + latest_snapshots = RankingSnapshot.objects.filter(ride=obj.ride).order_by("-snapshot_date")[:2] if len(latest_snapshots) >= 2: return latest_snapshots[0].rank - latest_snapshots[1].rank @@ -100,9 +98,7 @@ class RideRankingSerializer(serializers.ModelSerializer): """Get previous rank.""" from apps.rides.models import RankingSnapshot - latest_snapshots = RankingSnapshot.objects.filter(ride=obj.ride).order_by( - "-snapshot_date" - )[:2] + latest_snapshots = RankingSnapshot.objects.filter(ride=obj.ride).order_by("-snapshot_date")[:2] if len(latest_snapshots) >= 2: return latest_snapshots[1].rank @@ -149,28 +145,14 @@ class RideRankingDetailSerializer(serializers.ModelSerializer): "name": ride.park.name, "slug": ride.park.slug, "location": { - "city": ( - ride.park.location.city - if hasattr(ride.park, "location") - else None - ), - "state": ( - ride.park.location.state - if hasattr(ride.park, "location") - else None - ), - "country": ( - ride.park.location.country - if hasattr(ride.park, "location") - else None - ), + "city": (ride.park.location.city if hasattr(ride.park, "location") else None), + "state": (ride.park.location.state if hasattr(ride.park, "location") else None), + "country": (ride.park.location.country if hasattr(ride.park, "location") else None), }, }, "category": ride.category, "manufacturer": ( - {"id": ride.manufacturer.id, "name": ride.manufacturer.name} - if ride.manufacturer - else None + {"id": ride.manufacturer.id, "name": ride.manufacturer.name} if ride.manufacturer else None ), "opening_date": ride.opening_date, "status": ride.status, @@ -225,9 +207,7 @@ class RideRankingDetailSerializer(serializers.ModelSerializer): """Get recent ranking history.""" from apps.rides.models import RankingSnapshot - history = RankingSnapshot.objects.filter(ride=obj.ride).order_by( - "-snapshot_date" - )[:30] + history = RankingSnapshot.objects.filter(ride=obj.ride).order_by("-snapshot_date")[:30] return [ { diff --git a/backend/apps/api/v1/tests/test_contracts.py b/backend/apps/api/v1/tests/test_contracts.py index 3657a430..8881ad16 100644 --- a/backend/apps/api/v1/tests/test_contracts.py +++ b/backend/apps/api/v1/tests/test_contracts.py @@ -29,40 +29,43 @@ class FilterMetadataContractTests(TestCase): metadata = smart_park_loader.get_filter_metadata() # Should have required top-level keys - self.assertIn('categorical', metadata) - self.assertIn('ranges', metadata) - self.assertIn('total_count', metadata) + self.assertIn("categorical", metadata) + self.assertIn("ranges", metadata) + self.assertIn("total_count", metadata) # Categorical filters should be objects with value/label/count - categorical = metadata['categorical'] + categorical = metadata["categorical"] self.assertIsInstance(categorical, dict) for filter_name, filter_options in categorical.items(): with self.subTest(filter_name=filter_name): - self.assertIsInstance(filter_options, list, - f"Filter '{filter_name}' should be a list") + self.assertIsInstance(filter_options, list, f"Filter '{filter_name}' should be a list") for i, option in enumerate(filter_options): with self.subTest(filter_name=filter_name, option_index=i): - self.assertIsInstance(option, dict, - f"Filter '{filter_name}' option {i} should be an object, not {type(option).__name__}") + self.assertIsInstance( + option, + dict, + f"Filter '{filter_name}' option {i} should be an object, not {type(option).__name__}", + ) # Check required properties - self.assertIn('value', option, - f"Filter '{filter_name}' option {i} missing 'value' property") - self.assertIn('label', option, - f"Filter '{filter_name}' option {i} missing 'label' property") + self.assertIn("value", option, f"Filter '{filter_name}' option {i} missing 'value' property") + self.assertIn("label", option, f"Filter '{filter_name}' option {i} missing 'label' property") # Check types - self.assertIsInstance(option['value'], str, - f"Filter '{filter_name}' option {i} 'value' should be string") - self.assertIsInstance(option['label'], str, - f"Filter '{filter_name}' option {i} 'label' should be string") + self.assertIsInstance( + option["value"], str, f"Filter '{filter_name}' option {i} 'value' should be string" + ) + self.assertIsInstance( + option["label"], str, f"Filter '{filter_name}' option {i} 'label' should be string" + ) # Count is optional but should be int if present - if 'count' in option and option['count'] is not None: - self.assertIsInstance(option['count'], int, - f"Filter '{filter_name}' option {i} 'count' should be int") + if "count" in option and option["count"] is not None: + self.assertIsInstance( + option["count"], int, f"Filter '{filter_name}' option {i} 'count' should be int" + ) def test_rides_filter_metadata_structure(self): """Test that rides filter metadata has correct structure.""" @@ -70,16 +73,16 @@ class FilterMetadataContractTests(TestCase): metadata = loader.get_filter_metadata() # Should have required top-level keys - self.assertIn('categorical', metadata) - self.assertIn('ranges', metadata) - self.assertIn('total_count', metadata) + self.assertIn("categorical", metadata) + self.assertIn("ranges", metadata) + self.assertIn("total_count", metadata) # Categorical filters should be objects with value/label/count - categorical = metadata['categorical'] + categorical = metadata["categorical"] self.assertIsInstance(categorical, dict) # Test specific categorical filters that were problematic - critical_filters = ['categories', 'statuses', 'roller_coaster_types', 'track_materials'] + critical_filters = ["categories", "statuses", "roller_coaster_types", "track_materials"] for filter_name in critical_filters: if filter_name in categorical: @@ -89,40 +92,42 @@ class FilterMetadataContractTests(TestCase): for i, option in enumerate(filter_options): with self.subTest(filter_name=filter_name, option_index=i): - self.assertIsInstance(option, dict, - f"CRITICAL: Filter '{filter_name}' option {i} is {type(option).__name__} but should be dict") + self.assertIsInstance( + option, + dict, + f"CRITICAL: Filter '{filter_name}' option {i} is {type(option).__name__} but should be dict", + ) - self.assertIn('value', option) - self.assertIn('label', option) - self.assertIn('count', option) + self.assertIn("value", option) + self.assertIn("label", option) + self.assertIn("count", option) def test_range_metadata_structure(self): """Test that range metadata has correct structure.""" # Test parks ranges parks_metadata = smart_park_loader.get_filter_metadata() - ranges = parks_metadata['ranges'] + ranges = parks_metadata["ranges"] for range_name, range_data in ranges.items(): with self.subTest(range_name=range_name): - self.assertIsInstance(range_data, dict, - f"Range '{range_name}' should be an object") + self.assertIsInstance(range_data, dict, f"Range '{range_name}' should be an object") # Check required properties - self.assertIn('min', range_data) - self.assertIn('max', range_data) - self.assertIn('step', range_data) - self.assertIn('unit', range_data) + self.assertIn("min", range_data) + self.assertIn("max", range_data) + self.assertIn("step", range_data) + self.assertIn("unit", range_data) # Check types (min/max can be None) - if range_data['min'] is not None: - self.assertIsInstance(range_data['min'], (int, float)) - if range_data['max'] is not None: - self.assertIsInstance(range_data['max'], (int, float)) + if range_data["min"] is not None: + self.assertIsInstance(range_data["min"], (int, float)) + if range_data["max"] is not None: + self.assertIsInstance(range_data["max"], (int, float)) - self.assertIsInstance(range_data['step'], (int, float)) + self.assertIsInstance(range_data["step"], (int, float)) # Unit can be None or string - if range_data['unit'] is not None: - self.assertIsInstance(range_data['unit'], str) + if range_data["unit"] is not None: + self.assertIsInstance(range_data["unit"], str) class ContractValidationUtilityTests(TestCase): @@ -131,38 +136,29 @@ class ContractValidationUtilityTests(TestCase): def test_validate_filter_metadata_contract_valid(self): """Test validation passes for valid filter metadata.""" valid_metadata = { - 'categorical': { - 'statuses': [ - {'value': 'OPERATING', 'label': 'Operating', 'count': 5}, - {'value': 'CLOSED_TEMP', 'label': 'Temporarily Closed', 'count': 2} + "categorical": { + "statuses": [ + {"value": "OPERATING", "label": "Operating", "count": 5}, + {"value": "CLOSED_TEMP", "label": "Temporarily Closed", "count": 2}, ] }, - 'ranges': { - 'rating': { - 'min': 1.0, - 'max': 10.0, - 'step': 0.1, - 'unit': 'stars' - } - }, - 'total_count': 100 + "ranges": {"rating": {"min": 1.0, "max": 10.0, "step": 0.1, "unit": "stars"}}, + "total_count": 100, } # Should not raise an exception validated = validate_filter_metadata_contract(valid_metadata) self.assertIsInstance(validated, dict) - self.assertEqual(validated['total_count'], 100) + self.assertEqual(validated["total_count"], 100) def test_validate_filter_metadata_contract_invalid(self): """Test validation fails for invalid filter metadata.""" from rest_framework import serializers invalid_metadata = { - 'categorical': { - 'statuses': ['OPERATING', 'CLOSED_TEMP'] # Should be objects, not strings - }, - 'ranges': {}, - 'total_count': 100 + "categorical": {"statuses": ["OPERATING", "CLOSED_TEMP"]}, # Should be objects, not strings + "ranges": {}, + "total_count": 100, } # Should raise ValidationError @@ -171,82 +167,71 @@ class ContractValidationUtilityTests(TestCase): def test_ensure_filter_option_format_strings(self): """Test converting string arrays to proper format.""" - string_options = ['OPERATING', 'CLOSED_TEMP', 'UNDER_CONSTRUCTION'] + string_options = ["OPERATING", "CLOSED_TEMP", "UNDER_CONSTRUCTION"] formatted = ensure_filter_option_format(string_options) self.assertEqual(len(formatted), 3) for i, option in enumerate(formatted): self.assertIsInstance(option, dict) - self.assertIn('value', option) - self.assertIn('label', option) - self.assertIn('count', option) - self.assertIn('selected', option) + self.assertIn("value", option) + self.assertIn("label", option) + self.assertIn("count", option) + self.assertIn("selected", option) - self.assertEqual(option['value'], string_options[i]) - self.assertEqual(option['label'], string_options[i]) - self.assertIsNone(option['count']) - self.assertFalse(option['selected']) + self.assertEqual(option["value"], string_options[i]) + self.assertEqual(option["label"], string_options[i]) + self.assertIsNone(option["count"]) + self.assertFalse(option["selected"]) def test_ensure_filter_option_format_tuples(self): """Test converting tuple arrays to proper format.""" - tuple_options = [ - ('OPERATING', 'Operating', 5), - ('CLOSED_TEMP', 'Temporarily Closed', 2) - ] + tuple_options = [("OPERATING", "Operating", 5), ("CLOSED_TEMP", "Temporarily Closed", 2)] formatted = ensure_filter_option_format(tuple_options) self.assertEqual(len(formatted), 2) - self.assertEqual(formatted[0]['value'], 'OPERATING') - self.assertEqual(formatted[0]['label'], 'Operating') - self.assertEqual(formatted[0]['count'], 5) + self.assertEqual(formatted[0]["value"], "OPERATING") + self.assertEqual(formatted[0]["label"], "Operating") + self.assertEqual(formatted[0]["count"], 5) - self.assertEqual(formatted[1]['value'], 'CLOSED_TEMP') - self.assertEqual(formatted[1]['label'], 'Temporarily Closed') - self.assertEqual(formatted[1]['count'], 2) + self.assertEqual(formatted[1]["value"], "CLOSED_TEMP") + self.assertEqual(formatted[1]["label"], "Temporarily Closed") + self.assertEqual(formatted[1]["count"], 2) def test_ensure_filter_option_format_dicts(self): """Test that properly formatted dicts pass through correctly.""" dict_options = [ - {'value': 'OPERATING', 'label': 'Operating', 'count': 5}, - {'value': 'CLOSED_TEMP', 'label': 'Temporarily Closed', 'count': 2} + {"value": "OPERATING", "label": "Operating", "count": 5}, + {"value": "CLOSED_TEMP", "label": "Temporarily Closed", "count": 2}, ] formatted = ensure_filter_option_format(dict_options) self.assertEqual(len(formatted), 2) - self.assertEqual(formatted[0]['value'], 'OPERATING') - self.assertEqual(formatted[0]['label'], 'Operating') - self.assertEqual(formatted[0]['count'], 5) + self.assertEqual(formatted[0]["value"], "OPERATING") + self.assertEqual(formatted[0]["label"], "Operating") + self.assertEqual(formatted[0]["count"], 5) def test_ensure_range_format(self): """Test range format utility.""" - range_data = { - 'min': 1.0, - 'max': 10.0, - 'step': 0.5, - 'unit': 'stars' - } + range_data = {"min": 1.0, "max": 10.0, "step": 0.5, "unit": "stars"} formatted = ensure_range_format(range_data) - self.assertEqual(formatted['min'], 1.0) - self.assertEqual(formatted['max'], 10.0) - self.assertEqual(formatted['step'], 0.5) - self.assertEqual(formatted['unit'], 'stars') + self.assertEqual(formatted["min"], 1.0) + self.assertEqual(formatted["max"], 10.0) + self.assertEqual(formatted["step"], 0.5) + self.assertEqual(formatted["unit"], "stars") def test_ensure_range_format_missing_step(self): """Test range format with missing step defaults to 1.0.""" - range_data = { - 'min': 1, - 'max': 10 - } + range_data = {"min": 1, "max": 10} formatted = ensure_range_format(range_data) - self.assertEqual(formatted['step'], 1.0) - self.assertIsNone(formatted['unit']) + self.assertEqual(formatted["step"], 1.0) + self.assertIsNone(formatted["unit"]) class APIEndpointContractTests(APITestCase): @@ -278,26 +263,21 @@ class TypeScriptInterfaceComplianceTests(TestCase): # selected?: boolean; # } - option = { - 'value': 'OPERATING', - 'label': 'Operating', - 'count': 5, - 'selected': False - } + option = {"value": "OPERATING", "label": "Operating", "count": 5, "selected": False} # All required fields present - self.assertIn('value', option) - self.assertIn('label', option) + self.assertIn("value", option) + self.assertIn("label", option) # Correct types - self.assertIsInstance(option['value'], str) - self.assertIsInstance(option['label'], str) + self.assertIsInstance(option["value"], str) + self.assertIsInstance(option["label"], str) # Optional fields have correct types if present - if 'count' in option and option['count'] is not None: - self.assertIsInstance(option['count'], int) - if 'selected' in option: - self.assertIsInstance(option['selected'], bool) + if "count" in option and option["count"] is not None: + self.assertIsInstance(option["count"], int) + if "selected" in option: + self.assertIsInstance(option["selected"], bool) def test_filter_range_interface_compliance(self): """Test FilterRange interface compliance.""" @@ -309,29 +289,24 @@ class TypeScriptInterfaceComplianceTests(TestCase): # unit?: string; # } - range_data = { - 'min': 1.0, - 'max': 10.0, - 'step': 0.1, - 'unit': 'stars' - } + range_data = {"min": 1.0, "max": 10.0, "step": 0.1, "unit": "stars"} # All required fields present - self.assertIn('min', range_data) - self.assertIn('max', range_data) - self.assertIn('step', range_data) + self.assertIn("min", range_data) + self.assertIn("max", range_data) + self.assertIn("step", range_data) # Correct types (min/max can be null) - if range_data['min'] is not None: - self.assertIsInstance(range_data['min'], (int, float)) - if range_data['max'] is not None: - self.assertIsInstance(range_data['max'], (int, float)) + if range_data["min"] is not None: + self.assertIsInstance(range_data["min"], (int, float)) + if range_data["max"] is not None: + self.assertIsInstance(range_data["max"], (int, float)) - self.assertIsInstance(range_data['step'], (int, float)) + self.assertIsInstance(range_data["step"], (int, float)) # Optional unit field - if 'unit' in range_data and range_data['unit'] is not None: - self.assertIsInstance(range_data['unit'], str) + if "unit" in range_data and range_data["unit"] is not None: + self.assertIsInstance(range_data["unit"], str) class RegressionTests(TestCase): @@ -345,7 +320,7 @@ class RegressionTests(TestCase): # Test parks parks_metadata = smart_park_loader.get_filter_metadata() - categorical = parks_metadata.get('categorical', {}) + categorical = parks_metadata.get("categorical", {}) for filter_name, filter_options in categorical.items(): with self.subTest(filter_name=filter_name): @@ -353,19 +328,25 @@ class RegressionTests(TestCase): for i, option in enumerate(filter_options): with self.subTest(filter_name=filter_name, option_index=i): - self.assertIsInstance(option, dict, + self.assertIsInstance( + option, + dict, f"REGRESSION: Filter '{filter_name}' option {i} is a {type(option).__name__} " - f"but should be a dict. This causes frontend crashes!") + f"but should be a dict. This causes frontend crashes!", + ) # Must not be a string - self.assertNotIsInstance(option, str, + self.assertNotIsInstance( + option, + str, f"CRITICAL REGRESSION: Filter '{filter_name}' option {i} is a string '{option}' " - f"but frontend expects object with value/label/count properties!") + f"but frontend expects object with value/label/count properties!", + ) # Test rides rides_loader = SmartRideLoader() rides_metadata = rides_loader.get_filter_metadata() - categorical = rides_metadata.get('categorical', {}) + categorical = rides_metadata.get("categorical", {}) for filter_name, filter_options in categorical.items(): with self.subTest(filter_name=f"rides_{filter_name}"): @@ -373,9 +354,12 @@ class RegressionTests(TestCase): for i, option in enumerate(filter_options): with self.subTest(filter_name=f"rides_{filter_name}", option_index=i): - self.assertIsInstance(option, dict, + self.assertIsInstance( + option, + dict, f"REGRESSION: Rides filter '{filter_name}' option {i} is a {type(option).__name__} " - f"but should be a dict. This causes frontend crashes!") + f"but should be a dict. This causes frontend crashes!", + ) def test_ranges_have_step_and_unit(self): """Regression test: Ensure ranges have step and unit properties.""" @@ -383,18 +367,15 @@ class RegressionTests(TestCase): # Backend was sometimes missing step and unit parks_metadata = smart_park_loader.get_filter_metadata() - ranges = parks_metadata.get('ranges', {}) + ranges = parks_metadata.get("ranges", {}) for range_name, range_data in ranges.items(): with self.subTest(range_name=range_name): - self.assertIn('step', range_data, - f"Range '{range_name}' missing 'step' property required by frontend") - self.assertIn('unit', range_data, - f"Range '{range_name}' missing 'unit' property required by frontend") + self.assertIn("step", range_data, f"Range '{range_name}' missing 'step' property required by frontend") + self.assertIn("unit", range_data, f"Range '{range_name}' missing 'unit' property required by frontend") # Step should be a number - self.assertIsInstance(range_data['step'], (int, float), - f"Range '{range_name}' step should be a number") + self.assertIsInstance(range_data["step"], (int, float), f"Range '{range_name}' step should be a number") def test_no_undefined_values(self): """Regression test: Ensure no undefined values (should be null).""" diff --git a/backend/apps/api/v1/views/auth.py b/backend/apps/api/v1/views/auth.py index 4eef4f3d..2d7a2905 100644 --- a/backend/apps/api/v1/views/auth.py +++ b/backend/apps/api/v1/views/auth.py @@ -54,9 +54,8 @@ except ImportError: # Type hint for the mixin if TYPE_CHECKING: - from typing import Union - TurnstileMixinType = Union[type[FallbackTurnstileMixin], Any] + TurnstileMixinType = type[FallbackTurnstileMixin] | Any else: TurnstileMixinType = TurnstileMixin @@ -87,11 +86,9 @@ class LoginAPIView(TurnstileMixin, APIView): # type: ignore[misc] # Validate Turnstile if configured self.validate_turnstile(request) except ValidationError as e: - return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) - serializer = LoginInputSerializer( - data=request.data, context={"request": request} - ) + serializer = LoginInputSerializer(data=request.data, context={"request": request}) if serializer.is_valid(): # The serializer handles authentication validation user = serializer.validated_data["user"] # type: ignore[index] @@ -106,7 +103,7 @@ class LoginAPIView(TurnstileMixin, APIView): # type: ignore[misc] { "token": token.key, "user": user, - "message": "Login successful", + "detail": "Login successful", } ) return Response(response_serializer.data) @@ -138,7 +135,7 @@ class SignupAPIView(TurnstileMixin, APIView): # type: ignore[misc] # Validate Turnstile if configured self.validate_turnstile(request) except ValidationError as e: - return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) serializer = SignupInputSerializer(data=request.data) if serializer.is_valid(): @@ -152,7 +149,7 @@ class SignupAPIView(TurnstileMixin, APIView): # type: ignore[misc] { "token": token.key, "user": user, - "message": "Registration successful", + "detail": "Registration successful", } ) return Response(response_serializer.data, status=status.HTTP_201_CREATED) @@ -186,14 +183,10 @@ class LogoutAPIView(APIView): # Logout from session logout(request._request) # type: ignore[attr-defined] - response_serializer = LogoutOutputSerializer( - {"message": "Logout successful"} - ) + response_serializer = LogoutOutputSerializer({"detail": "Logout successful"}) return Response(response_serializer.data) except Exception: - return Response( - {"error": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + return Response({"detail": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @extend_schema_view( @@ -237,15 +230,11 @@ class PasswordResetAPIView(APIView): serializer_class = PasswordResetInputSerializer def post(self, request: Request) -> Response: - serializer = PasswordResetInputSerializer( - data=request.data, context={"request": request} - ) + serializer = PasswordResetInputSerializer(data=request.data, context={"request": request}) if serializer.is_valid(): serializer.save() - response_serializer = PasswordResetOutputSerializer( - {"detail": "Password reset email sent"} - ) + response_serializer = PasswordResetOutputSerializer({"detail": "Password reset email sent"}) return Response(response_serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -271,15 +260,11 @@ class PasswordChangeAPIView(APIView): serializer_class = PasswordChangeInputSerializer def post(self, request: Request) -> Response: - serializer = PasswordChangeInputSerializer( - data=request.data, context={"request": request} - ) + serializer = PasswordChangeInputSerializer(data=request.data, context={"request": request}) if serializer.is_valid(): serializer.save() - response_serializer = PasswordChangeOutputSerializer( - {"detail": "Password changed successfully"} - ) + response_serializer = PasswordChangeOutputSerializer({"detail": "Password changed successfully"}) return Response(response_serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -338,9 +323,7 @@ class SocialProvidersAPIView(APIView): provider_name = social_app.name or social_app.provider.title() # Build auth URL efficiently - auth_url = request.build_absolute_uri( - f"/accounts/{social_app.provider}/login/" - ) + auth_url = request.build_absolute_uri(f"/accounts/{social_app.provider}/login/") providers_list.append( { @@ -370,13 +353,9 @@ class SocialProvidersAPIView(APIView): "status": "error", "error": { "code": "SOCIAL_PROVIDERS_ERROR", - "message": "Unable to retrieve social providers", + "detail": "Unable to retrieve social providers", "details": str(e) if str(e) else None, - "request_user": ( - str(request.user) - if hasattr(request, "user") - else "AnonymousUser" - ), + "request_user": (str(request.user) if hasattr(request, "user") else "AnonymousUser"), }, "data": None, }, diff --git a/backend/apps/api/v1/views/base.py b/backend/apps/api/v1/views/base.py index 9436b1ce..89369106 100644 --- a/backend/apps/api/v1/views/base.py +++ b/backend/apps/api/v1/views/base.py @@ -39,7 +39,7 @@ class ContractCompliantAPIView(APIView): response = super().dispatch(request, *args, **kwargs) # Validate contract in DEBUG mode - if settings.DEBUG and hasattr(response, 'data'): + if settings.DEBUG and hasattr(response, "data"): self._validate_response_contract(response.data) return response @@ -49,19 +49,18 @@ class ContractCompliantAPIView(APIView): logger.error( f"API error in {self.__class__.__name__}: {str(e)}", extra={ - 'view_class': self.__class__.__name__, - 'request_path': request.path, - 'request_method': request.method, - 'user': getattr(request, 'user', None), - 'error': str(e) + "view_class": self.__class__.__name__, + "request_path": request.path, + "request_method": request.method, + "user": getattr(request, "user", None), + "detail": str(e), }, - exc_info=True + exc_info=True, ) # Return standardized error response return self.error_response( - message="An internal error occurred", - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + message="An internal error occurred", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) def success_response( @@ -69,7 +68,7 @@ class ContractCompliantAPIView(APIView): data: Any = None, message: str = None, status_code: int = status.HTTP_200_OK, - headers: dict[str, str] = None + headers: dict[str, str] = None, ) -> Response: """ Create a standardized success response. @@ -83,21 +82,15 @@ class ContractCompliantAPIView(APIView): Returns: Response with standardized format """ - response_data = { - 'success': True - } + response_data = {"success": True} if data is not None: - response_data['data'] = data + response_data["data"] = data if message: - response_data['message'] = message + response_data["message"] = message - return Response( - response_data, - status=status_code, - headers=headers - ) + return Response(response_data, status=status_code, headers=headers) def error_response( self, @@ -105,7 +98,7 @@ class ContractCompliantAPIView(APIView): status_code: int = status.HTTP_400_BAD_REQUEST, error_code: str = None, details: Any = None, - headers: dict[str, str] = None + headers: dict[str, str] = None, ) -> Response: """ Create a standardized error response. @@ -120,37 +113,22 @@ class ContractCompliantAPIView(APIView): Returns: Response with standardized error format """ - error_data = { - 'code': error_code or 'API_ERROR', - 'message': message - } + error_data = {"code": error_code or "API_ERROR", "message": message} if details: - error_data['details'] = details + error_data["details"] = details # Add user context if available - if hasattr(self, 'request') and hasattr(self.request, 'user'): + if hasattr(self, "request") and hasattr(self.request, "user"): user = self.request.user if user and user.is_authenticated: - error_data['request_user'] = user.username + error_data["request_user"] = user.username - response_data = { - 'status': 'error', - 'error': error_data, - 'data': None - } + response_data = {"status": "error", "error": error_data, "data": None} - return Response( - response_data, - status=status_code, - headers=headers - ) + return Response(response_data, status=status_code, headers=headers) - def validation_error_response( - self, - errors: dict[str, Any], - message: str = "Validation failed" - ) -> Response: + def validation_error_response(self, errors: dict[str, Any], message: str = "Validation failed") -> Response: """ Create a standardized validation error response. @@ -161,14 +139,7 @@ class ContractCompliantAPIView(APIView): Returns: Response with validation errors """ - return Response( - { - 'success': False, - 'message': message, - 'errors': errors - }, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"success": False, "message": message, "errors": errors}, status=status.HTTP_400_BAD_REQUEST) def _validate_response_contract(self, data: Any) -> None: """ @@ -179,7 +150,7 @@ class ContractCompliantAPIView(APIView): """ try: # Check if this looks like filter metadata - if isinstance(data, dict) and 'categorical' in data and 'ranges' in data: + if isinstance(data, dict) and "categorical" in data and "ranges" in data: validate_filter_metadata_contract(data) # Add more contract validations as needed @@ -188,10 +159,10 @@ class ContractCompliantAPIView(APIView): logger.warning( f"Contract validation failed in {self.__class__.__name__}: {str(e)}", extra={ - 'view_class': self.__class__.__name__, - 'validation_error': str(e), - 'response_data_type': type(data).__name__ - } + "view_class": self.__class__.__name__, + "validation_error": str(e), + "response_data_type": type(data).__name__, + }, ) @@ -225,17 +196,11 @@ class FilterMetadataAPIView(ContractCompliantAPIView): except Exception as e: logger.error( f"Error getting filter metadata in {self.__class__.__name__}: {str(e)}", - extra={ - 'view_class': self.__class__.__name__, - 'error': str(e) - }, - exc_info=True + extra={"view_class": self.__class__.__name__, "detail": str(e)}, + exc_info=True, ) - return self.error_response( - message="Failed to retrieve filter metadata", - error_code="FILTER_METADATA_ERROR" - ) + return self.error_response(message="Failed to retrieve filter metadata", error_code="FILTER_METADATA_ERROR") class HybridFilteringAPIView(ContractCompliantAPIView): @@ -276,17 +241,14 @@ class HybridFilteringAPIView(ContractCompliantAPIView): logger.error( f"Error in hybrid filtering for {self.__class__.__name__}: {str(e)}", extra={ - 'view_class': self.__class__.__name__, - 'filters': getattr(self, '_extracted_filters', {}), - 'error': str(e) + "view_class": self.__class__.__name__, + "filters": getattr(self, "_extracted_filters", {}), + "detail": str(e), }, - exc_info=True + exc_info=True, ) - return self.error_response( - message="Failed to retrieve filtered data", - error_code="HYBRID_FILTERING_ERROR" - ) + return self.error_response(message="Failed to retrieve filtered data", error_code="HYBRID_FILTERING_ERROR") def extract_filters(self, request) -> dict[str, Any]: """ @@ -313,19 +275,19 @@ class HybridFilteringAPIView(ContractCompliantAPIView): def _validate_hybrid_response(self, data: dict[str, Any]) -> None: """Validate hybrid response structure.""" - required_fields = ['strategy', 'total_count'] + required_fields = ["strategy", "total_count"] for field in required_fields: if field not in data: raise ValueError(f"Hybrid response missing required field: {field}") # Validate strategy value - if data['strategy'] not in ['client_side', 'server_side']: + if data["strategy"] not in ["client_side", "server_side"]: raise ValueError(f"Invalid strategy value: {data['strategy']}") # Validate filter metadata if present - if 'filter_metadata' in data: - validate_filter_metadata_contract(data['filter_metadata']) + if "filter_metadata" in data: + validate_filter_metadata_contract(data["filter_metadata"]) class PaginatedAPIView(ContractCompliantAPIView): @@ -340,11 +302,7 @@ class PaginatedAPIView(ContractCompliantAPIView): max_page_size = 100 def get_paginated_response( - self, - queryset, - serializer_class: type[Serializer], - request, - page_size: int = None + self, queryset, serializer_class: type[Serializer], request, page_size: int = None ) -> Response: """ Create a paginated response. @@ -362,13 +320,10 @@ class PaginatedAPIView(ContractCompliantAPIView): # Determine page size if page_size is None: - page_size = min( - int(request.query_params.get('page_size', self.default_page_size)), - self.max_page_size - ) + page_size = min(int(request.query_params.get("page_size", self.default_page_size)), self.max_page_size) # Get page number - page_number = request.query_params.get('page', 1) + page_number = request.query_params.get("page", 1) try: page_number = int(page_number) @@ -389,28 +344,28 @@ class PaginatedAPIView(ContractCompliantAPIView): serializer = serializer_class(page.object_list, many=True) # Build pagination URLs - request_url = request.build_absolute_uri().split('?')[0] + request_url = request.build_absolute_uri().split("?")[0] query_params = request.query_params.copy() next_url = None if page.has_next(): - query_params['page'] = page.next_page_number() + query_params["page"] = page.next_page_number() next_url = f"{request_url}?{query_params.urlencode()}" previous_url = None if page.has_previous(): - query_params['page'] = page.previous_page_number() + query_params["page"] = page.previous_page_number() previous_url = f"{request_url}?{query_params.urlencode()}" # Create response data response_data = { - 'count': paginator.count, - 'next': next_url, - 'previous': previous_url, - 'results': serializer.data, - 'page_size': page_size, - 'current_page': page.number, - 'total_pages': paginator.num_pages + "count": paginator.count, + "next": next_url, + "previous": previous_url, + "results": serializer.data, + "page_size": page_size, + "current_page": page.number, + "total_pages": paginator.num_pages, } return self.success_response(response_data) @@ -430,29 +385,23 @@ def contract_compliant_view(view_class): response = original_dispatch(self, request, *args, **kwargs) # Add contract validation in DEBUG mode - if settings.DEBUG and hasattr(response, 'data'): + if settings.DEBUG and hasattr(response, "data"): # Basic validation - can be extended pass return response except Exception as e: - logger.error( - f"Error in decorated view {view_class.__name__}: {str(e)}", - exc_info=True - ) + logger.error(f"Error in decorated view {view_class.__name__}: {str(e)}", exc_info=True) # Return basic error response return Response( { - 'status': 'error', - 'error': { - 'code': 'API_ERROR', - 'message': 'An internal error occurred' - }, - 'data': None + "status": "error", + "error": {"code": "API_ERROR", "detail": "An internal error occurred"}, + "data": None, }, - status=status.HTTP_500_INTERNAL_SERVER_ERROR + status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) view_class.dispatch = new_dispatch diff --git a/backend/apps/api/v1/views/discovery.py b/backend/apps/api/v1/views/discovery.py index 7ecf5be3..5c36195f 100644 --- a/backend/apps/api/v1/views/discovery.py +++ b/backend/apps/api/v1/views/discovery.py @@ -1,4 +1,3 @@ - from django.utils import timezone from drf_spectacular.utils import extend_schema from rest_framework.permissions import AllowAny @@ -13,6 +12,7 @@ class DiscoveryAPIView(APIView): """ API endpoint for discovery content (Top Lists, Opening/Closing Soon). """ + permission_classes = [AllowAny] @extend_schema( @@ -68,7 +68,7 @@ class DiscoveryAPIView(APIView): "recently_closed": { "parks": self._serialize(recently_closed_parks, "park"), "rides": self._serialize(recently_closed_rides, "ride"), - } + }, } return Response(data) @@ -83,14 +83,13 @@ class DiscoveryAPIView(APIView): "average_rating": item.average_rating, } if type_ == "park": - data.update({ - "city": item.location.city if item.location else None, - "state": item.location.state if item.location else None, - }) + data.update( + { + "city": item.location.city if item.location else None, + "state": item.location.state if item.location else None, + } + ) elif type_ == "ride": - data.update({ - "park_name": item.park.name, - "park_slug": item.park.slug - }) + data.update({"park_name": item.park.name, "park_slug": item.park.slug}) results.append(data) return results diff --git a/backend/apps/api/v1/views/health.py b/backend/apps/api/v1/views/health.py index c8078fc1..f98012c4 100644 --- a/backend/apps/api/v1/views/health.py +++ b/backend/apps/api/v1/views/health.py @@ -30,7 +30,7 @@ class FallbackCacheMonitor: """Fallback class if CacheMonitor is not available.""" def get_cache_stats(self): - return {"error": "Cache monitoring not available"} + return {"detail": "Cache monitoring not available"} class FallbackIndexAnalyzer: @@ -38,7 +38,7 @@ class FallbackIndexAnalyzer: @staticmethod def analyze_slow_queries(threshold): - return {"error": "Query analysis not available"} + return {"detail": "Query analysis not available"} # Try to import the real classes, use fallbacks if not available @@ -56,9 +56,7 @@ except ImportError: @extend_schema_view( get=extend_schema( summary="Health check", - description=( - "Get comprehensive health check information including system metrics." - ), + description=("Get comprehensive health check information including system metrics."), responses={ 200: HealthCheckOutputSerializer, 503: HealthCheckOutputSerializer, @@ -88,7 +86,7 @@ class HealthCheckAPIView(APIView): cache_monitor = CacheMonitor() cache_stats = cache_monitor.get_cache_stats() except Exception: - cache_stats = {"error": "Cache monitoring unavailable"} + cache_stats = {"detail": "Cache monitoring unavailable"} # Build comprehensive health data health_data = { @@ -120,9 +118,7 @@ class HealthCheckAPIView(APIView): critical_service = False response_time = None - plugin_errors = ( - errors.get(plugin_class_name, []) if isinstance(errors, dict) else [] - ) + plugin_errors = errors.get(plugin_class_name, []) if isinstance(errors, dict) else [] health_data["checks"][plugin_name] = { "status": "healthy" if not plugin_errors else "unhealthy", @@ -194,9 +190,7 @@ class HealthCheckAPIView(APIView): "transactions_committed": row[1], "transactions_rolled_back": row[2], "cache_hit_ratio": ( - round((row[4] / (row[3] + row[4])) * 100, 2) - if (row[3] + row[4]) > 0 - else 0 + round((row[4] / (row[3] + row[4])) * 100, 2) if (row[3] + row[4]) > 0 else 0 ), } ) @@ -206,7 +200,7 @@ class HealthCheckAPIView(APIView): return metrics except Exception as e: - return {"connection_status": "error", "error": str(e)} + return {"connection_status": "error", "detail": str(e)} def _get_system_metrics(self) -> dict: """Get system performance metrics.""" @@ -270,7 +264,7 @@ class PerformanceMetricsAPIView(APIView): def get(self, request: Request) -> Response: """Return performance metrics and analysis.""" if not settings.DEBUG: - return Response({"error": "Only available in debug mode"}, status=403) + return Response({"detail": "Only available in debug mode"}, status=403) metrics = { "timestamp": timezone.now(), @@ -306,7 +300,7 @@ class PerformanceMetricsAPIView(APIView): return analysis except Exception as e: - return {"error": str(e)} + return {"detail": str(e)} def _get_cache_performance(self): """Get cache performance metrics.""" @@ -314,14 +308,14 @@ class PerformanceMetricsAPIView(APIView): cache_monitor = CacheMonitor() return cache_monitor.get_cache_stats() except Exception as e: - return {"error": str(e)} + return {"detail": str(e)} def _get_slow_queries(self): """Get recent slow queries.""" try: return IndexAnalyzer.analyze_slow_queries(0.1) # 100ms threshold except Exception as e: - return {"error": str(e)} + return {"detail": str(e)} @extend_schema_view( @@ -336,9 +330,7 @@ class PerformanceMetricsAPIView(APIView): ), options=extend_schema( summary="CORS preflight for simple health check", - description=( - "Handle CORS preflight requests for the simple health check endpoint." - ), + description=("Handle CORS preflight requests for the simple health check endpoint."), responses={ 200: SimpleHealthOutputSerializer, }, @@ -370,7 +362,7 @@ class SimpleHealthAPIView(APIView): except Exception as e: response_data = { "status": "error", - "error": str(e), + "detail": str(e), "timestamp": timezone.now(), } serializer = SimpleHealthOutputSerializer(response_data) diff --git a/backend/apps/api/v1/views/leaderboard.py b/backend/apps/api/v1/views/leaderboard.py index 129f5882..0effe4d8 100644 --- a/backend/apps/api/v1/views/leaderboard.py +++ b/backend/apps/api/v1/views/leaderboard.py @@ -1,6 +1,7 @@ """ Leaderboard views for user rankings """ + from datetime import timedelta from django.db.models import Count, Sum @@ -15,7 +16,7 @@ from apps.reviews.models import Review from apps.rides.models import RideCredit -@api_view(['GET']) +@api_view(["GET"]) @permission_classes([AllowAny]) def leaderboard(request): """ @@ -26,25 +27,25 @@ def leaderboard(request): - period: 'all' | 'monthly' | 'weekly' (default: all) - limit: int (default: 25, max: 100) """ - category = request.query_params.get('category', 'credits') - period = request.query_params.get('period', 'all') - limit = min(int(request.query_params.get('limit', 25)), 100) + category = request.query_params.get("category", "credits") + period = request.query_params.get("period", "all") + limit = min(int(request.query_params.get("limit", 25)), 100) # Calculate date filter based on period date_filter = None - if period == 'weekly': + if period == "weekly": date_filter = timezone.now() - timedelta(days=7) - elif period == 'monthly': + elif period == "monthly": date_filter = timezone.now() - timedelta(days=30) - if category == 'credits': + if category == "credits": return _get_credits_leaderboard(date_filter, limit) - elif category == 'reviews': + elif category == "reviews": return _get_reviews_leaderboard(date_filter, limit) - elif category == 'contributions': + elif category == "contributions": return _get_contributions_leaderboard(date_filter, limit) else: - return Response({'error': 'Invalid category'}, status=400) + return Response({"detail": "Invalid category"}, status=400) def _get_credits_leaderboard(date_filter, limit): @@ -55,26 +56,34 @@ def _get_credits_leaderboard(date_filter, limit): queryset = queryset.filter(created_at__gte=date_filter) # Aggregate credits per user - users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate( - total_credits=Coalesce(Sum('count'), 0), - unique_rides=Count('ride', distinct=True), - ).order_by('-total_credits')[:limit] + users_data = ( + queryset.values("user_id", "user__username", "user__display_name") + .annotate( + total_credits=Coalesce(Sum("count"), 0), + unique_rides=Count("ride", distinct=True), + ) + .order_by("-total_credits")[:limit] + ) results = [] for rank, entry in enumerate(users_data, 1): - results.append({ - 'rank': rank, - 'user_id': entry['user_id'], - 'username': entry['user__username'], - 'display_name': entry['user__display_name'] or entry['user__username'], - 'total_credits': entry['total_credits'], - 'unique_rides': entry['unique_rides'], - }) + results.append( + { + "rank": rank, + "user_id": entry["user_id"], + "username": entry["user__username"], + "display_name": entry["user__display_name"] or entry["user__username"], + "total_credits": entry["total_credits"], + "unique_rides": entry["unique_rides"], + } + ) - return Response({ - 'category': 'credits', - 'results': results, - }) + return Response( + { + "category": "credits", + "results": results, + } + ) def _get_reviews_leaderboard(date_filter, limit): @@ -85,49 +94,65 @@ def _get_reviews_leaderboard(date_filter, limit): queryset = queryset.filter(created_at__gte=date_filter) # Count reviews per user - users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate( - review_count=Count('id'), - ).order_by('-review_count')[:limit] + users_data = ( + queryset.values("user_id", "user__username", "user__display_name") + .annotate( + review_count=Count("id"), + ) + .order_by("-review_count")[:limit] + ) results = [] for rank, entry in enumerate(users_data, 1): - results.append({ - 'rank': rank, - 'user_id': entry['user_id'], - 'username': entry['user__username'], - 'display_name': entry['user__display_name'] or entry['user__username'], - 'review_count': entry['review_count'], - }) + results.append( + { + "rank": rank, + "user_id": entry["user_id"], + "username": entry["user__username"], + "display_name": entry["user__display_name"] or entry["user__username"], + "review_count": entry["review_count"], + } + ) - return Response({ - 'category': 'reviews', - 'results': results, - }) + return Response( + { + "category": "reviews", + "results": results, + } + ) def _get_contributions_leaderboard(date_filter, limit): """Top users by approved contributions.""" - queryset = EditSubmission.objects.filter(status='approved') + queryset = EditSubmission.objects.filter(status="approved") if date_filter: queryset = queryset.filter(created_at__gte=date_filter) # Count contributions per user - users_data = queryset.values('submitted_by_id', 'submitted_by__username', 'submitted_by__display_name').annotate( - contribution_count=Count('id'), - ).order_by('-contribution_count')[:limit] + users_data = ( + queryset.values("submitted_by_id", "submitted_by__username", "submitted_by__display_name") + .annotate( + contribution_count=Count("id"), + ) + .order_by("-contribution_count")[:limit] + ) results = [] for rank, entry in enumerate(users_data, 1): - results.append({ - 'rank': rank, - 'user_id': entry['submitted_by_id'], - 'username': entry['submitted_by__username'], - 'display_name': entry['submitted_by__display_name'] or entry['submitted_by__username'], - 'contribution_count': entry['contribution_count'], - }) + results.append( + { + "rank": rank, + "user_id": entry["submitted_by_id"], + "username": entry["submitted_by__username"], + "display_name": entry["submitted_by__display_name"] or entry["submitted_by__username"], + "contribution_count": entry["contribution_count"], + } + ) - return Response({ - 'category': 'contributions', - 'results': results, - }) + return Response( + { + "category": "contributions", + "results": results, + } + ) diff --git a/backend/apps/api/v1/views/stats.py b/backend/apps/api/v1/views/stats.py index 42c7e766..9860961d 100644 --- a/backend/apps/api/v1/views/stats.py +++ b/backend/apps/api/v1/views/stats.py @@ -186,21 +186,13 @@ class StatsAPIView(APIView): total_rides = Ride.objects.count() # Company counts by role - total_manufacturers = RideCompany.objects.filter( - roles__contains=["MANUFACTURER"] - ).count() + total_manufacturers = RideCompany.objects.filter(roles__contains=["MANUFACTURER"]).count() - total_operators = ParkCompany.objects.filter( - roles__contains=["OPERATOR"] - ).count() + total_operators = ParkCompany.objects.filter(roles__contains=["OPERATOR"]).count() - total_designers = RideCompany.objects.filter( - roles__contains=["DESIGNER"] - ).count() + total_designers = RideCompany.objects.filter(roles__contains=["DESIGNER"]).count() - total_property_owners = ParkCompany.objects.filter( - roles__contains=["PROPERTY_OWNER"] - ).count() + total_property_owners = ParkCompany.objects.filter(roles__contains=["PROPERTY_OWNER"]).count() # Photo counts (combined) total_park_photos = ParkPhoto.objects.count() @@ -211,11 +203,7 @@ class StatsAPIView(APIView): total_roller_coasters = RollerCoasterStats.objects.count() # Ride category counts - ride_categories = ( - Ride.objects.values("category") - .annotate(count=Count("id")) - .exclude(category="") - ) + ride_categories = Ride.objects.values("category").annotate(count=Count("id")).exclude(category="") category_stats = {} for category in ride_categories: @@ -232,9 +220,7 @@ class StatsAPIView(APIView): "OT": "other_rides", } - category_name = category_names.get( - category_code, f"category_{category_code.lower()}" - ) + category_name = category_names.get(category_code, f"category_{category_code.lower()}") category_stats[category_name] = category_count # Park status counts @@ -281,9 +267,7 @@ class StatsAPIView(APIView): "RELOCATED": "relocated_rides", } - status_name = status_names.get( - status_code, f"ride_status_{status_code.lower()}" - ) + status_name = status_names.get(status_code, f"ride_status_{status_code.lower()}") ride_status_stats[status_name] = status_count # Review counts @@ -365,7 +349,7 @@ class StatsRecalculateAPIView(APIView): # Return success response with the fresh stats return Response( { - "message": "Platform statistics have been successfully recalculated", + "detail": "Platform statistics have been successfully recalculated", "stats": fresh_stats, "recalculated_at": timezone.now().isoformat(), }, diff --git a/backend/apps/api/v1/views/trending.py b/backend/apps/api/v1/views/trending.py index afd375ae..0a6345cd 100644 --- a/backend/apps/api/v1/views/trending.py +++ b/backend/apps/api/v1/views/trending.py @@ -127,18 +127,14 @@ class TriggerTrendingCalculationAPIView(APIView): try: # Run trending calculation command with redirect_stdout(trending_output), redirect_stderr(trending_output): - call_command( - "calculate_trending", "--content-type=all", "--limit=50" - ) + call_command("calculate_trending", "--content-type=all", "--limit=50") trending_completed = True except Exception as e: trending_output.write(f"Error: {str(e)}") try: # Run new content calculation command - with redirect_stdout(new_content_output), redirect_stderr( - new_content_output - ): + with redirect_stdout(new_content_output), redirect_stderr(new_content_output): call_command( "calculate_new_content", "--content-type=all", @@ -153,7 +149,7 @@ class TriggerTrendingCalculationAPIView(APIView): return Response( { - "message": "Trending content calculation completed", + "detail": "Trending content calculation completed", "trending_completed": trending_completed, "new_content_completed": new_content_completed, "completion_time": completion_time, @@ -166,7 +162,7 @@ class TriggerTrendingCalculationAPIView(APIView): except Exception as e: return Response( { - "error": "Failed to trigger trending content calculation", + "detail": "Failed to trigger trending content calculation", "details": str(e), }, status=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -213,9 +209,7 @@ class NewContentAPIView(APIView): days_back = min(int(request.query_params.get("days", 30)), 365) # Get new content using direct calculation service - all_new_content = trending_service.get_new_content( - limit=limit * 2, days_back=days_back - ) + all_new_content = trending_service.get_new_content(limit=limit * 2, days_back=days_back) recently_added = [] newly_opened = [] diff --git a/backend/apps/api/v1/viewsets.py b/backend/apps/api/v1/viewsets.py index 823b9c96..319593bd 100644 --- a/backend/apps/api/v1/viewsets.py +++ b/backend/apps/api/v1/viewsets.py @@ -26,7 +26,7 @@ class FallbackCacheMonitor: """Fallback class if CacheMonitor is not available.""" def get_cache_stats(self): - return {"error": "Cache monitoring not available"} + return {"detail": "Cache monitoring not available"} class FallbackIndexAnalyzer: @@ -34,7 +34,7 @@ class FallbackIndexAnalyzer: @staticmethod def analyze_slow_queries(threshold): - return {"error": "Query analysis not available"} + return {"detail": "Query analysis not available"} # Try to import the real classes, use fallbacks if not available diff --git a/backend/apps/api/v1/viewsets_rankings.py b/backend/apps/api/v1/viewsets_rankings.py index 8509d39b..2c911cf4 100644 --- a/backend/apps/api/v1/viewsets_rankings.py +++ b/backend/apps/api/v1/viewsets_rankings.py @@ -155,11 +155,7 @@ class RideRankingViewSet(ReadOnlyModelViewSet): from apps.rides.models import RankingSnapshot ranking = self.get_object() - history = RankingSnapshot.objects.filter(ride=ranking.ride).order_by( - "-snapshot_date" - )[ - :90 - ] # Last 3 months + history = RankingSnapshot.objects.filter(ride=ranking.ride).order_by("-snapshot_date")[:90] # Last 3 months serializer = self.get_serializer(history, many=True) return Response(serializer.data) @@ -180,11 +176,7 @@ class RideRankingViewSet(ReadOnlyModelViewSet): top_rated = RideRanking.objects.select_related("ride", "ride__park").first() # Get most compared ride - most_compared = ( - RideRanking.objects.select_related("ride", "ride__park") - .order_by("-comparison_count") - .first() - ) + most_compared = RideRanking.objects.select_related("ride", "ride__park").order_by("-comparison_count").first() # Get biggest rank change (last 7 days) from datetime import timedelta @@ -197,9 +189,7 @@ class RideRankingViewSet(ReadOnlyModelViewSet): current_rankings = RideRanking.objects.select_related("ride") for ranking in current_rankings[:100]: # Check top 100 for performance old_snapshot = ( - RankingSnapshot.objects.filter( - ride=ranking.ride, snapshot_date__lte=week_ago - ) + RankingSnapshot.objects.filter(ride=ranking.ride, snapshot_date__lte=week_ago) .order_by("-snapshot_date") .first() ) @@ -232,11 +222,7 @@ class RideRankingViewSet(ReadOnlyModelViewSet): "park": top_rated.ride.park.name, "rank": top_rated.rank, "winning_percentage": float(top_rated.winning_percentage), - "average_rating": ( - float(top_rated.average_rating) - if top_rated.average_rating - else None - ), + "average_rating": (float(top_rated.average_rating) if top_rated.average_rating else None), } if top_rated else None @@ -272,9 +258,7 @@ class RideRankingViewSet(ReadOnlyModelViewSet): ranking = self.get_object() comparisons = ( - RidePairComparison.objects.filter( - Q(ride_a=ranking.ride) | Q(ride_b=ranking.ride) - ) + RidePairComparison.objects.filter(Q(ride_a=ranking.ride) | Q(ride_b=ranking.ride)) .select_related("ride_a", "ride_b", "ride_a__park", "ride_b__park") .order_by("-mutual_riders_count")[:50] ) @@ -309,16 +293,8 @@ class RideRankingViewSet(ReadOnlyModelViewSet): "ties": comp.ties, "result": result, "mutual_riders": comp.mutual_riders_count, - "ride_a_avg_rating": ( - float(comp.ride_a_avg_rating) - if comp.ride_a_avg_rating - else None - ), - "ride_b_avg_rating": ( - float(comp.ride_b_avg_rating) - if comp.ride_b_avg_rating - else None - ), + "ride_a_avg_rating": (float(comp.ride_a_avg_rating) if comp.ride_a_avg_rating else None), + "ride_b_avg_rating": (float(comp.ride_b_avg_rating) if comp.ride_b_avg_rating else None), } ) @@ -345,9 +321,7 @@ class TriggerRankingCalculationView(APIView): def post(self, request): """Trigger ranking calculation.""" if not request.user.is_staff: - return Response( - {"error": "Admin access required"}, status=status.HTTP_403_FORBIDDEN - ) + return Response({"detail": "Admin access required"}, status=status.HTTP_403_FORBIDDEN) # Replace direct import with a guarded runtime import to avoid static-analysis/initialization errors try: @@ -367,7 +341,7 @@ class TriggerRankingCalculationView(APIView): if not RideRankingService: return Response( - {"error": "Ranking service unavailable"}, + {"detail": "Ranking service unavailable"}, status=status.HTTP_503_SERVICE_UNAVAILABLE, ) diff --git a/backend/apps/blog/models.py b/backend/apps/blog/models.py index 98573ac3..77f1f7dd 100644 --- a/backend/apps/blog/models.py +++ b/backend/apps/blog/models.py @@ -16,6 +16,7 @@ class Tag(SluggedModel): def __str__(self): return self.name + class Post(SluggedModel): title = models.CharField(max_length=255) content = models.TextField(help_text="Markdown content supported") @@ -27,14 +28,10 @@ class Post(SluggedModel): null=True, blank=True, related_name="blog_posts", - help_text="Featured image" + help_text="Featured image", ) - author = models.ForeignKey( - settings.AUTH_USER_MODEL, - on_delete=models.CASCADE, - related_name="blog_posts" - ) + author = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="blog_posts") published_at = models.DateTimeField(null=True, blank=True, db_index=True) is_published = models.BooleanField(default=False, db_index=True) diff --git a/backend/apps/blog/serializers.py b/backend/apps/blog/serializers.py index 1e137ea9..f4c5a31c 100644 --- a/backend/apps/blog/serializers.py +++ b/backend/apps/blog/serializers.py @@ -12,8 +12,10 @@ class TagSerializer(serializers.ModelSerializer): model = Tag fields = ["id", "name", "slug"] + class PostListSerializer(serializers.ModelSerializer): """Lighter serializer for lists""" + author = UserSerializer(read_only=True) tags = TagSerializer(many=True, read_only=True) image = CloudflareImageSerializer(read_only=True) @@ -31,16 +33,13 @@ class PostListSerializer(serializers.ModelSerializer): "tags", ] + class PostDetailSerializer(serializers.ModelSerializer): author = UserSerializer(read_only=True) tags = TagSerializer(many=True, read_only=True) image = CloudflareImageSerializer(read_only=True) image_id = serializers.PrimaryKeyRelatedField( - queryset=CloudflareImage.objects.all(), - source='image', - write_only=True, - required=False, - allow_null=True + queryset=CloudflareImage.objects.all(), source="image", write_only=True, required=False, allow_null=True ) class Meta: diff --git a/backend/apps/blog/views.py b/backend/apps/blog/views.py index bdfa904b..c3a665f3 100644 --- a/backend/apps/blog/views.py +++ b/backend/apps/blog/views.py @@ -14,13 +14,15 @@ class TagViewSet(viewsets.ReadOnlyModelViewSet): permission_classes = [permissions.AllowAny] filter_backends = [filters.SearchFilter] search_fields = ["name"] - pagination_class = None # Tags are usually few + pagination_class = None # Tags are usually few + class PostViewSet(viewsets.ModelViewSet): """ Public API: Read Only (unless staff). Only published posts unless staff. """ + permission_classes = [IsStaffOrReadOnly] filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter] search_fields = ["title", "excerpt", "content"] diff --git a/backend/apps/context_portal/alembic/versions/2025_06_17_initial_schema.py b/backend/apps/context_portal/alembic/versions/2025_06_17_initial_schema.py index 55d3b2e3..f49a0742 100644 --- a/backend/apps/context_portal/alembic/versions/2025_06_17_initial_schema.py +++ b/backend/apps/context_portal/alembic/versions/2025_06_17_initial_schema.py @@ -117,9 +117,7 @@ def upgrade() -> None: sa.Column("status", sa.String(length=50), nullable=False), sa.Column("description", sa.Text(), nullable=False), sa.Column("parent_id", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint( - ["parent_id"], ["progress_entries.id"], ondelete="SET NULL" - ), + sa.ForeignKeyConstraint(["parent_id"], ["progress_entries.id"], ondelete="SET NULL"), sa.PrimaryKeyConstraint("id"), ) op.create_table( diff --git a/backend/apps/core/__init__.py b/backend/apps/core/__init__.py index 240dd811..576e1bbc 100644 --- a/backend/apps/core/__init__.py +++ b/backend/apps/core/__init__.py @@ -9,4 +9,4 @@ system status, and other foundational features. from .choices import core_choices # Ensure choices are registered on app startup -__all__ = ['core_choices'] +__all__ = ["core_choices"] diff --git a/backend/apps/core/admin.py b/backend/apps/core/admin.py index f7a927b3..40129862 100644 --- a/backend/apps/core/admin.py +++ b/backend/apps/core/admin.py @@ -23,9 +23,7 @@ from .models import SlugHistory @admin.register(SlugHistory) -class SlugHistoryAdmin( - ReadOnlyAdminMixin, QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin -): +class SlugHistoryAdmin(ReadOnlyAdminMixin, QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin): """ Admin interface for SlugHistory management. diff --git a/backend/apps/core/admin/mixins.py b/backend/apps/core/admin/mixins.py index 410f9e25..433fb1a0 100644 --- a/backend/apps/core/admin/mixins.py +++ b/backend/apps/core/admin/mixins.py @@ -221,13 +221,9 @@ class ExportActionMixin: writer.writerow(row) response = HttpResponse(output.getvalue(), content_type="text/csv") - response["Content-Disposition"] = ( - f'attachment; filename="{self.get_export_filename("csv")}"' - ) + response["Content-Disposition"] = f'attachment; filename="{self.get_export_filename("csv")}"' - self.message_user( - request, f"Successfully exported {queryset.count()} records to CSV." - ) + self.message_user(request, f"Successfully exported {queryset.count()} records to CSV.") return response @admin.action(description="Export selected to JSON") @@ -250,13 +246,9 @@ class ExportActionMixin: json.dumps(data, indent=2, cls=DjangoJSONEncoder), content_type="application/json", ) - response["Content-Disposition"] = ( - f'attachment; filename="{self.get_export_filename("json")}"' - ) + response["Content-Disposition"] = f'attachment; filename="{self.get_export_filename("json")}"' - self.message_user( - request, f"Successfully exported {queryset.count()} records to JSON." - ) + self.message_user(request, f"Successfully exported {queryset.count()} records to JSON.") return response def get_actions(self, request): diff --git a/backend/apps/core/analytics.py b/backend/apps/core/analytics.py index 2ef8506b..29bb183e 100644 --- a/backend/apps/core/analytics.py +++ b/backend/apps/core/analytics.py @@ -10,9 +10,7 @@ from django.utils import timezone @pghistory.track() class PageView(models.Model): - content_type = models.ForeignKey( - ContentType, on_delete=models.CASCADE, related_name="page_views" - ) + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, related_name="page_views") object_id = models.PositiveIntegerField() content_object = GenericForeignKey("content_type", "object_id") @@ -64,9 +62,7 @@ class PageView(models.Model): return model_class.objects.none() @classmethod - def get_views_growth( - cls, content_type, object_id, current_period_hours, previous_period_hours - ): + def get_views_growth(cls, content_type, object_id, current_period_hours, previous_period_hours): """Get view growth statistics between two time periods. Args: @@ -102,9 +98,7 @@ class PageView(models.Model): if previous_views == 0: growth_percentage = current_views * 100 if current_views > 0 else 0 else: - growth_percentage = ( - (current_views - previous_views) / previous_views - ) * 100 + growth_percentage = ((current_views - previous_views) / previous_views) * 100 return current_views, previous_views, growth_percentage @@ -121,6 +115,4 @@ class PageView(models.Model): int: Total view count """ cutoff = timezone.now() - timedelta(hours=hours) - return cls.objects.filter( - content_type=content_type, object_id=object_id, timestamp__gte=cutoff - ).count() + return cls.objects.filter(content_type=content_type, object_id=object_id, timestamp__gte=cutoff).count() diff --git a/backend/apps/core/api/exceptions.py b/backend/apps/core/api/exceptions.py index ab9bbc48..d9b8f814 100644 --- a/backend/apps/core/api/exceptions.py +++ b/backend/apps/core/api/exceptions.py @@ -31,9 +31,7 @@ from ..logging import get_logger, log_exception logger = get_logger(__name__) -def custom_exception_handler( - exc: Exception, context: dict[str, Any] -) -> Response | None: +def custom_exception_handler(exc: Exception, context: dict[str, Any]) -> Response | None: """ Custom exception handler for DRF that provides standardized error responses. @@ -172,9 +170,7 @@ def custom_exception_handler( request=request, ) - response = Response( - custom_response_data, status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + response = Response(custom_response_data, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return response @@ -234,10 +230,7 @@ def _format_django_validation_errors( """Format Django ValidationError for API response.""" if hasattr(exc, "error_dict"): # Field-specific errors - return { - field: [str(error) for error in errors] - for field, errors in exc.error_dict.items() - } + return {field: [str(error) for error in errors] for field, errors in exc.error_dict.items()} elif hasattr(exc, "error_list"): # Non-field errors return {"non_field_errors": [str(error) for error in exc.error_list]} diff --git a/backend/apps/core/api/mixins.py b/backend/apps/core/api/mixins.py index 9e80526a..5aab1977 100644 --- a/backend/apps/core/api/mixins.py +++ b/backend/apps/core/api/mixins.py @@ -103,15 +103,11 @@ class ApiMixin: # These will raise if not implemented; they also inform static analyzers about their existence. def paginate_queryset(self, queryset): """Override / implement in subclass or provided base if pagination is needed.""" - raise NotImplementedError( - "Subclasses must implement paginate_queryset to enable pagination" - ) + raise NotImplementedError("Subclasses must implement paginate_queryset to enable pagination") def get_paginated_response(self, data): """Override / implement in subclass or provided base to return paginated responses.""" - raise NotImplementedError( - "Subclasses must implement get_paginated_response to enable pagination" - ) + raise NotImplementedError("Subclasses must implement get_paginated_response to enable pagination") def get_object(self): """Default placeholder; subclasses should implement this.""" @@ -168,9 +164,7 @@ class UpdateApiMixin(ApiMixin): def update(self, _request: Request, *_args, **_kwargs) -> Response: """Handle PUT/PATCH requests for updating resources.""" instance = self.get_object() - serializer = self.get_input_serializer( - data=_request.data, partial=_kwargs.get("partial", False) - ) + serializer = self.get_input_serializer(data=_request.data, partial=_kwargs.get("partial", False)) serializer.is_valid(raise_exception=True) # Update the object using the service layer @@ -229,9 +223,7 @@ class ListApiMixin(ApiMixin): Override this method to use selector patterns. Should call selector functions, not access model managers directly. """ - raise NotImplementedError( - "Subclasses must implement get_queryset using selectors" - ) + raise NotImplementedError("Subclasses must implement get_queryset using selectors") def get_output_serializer(self, *args, **kwargs): """Get the output serializer for response.""" diff --git a/backend/apps/core/checks.py b/backend/apps/core/checks.py index c695f97a..35cad755 100644 --- a/backend/apps/core/checks.py +++ b/backend/apps/core/checks.py @@ -28,6 +28,7 @@ from django.core.checks import Error, Tags, Warning, register # Secret Key Validation # ============================================================================= + @register(Tags.security) def check_secret_key(app_configs, **kwargs): """ @@ -38,30 +39,30 @@ def check_secret_key(app_configs, **kwargs): - Key has sufficient entropy (length and character variety) """ errors = [] - secret_key = getattr(settings, 'SECRET_KEY', '') + secret_key = getattr(settings, "SECRET_KEY", "") # Check for empty or missing key if not secret_key: errors.append( Error( - 'SECRET_KEY is not set.', - hint='Set a strong, random SECRET_KEY in your environment.', - id='security.E001', + "SECRET_KEY is not set.", + hint="Set a strong, random SECRET_KEY in your environment.", + id="security.E001", ) ) return errors # Check for known insecure default values insecure_defaults = [ - 'django-insecure', - 'your-secret-key', - 'change-me', - 'changeme', - 'secret', - 'xxx', - 'test', - 'development', - 'dev-key', + "django-insecure", + "your-secret-key", + "change-me", + "changeme", + "secret", + "xxx", + "test", + "development", + "dev-key", ] key_lower = secret_key.lower() @@ -71,7 +72,7 @@ def check_secret_key(app_configs, **kwargs): Error( f'SECRET_KEY appears to contain an insecure default value: "{default}"', hint='Generate a new secret key using: python -c "from django.core.management.utils import get_random_secret_key; print(get_random_secret_key())"', - id='security.E002', + id="security.E002", ) ) break @@ -80,25 +81,25 @@ def check_secret_key(app_configs, **kwargs): if len(secret_key) < 50: errors.append( Warning( - f'SECRET_KEY is only {len(secret_key)} characters long.', - hint='A secret key should be at least 50 characters for proper security.', - id='security.W001', + f"SECRET_KEY is only {len(secret_key)} characters long.", + hint="A secret key should be at least 50 characters for proper security.", + id="security.W001", ) ) # Check for sufficient character variety - has_upper = bool(re.search(r'[A-Z]', secret_key)) - has_lower = bool(re.search(r'[a-z]', secret_key)) - has_digit = bool(re.search(r'[0-9]', secret_key)) + has_upper = bool(re.search(r"[A-Z]", secret_key)) + has_lower = bool(re.search(r"[a-z]", secret_key)) + has_digit = bool(re.search(r"[0-9]", secret_key)) has_special = bool(re.search(r'[!@#$%^&*()_+\-=\[\]{};\':"\\|,.<>\/?]', secret_key)) char_types = sum([has_upper, has_lower, has_digit, has_special]) if char_types < 3: errors.append( Warning( - 'SECRET_KEY lacks character variety.', - hint='A good secret key should contain uppercase, lowercase, digits, and special characters.', - id='security.W002', + "SECRET_KEY lacks character variety.", + hint="A good secret key should contain uppercase, lowercase, digits, and special characters.", + id="security.W002", ) ) @@ -109,6 +110,7 @@ def check_secret_key(app_configs, **kwargs): # Debug Mode Check # ============================================================================= + @register(Tags.security) def check_debug_mode(app_configs, **kwargs): """ @@ -117,27 +119,27 @@ def check_debug_mode(app_configs, **kwargs): errors = [] # Check if we're in a production-like environment - env = os.environ.get('DJANGO_SETTINGS_MODULE', '') - is_production = 'production' in env.lower() or 'prod' in env.lower() + env = os.environ.get("DJANGO_SETTINGS_MODULE", "") + is_production = "production" in env.lower() or "prod" in env.lower() if is_production and settings.DEBUG: errors.append( Error( - 'DEBUG is True in what appears to be a production environment.', - hint='Set DEBUG=False in production settings.', - id='security.E003', + "DEBUG is True in what appears to be a production environment.", + hint="Set DEBUG=False in production settings.", + id="security.E003", ) ) # Also check if DEBUG is True with ALLOWED_HOSTS configured # (indicates possible production deployment with debug on) - if settings.DEBUG and settings.ALLOWED_HOSTS and '*' not in settings.ALLOWED_HOSTS: - if len(settings.ALLOWED_HOSTS) > 0 and 'localhost' not in settings.ALLOWED_HOSTS[0]: + if settings.DEBUG and settings.ALLOWED_HOSTS and "*" not in settings.ALLOWED_HOSTS: # noqa: SIM102 + if len(settings.ALLOWED_HOSTS) > 0 and "localhost" not in settings.ALLOWED_HOSTS[0]: errors.append( Warning( - 'DEBUG is True but ALLOWED_HOSTS contains non-localhost values.', - hint='This may indicate DEBUG is accidentally enabled in a deployed environment.', - id='security.W003', + "DEBUG is True but ALLOWED_HOSTS contains non-localhost values.", + hint="This may indicate DEBUG is accidentally enabled in a deployed environment.", + id="security.W003", ) ) @@ -148,30 +150,31 @@ def check_debug_mode(app_configs, **kwargs): # ALLOWED_HOSTS Check # ============================================================================= + @register(Tags.security) def check_allowed_hosts(app_configs, **kwargs): """ Check ALLOWED_HOSTS configuration. """ errors = [] - allowed_hosts = getattr(settings, 'ALLOWED_HOSTS', []) + allowed_hosts = getattr(settings, "ALLOWED_HOSTS", []) if not settings.DEBUG: # In non-debug mode, ALLOWED_HOSTS must be set if not allowed_hosts: errors.append( Error( - 'ALLOWED_HOSTS is empty but DEBUG is False.', - hint='Set ALLOWED_HOSTS to a list of allowed hostnames.', - id='security.E004', + "ALLOWED_HOSTS is empty but DEBUG is False.", + hint="Set ALLOWED_HOSTS to a list of allowed hostnames.", + id="security.E004", ) ) - elif '*' in allowed_hosts: + elif "*" in allowed_hosts: errors.append( Error( 'ALLOWED_HOSTS contains "*" which allows all hosts.', - hint='Specify explicit hostnames instead of wildcards.', - id='security.E005', + hint="Specify explicit hostnames instead of wildcards.", + id="security.E005", ) ) @@ -182,6 +185,7 @@ def check_allowed_hosts(app_configs, **kwargs): # Security Headers Check # ============================================================================= + @register(Tags.security) def check_security_headers(app_configs, **kwargs): """ @@ -190,34 +194,34 @@ def check_security_headers(app_configs, **kwargs): errors = [] # Check X-Frame-Options - x_frame_options = getattr(settings, 'X_FRAME_OPTIONS', None) - if x_frame_options not in ('DENY', 'SAMEORIGIN'): + x_frame_options = getattr(settings, "X_FRAME_OPTIONS", None) + if x_frame_options not in ("DENY", "SAMEORIGIN"): errors.append( Warning( f'X_FRAME_OPTIONS is set to "{x_frame_options}" or not set.', hint='Set X_FRAME_OPTIONS to "DENY" or "SAMEORIGIN" to prevent clickjacking.', - id='security.W004', + id="security.W004", ) ) # Check content type sniffing protection - if not getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False): + if not getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False): errors.append( Warning( - 'SECURE_CONTENT_TYPE_NOSNIFF is not enabled.', - hint='Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.', - id='security.W005', + "SECURE_CONTENT_TYPE_NOSNIFF is not enabled.", + hint="Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.", + id="security.W005", ) ) # Check referrer policy - referrer_policy = getattr(settings, 'SECURE_REFERRER_POLICY', None) + referrer_policy = getattr(settings, "SECURE_REFERRER_POLICY", None) if not referrer_policy: errors.append( Warning( - 'SECURE_REFERRER_POLICY is not set.', - hint='Set SECURE_REFERRER_POLICY to control referrer header behavior.', - id='security.W006', + "SECURE_REFERRER_POLICY is not set.", + hint="Set SECURE_REFERRER_POLICY to control referrer header behavior.", + id="security.W006", ) ) @@ -228,6 +232,7 @@ def check_security_headers(app_configs, **kwargs): # HTTPS Settings Check # ============================================================================= + @register(Tags.security) def check_https_settings(app_configs, **kwargs): """ @@ -240,32 +245,32 @@ def check_https_settings(app_configs, **kwargs): return errors # Check SSL redirect - if not getattr(settings, 'SECURE_SSL_REDIRECT', False): + if not getattr(settings, "SECURE_SSL_REDIRECT", False): errors.append( Warning( - 'SECURE_SSL_REDIRECT is not enabled.', - hint='Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.', - id='security.W007', + "SECURE_SSL_REDIRECT is not enabled.", + hint="Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.", + id="security.W007", ) ) # Check HSTS settings - hsts_seconds = getattr(settings, 'SECURE_HSTS_SECONDS', 0) + hsts_seconds = getattr(settings, "SECURE_HSTS_SECONDS", 0) if hsts_seconds < 31536000: # Less than 1 year errors.append( Warning( - f'SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).', - hint='Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.', - id='security.W008', + f"SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).", + hint="Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.", + id="security.W008", ) ) - if not getattr(settings, 'SECURE_HSTS_INCLUDE_SUBDOMAINS', False): + if not getattr(settings, "SECURE_HSTS_INCLUDE_SUBDOMAINS", False): errors.append( Warning( - 'SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.', - hint='Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.', - id='security.W009', + "SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.", + hint="Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.", + id="security.W009", ) ) @@ -276,6 +281,7 @@ def check_https_settings(app_configs, **kwargs): # Cookie Security Check # ============================================================================= + @register(Tags.security) def check_cookie_security(app_configs, **kwargs): """ @@ -288,42 +294,42 @@ def check_cookie_security(app_configs, **kwargs): return errors # Check session cookie security - if not getattr(settings, 'SESSION_COOKIE_SECURE', False): + if not getattr(settings, "SESSION_COOKIE_SECURE", False): errors.append( Warning( - 'SESSION_COOKIE_SECURE is not enabled.', - hint='Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.', - id='security.W010', + "SESSION_COOKIE_SECURE is not enabled.", + hint="Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.", + id="security.W010", ) ) - if not getattr(settings, 'SESSION_COOKIE_HTTPONLY', True): + if not getattr(settings, "SESSION_COOKIE_HTTPONLY", True): errors.append( Warning( - 'SESSION_COOKIE_HTTPONLY is disabled.', - hint='Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.', - id='security.W011', + "SESSION_COOKIE_HTTPONLY is disabled.", + hint="Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.", + id="security.W011", ) ) # Check CSRF cookie security - if not getattr(settings, 'CSRF_COOKIE_SECURE', False): + if not getattr(settings, "CSRF_COOKIE_SECURE", False): errors.append( Warning( - 'CSRF_COOKIE_SECURE is not enabled.', - hint='Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.', - id='security.W012', + "CSRF_COOKIE_SECURE is not enabled.", + hint="Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.", + id="security.W012", ) ) # Check SameSite attributes - session_samesite = getattr(settings, 'SESSION_COOKIE_SAMESITE', 'Lax') - if session_samesite not in ('Strict', 'Lax'): + session_samesite = getattr(settings, "SESSION_COOKIE_SAMESITE", "Lax") + if session_samesite not in ("Strict", "Lax"): errors.append( Warning( f'SESSION_COOKIE_SAMESITE is set to "{session_samesite}".', hint='Set SESSION_COOKIE_SAMESITE to "Strict" or "Lax" for CSRF protection.', - id='security.W013', + id="security.W013", ) ) @@ -334,6 +340,7 @@ def check_cookie_security(app_configs, **kwargs): # Database Security Check # ============================================================================= + @register(Tags.security) def check_database_security(app_configs, **kwargs): """ @@ -345,27 +352,27 @@ def check_database_security(app_configs, **kwargs): if settings.DEBUG: return errors - databases = getattr(settings, 'DATABASES', {}) - default_db = databases.get('default', {}) + databases = getattr(settings, "DATABASES", {}) + default_db = databases.get("default", {}) # Check for empty password - if not default_db.get('PASSWORD') and default_db.get('ENGINE', '').endswith('postgresql'): + if not default_db.get("PASSWORD") and default_db.get("ENGINE", "").endswith("postgresql"): errors.append( Warning( - 'Database password is empty.', - hint='Set a strong password for database authentication.', - id='security.W014', + "Database password is empty.", + hint="Set a strong password for database authentication.", + id="security.W014", ) ) # Check for SSL mode in PostgreSQL - options = default_db.get('OPTIONS', {}) - if 'sslmode' not in str(options) and default_db.get('ENGINE', '').endswith('postgresql'): + options = default_db.get("OPTIONS", {}) + if "sslmode" not in str(options) and default_db.get("ENGINE", "").endswith("postgresql"): errors.append( Warning( - 'Database SSL mode is not explicitly configured.', - hint='Consider setting sslmode in database OPTIONS for encrypted connections.', - id='security.W015', + "Database SSL mode is not explicitly configured.", + hint="Consider setting sslmode in database OPTIONS for encrypted connections.", + id="security.W015", ) ) diff --git a/backend/apps/core/choices/__init__.py b/backend/apps/core/choices/__init__.py index ffa24bdb..ac3bf6b2 100644 --- a/backend/apps/core/choices/__init__.py +++ b/backend/apps/core/choices/__init__.py @@ -19,14 +19,14 @@ from .serializers import RichChoiceOptionSerializer, RichChoiceSerializer from .utils import get_choice_display, validate_choice_value __all__ = [ - 'RichChoice', - 'ChoiceCategory', - 'ChoiceGroup', - 'ChoiceRegistry', - 'register_choices', - 'RichChoiceField', - 'RichChoiceSerializer', - 'RichChoiceOptionSerializer', - 'validate_choice_value', - 'get_choice_display', + "RichChoice", + "ChoiceCategory", + "ChoiceGroup", + "ChoiceRegistry", + "register_choices", + "RichChoiceField", + "RichChoiceSerializer", + "RichChoiceOptionSerializer", + "validate_choice_value", + "get_choice_display", ] diff --git a/backend/apps/core/choices/base.py b/backend/apps/core/choices/base.py index 4ce5cf70..ea826a16 100644 --- a/backend/apps/core/choices/base.py +++ b/backend/apps/core/choices/base.py @@ -11,6 +11,7 @@ from typing import Any class ChoiceCategory(Enum): """Categories for organizing choice types""" + STATUS = "status" TYPE = "type" CLASSIFICATION = "classification" @@ -42,6 +43,7 @@ class RichChoice: deprecated: Whether this choice is deprecated and should not be used for new entries category: Category for organizing related choices """ + value: str label: str description: str = "" @@ -59,40 +61,38 @@ class RichChoice: @property def color(self) -> str | None: """Get the color from metadata if available""" - return self.metadata.get('color') + return self.metadata.get("color") @property def icon(self) -> str | None: """Get the icon from metadata if available""" - return self.metadata.get('icon') + return self.metadata.get("icon") @property def css_class(self) -> str | None: """Get the CSS class from metadata if available""" - return self.metadata.get('css_class') + return self.metadata.get("css_class") @property def sort_order(self) -> int: """Get the sort order from metadata, defaulting to 0""" - return self.metadata.get('sort_order', 0) - + return self.metadata.get("sort_order", 0) def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation for API serialization""" return { - 'value': self.value, - 'label': self.label, - 'description': self.description, - 'metadata': self.metadata, - 'deprecated': self.deprecated, - 'category': self.category.value, - 'color': self.color, - 'icon': self.icon, - 'css_class': self.css_class, - 'sort_order': self.sort_order, + "value": self.value, + "label": self.label, + "description": self.description, + "metadata": self.metadata, + "deprecated": self.deprecated, + "category": self.category.value, + "color": self.color, + "icon": self.icon, + "css_class": self.css_class, + "sort_order": self.sort_order, } - def __str__(self) -> str: return self.label @@ -108,6 +108,7 @@ class ChoiceGroup: This allows for organizing choices into logical groups with common properties and behaviors. """ + name: str choices: list[RichChoice] description: str = "" @@ -147,8 +148,8 @@ class ChoiceGroup: def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation for API serialization""" return { - 'name': self.name, - 'description': self.description, - 'metadata': self.metadata, - 'choices': [choice.to_dict() for choice in self.choices] + "name": self.name, + "description": self.description, + "metadata": self.metadata, + "choices": [choice.to_dict() for choice in self.choices], } diff --git a/backend/apps/core/choices/core_choices.py b/backend/apps/core/choices/core_choices.py index c5a7f34f..6b337540 100644 --- a/backend/apps/core/choices/core_choices.py +++ b/backend/apps/core/choices/core_choices.py @@ -15,26 +15,26 @@ HEALTH_STATUSES = [ label="Healthy", description="System is operating normally with no issues detected", metadata={ - 'color': 'green', - 'icon': 'check-circle', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 1, - 'http_status': 200 + "color": "green", + "icon": "check-circle", + "css_class": "bg-green-100 text-green-800", + "sort_order": 1, + "http_status": 200, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="unhealthy", label="Unhealthy", description="System has detected issues that may affect functionality", metadata={ - 'color': 'red', - 'icon': 'x-circle', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 2, - 'http_status': 503 + "color": "red", + "icon": "x-circle", + "css_class": "bg-red-100 text-red-800", + "sort_order": 2, + "http_status": 503, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -45,26 +45,26 @@ SIMPLE_HEALTH_STATUSES = [ label="OK", description="Basic health check passed", metadata={ - 'color': 'green', - 'icon': 'check', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 1, - 'http_status': 200 + "color": "green", + "icon": "check", + "css_class": "bg-green-100 text-green-800", + "sort_order": 1, + "http_status": 200, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="error", label="Error", description="Basic health check failed", metadata={ - 'color': 'red', - 'icon': 'x', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 2, - 'http_status': 500 + "color": "red", + "icon": "x", + "css_class": "bg-red-100 text-red-800", + "sort_order": 2, + "http_status": 500, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -75,52 +75,52 @@ ENTITY_TYPES = [ label="Park", description="Theme parks and amusement parks", metadata={ - 'color': 'green', - 'icon': 'map-pin', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 1, - 'search_weight': 1.0 + "color": "green", + "icon": "map-pin", + "css_class": "bg-green-100 text-green-800", + "sort_order": 1, + "search_weight": 1.0, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="ride", label="Ride", description="Individual rides and attractions", metadata={ - 'color': 'blue', - 'icon': 'activity', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 2, - 'search_weight': 1.0 + "color": "blue", + "icon": "activity", + "css_class": "bg-blue-100 text-blue-800", + "sort_order": 2, + "search_weight": 1.0, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="company", label="Company", description="Manufacturers, operators, and designers", metadata={ - 'color': 'purple', - 'icon': 'building', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 3, - 'search_weight': 0.8 + "color": "purple", + "icon": "building", + "css_class": "bg-purple-100 text-purple-800", + "sort_order": 3, + "search_weight": 0.8, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="user", label="User", description="User profiles and accounts", metadata={ - 'color': 'orange', - 'icon': 'user', - 'css_class': 'bg-orange-100 text-orange-800', - 'sort_order': 4, - 'search_weight': 0.5 + "color": "orange", + "icon": "user", + "css_class": "bg-orange-100 text-orange-800", + "sort_order": 4, + "search_weight": 0.5, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -133,7 +133,7 @@ def register_core_choices(): choices=HEALTH_STATUSES, domain="core", description="Health check status options", - metadata={'domain': 'core', 'type': 'health_status'} + metadata={"domain": "core", "type": "health_status"}, ) register_choices( @@ -141,7 +141,7 @@ def register_core_choices(): choices=SIMPLE_HEALTH_STATUSES, domain="core", description="Simple health check status options", - metadata={'domain': 'core', 'type': 'simple_health_status'} + metadata={"domain": "core", "type": "simple_health_status"}, ) register_choices( @@ -149,7 +149,7 @@ def register_core_choices(): choices=ENTITY_TYPES, domain="core", description="Entity type classifications for search functionality", - metadata={'domain': 'core', 'type': 'entity_type'} + metadata={"domain": "core", "type": "entity_type"}, ) diff --git a/backend/apps/core/choices/fields.py b/backend/apps/core/choices/fields.py index 6ef21af7..a3f54be7 100644 --- a/backend/apps/core/choices/fields.py +++ b/backend/apps/core/choices/fields.py @@ -23,12 +23,7 @@ class RichChoiceField(models.CharField): """ def __init__( - self, - choice_group: str, - domain: str = "core", - max_length: int = 50, - allow_deprecated: bool = False, - **kwargs + self, choice_group: str, domain: str = "core", max_length: int = 50, allow_deprecated: bool = False, **kwargs ): """ Initialize the RichChoiceField. @@ -52,8 +47,8 @@ class RichChoiceField(models.CharField): choices = [(choice.value, choice.label) for choice in choices_list] - kwargs['choices'] = choices - kwargs['max_length'] = max_length + kwargs["choices"] = choices + kwargs["max_length"] = max_length super().__init__(**kwargs) @@ -61,21 +56,17 @@ class RichChoiceField(models.CharField): """Validate the choice value""" super().validate(value, model_instance) - if value is None or value == '': + if value is None or value == "": return # Check if choice exists in registry choice = registry.get_choice(self.choice_group, value, self.domain) if choice is None: - raise ValidationError( - f"'{value}' is not a valid choice for {self.choice_group}" - ) + raise ValidationError(f"'{value}' is not a valid choice for {self.choice_group}") # Check if deprecated choices are allowed if choice.deprecated and not self.allow_deprecated: - raise ValidationError( - f"'{value}' is deprecated and cannot be used for new entries" - ) + raise ValidationError(f"'{value}' is deprecated and cannot be used for new entries") def get_rich_choice(self, value: str) -> RichChoice | None: """Get the RichChoice object for a value""" @@ -94,21 +85,21 @@ class RichChoiceField(models.CharField): value = getattr(instance, name) return self.get_rich_choice(value) if value else None - setattr(cls, f'get_{name}_rich_choice', get_rich_choice_method) + setattr(cls, f"get_{name}_rich_choice", get_rich_choice_method) # Add get_FOO_display method (Django provides this, but we enhance it) def get_display_method(instance): value = getattr(instance, name) - return self.get_choice_display(value) if value else '' + return self.get_choice_display(value) if value else "" - setattr(cls, f'get_{name}_display', get_display_method) + setattr(cls, f"get_{name}_display", get_display_method) def deconstruct(self): """Support for Django migrations""" name, path, args, kwargs = super().deconstruct() - kwargs['choice_group'] = self.choice_group - kwargs['domain'] = self.domain - kwargs['allow_deprecated'] = self.allow_deprecated + kwargs["choice_group"] = self.choice_group + kwargs["domain"] = self.domain + kwargs["allow_deprecated"] = self.allow_deprecated return name, path, args, kwargs @@ -123,7 +114,7 @@ class RichChoiceFormField(ChoiceField): domain: str = "core", allow_deprecated: bool = False, show_descriptions: bool = False, - **kwargs + **kwargs, ): """ Initialize the form field. @@ -154,36 +145,28 @@ class RichChoiceFormField(ChoiceField): label = f"{choice.label} - {choice.description}" choices.append((choice.value, label)) - kwargs['choices'] = choices + kwargs["choices"] = choices super().__init__(**kwargs) def validate(self, value: Any) -> None: """Validate the choice value""" super().validate(value) - if value is None or value == '': + if value is None or value == "": return # Check if choice exists in registry choice = registry.get_choice(self.choice_group, value, self.domain) if choice is None: - raise ValidationError( - f"'{value}' is not a valid choice for {self.choice_group}" - ) + raise ValidationError(f"'{value}' is not a valid choice for {self.choice_group}") # Check if deprecated choices are allowed if choice.deprecated and not self.allow_deprecated: - raise ValidationError( - f"'{value}' is deprecated and cannot be used" - ) + raise ValidationError(f"'{value}' is deprecated and cannot be used") def create_rich_choice_field( - choice_group: str, - domain: str = "core", - max_length: int = 50, - allow_deprecated: bool = False, - **kwargs + choice_group: str, domain: str = "core", max_length: int = 50, allow_deprecated: bool = False, **kwargs ) -> RichChoiceField: """ Factory function to create a RichChoiceField. @@ -192,9 +175,5 @@ def create_rich_choice_field( across multiple models. """ return RichChoiceField( - choice_group=choice_group, - domain=domain, - max_length=max_length, - allow_deprecated=allow_deprecated, - **kwargs + choice_group=choice_group, domain=domain, max_length=max_length, allow_deprecated=allow_deprecated, **kwargs ) diff --git a/backend/apps/core/choices/registry.py b/backend/apps/core/choices/registry.py index 9a15f1e4..6ab50041 100644 --- a/backend/apps/core/choices/registry.py +++ b/backend/apps/core/choices/registry.py @@ -29,7 +29,7 @@ class ChoiceRegistry: choices: list[RichChoice], domain: str = "core", description: str = "", - metadata: dict[str, Any] | None = None + metadata: dict[str, Any] | None = None, ) -> ChoiceGroup: """ Register a group of choices. @@ -65,12 +65,7 @@ class ChoiceRegistry: f"Existing: {existing_values}, New: {new_values}" ) - choice_group = ChoiceGroup( - name=full_name, - choices=choices, - description=description, - metadata=metadata or {} - ) + choice_group = ChoiceGroup(name=full_name, choices=choices, description=description, metadata=metadata or {}) self._choices[full_name] = choice_group @@ -103,7 +98,6 @@ class ChoiceRegistry: choice_group = self.get(name, domain) return choice_group.get_active_choices() if choice_group else [] - def get_domains(self) -> list[str]: """Get all registered domains""" return list(self._domains.keys()) @@ -113,10 +107,7 @@ class ChoiceRegistry: if domain not in self._domains: return {} - return { - name: self._choices[f"{domain}.{name}"] - for name in self._domains[domain] - } + return {name: self._choices[f"{domain}.{name}"] for name in self._domains[domain]} def list_all(self) -> dict[str, ChoiceGroup]: """Get all registered choice groups""" @@ -159,7 +150,7 @@ def register_choices( choices: list[RichChoice], domain: str = "core", description: str = "", - metadata: dict[str, Any] | None = None + metadata: dict[str, Any] | None = None, ) -> ChoiceGroup: """ Convenience function to register choices with the global registry. @@ -187,8 +178,6 @@ def get_choice(group_name: str, value: str, domain: str = "core") -> RichChoice return registry.get_choice(group_name, value, domain) - - def validate_choice(group_name: str, value: str, domain: str = "core") -> bool: """Validate a choice value using the global registry""" return registry.validate_choice(group_name, value, domain) diff --git a/backend/apps/core/choices/serializers.py b/backend/apps/core/choices/serializers.py index ffce354e..7b01ff1d 100644 --- a/backend/apps/core/choices/serializers.py +++ b/backend/apps/core/choices/serializers.py @@ -20,6 +20,7 @@ class RichChoiceSerializer(serializers.Serializer): This provides a consistent API representation for choice objects with all their metadata. """ + value = serializers.CharField() label = serializers.CharField() description = serializers.CharField() @@ -42,6 +43,7 @@ class RichChoiceOptionSerializer(serializers.Serializer): This replaces the legacy FilterOptionSerializer with rich choice support. """ + value = serializers.CharField() label = serializers.CharField() description = serializers.CharField(allow_blank=True) @@ -58,30 +60,30 @@ class RichChoiceOptionSerializer(serializers.Serializer): if isinstance(instance, RichChoice): # Convert RichChoice to option format return { - 'value': instance.value, - 'label': instance.label, - 'description': instance.description, - 'count': None, - 'selected': False, - 'deprecated': instance.deprecated, - 'color': instance.color, - 'icon': instance.icon, - 'css_class': instance.css_class, - 'metadata': instance.metadata, + "value": instance.value, + "label": instance.label, + "description": instance.description, + "count": None, + "selected": False, + "deprecated": instance.deprecated, + "color": instance.color, + "icon": instance.icon, + "css_class": instance.css_class, + "metadata": instance.metadata, } elif isinstance(instance, dict): # Handle dictionary input (for backwards compatibility) return { - 'value': instance.get('value', ''), - 'label': instance.get('label', ''), - 'description': instance.get('description', ''), - 'count': instance.get('count'), - 'selected': instance.get('selected', False), - 'deprecated': instance.get('deprecated', False), - 'color': instance.get('color'), - 'icon': instance.get('icon'), - 'css_class': instance.get('css_class'), - 'metadata': instance.get('metadata', {}), + "value": instance.get("value", ""), + "label": instance.get("label", ""), + "description": instance.get("description", ""), + "count": instance.get("count"), + "selected": instance.get("selected", False), + "deprecated": instance.get("deprecated", False), + "color": instance.get("color"), + "icon": instance.get("icon"), + "css_class": instance.get("css_class"), + "metadata": instance.get("metadata", {}), } else: return super().to_representation(instance) @@ -94,6 +96,7 @@ class ChoiceGroupSerializer(serializers.Serializer): This provides API representation for entire choice groups with all their choices and metadata. """ + name = serializers.CharField() description = serializers.CharField() metadata = serializers.DictField() @@ -112,13 +115,7 @@ class RichChoiceFieldSerializer(serializers.CharField): include rich choice metadata in the response. """ - def __init__( - self, - choice_group: str, - domain: str = "core", - include_metadata: bool = False, - **kwargs - ): + def __init__(self, choice_group: str, domain: str = "core", include_metadata: bool = False, **kwargs): """ Initialize the serializer field. @@ -146,16 +143,16 @@ class RichChoiceFieldSerializer(serializers.CharField): else: # Fallback for unknown values return { - 'value': value, - 'label': value, - 'description': '', - 'metadata': {}, - 'deprecated': False, - 'category': 'other', - 'color': None, - 'icon': None, - 'css_class': None, - 'sort_order': 0, + "value": value, + "label": value, + "description": "", + "metadata": {}, + "deprecated": False, + "category": "other", + "color": None, + "icon": None, + "css_class": None, + "sort_order": 0, } else: # Return just the value @@ -163,20 +160,16 @@ class RichChoiceFieldSerializer(serializers.CharField): def to_internal_value(self, data: Any) -> str: """Convert input data to choice value""" - if isinstance(data, dict) and 'value' in data: + if isinstance(data, dict) and "value" in data: # Handle rich choice object input - return data['value'] + return data["value"] else: # Handle string input return super().to_internal_value(data) def create_choice_options_serializer( - choice_group: str, - domain: str = "core", - include_counts: bool = False, - queryset=None, - count_field: str = 'id' + choice_group: str, domain: str = "core", include_counts: bool = False, queryset=None, count_field: str = "id" ) -> list[dict[str, Any]]: """ Create choice options for filter endpoints. @@ -199,47 +192,44 @@ def create_choice_options_serializer( for choice in choices: option_data = { - 'value': choice.value, - 'label': choice.label, - 'description': choice.description, - 'selected': False, - 'deprecated': choice.deprecated, - 'color': choice.color, - 'icon': choice.icon, - 'css_class': choice.css_class, - 'metadata': choice.metadata, + "value": choice.value, + "label": choice.label, + "description": choice.description, + "selected": False, + "deprecated": choice.deprecated, + "color": choice.color, + "icon": choice.icon, + "css_class": choice.css_class, + "metadata": choice.metadata, } if include_counts and queryset is not None: # Count items for this choice try: count = queryset.filter(**{count_field: choice.value}).count() - option_data['count'] = count + option_data["count"] = count except Exception: # If counting fails, set count to None - option_data['count'] = None + option_data["count"] = None else: - option_data['count'] = None + option_data["count"] = None options.append(option_data) # Sort by sort_order, then by label - options.sort(key=lambda x: ( - (lambda c: c.sort_order if (c is not None and hasattr(c, 'sort_order')) else 0)( - registry.get_choice(choice_group, x['value'], domain) - ), - x['label'] - )) + options.sort( + key=lambda x: ( + (lambda c: c.sort_order if (c is not None and hasattr(c, "sort_order")) else 0)( + registry.get_choice(choice_group, x["value"], domain) + ), + x["label"], + ) + ) return options -def serialize_choice_value( - value: str, - choice_group: str, - domain: str = "core", - include_metadata: bool = False -) -> Any: +def serialize_choice_value(value: str, choice_group: str, domain: str = "core", include_metadata: bool = False) -> Any: """ Serialize a single choice value. @@ -262,16 +252,16 @@ def serialize_choice_value( else: # Fallback for unknown values return { - 'value': value, - 'label': value, - 'description': '', - 'metadata': {}, - 'deprecated': False, - 'category': 'other', - 'color': None, - 'icon': None, - 'css_class': None, - 'sort_order': 0, + "value": value, + "label": value, + "description": "", + "metadata": {}, + "deprecated": False, + "category": "other", + "color": None, + "icon": None, + "css_class": None, + "sort_order": 0, } else: return value diff --git a/backend/apps/core/choices/utils.py b/backend/apps/core/choices/utils.py index 489c4935..3b81e470 100644 --- a/backend/apps/core/choices/utils.py +++ b/backend/apps/core/choices/utils.py @@ -10,12 +10,7 @@ from .base import ChoiceCategory, RichChoice from .registry import registry -def validate_choice_value( - value: str, - choice_group: str, - domain: str = "core", - allow_deprecated: bool = False -) -> bool: +def validate_choice_value(value: str, choice_group: str, domain: str = "core", allow_deprecated: bool = False) -> bool: """ Validate that a choice value is valid for a given choice group. @@ -38,11 +33,7 @@ def validate_choice_value( return not (choice.deprecated and not allow_deprecated) -def get_choice_display( - value: str, - choice_group: str, - domain: str = "core" -) -> str: +def get_choice_display(value: str, choice_group: str, domain: str = "core") -> str: """ Get the display label for a choice value. @@ -67,11 +58,8 @@ def get_choice_display( raise ValueError(f"Choice value '{value}' not found in group '{choice_group}' for domain '{domain}'") - - def create_status_choices( - statuses: dict[str, dict[str, Any]], - category: ChoiceCategory = ChoiceCategory.STATUS + statuses: dict[str, dict[str, Any]], category: ChoiceCategory = ChoiceCategory.STATUS ) -> list[RichChoice]: """ Create status choices with consistent color coding. @@ -86,28 +74,28 @@ def create_status_choices( choices = [] for value, config in statuses.items(): - metadata = config.get('metadata', {}) + metadata = config.get("metadata", {}) # Add default status colors if not specified - if 'color' not in metadata: - if 'operating' in value.lower() or 'active' in value.lower(): - metadata['color'] = 'green' - elif 'closed' in value.lower() or 'inactive' in value.lower(): - metadata['color'] = 'red' - elif 'temp' in value.lower() or 'pending' in value.lower(): - metadata['color'] = 'yellow' - elif 'construction' in value.lower(): - metadata['color'] = 'blue' + if "color" not in metadata: + if "operating" in value.lower() or "active" in value.lower(): + metadata["color"] = "green" + elif "closed" in value.lower() or "inactive" in value.lower(): + metadata["color"] = "red" + elif "temp" in value.lower() or "pending" in value.lower(): + metadata["color"] = "yellow" + elif "construction" in value.lower(): + metadata["color"] = "blue" else: - metadata['color'] = 'gray' + metadata["color"] = "gray" choice = RichChoice( value=value, - label=config['label'], - description=config.get('description', ''), + label=config["label"], + description=config.get("description", ""), metadata=metadata, - deprecated=config.get('deprecated', False), - category=category + deprecated=config.get("deprecated", False), + category=category, ) choices.append(choice) @@ -115,8 +103,7 @@ def create_status_choices( def create_type_choices( - types: dict[str, dict[str, Any]], - category: ChoiceCategory = ChoiceCategory.TYPE + types: dict[str, dict[str, Any]], category: ChoiceCategory = ChoiceCategory.TYPE ) -> list[RichChoice]: """ Create type/classification choices. @@ -133,21 +120,18 @@ def create_type_choices( for value, config in types.items(): choice = RichChoice( value=value, - label=config['label'], - description=config.get('description', ''), - metadata=config.get('metadata', {}), - deprecated=config.get('deprecated', False), - category=category + label=config["label"], + description=config.get("description", ""), + metadata=config.get("metadata", {}), + deprecated=config.get("deprecated", False), + category=category, ) choices.append(choice) return choices -def merge_choice_metadata( - base_metadata: dict[str, Any], - override_metadata: dict[str, Any] -) -> dict[str, Any]: +def merge_choice_metadata(base_metadata: dict[str, Any], override_metadata: dict[str, Any]) -> dict[str, Any]: """ Merge choice metadata dictionaries. @@ -163,10 +147,7 @@ def merge_choice_metadata( return merged -def filter_choices_by_category( - choices: list[RichChoice], - category: ChoiceCategory -) -> list[RichChoice]: +def filter_choices_by_category(choices: list[RichChoice], category: ChoiceCategory) -> list[RichChoice]: """ Filter choices by category. @@ -180,10 +161,7 @@ def filter_choices_by_category( return [choice for choice in choices if choice.category == category] -def sort_choices( - choices: list[RichChoice], - sort_by: str = "sort_order" -) -> list[RichChoice]: +def sort_choices(choices: list[RichChoice], sort_by: str = "sort_order") -> list[RichChoice]: """ Sort choices by specified criteria. @@ -204,10 +182,7 @@ def sort_choices( return choices -def get_choice_colors( - choice_group: str, - domain: str = "core" -) -> dict[str, str]: +def get_choice_colors(choice_group: str, domain: str = "core") -> dict[str, str]: """ Get a mapping of choice values to their colors. @@ -219,18 +194,10 @@ def get_choice_colors( Dictionary mapping choice values to colors """ choices = registry.get_choices(choice_group, domain) - return { - choice.value: choice.color - for choice in choices - if choice.color - } + return {choice.value: choice.color for choice in choices if choice.color} -def validate_choice_group_data( - name: str, - choices: list[RichChoice], - domain: str = "core" -) -> list[str]: +def validate_choice_group_data(name: str, choices: list[RichChoice], domain: str = "core") -> list[str]: """ Validate choice group data and return list of errors. @@ -267,7 +234,7 @@ def validate_choice_group_data( description=choice.description, metadata=choice.metadata, deprecated=choice.deprecated, - category=choice.category + category=choice.category, ) except ValueError as e: errors.append(f"Choice {i}: {str(e)}") @@ -286,19 +253,16 @@ def create_choice_from_config(config: dict[str, Any]) -> RichChoice: RichChoice object """ return RichChoice( - value=config['value'], - label=config['label'], - description=config.get('description', ''), - metadata=config.get('metadata', {}), - deprecated=config.get('deprecated', False), - category=ChoiceCategory(config.get('category', 'other')) + value=config["value"], + label=config["label"], + description=config.get("description", ""), + metadata=config.get("metadata", {}), + deprecated=config.get("deprecated", False), + category=ChoiceCategory(config.get("category", "other")), ) -def export_choices_to_dict( - choice_group: str, - domain: str = "core" -) -> dict[str, Any]: +def export_choices_to_dict(choice_group: str, domain: str = "core") -> dict[str, Any]: """ Export a choice group to a dictionary format. diff --git a/backend/apps/core/decorators/cache_decorators.py b/backend/apps/core/decorators/cache_decorators.py index 4b880878..89177761 100644 --- a/backend/apps/core/decorators/cache_decorators.py +++ b/backend/apps/core/decorators/cache_decorators.py @@ -48,11 +48,7 @@ def cache_api_response( cache_key_parts = [ key_prefix, view_func.__name__, - ( - str(getattr(request.user, "id", "anonymous")) - if request.user.is_authenticated - else "anonymous" - ), + (str(getattr(request.user, "id", "anonymous")) if request.user.is_authenticated else "anonymous"), str(hash(frozenset(request.GET.items()))), ] @@ -72,9 +68,7 @@ def cache_api_response( # Try to get from cache cache_service = EnhancedCacheService() - cached_response = getattr(cache_service, cache_backend + "_cache").get( - cache_key - ) + cached_response = getattr(cache_service, cache_backend + "_cache").get(cache_key) if cached_response: logger.debug( @@ -87,11 +81,8 @@ def cache_api_response( ) # If cached data is our dict format for DRF responses, reconstruct it - if isinstance(cached_response, dict) and '__drf_data__' in cached_response: - return DRFResponse( - data=cached_response['__drf_data__'], - status=cached_response.get('status', 200) - ) + if isinstance(cached_response, dict) and "__drf_data__" in cached_response: + return DRFResponse(data=cached_response["__drf_data__"], status=cached_response.get("status", 200)) return cached_response @@ -104,17 +95,12 @@ def cache_api_response( if hasattr(response, "status_code") and response.status_code == 200: # For DRF responses, we must cache the data, not the response object # because the response object is not rendered yet and cannot be pickled - if hasattr(response, 'data'): - cache_payload = { - '__drf_data__': response.data, - 'status': response.status_code - } + if hasattr(response, "data"): + cache_payload = {"__drf_data__": response.data, "status": response.status_code} else: cache_payload = response - getattr(cache_service, cache_backend + "_cache").set( - cache_key, cache_payload, timeout - ) + getattr(cache_service, cache_backend + "_cache").set(cache_key, cache_payload, timeout) logger.debug( f"Cached API response for view {view_func.__name__}", extra={ @@ -162,9 +148,7 @@ def cache_queryset_result( cache_key = f"{cache_key_template}:{hash(str(args) + str(kwargs))}" cache_service = EnhancedCacheService() - cached_result = getattr(cache_service, cache_backend + "_cache").get( - cache_key - ) + cached_result = getattr(cache_service, cache_backend + "_cache").get(cache_key) if cached_result is not None: logger.debug(f"Cache hit for queryset operation: {func.__name__}") @@ -175,9 +159,7 @@ def cache_queryset_result( result = func(*args, **kwargs) execution_time = time.time() - start_time - getattr(cache_service, cache_backend + "_cache").set( - cache_key, result, timeout - ) + getattr(cache_service, cache_backend + "_cache").set(cache_key, result, timeout) logger.debug( f"Cached queryset result for {func.__name__}", extra={ @@ -250,24 +232,18 @@ class CachedAPIViewMixin(View): cache_backend = "api" @method_decorator(vary_on_headers("User-Agent", "Accept-Language")) - def dispatch( - self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> HttpResponseBase: + def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase: """Add caching to the dispatch method""" if request.method == "GET" and getattr(self, "enable_caching", True): return self._cached_dispatch(request, *args, **kwargs) return super().dispatch(request, *args, **kwargs) - def _cached_dispatch( - self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> HttpResponseBase: + def _cached_dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase: """Handle cached dispatch for GET requests""" cache_key = self._generate_cache_key(request, *args, **kwargs) cache_service = EnhancedCacheService() - cached_response = getattr(cache_service, self.cache_backend + "_cache").get( - cache_key - ) + cached_response = getattr(cache_service, self.cache_backend + "_cache").get(cache_key) if cached_response: logger.debug(f"Cache hit for view {self.__class__.__name__}") @@ -278,26 +254,18 @@ class CachedAPIViewMixin(View): # Cache successful responses if hasattr(response, "status_code") and response.status_code == 200: - getattr(cache_service, self.cache_backend + "_cache").set( - cache_key, response, self.cache_timeout - ) + getattr(cache_service, self.cache_backend + "_cache").set(cache_key, response, self.cache_timeout) logger.debug(f"Cached response for view {self.__class__.__name__}") return response - def _generate_cache_key( - self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> str: + def _generate_cache_key(self, request: HttpRequest, *args: Any, **kwargs: Any) -> str: """Generate cache key for the request""" key_parts = [ self.cache_key_prefix, self.__class__.__name__, request.method, - ( - str(getattr(request.user, "id", "anonymous")) - if request.user.is_authenticated - else "anonymous" - ), + (str(getattr(request.user, "id", "anonymous")) if request.user.is_authenticated else "anonymous"), str(hash(frozenset(request.GET.items()))), ] @@ -344,15 +312,11 @@ def smart_cache( "kwargs": json.dumps(kwargs, sort_keys=True, default=str), } key_string = json.dumps(key_data, sort_keys=True) - cache_key = ( - f"smart_cache:{hashlib.md5(key_string.encode()).hexdigest()}" - ) + cache_key = f"smart_cache:{hashlib.md5(key_string.encode()).hexdigest()}" # Try to get from cache cache_service = EnhancedCacheService() - cached_result = getattr(cache_service, cache_backend + "_cache").get( - cache_key - ) + cached_result = getattr(cache_service, cache_backend + "_cache").get(cache_key) if cached_result is not None: logger.debug(f"Smart cache hit for {func.__name__}") @@ -364,9 +328,7 @@ def smart_cache( execution_time = time.time() - start_time # Cache result - getattr(cache_service, cache_backend + "_cache").set( - cache_key, result, timeout - ) + getattr(cache_service, cache_backend + "_cache").set(cache_key, result, timeout) logger.debug( f"Smart cached result for {func.__name__}", @@ -426,16 +388,10 @@ def generate_model_cache_key(model_instance: Any, suffix: str = "") -> str: """Generate cache key based on model instance""" model_name = model_instance._meta.model_name instance_id = model_instance.id - return ( - f"{model_name}:{instance_id}:{suffix}" - if suffix - else f"{model_name}:{instance_id}" - ) + return f"{model_name}:{instance_id}:{suffix}" if suffix else f"{model_name}:{instance_id}" -def generate_queryset_cache_key( - queryset: Any, params: dict[str, Any] | None = None -) -> str: +def generate_queryset_cache_key(queryset: Any, params: dict[str, Any] | None = None) -> str: """Generate cache key for queryset with parameters""" model_name = queryset.model._meta.model_name params_str = json.dumps(params or {}, sort_keys=True, default=str) diff --git a/backend/apps/core/forms.py b/backend/apps/core/forms.py index 0e50a156..c332fb8c 100644 --- a/backend/apps/core/forms.py +++ b/backend/apps/core/forms.py @@ -22,9 +22,7 @@ class BaseAutocomplete(Autocomplete): # UI text configuration using gettext for i18n no_result_text = _("No matches found") - narrow_search_text = _( - "Showing %(page_size)s of %(total)s matches. Please refine your search." - ) + narrow_search_text = _("Showing %(page_size)s of %(total)s matches. Please refine your search.") type_at_least_n_characters = _("Type at least %(n)s characters...") # Project-wide component settings diff --git a/backend/apps/core/forms/htmx_forms.py b/backend/apps/core/forms/htmx_forms.py index eeabc122..8201edea 100644 --- a/backend/apps/core/forms/htmx_forms.py +++ b/backend/apps/core/forms/htmx_forms.py @@ -1,6 +1,7 @@ """ Base forms and views for HTMX integration. """ + from django.http import JsonResponse from django.views.generic.edit import FormView @@ -20,9 +21,6 @@ class HTMXFormView(FormView): def post(self, request, *args, **kwargs): # If HTMX field validation pattern: ?field=name - if ( - request.headers.get("HX-Request") == "true" - and request.GET.get("validate_field") - ): + if request.headers.get("HX-Request") == "true" and request.GET.get("validate_field"): return self.validate_field(request.GET.get("validate_field")) return super().post(request, *args, **kwargs) diff --git a/backend/apps/core/forms/search.py b/backend/apps/core/forms/search.py index ff275b3a..14da807c 100644 --- a/backend/apps/core/forms/search.py +++ b/backend/apps/core/forms/search.py @@ -42,12 +42,8 @@ class LocationSearchForm(forms.Form): ) # Hidden fields for coordinates - lat = forms.FloatField( - required=False, widget=forms.HiddenInput(attrs={"id": "lat-input"}) - ) - lng = forms.FloatField( - required=False, widget=forms.HiddenInput(attrs={"id": "lng-input"}) - ) + lat = forms.FloatField(required=False, widget=forms.HiddenInput(attrs={"id": "lat-input"})) + lng = forms.FloatField(required=False, widget=forms.HiddenInput(attrs={"id": "lng-input"})) # Search radius radius_km = forms.ChoiceField( @@ -81,8 +77,7 @@ class LocationSearchForm(forms.Form): widget=forms.CheckboxInput( attrs={ "class": ( - "rounded border-gray-300 text-blue-600 focus:ring-blue-500 " - "dark:border-gray-600 dark:bg-gray-700" + "rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700" ) } ), @@ -93,8 +88,7 @@ class LocationSearchForm(forms.Form): widget=forms.CheckboxInput( attrs={ "class": ( - "rounded border-gray-300 text-blue-600 focus:ring-blue-500 " - "dark:border-gray-600 dark:bg-gray-700" + "rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700" ) } ), @@ -105,8 +99,7 @@ class LocationSearchForm(forms.Form): widget=forms.CheckboxInput( attrs={ "class": ( - "rounded border-gray-300 text-blue-600 focus:ring-blue-500 " - "dark:border-gray-600 dark:bg-gray-700" + "rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700" ) } ), diff --git a/backend/apps/core/health_checks/custom_checks.py b/backend/apps/core/health_checks/custom_checks.py index 0159e59d..18f752b8 100644 --- a/backend/apps/core/health_checks/custom_checks.py +++ b/backend/apps/core/health_checks/custom_checks.py @@ -58,13 +58,9 @@ class CacheHealthCheck(BaseHealthCheckBackend): if max_memory > 0: memory_usage_percent = (used_memory / max_memory) * 100 if memory_usage_percent > 90: - self.add_error( - f"Redis memory usage critical: {memory_usage_percent:.1f}%" - ) + self.add_error(f"Redis memory usage critical: {memory_usage_percent:.1f}%") elif memory_usage_percent > 80: - logger.warning( - f"Redis memory usage high: {memory_usage_percent:.1f}%" - ) + logger.warning(f"Redis memory usage high: {memory_usage_percent:.1f}%") except ImportError: # django-redis not available, skip additional checks @@ -160,9 +156,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend): try: __import__(module_name) except ImportError as e: - self.add_error( - f"Critical module import failed: {module_name} - {e}" - ) + self.add_error(f"Critical module import failed: {module_name} - {e}") # Check if we can access critical models try: @@ -179,9 +173,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend): ride_count = Ride.objects.count() user_count = User.objects.count() - logger.debug( - f"Model counts - Parks: {park_count}, Rides: {ride_count}, Users: {user_count}" - ) + logger.debug(f"Model counts - Parks: {park_count}, Rides: {ride_count}, Users: {user_count}") except Exception as e: self.add_error(f"Model access check failed: {e}") @@ -195,9 +187,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend): self.add_error(f"Media directory does not exist: {settings.MEDIA_ROOT}") if not os.path.exists(settings.STATIC_ROOT) and not settings.DEBUG: - self.add_error( - f"Static directory does not exist: {settings.STATIC_ROOT}" - ) + self.add_error(f"Static directory does not exist: {settings.STATIC_ROOT}") except Exception as e: self.add_error(f"Application health check failed: {e}") @@ -214,10 +204,7 @@ class ExternalServiceHealthCheck(BaseHealthCheckBackend): from django.conf import settings from django.core.mail import get_connection - if ( - hasattr(settings, "EMAIL_BACKEND") - and "console" not in settings.EMAIL_BACKEND - ): + if hasattr(settings, "EMAIL_BACKEND") and "console" not in settings.EMAIL_BACKEND: # Only check if not using console backend connection = get_connection() if hasattr(connection, "open"): @@ -304,9 +291,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend): media_free_percent:.1f}% free in media directory" ) elif media_free_percent < 20: - logger.warning( - f"Low disk space: {media_free_percent:.1f}% free in media directory" - ) + logger.warning(f"Low disk space: {media_free_percent:.1f}% free in media directory") if logs_free_percent < 10: self.add_error( @@ -314,9 +299,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend): logs_free_percent:.1f}% free in logs directory" ) elif logs_free_percent < 20: - logger.warning( - f"Low disk space: {logs_free_percent:.1f}% free in logs directory" - ) + logger.warning(f"Low disk space: {logs_free_percent:.1f}% free in logs directory") except Exception as e: logger.warning(f"Disk space check failed: {e}") diff --git a/backend/apps/core/history.py b/backend/apps/core/history.py index 53a1ffcb..4a7831bb 100644 --- a/backend/apps/core/history.py +++ b/backend/apps/core/history.py @@ -94,9 +94,7 @@ class TrackedModel(models.Model): event_model = getattr(events, "model", None) if event_model: - return event_model.objects.filter(pgh_obj_id=self.pk).order_by( - "-pgh_created_at" - ) + return event_model.objects.filter(pgh_obj_id=self.pk).order_by("-pgh_created_at") except (AttributeError, TypeError): pass return self.__class__.objects.none() diff --git a/backend/apps/core/logging.py b/backend/apps/core/logging.py index 511883ca..6f9eac2e 100644 --- a/backend/apps/core/logging.py +++ b/backend/apps/core/logging.py @@ -23,9 +23,7 @@ class ThrillWikiFormatter(logging.Formatter): if hasattr(record, "request"): record.request_id = getattr(record.request, "id", "unknown") record.user_id = ( - getattr(record.request.user, "id", "anonymous") - if hasattr(record.request, "user") - else "unknown" + getattr(record.request.user, "id", "anonymous") if hasattr(record.request, "user") else "unknown" ) record.path = getattr(record.request, "path", "unknown") record.method = getattr(record.request, "method", "unknown") @@ -52,9 +50,7 @@ def get_logger(name: str) -> logging.Logger: # Only configure if not already configured if not logger.handlers: handler = logging.StreamHandler(sys.stdout) - formatter = ThrillWikiFormatter( - fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = ThrillWikiFormatter(fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.INFO if settings.DEBUG else logging.WARNING) @@ -91,11 +87,7 @@ def log_exception( { "request_path": getattr(request, "path", "unknown"), "request_method": getattr(request, "method", "unknown"), - "user_id": ( - getattr(request.user, "id", "anonymous") - if hasattr(request, "user") - else "unknown" - ), + "user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"), } ) @@ -134,11 +126,7 @@ def log_business_event( { "request_path": getattr(request, "path", "unknown"), "request_method": getattr(request, "method", "unknown"), - "user_id": ( - getattr(request.user, "id", "anonymous") - if hasattr(request, "user") - else "unknown" - ), + "user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"), } ) @@ -196,11 +184,7 @@ def log_api_request( "request_type": "api", "path": getattr(request, "path", "unknown"), "method": getattr(request, "method", "unknown"), - "user_id": ( - getattr(request.user, "id", "anonymous") - if hasattr(request, "user") - else "unknown" - ), + "user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"), "response_status": response_status, "duration_ms": duration_ms, } @@ -246,11 +230,7 @@ def log_security_event( { "request_path": getattr(request, "path", "unknown"), "request_method": getattr(request, "method", "unknown"), - "user_id": ( - getattr(request.user, "id", "anonymous") - if hasattr(request, "user") - else "unknown" - ), + "user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"), "remote_addr": request.META.get("REMOTE_ADDR", "unknown"), "user_agent": request.META.get("HTTP_USER_AGENT", "unknown"), } diff --git a/backend/apps/core/management/commands/calculate_new_content.py b/backend/apps/core/management/commands/calculate_new_content.py index 2e86b0bd..5dce040c 100644 --- a/backend/apps/core/management/commands/calculate_new_content.py +++ b/backend/apps/core/management/commands/calculate_new_content.py @@ -43,9 +43,7 @@ class Command(BaseCommand): default=50, help="Maximum number of results to calculate (default: 50)", ) - parser.add_argument( - "--verbose", action="store_true", help="Enable verbose output" - ) + parser.add_argument("--verbose", action="store_true", help="Enable verbose output") def handle(self, *args, **options): content_type = options["content_type"] @@ -61,17 +59,13 @@ class Command(BaseCommand): new_items = [] if content_type in ["all", "parks"]: - parks = self._get_new_parks( - cutoff_date, limit if content_type == "parks" else limit * 2 - ) + parks = self._get_new_parks(cutoff_date, limit if content_type == "parks" else limit * 2) new_items.extend(parks) if verbose: self.stdout.write(f"Found {len(parks)} new parks") if content_type in ["all", "rides"]: - rides = self._get_new_rides( - cutoff_date, limit if content_type == "rides" else limit * 2 - ) + rides = self._get_new_rides(cutoff_date, limit if content_type == "rides" else limit * 2) new_items.extend(rides) if verbose: self.stdout.write(f"Found {len(rides)} new rides") @@ -88,27 +82,22 @@ class Command(BaseCommand): cache.set(cache_key, formatted_results, 1800) # Cache for 30 minutes self.stdout.write( - self.style.SUCCESS( - f"Successfully calculated {len(formatted_results)} new items for {content_type}" - ) + self.style.SUCCESS(f"Successfully calculated {len(formatted_results)} new items for {content_type}") ) if verbose: for item in formatted_results[:5]: # Show first 5 items - self.stdout.write( - f" {item['name']} ({item['park']}) - opened: {item['date_opened']}" - ) + self.stdout.write(f" {item['name']} ({item['park']}) - opened: {item['date_opened']}") except Exception as e: logger.error(f"Error calculating new content: {e}", exc_info=True) - raise CommandError(f"Failed to calculate new content: {e}") + raise CommandError(f"Failed to calculate new content: {e}") from None def _get_new_parks(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]: """Get recently added parks using real data.""" new_parks = ( Park.objects.filter( - Q(created_at__gte=cutoff_date) - | Q(opening_date__gte=cutoff_date.date()), + Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()), status="OPERATING", ) .select_related("location", "operator") @@ -146,8 +135,7 @@ class Command(BaseCommand): """Get recently added rides using real data.""" new_rides = ( Ride.objects.filter( - Q(created_at__gte=cutoff_date) - | Q(opening_date__gte=cutoff_date.date()), + Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()), status="OPERATING", ) .select_related("park", "park__location") @@ -156,9 +144,7 @@ class Command(BaseCommand): results = [] for ride in new_rides: - date_added = getattr(ride, "opening_date", None) or getattr( - ride, "created_at", None - ) + date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None) if date_added and isinstance(date_added, datetime): date_added = date_added.date() @@ -184,9 +170,7 @@ class Command(BaseCommand): return results - def _format_new_content_results( - self, new_items: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + def _format_new_content_results(self, new_items: list[dict[str, Any]]) -> list[dict[str, Any]]: """Format new content results for frontend consumption.""" formatted_results = [] diff --git a/backend/apps/core/management/commands/calculate_trending.py b/backend/apps/core/management/commands/calculate_trending.py index 16f1c8df..952f689c 100644 --- a/backend/apps/core/management/commands/calculate_trending.py +++ b/backend/apps/core/management/commands/calculate_trending.py @@ -37,9 +37,7 @@ class Command(BaseCommand): default=50, help="Maximum number of results to calculate (default: 50)", ) - parser.add_argument( - "--verbose", action="store_true", help="Enable verbose output" - ) + parser.add_argument("--verbose", action="store_true", help="Enable verbose output") def handle(self, *args, **options): content_type = options["content_type"] @@ -98,29 +96,23 @@ class Command(BaseCommand): if verbose: for item in formatted_results[:5]: # Show first 5 items - self.stdout.write( - f" {item['name']} (score: {item.get('views_change', 'N/A')})" - ) + self.stdout.write(f" {item['name']} (score: {item.get('views_change', 'N/A')})") except Exception as e: logger.error(f"Error calculating trending content: {e}", exc_info=True) - raise CommandError(f"Failed to calculate trending content: {e}") + raise CommandError(f"Failed to calculate trending content: {e}") from None def _calculate_trending_parks( self, current_period_hours: int, previous_period_hours: int, limit: int ) -> list[dict[str, Any]]: """Calculate trending scores for parks using real data.""" - parks = Park.objects.filter(status="OPERATING").select_related( - "location", "operator" - ) + parks = Park.objects.filter(status="OPERATING").select_related("location", "operator") trending_parks = [] for park in parks: try: - score = self._calculate_content_score( - park, "park", current_period_hours, previous_period_hours - ) + score = self._calculate_content_score(park, "park", current_period_hours, previous_period_hours) if score > 0: # Only include items with positive trending scores trending_parks.append( { @@ -132,16 +124,8 @@ class Command(BaseCommand): "slug": park.slug, "park": park.name, # For parks, park field is the park name itself "category": "park", - "rating": ( - float(park.average_rating) - if park.average_rating - else 0.0 - ), - "date_opened": ( - park.opening_date.isoformat() - if park.opening_date - else "" - ), + "rating": (float(park.average_rating) if park.average_rating else 0.0), + "date_opened": (park.opening_date.isoformat() if park.opening_date else ""), "url": park.url, } ) @@ -154,17 +138,13 @@ class Command(BaseCommand): self, current_period_hours: int, previous_period_hours: int, limit: int ) -> list[dict[str, Any]]: """Calculate trending scores for rides using real data.""" - rides = Ride.objects.filter(status="OPERATING").select_related( - "park", "park__location" - ) + rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location") trending_rides = [] for ride in rides: try: - score = self._calculate_content_score( - ride, "ride", current_period_hours, previous_period_hours - ) + score = self._calculate_content_score(ride, "ride", current_period_hours, previous_period_hours) if score > 0: # Only include items with positive trending scores trending_rides.append( { @@ -176,16 +156,8 @@ class Command(BaseCommand): "slug": ride.slug, "park": ride.park.name if ride.park else "", "category": "ride", - "rating": ( - float(ride.average_rating) - if ride.average_rating - else 0.0 - ), - "date_opened": ( - ride.opening_date.isoformat() - if ride.opening_date - else "" - ), + "rating": (float(ride.average_rating) if ride.average_rating else 0.0), + "date_opened": (ride.opening_date.isoformat() if ride.opening_date else ""), "url": ride.url, "park_url": ride.park.url if ride.park else "", } @@ -219,24 +191,15 @@ class Command(BaseCommand): recency_score = self._calculate_recency_score(content_obj) # 4. Popularity Score (10% weight) - popularity_score = self._calculate_popularity_score( - ct, content_obj.id, current_period_hours - ) + popularity_score = self._calculate_popularity_score(ct, content_obj.id, current_period_hours) # Calculate weighted final score - final_score = ( - view_growth_score * 0.4 - + rating_score * 0.3 - + recency_score * 0.2 - + popularity_score * 0.1 - ) + final_score = view_growth_score * 0.4 + rating_score * 0.3 + recency_score * 0.2 + popularity_score * 0.1 return final_score except Exception as e: - logger.error( - f"Error calculating score for {content_type} {content_obj.id}: {e}" - ) + logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}") return 0.0 def _calculate_view_growth_score( @@ -248,13 +211,11 @@ class Command(BaseCommand): ) -> float: """Calculate normalized view growth score using real PageView data.""" try: - current_views, previous_views, growth_percentage = ( - PageView.get_views_growth( - content_type, - object_id, - current_period_hours, - previous_period_hours, - ) + current_views, previous_views, growth_percentage = PageView.get_views_growth( + content_type, + object_id, + current_period_hours, + previous_period_hours, ) if previous_views == 0: @@ -262,9 +223,7 @@ class Command(BaseCommand): return min(current_views / 100.0, 1.0) if current_views > 0 else 0.0 # Normalize growth percentage to 0-1 scale - normalized_growth = ( - min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0 - ) + normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0 return max(normalized_growth, 0.0) except Exception as e: @@ -317,14 +276,10 @@ class Command(BaseCommand): logger.warning(f"Error calculating recency score: {e}") return 0.5 - def _calculate_popularity_score( - self, content_type: ContentType, object_id: int, hours: int - ) -> float: + def _calculate_popularity_score(self, content_type: ContentType, object_id: int, hours: int) -> float: """Calculate popularity score based on total view count.""" try: - total_views = PageView.get_total_views_count( - content_type, object_id, hours=hours - ) + total_views = PageView.get_total_views_count(content_type, object_id, hours=hours) # Normalize views to 0-1 scale if total_views == 0: @@ -352,13 +307,11 @@ class Command(BaseCommand): # Get view change for display content_obj = item["content_object"] ct = ContentType.objects.get_for_model(content_obj) - current_views, previous_views, growth_percentage = ( - PageView.get_views_growth( - ct, - content_obj.id, - current_period_hours, - previous_period_hours, - ) + current_views, previous_views, growth_percentage = PageView.get_views_growth( + ct, + content_obj.id, + current_period_hours, + previous_period_hours, ) # Format exactly as frontend expects @@ -371,9 +324,7 @@ class Command(BaseCommand): "rank": rank, "views": current_views, "views_change": ( - f"+{growth_percentage:.1f}%" - if growth_percentage > 0 - else f"{growth_percentage:.1f}%" + f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%" ), "slug": item["slug"], "date_opened": item["date_opened"], diff --git a/backend/apps/core/management/commands/clear_cache.py b/backend/apps/core/management/commands/clear_cache.py index 8d1fba7e..cfe5c234 100644 --- a/backend/apps/core/management/commands/clear_cache.py +++ b/backend/apps/core/management/commands/clear_cache.py @@ -21,10 +21,7 @@ from django.core.management.base import BaseCommand class Command(BaseCommand): - help = ( - "Clear all types of cache data including Django cache, " - "__pycache__, and build caches" - ) + help = "Clear all types of cache data including Django cache, " "__pycache__, and build caches" def add_arguments(self, parser): parser.add_argument( @@ -92,9 +89,7 @@ class Command(BaseCommand): ) if self.dry_run: - self.stdout.write( - self.style.WARNING("🔍 DRY RUN MODE - No files will be deleted") - ) + self.stdout.write(self.style.WARNING("🔍 DRY RUN MODE - No files will be deleted")) self.stdout.write("") self.stdout.write(self.style.SUCCESS("🧹 ThrillWiki Cache Clearing Utility")) @@ -129,9 +124,7 @@ class Command(BaseCommand): self.clear_opcache() self.stdout.write("") - self.stdout.write( - self.style.SUCCESS("✅ Cache clearing completed successfully!") - ) + self.stdout.write(self.style.SUCCESS("✅ Cache clearing completed successfully!")) def clear_django_cache(self): """Clear Django cache framework cache.""" @@ -154,23 +147,13 @@ class Command(BaseCommand): if not self.dry_run: cache_backend.clear() - cache_info = ( - f"{alias} cache ({cache_backend.__class__.__name__})" - ) - self.stdout.write( - self.style.SUCCESS(f" ✅ Cleared {cache_info}") - ) + cache_info = f"{alias} cache ({cache_backend.__class__.__name__})" + self.stdout.write(self.style.SUCCESS(f" ✅ Cleared {cache_info}")) except Exception as e: - self.stdout.write( - self.style.WARNING( - f" ⚠️ Could not clear {alias} cache: {e}" - ) - ) + self.stdout.write(self.style.WARNING(f" ⚠️ Could not clear {alias} cache: {e}")) except Exception as e: - self.stdout.write( - self.style.ERROR(f" ❌ Error clearing Django cache: {e}") - ) + self.stdout.write(self.style.ERROR(f" ❌ Error clearing Django cache: {e}")) def clear_pycache(self): """Clear Python __pycache__ directories and .pyc files.""" @@ -188,11 +171,7 @@ class Command(BaseCommand): if pycache_dir.is_dir(): try: # Calculate size before removal - dir_size = sum( - f.stat().st_size - for f in pycache_dir.rglob("*") - if f.is_file() - ) + dir_size = sum(f.stat().st_size for f in pycache_dir.rglob("*") if f.is_file()) removed_size += dir_size if self.verbose: @@ -203,11 +182,7 @@ class Command(BaseCommand): removed_count += 1 except Exception as e: - self.stdout.write( - self.style.WARNING( - f" ⚠️ Could not remove {pycache_dir}: {e}" - ) - ) + self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {pycache_dir}: {e}")) # Find and remove .pyc files for pyc_file in project_root.rglob("*.pyc"): @@ -223,22 +198,14 @@ class Command(BaseCommand): removed_count += 1 except Exception as e: - self.stdout.write( - self.style.WARNING(f" ⚠️ Could not remove {pyc_file}: {e}") - ) + self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {pyc_file}: {e}")) # Format file size size_mb = removed_size / (1024 * 1024) - self.stdout.write( - self.style.SUCCESS( - f" ✅ Removed {removed_count} Python cache items ({size_mb:.2f} MB)" - ) - ) + self.stdout.write(self.style.SUCCESS(f" ✅ Removed {removed_count} Python cache items ({size_mb:.2f} MB)")) except Exception as e: - self.stdout.write( - self.style.ERROR(f" ❌ Error clearing Python cache: {e}") - ) + self.stdout.write(self.style.ERROR(f" ❌ Error clearing Python cache: {e}")) def clear_static_cache(self): """Clear static files cache.""" @@ -251,9 +218,7 @@ class Command(BaseCommand): static_path = Path(static_root) # Calculate size - total_size = sum( - f.stat().st_size for f in static_path.rglob("*") if f.is_file() - ) + total_size = sum(f.stat().st_size for f in static_path.rglob("*") if f.is_file()) size_mb = total_size / (1024 * 1024) if self.verbose: @@ -263,22 +228,12 @@ class Command(BaseCommand): shutil.rmtree(static_path) static_path.mkdir(parents=True, exist_ok=True) - self.stdout.write( - self.style.SUCCESS( - f" ✅ Cleared static files cache ({size_mb:.2f} MB)" - ) - ) + self.stdout.write(self.style.SUCCESS(f" ✅ Cleared static files cache ({size_mb:.2f} MB)")) else: - self.stdout.write( - self.style.WARNING( - " ⚠️ No STATIC_ROOT configured or directory doesn't exist" - ) - ) + self.stdout.write(self.style.WARNING(" ⚠️ No STATIC_ROOT configured or directory doesn't exist")) except Exception as e: - self.stdout.write( - self.style.ERROR(f" ❌ Error clearing static cache: {e}") - ) + self.stdout.write(self.style.ERROR(f" ❌ Error clearing static cache: {e}")) def clear_sessions_cache(self): """Clear session cache if using cache-based sessions.""" @@ -289,9 +244,7 @@ class Command(BaseCommand): if "cache" in session_engine: # Using cache-based sessions - session_cache_alias = getattr( - settings, "SESSION_CACHE_ALIAS", "default" - ) + session_cache_alias = getattr(settings, "SESSION_CACHE_ALIAS", "default") session_cache = caches[session_cache_alias] if not self.dry_run: @@ -299,20 +252,12 @@ class Command(BaseCommand): # In production, you might want more sophisticated session clearing session_cache.clear() - self.stdout.write( - self.style.SUCCESS( - f" ✅ Cleared cache-based sessions ({session_cache_alias})" - ) - ) + self.stdout.write(self.style.SUCCESS(f" ✅ Cleared cache-based sessions ({session_cache_alias})")) else: - self.stdout.write( - self.style.WARNING(" ⚠️ Not using cache-based sessions") - ) + self.stdout.write(self.style.WARNING(" ⚠️ Not using cache-based sessions")) except Exception as e: - self.stdout.write( - self.style.ERROR(f" ❌ Error clearing session cache: {e}") - ) + self.stdout.write(self.style.ERROR(f" ❌ Error clearing session cache: {e}")) def clear_template_cache(self): """Clear template cache.""" @@ -332,18 +277,14 @@ class Command(BaseCommand): # Get engine instance safely engine_instance = getattr(engine, "engine", None) if engine_instance: - template_loaders = getattr( - engine_instance, "template_loaders", [] - ) + template_loaders = getattr(engine_instance, "template_loaders", []) for loader in template_loaders: if isinstance(loader, CachedLoader): if not self.dry_run: loader.reset() cleared_engines += 1 if self.verbose: - self.stdout.write( - f" 🗑️ Cleared cached loader: {loader}" - ) + self.stdout.write(f" 🗑️ Cleared cached loader: {loader}") # Check for Jinja2 engines (if present) elif "Jinja2" in engine_backend and hasattr(engine, "env"): @@ -353,34 +294,21 @@ class Command(BaseCommand): env.cache.clear() cleared_engines += 1 if self.verbose: - self.stdout.write( - f" 🗑️ Cleared Jinja2 cache: {engine}" - ) + self.stdout.write(f" 🗑️ Cleared Jinja2 cache: {engine}") except Exception as e: if self.verbose: - self.stdout.write( - self.style.WARNING( - f" ⚠️ Could not clear cache for engine {engine}: {e}" - ) - ) + self.stdout.write(self.style.WARNING(f" ⚠️ Could not clear cache for engine {engine}: {e}")) if cleared_engines > 0: self.stdout.write( - self.style.SUCCESS( - f" ✅ Cleared template cache for " - f"{cleared_engines} loaders/engines" - ) + self.style.SUCCESS(f" ✅ Cleared template cache for " f"{cleared_engines} loaders/engines") ) else: - self.stdout.write( - self.style.WARNING(" ⚠️ No cached template loaders found") - ) + self.stdout.write(self.style.WARNING(" ⚠️ No cached template loaders found")) except Exception as e: - self.stdout.write( - self.style.ERROR(f" ❌ Error clearing template cache: {e}") - ) + self.stdout.write(self.style.ERROR(f" ❌ Error clearing template cache: {e}")) def clear_tailwind_cache(self): """Clear Tailwind CSS build cache.""" @@ -410,27 +338,15 @@ class Command(BaseCommand): cleared_count += 1 except Exception as e: - self.stdout.write( - self.style.WARNING( - f" ⚠️ Could not remove {cache_path}: {e}" - ) - ) + self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {cache_path}: {e}")) if cleared_count > 0: - self.stdout.write( - self.style.SUCCESS( - f" ✅ Cleared {cleared_count} Tailwind cache directories" - ) - ) + self.stdout.write(self.style.SUCCESS(f" ✅ Cleared {cleared_count} Tailwind cache directories")) else: - self.stdout.write( - self.style.WARNING(" ⚠️ No Tailwind cache directories found") - ) + self.stdout.write(self.style.WARNING(" ⚠️ No Tailwind cache directories found")) except Exception as e: - self.stdout.write( - self.style.ERROR(f" ❌ Error clearing Tailwind cache: {e}") - ) + self.stdout.write(self.style.ERROR(f" ❌ Error clearing Tailwind cache: {e}")) def clear_opcache(self): """Clear PHP OPcache if available.""" @@ -452,21 +368,13 @@ class Command(BaseCommand): if result.returncode == 0: if "cleared" in result.stdout: - self.stdout.write( - self.style.SUCCESS(" ✅ OPcache cleared successfully") - ) + self.stdout.write(self.style.SUCCESS(" ✅ OPcache cleared successfully")) else: self.stdout.write(self.style.WARNING(" ⚠️ OPcache not available")) else: - self.stdout.write( - self.style.WARNING( - " ⚠️ PHP not available or OPcache not accessible" - ) - ) + self.stdout.write(self.style.WARNING(" ⚠️ PHP not available or OPcache not accessible")) except (subprocess.TimeoutExpired, FileNotFoundError): - self.stdout.write( - self.style.WARNING(" ⚠️ PHP not found or not accessible") - ) + self.stdout.write(self.style.WARNING(" ⚠️ PHP not found or not accessible")) except Exception as e: self.stdout.write(self.style.ERROR(f" ❌ Error clearing OPcache: {e}")) diff --git a/backend/apps/core/management/commands/list_transition_callbacks.py b/backend/apps/core/management/commands/list_transition_callbacks.py index 4f93d7e5..968c4569 100644 --- a/backend/apps/core/management/commands/list_transition_callbacks.py +++ b/backend/apps/core/management/commands/list_transition_callbacks.py @@ -15,69 +15,69 @@ from apps.core.state_machine.config import callback_config class Command(BaseCommand): - help = 'List all registered FSM transition callbacks' + help = "List all registered FSM transition callbacks" def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( - '--model', + "--model", type=str, - help='Filter by model name (e.g., EditSubmission, Ride)', + help="Filter by model name (e.g., EditSubmission, Ride)", ) parser.add_argument( - '--stage', + "--stage", type=str, - choices=['pre', 'post', 'error', 'all'], - default='all', - help='Filter by callback stage', + choices=["pre", "post", "error", "all"], + default="all", + help="Filter by callback stage", ) parser.add_argument( - '--verbose', - '-v', - action='store_true', - help='Show detailed callback information', + "--verbose", + "-v", + action="store_true", + help="Show detailed callback information", ) parser.add_argument( - '--format', + "--format", type=str, - choices=['text', 'table', 'json'], - default='text', - help='Output format', + choices=["text", "table", "json"], + default="text", + help="Output format", ) def handle(self, *args, **options): - model_filter = options.get('model') - stage_filter = options.get('stage') - verbose = options.get('verbose', False) - output_format = options.get('format', 'text') + model_filter = options.get("model") + stage_filter = options.get("stage") + verbose = options.get("verbose", False) + output_format = options.get("format", "text") # Get all registrations all_registrations = callback_registry.get_all_registrations() - if output_format == 'json': + if output_format == "json": self._output_json(all_registrations, model_filter, stage_filter) - elif output_format == 'table': + elif output_format == "table": self._output_table(all_registrations, model_filter, stage_filter, verbose) else: self._output_text(all_registrations, model_filter, stage_filter, verbose) def _output_text(self, registrations, model_filter, stage_filter, verbose): """Output in text format.""" - self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n')) + self.stdout.write(self.style.SUCCESS("\n=== FSM Transition Callbacks ===\n")) # Group by model models_seen = set() total_callbacks = 0 for stage in CallbackStage: - if stage_filter != 'all' and stage.value != stage_filter: + if stage_filter != "all" and stage.value != stage_filter: continue stage_regs = registrations.get(stage, []) if not stage_regs: continue - self.stdout.write(self.style.WARNING(f'\n{stage.value.upper()} Callbacks:')) - self.stdout.write('-' * 50) + self.stdout.write(self.style.WARNING(f"\n{stage.value.upper()} Callbacks:")) + self.stdout.write("-" * 50) # Group by model by_model = {} @@ -92,42 +92,34 @@ class Command(BaseCommand): total_callbacks += 1 for model_name, regs in sorted(by_model.items()): - self.stdout.write(f'\n {model_name}:') + self.stdout.write(f"\n {model_name}:") for reg in regs: - transition = f'{reg.source} → {reg.target}' + transition = f"{reg.source} → {reg.target}" callback_name = reg.callback.name priority = reg.callback.priority - self.stdout.write( - f' [{transition}] {callback_name} (priority: {priority})' - ) + self.stdout.write(f" [{transition}] {callback_name} (priority: {priority})") if verbose: - self.stdout.write( - f' continue_on_error: {reg.callback.continue_on_error}' - ) - if hasattr(reg.callback, 'patterns'): - self.stdout.write( - f' patterns: {reg.callback.patterns}' - ) + self.stdout.write(f" continue_on_error: {reg.callback.continue_on_error}") + if hasattr(reg.callback, "patterns"): + self.stdout.write(f" patterns: {reg.callback.patterns}") # Summary - self.stdout.write('\n' + '=' * 50) - self.stdout.write(self.style.SUCCESS( - f'Total: {total_callbacks} callbacks across {len(models_seen)} models' - )) + self.stdout.write("\n" + "=" * 50) + self.stdout.write(self.style.SUCCESS(f"Total: {total_callbacks} callbacks across {len(models_seen)} models")) # Configuration status - self.stdout.write(self.style.WARNING('\nConfiguration Status:')) - self.stdout.write(f' Callbacks enabled: {callback_config.enabled}') - self.stdout.write(f' Notifications enabled: {callback_config.notifications_enabled}') - self.stdout.write(f' Cache invalidation enabled: {callback_config.cache_invalidation_enabled}') - self.stdout.write(f' Related updates enabled: {callback_config.related_updates_enabled}') - self.stdout.write(f' Debug mode: {callback_config.debug_mode}') + self.stdout.write(self.style.WARNING("\nConfiguration Status:")) + self.stdout.write(f" Callbacks enabled: {callback_config.enabled}") + self.stdout.write(f" Notifications enabled: {callback_config.notifications_enabled}") + self.stdout.write(f" Cache invalidation enabled: {callback_config.cache_invalidation_enabled}") + self.stdout.write(f" Related updates enabled: {callback_config.related_updates_enabled}") + self.stdout.write(f" Debug mode: {callback_config.debug_mode}") def _output_table(self, registrations, model_filter, stage_filter, verbose): """Output in table format.""" - self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n')) + self.stdout.write(self.style.SUCCESS("\n=== FSM Transition Callbacks ===\n")) # Header if verbose: @@ -136,10 +128,10 @@ class Command(BaseCommand): header = f"{'Model':<20} {'Source':<15} {'Target':<15} {'Stage':<8} {'Callback':<30}" self.stdout.write(self.style.WARNING(header)) - self.stdout.write('-' * len(header)) + self.stdout.write("-" * len(header)) for stage in CallbackStage: - if stage_filter != 'all' and stage.value != stage_filter: + if stage_filter != "all" and stage.value != stage_filter: continue stage_regs = registrations.get(stage, []) @@ -160,18 +152,18 @@ class Command(BaseCommand): import json output = { - 'callbacks': [], - 'configuration': { - 'enabled': callback_config.enabled, - 'notifications_enabled': callback_config.notifications_enabled, - 'cache_invalidation_enabled': callback_config.cache_invalidation_enabled, - 'related_updates_enabled': callback_config.related_updates_enabled, - 'debug_mode': callback_config.debug_mode, - } + "callbacks": [], + "configuration": { + "enabled": callback_config.enabled, + "notifications_enabled": callback_config.notifications_enabled, + "cache_invalidation_enabled": callback_config.cache_invalidation_enabled, + "related_updates_enabled": callback_config.related_updates_enabled, + "debug_mode": callback_config.debug_mode, + }, } for stage in CallbackStage: - if stage_filter != 'all' and stage.value != stage_filter: + if stage_filter != "all" and stage.value != stage_filter: continue stage_regs = registrations.get(stage, []) @@ -180,15 +172,17 @@ class Command(BaseCommand): if model_filter and model_name != model_filter: continue - output['callbacks'].append({ - 'model': model_name, - 'field': reg.field_name, - 'source': reg.source, - 'target': reg.target, - 'stage': stage.value, - 'callback': reg.callback.name, - 'priority': reg.callback.priority, - 'continue_on_error': reg.callback.continue_on_error, - }) + output["callbacks"].append( + { + "model": model_name, + "field": reg.field_name, + "source": reg.source, + "target": reg.target, + "stage": stage.value, + "callback": reg.callback.name, + "priority": reg.callback.priority, + "continue_on_error": reg.callback.continue_on_error, + } + ) self.stdout.write(json.dumps(output, indent=2)) diff --git a/backend/apps/core/management/commands/optimize_static.py b/backend/apps/core/management/commands/optimize_static.py index 5efd05cb..df184a41 100644 --- a/backend/apps/core/management/commands/optimize_static.py +++ b/backend/apps/core/management/commands/optimize_static.py @@ -52,26 +52,17 @@ class Command(BaseCommand): import rjsmin except ImportError: rjsmin = None - self.stdout.write( - self.style.WARNING( - "rjsmin not installed. Install with: pip install rjsmin" - ) - ) + self.stdout.write(self.style.WARNING("rjsmin not installed. Install with: pip install rjsmin")) try: import rcssmin except ImportError: rcssmin = None - self.stdout.write( - self.style.WARNING( - "rcssmin not installed. Install with: pip install rcssmin" - ) - ) + self.stdout.write(self.style.WARNING("rcssmin not installed. Install with: pip install rcssmin")) if not rjsmin and not rcssmin: raise CommandError( - "Neither rjsmin nor rcssmin is installed. " - "Install at least one: pip install rjsmin rcssmin" + "Neither rjsmin nor rcssmin is installed. " "Install at least one: pip install rjsmin rcssmin" ) # Get static file directories @@ -93,9 +84,7 @@ class Command(BaseCommand): if not css_only and rjsmin: js_dir = static_dir / "js" if js_dir.exists(): - saved, count = self._process_js_files( - js_dir, rjsmin, dry_run, force - ) + saved, count = self._process_js_files(js_dir, rjsmin, dry_run, force) total_js_saved += saved js_files_processed += count @@ -103,9 +92,7 @@ class Command(BaseCommand): if not js_only and rcssmin: css_dir = static_dir / "css" if css_dir.exists(): - saved, count = self._process_css_files( - css_dir, rcssmin, dry_run, force - ) + saved, count = self._process_css_files(css_dir, rcssmin, dry_run, force) total_css_saved += saved css_files_processed += count @@ -114,17 +101,11 @@ class Command(BaseCommand): self.stdout.write(self.style.SUCCESS("Static file optimization complete!")) self.stdout.write(f"JavaScript files processed: {js_files_processed}") self.stdout.write(f"CSS files processed: {css_files_processed}") - self.stdout.write( - f"Total JS savings: {self._format_size(total_js_saved)}" - ) - self.stdout.write( - f"Total CSS savings: {self._format_size(total_css_saved)}" - ) + self.stdout.write(f"Total JS savings: {self._format_size(total_js_saved)}") + self.stdout.write(f"Total CSS savings: {self._format_size(total_css_saved)}") if dry_run: - self.stdout.write( - self.style.WARNING("\nDry run - no files were modified") - ) + self.stdout.write(self.style.WARNING("\nDry run - no files were modified")) def _process_js_files(self, js_dir, rjsmin, dry_run, force): """Process JavaScript files for minification.""" @@ -140,9 +121,7 @@ class Command(BaseCommand): # Skip if minified version exists and not forcing if min_file.exists() and not force: - self.stdout.write( - f" Skipping {js_file.name} (min version exists)" - ) + self.stdout.write(f" Skipping {js_file.name} (min version exists)") continue try: @@ -169,9 +148,7 @@ class Command(BaseCommand): files_processed += 1 except Exception as e: - self.stdout.write( - self.style.ERROR(f" Error processing {js_file.name}: {e}") - ) + self.stdout.write(self.style.ERROR(f" Error processing {js_file.name}: {e}")) return total_saved, files_processed @@ -189,9 +166,7 @@ class Command(BaseCommand): # Skip if minified version exists and not forcing if min_file.exists() and not force: - self.stdout.write( - f" Skipping {css_file.name} (min version exists)" - ) + self.stdout.write(f" Skipping {css_file.name} (min version exists)") continue try: @@ -218,9 +193,7 @@ class Command(BaseCommand): files_processed += 1 except Exception as e: - self.stdout.write( - self.style.ERROR(f" Error processing {css_file.name}: {e}") - ) + self.stdout.write(self.style.ERROR(f" Error processing {css_file.name}: {e}")) return total_saved, files_processed diff --git a/backend/apps/core/management/commands/rundev.py b/backend/apps/core/management/commands/rundev.py index 83c7ba8a..64231d75 100644 --- a/backend/apps/core/management/commands/rundev.py +++ b/backend/apps/core/management/commands/rundev.py @@ -39,19 +39,13 @@ class Command(BaseCommand): def handle(self, *args, **options): """Run the development setup and start the server.""" if not options["skip_setup"]: - self.stdout.write( - self.style.SUCCESS( - "🚀 Setting up and starting ThrillWiki Development Server..." - ) - ) + self.stdout.write(self.style.SUCCESS("🚀 Setting up and starting ThrillWiki Development Server...")) # Run the setup_dev command first execute_from_command_line(["manage.py", "setup_dev"]) else: - self.stdout.write( - self.style.SUCCESS("🚀 Starting ThrillWiki Development Server...") - ) + self.stdout.write(self.style.SUCCESS("🚀 Starting ThrillWiki Development Server...")) # Determine which server command to use self.get_server_command(options) @@ -59,9 +53,7 @@ class Command(BaseCommand): # Start the server self.stdout.write("") self.stdout.write( - self.style.SUCCESS( - f"🌟 Starting Django development server on http://{options['host']}:{options['port']}" - ) + self.style.SUCCESS(f"🌟 Starting Django development server on http://{options['host']}:{options['port']}") ) self.stdout.write("Press Ctrl+C to stop the server") self.stdout.write("") @@ -76,9 +68,7 @@ class Command(BaseCommand): ] ) else: - execute_from_command_line( - ["manage.py", "runserver", f"{options['host']}:{options['port']}"] - ) + execute_from_command_line(["manage.py", "runserver", f"{options['host']}:{options['port']}"]) except KeyboardInterrupt: self.stdout.write("") self.stdout.write(self.style.SUCCESS("👋 Development server stopped")) diff --git a/backend/apps/core/management/commands/security_audit.py b/backend/apps/core/management/commands/security_audit.py index d54ac8ae..ce1613ff 100644 --- a/backend/apps/core/management/commands/security_audit.py +++ b/backend/apps/core/management/commands/security_audit.py @@ -16,23 +16,23 @@ from django.core.management.base import BaseCommand class Command(BaseCommand): - help = 'Run security audit and generate a report' + help = "Run security audit and generate a report" def add_arguments(self, parser): parser.add_argument( - '--output', + "--output", type=str, - help='Output file for the security report', + help="Output file for the security report", ) parser.add_argument( - '--verbose', - action='store_true', - help='Show detailed information for each check', + "--verbose", + action="store_true", + help="Show detailed information for each check", ) def handle(self, *args, **options): - self.verbose = options.get('verbose', False) - output_file = options.get('output') + self.verbose = options.get("verbose", False) + output_file = options.get("output") report_lines = [] @@ -66,11 +66,9 @@ class Command(BaseCommand): # Write to file if specified if output_file: - with open(output_file, 'w') as f: - f.write('\n'.join(report_lines)) - self.stdout.write( - self.style.SUCCESS(f'\nReport saved to: {output_file}') - ) + with open(output_file, "w") as f: + f.write("\n".join(report_lines)) + self.stdout.write(self.style.SUCCESS(f"\nReport saved to: {output_file}")) def log(self, message, report_lines): """Log message to both stdout and report.""" @@ -82,10 +80,7 @@ class Command(BaseCommand): errors = registry.run_checks(tags=[Tags.security]) if not errors: - self.log( - self.style.SUCCESS(" ✓ All Django security checks passed"), - report_lines - ) + self.log(self.style.SUCCESS(" ✓ All Django security checks passed"), report_lines) else: for error in errors: prefix = self.style.ERROR(" ✗ ERROR") if error.is_serious() else self.style.WARNING(" ! WARNING") @@ -97,73 +92,71 @@ class Command(BaseCommand): def check_configuration(self, report_lines): """Check various configuration settings.""" checks = [ - ('DEBUG mode', not settings.DEBUG, 'DEBUG should be False'), + ("DEBUG mode", not settings.DEBUG, "DEBUG should be False"), + ("SECRET_KEY length", len(settings.SECRET_KEY) >= 50, f"Length: {len(settings.SECRET_KEY)}"), ( - 'SECRET_KEY length', - len(settings.SECRET_KEY) >= 50, - f'Length: {len(settings.SECRET_KEY)}' + "ALLOWED_HOSTS", + bool(settings.ALLOWED_HOSTS) and "*" not in settings.ALLOWED_HOSTS, + str(settings.ALLOWED_HOSTS), ), ( - 'ALLOWED_HOSTS', - bool(settings.ALLOWED_HOSTS) and '*' not in settings.ALLOWED_HOSTS, - str(settings.ALLOWED_HOSTS) + "CSRF_TRUSTED_ORIGINS", + bool(getattr(settings, "CSRF_TRUSTED_ORIGINS", [])), + str(getattr(settings, "CSRF_TRUSTED_ORIGINS", [])), ), ( - 'CSRF_TRUSTED_ORIGINS', - bool(getattr(settings, 'CSRF_TRUSTED_ORIGINS', [])), - str(getattr(settings, 'CSRF_TRUSTED_ORIGINS', [])) + "X_FRAME_OPTIONS", + getattr(settings, "X_FRAME_OPTIONS", "") in ("DENY", "SAMEORIGIN"), + str(getattr(settings, "X_FRAME_OPTIONS", "Not set")), ), ( - 'X_FRAME_OPTIONS', - getattr(settings, 'X_FRAME_OPTIONS', '') in ('DENY', 'SAMEORIGIN'), - str(getattr(settings, 'X_FRAME_OPTIONS', 'Not set')) + "SECURE_CONTENT_TYPE_NOSNIFF", + getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False), + str(getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False)), ), ( - 'SECURE_CONTENT_TYPE_NOSNIFF', - getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False), - str(getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False)) + "SECURE_BROWSER_XSS_FILTER", + getattr(settings, "SECURE_BROWSER_XSS_FILTER", False), + str(getattr(settings, "SECURE_BROWSER_XSS_FILTER", False)), ), ( - 'SECURE_BROWSER_XSS_FILTER', - getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False), - str(getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False)) + "SESSION_COOKIE_HTTPONLY", + getattr(settings, "SESSION_COOKIE_HTTPONLY", True), + str(getattr(settings, "SESSION_COOKIE_HTTPONLY", "Not set")), ), ( - 'SESSION_COOKIE_HTTPONLY', - getattr(settings, 'SESSION_COOKIE_HTTPONLY', True), - str(getattr(settings, 'SESSION_COOKIE_HTTPONLY', 'Not set')) - ), - ( - 'CSRF_COOKIE_HTTPONLY', - getattr(settings, 'CSRF_COOKIE_HTTPONLY', True), - str(getattr(settings, 'CSRF_COOKIE_HTTPONLY', 'Not set')) + "CSRF_COOKIE_HTTPONLY", + getattr(settings, "CSRF_COOKIE_HTTPONLY", True), + str(getattr(settings, "CSRF_COOKIE_HTTPONLY", "Not set")), ), ] # Production-only checks if not settings.DEBUG: - checks.extend([ - ( - 'SECURE_SSL_REDIRECT', - getattr(settings, 'SECURE_SSL_REDIRECT', False), - str(getattr(settings, 'SECURE_SSL_REDIRECT', False)) - ), - ( - 'SESSION_COOKIE_SECURE', - getattr(settings, 'SESSION_COOKIE_SECURE', False), - str(getattr(settings, 'SESSION_COOKIE_SECURE', False)) - ), - ( - 'CSRF_COOKIE_SECURE', - getattr(settings, 'CSRF_COOKIE_SECURE', False), - str(getattr(settings, 'CSRF_COOKIE_SECURE', False)) - ), - ( - 'SECURE_HSTS_SECONDS', - getattr(settings, 'SECURE_HSTS_SECONDS', 0) >= 31536000, - str(getattr(settings, 'SECURE_HSTS_SECONDS', 0)) - ), - ]) + checks.extend( + [ + ( + "SECURE_SSL_REDIRECT", + getattr(settings, "SECURE_SSL_REDIRECT", False), + str(getattr(settings, "SECURE_SSL_REDIRECT", False)), + ), + ( + "SESSION_COOKIE_SECURE", + getattr(settings, "SESSION_COOKIE_SECURE", False), + str(getattr(settings, "SESSION_COOKIE_SECURE", False)), + ), + ( + "CSRF_COOKIE_SECURE", + getattr(settings, "CSRF_COOKIE_SECURE", False), + str(getattr(settings, "CSRF_COOKIE_SECURE", False)), + ), + ( + "SECURE_HSTS_SECONDS", + getattr(settings, "SECURE_HSTS_SECONDS", 0) >= 31536000, + str(getattr(settings, "SECURE_HSTS_SECONDS", 0)), + ), + ] + ) for name, is_secure, value in checks: status = self.style.SUCCESS("✓") if is_secure else self.style.WARNING("!") @@ -176,59 +169,43 @@ class Command(BaseCommand): def check_middleware(self, report_lines): """Check security-related middleware is properly configured.""" - middleware = getattr(settings, 'MIDDLEWARE', []) + middleware = getattr(settings, "MIDDLEWARE", []) required_middleware = [ - ('django.middleware.security.SecurityMiddleware', 'SecurityMiddleware'), - ('django.middleware.csrf.CsrfViewMiddleware', 'CSRF Middleware'), - ('django.middleware.clickjacking.XFrameOptionsMiddleware', 'X-Frame-Options'), + ("django.middleware.security.SecurityMiddleware", "SecurityMiddleware"), + ("django.middleware.csrf.CsrfViewMiddleware", "CSRF Middleware"), + ("django.middleware.clickjacking.XFrameOptionsMiddleware", "X-Frame-Options"), ] custom_security_middleware = [ - ('apps.core.middleware.security_headers.SecurityHeadersMiddleware', 'Security Headers'), - ('apps.core.middleware.rate_limiting.AuthRateLimitMiddleware', 'Rate Limiting'), + ("apps.core.middleware.security_headers.SecurityHeadersMiddleware", "Security Headers"), + ("apps.core.middleware.rate_limiting.AuthRateLimitMiddleware", "Rate Limiting"), ] # Check required middleware for mw_path, mw_name in required_middleware: if mw_path in middleware: - self.log( - f" {self.style.SUCCESS('✓')} {mw_name} is enabled", - report_lines - ) + self.log(f" {self.style.SUCCESS('✓')} {mw_name} is enabled", report_lines) else: - self.log( - f" {self.style.ERROR('✗')} {mw_name} is NOT enabled", - report_lines - ) + self.log(f" {self.style.ERROR('✗')} {mw_name} is NOT enabled", report_lines) # Check custom security middleware for mw_path, mw_name in custom_security_middleware: if mw_path in middleware: - self.log( - f" {self.style.SUCCESS('✓')} {mw_name} is enabled", - report_lines - ) + self.log(f" {self.style.SUCCESS('✓')} {mw_name} is enabled", report_lines) else: - self.log( - f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)", - report_lines - ) + self.log(f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)", report_lines) # Check middleware order try: - security_idx = middleware.index('django.middleware.security.SecurityMiddleware') - session_idx = middleware.index('django.contrib.sessions.middleware.SessionMiddleware') + security_idx = middleware.index("django.middleware.security.SecurityMiddleware") + session_idx = middleware.index("django.contrib.sessions.middleware.SessionMiddleware") if security_idx < session_idx: - self.log( - f" {self.style.SUCCESS('✓')} Middleware ordering is correct", - report_lines - ) + self.log(f" {self.style.SUCCESS('✓')} Middleware ordering is correct", report_lines) else: self.log( - f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware", - report_lines + f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware", report_lines ) except ValueError: pass # Middleware not found, already reported above diff --git a/backend/apps/core/management/commands/setup_dev.py b/backend/apps/core/management/commands/setup_dev.py index c85277cc..46669562 100644 --- a/backend/apps/core/management/commands/setup_dev.py +++ b/backend/apps/core/management/commands/setup_dev.py @@ -39,9 +39,7 @@ class Command(BaseCommand): def handle(self, *args, **options): """Run the development setup process.""" - self.stdout.write( - self.style.SUCCESS("🚀 Setting up ThrillWiki Development Environment...") - ) + self.stdout.write(self.style.SUCCESS("🚀 Setting up ThrillWiki Development Environment...")) # Create necessary directories self.create_directories() @@ -71,9 +69,7 @@ class Command(BaseCommand): # Display environment info self.display_environment_info() - self.stdout.write( - self.style.SUCCESS("✅ Development environment setup complete!") - ) + self.stdout.write(self.style.SUCCESS("✅ Development environment setup complete!")) def create_directories(self): """Create necessary directories.""" @@ -99,36 +95,24 @@ class Command(BaseCommand): ) if result.returncode == 0: - self.stdout.write( - self.style.SUCCESS("✅ Database migrations are up to date") - ) + self.stdout.write(self.style.SUCCESS("✅ Database migrations are up to date")) else: self.stdout.write("🔄 Running database migrations...") - subprocess.run( - ["uv", "run", "manage.py", "migrate", "--noinput"], check=True - ) - self.stdout.write( - self.style.SUCCESS("✅ Database migrations completed") - ) + subprocess.run(["uv", "run", "manage.py", "migrate", "--noinput"], check=True) + self.stdout.write(self.style.SUCCESS("✅ Database migrations completed")) except subprocess.CalledProcessError as e: - self.stdout.write( - self.style.WARNING(f"⚠️ Migration error (continuing): {e}") - ) + self.stdout.write(self.style.WARNING(f"⚠️ Migration error (continuing): {e}")) def seed_sample_data(self): """Seed sample data to the database.""" self.stdout.write("🌱 Seeding sample data...") try: - subprocess.run( - ["uv", "run", "manage.py", "seed_sample_data"], check=True - ) + subprocess.run(["uv", "run", "manage.py", "seed_sample_data"], check=True) self.stdout.write(self.style.SUCCESS("✅ Sample data seeded")) except subprocess.CalledProcessError: - self.stdout.write( - self.style.WARNING("⚠️ Could not seed sample data (continuing)") - ) + self.stdout.write(self.style.WARNING("⚠️ Could not seed sample data (continuing)")) def create_superuser(self): """Create development superuser if it doesn't exist.""" @@ -145,13 +129,9 @@ class Command(BaseCommand): self.stdout.write("👤 Creating development superuser (admin/admin)...") if not User.objects.filter(username="admin").exists(): User.objects.create_superuser("admin", "admin@example.com", "admin") - self.stdout.write( - self.style.SUCCESS("✅ Created superuser: admin/admin") - ) + self.stdout.write(self.style.SUCCESS("✅ Created superuser: admin/admin")) else: - self.stdout.write( - self.style.SUCCESS("✅ Admin user already exists") - ) + self.stdout.write(self.style.SUCCESS("✅ Admin user already exists")) except Exception as e: self.stdout.write(self.style.WARNING(f"⚠️ Could not create superuser: {e}")) @@ -167,9 +147,7 @@ class Command(BaseCommand): ) self.stdout.write(self.style.SUCCESS("✅ Static files collected")) except subprocess.CalledProcessError as e: - self.stdout.write( - self.style.WARNING(f"⚠️ Could not collect static files: {e}") - ) + self.stdout.write(self.style.WARNING(f"⚠️ Could not collect static files: {e}")) def build_tailwind(self): """Build Tailwind CSS if npm is available.""" @@ -180,17 +158,11 @@ class Command(BaseCommand): subprocess.run(["npm", "--version"], capture_output=True, check=True) # Build Tailwind CSS - subprocess.run( - ["uv", "run", "manage.py", "tailwind", "build"], check=True - ) + subprocess.run(["uv", "run", "manage.py", "tailwind", "build"], check=True) self.stdout.write(self.style.SUCCESS("✅ Tailwind CSS built")) except (subprocess.CalledProcessError, FileNotFoundError): - self.stdout.write( - self.style.WARNING( - "⚠️ npm not found or Tailwind build failed, skipping" - ) - ) + self.stdout.write(self.style.WARNING("⚠️ npm not found or Tailwind build failed, skipping")) def run_system_checks(self): """Run Django system checks.""" @@ -200,9 +172,7 @@ class Command(BaseCommand): subprocess.run(["uv", "run", "manage.py", "check"], check=True) self.stdout.write(self.style.SUCCESS("✅ System checks passed")) except subprocess.CalledProcessError: - self.stdout.write( - self.style.WARNING("❌ System checks failed, but continuing...") - ) + self.stdout.write(self.style.WARNING("❌ System checks failed, but continuing...")) def display_environment_info(self): """Display development environment information.""" diff --git a/backend/apps/core/management/commands/test_transition_callbacks.py b/backend/apps/core/management/commands/test_transition_callbacks.py index 081db37b..d4e05095 100644 --- a/backend/apps/core/management/commands/test_transition_callbacks.py +++ b/backend/apps/core/management/commands/test_transition_callbacks.py @@ -18,62 +18,62 @@ from apps.core.state_machine.monitoring import callback_monitor class Command(BaseCommand): - help = 'Test FSM transition callbacks for specific transitions' + help = "Test FSM transition callbacks for specific transitions" def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( - 'model', + "model", type=str, - help='Model name (e.g., EditSubmission, Ride, Park)', + help="Model name (e.g., EditSubmission, Ride, Park)", ) parser.add_argument( - 'source', + "source", type=str, - help='Source state value', + help="Source state value", ) parser.add_argument( - 'target', + "target", type=str, - help='Target state value', + help="Target state value", ) parser.add_argument( - '--instance-id', + "--instance-id", type=int, - help='ID of an existing instance to use for testing', + help="ID of an existing instance to use for testing", ) parser.add_argument( - '--user-id', + "--user-id", type=int, - help='ID of user to use for testing', + help="ID of user to use for testing", ) parser.add_argument( - '--dry-run', - action='store_true', - help='Show what would be executed without running callbacks', + "--dry-run", + action="store_true", + help="Show what would be executed without running callbacks", ) parser.add_argument( - '--stage', + "--stage", type=str, - choices=['pre', 'post', 'error', 'all'], - default='all', - help='Which callback stage to test', + choices=["pre", "post", "error", "all"], + default="all", + help="Which callback stage to test", ) parser.add_argument( - '--field', + "--field", type=str, - default='status', - help='FSM field name (default: status)', + default="status", + help="FSM field name (default: status)", ) def handle(self, *args, **options): - model_name = options['model'] - source = options['source'] - target = options['target'] - instance_id = options.get('instance_id') - user_id = options.get('user_id') - dry_run = options.get('dry_run', False) - stage_filter = options.get('stage', 'all') - field_name = options.get('field', 'status') + model_name = options["model"] + source = options["source"] + target = options["target"] + instance_id = options.get("instance_id") + user_id = options.get("user_id") + dry_run = options.get("dry_run", False) + stage_filter = options.get("stage", "all") + field_name = options.get("field", "status") # Find the model class model_class = self._find_model(model_name) @@ -90,7 +90,7 @@ class Command(BaseCommand): try: user = User.objects.get(pk=user_id) except User.DoesNotExist: - raise CommandError(f"User with ID {user_id} not found") + raise CommandError(f"User with ID {user_id} not found") from None # Create transition context context = TransitionContext( @@ -101,19 +101,21 @@ class Command(BaseCommand): user=user, ) - self.stdout.write(self.style.SUCCESS( - f'\n=== Testing Transition Callbacks ===\n' - f'Model: {model_name}\n' - f'Transition: {source} → {target}\n' - f'Field: {field_name}\n' - f'Instance: {instance}\n' - f'User: {user}\n' - f'Dry Run: {dry_run}\n' - )) + self.stdout.write( + self.style.SUCCESS( + f"\n=== Testing Transition Callbacks ===\n" + f"Model: {model_name}\n" + f"Transition: {source} → {target}\n" + f"Field: {field_name}\n" + f"Instance: {instance}\n" + f"User: {user}\n" + f"Dry Run: {dry_run}\n" + ) + ) # Get callbacks for each stage stages_to_test = [] - if stage_filter == 'all': + if stage_filter == "all": stages_to_test = [CallbackStage.PRE, CallbackStage.POST, CallbackStage.ERROR] else: stages_to_test = [CallbackStage(stage_filter)] @@ -123,83 +125,69 @@ class Command(BaseCommand): total_failures = 0 for stage in stages_to_test: - callbacks = callback_registry.get_callbacks( - model_class, field_name, source, target, stage - ) + callbacks = callback_registry.get_callbacks(model_class, field_name, source, target, stage) if not callbacks: - self.stdout.write( - self.style.WARNING(f'\nNo {stage.value.upper()} callbacks registered') - ) + self.stdout.write(self.style.WARNING(f"\nNo {stage.value.upper()} callbacks registered")) continue - self.stdout.write( - self.style.WARNING(f'\n{stage.value.upper()} Callbacks ({len(callbacks)}):') - ) - self.stdout.write('-' * 50) + self.stdout.write(self.style.WARNING(f"\n{stage.value.upper()} Callbacks ({len(callbacks)}):")) + self.stdout.write("-" * 50) for callback in callbacks: total_callbacks += 1 callback_info = ( - f' {callback.name} (priority: {callback.priority}, ' - f'continue_on_error: {callback.continue_on_error})' + f" {callback.name} (priority: {callback.priority}, " + f"continue_on_error: {callback.continue_on_error})" ) if dry_run: self.stdout.write(callback_info) - self.stdout.write(self.style.NOTICE(' → Would execute (dry run)')) + self.stdout.write(self.style.NOTICE(" → Would execute (dry run)")) else: self.stdout.write(callback_info) # Check should_execute if not callback.should_execute(context): - self.stdout.write( - self.style.WARNING(' → Skipped (should_execute returned False)') - ) + self.stdout.write(self.style.WARNING(" → Skipped (should_execute returned False)")) continue # Execute callback try: if stage == CallbackStage.ERROR: - result = callback.execute( - context, - exception=Exception("Test exception") - ) + result = callback.execute(context, exception=Exception("Test exception")) else: result = callback.execute(context) if result: - self.stdout.write(self.style.SUCCESS(' → Success')) + self.stdout.write(self.style.SUCCESS(" → Success")) total_success += 1 else: - self.stdout.write(self.style.ERROR(' → Failed (returned False)')) + self.stdout.write(self.style.ERROR(" → Failed (returned False)")) total_failures += 1 except Exception as e: - self.stdout.write( - self.style.ERROR(f' → Exception: {type(e).__name__}: {e}') - ) + self.stdout.write(self.style.ERROR(f" → Exception: {type(e).__name__}: {e}")) total_failures += 1 # Summary - self.stdout.write('\n' + '=' * 50) - self.stdout.write(self.style.SUCCESS(f'Total callbacks: {total_callbacks}')) + self.stdout.write("\n" + "=" * 50) + self.stdout.write(self.style.SUCCESS(f"Total callbacks: {total_callbacks}")) if not dry_run: - self.stdout.write(self.style.SUCCESS(f'Successful: {total_success}')) + self.stdout.write(self.style.SUCCESS(f"Successful: {total_success}")) self.stdout.write( - self.style.ERROR(f'Failed: {total_failures}') if total_failures - else self.style.SUCCESS(f'Failed: {total_failures}') + self.style.ERROR(f"Failed: {total_failures}") + if total_failures + else self.style.SUCCESS(f"Failed: {total_failures}") ) # Show monitoring stats if available if not dry_run: - self.stdout.write(self.style.WARNING('\nRecent Executions:')) + self.stdout.write(self.style.WARNING("\nRecent Executions:")) recent = callback_monitor.get_recent_executions(limit=10) for record in recent: - status = '✓' if record.success else '✗' - self.stdout.write( - f' {status} {record.callback_name} [{record.duration_ms:.2f}ms]' - ) + status = "✓" if record.success else "✗" + self.stdout.write(f" {status} {record.callback_name} [{record.duration_ms:.2f}ms]") def _find_model(self, model_name): """Find a model class by name.""" @@ -217,9 +205,7 @@ class Command(BaseCommand): try: return model_class.objects.get(pk=instance_id) except model_class.DoesNotExist: - raise CommandError( - f"{model_class.__name__} with ID {instance_id} not found" - ) + raise CommandError(f"{model_class.__name__} with ID {instance_id} not found") from None # Create a mock instance for testing # This won't be saved to the database @@ -227,8 +213,6 @@ class Command(BaseCommand): instance.pk = 0 # Fake ID setattr(instance, field_name, source) - self.stdout.write(self.style.NOTICE( - 'Using mock instance (no --instance-id provided)' - )) + self.stdout.write(self.style.NOTICE("Using mock instance (no --instance-id provided)")) return instance diff --git a/backend/apps/core/management/commands/test_trending.py b/backend/apps/core/management/commands/test_trending.py index 588506b7..352163ba 100644 --- a/backend/apps/core/management/commands/test_trending.py +++ b/backend/apps/core/management/commands/test_trending.py @@ -37,9 +37,7 @@ class Command(BaseCommand): self.test_trending_algorithm() self.test_api_format() - self.stdout.write( - self.style.SUCCESS("✓ Trending system test completed successfully!") - ) + self.stdout.write(self.style.SUCCESS("✓ Trending system test completed successfully!")) def clean_test_data(self): """Clean existing test data.""" @@ -101,9 +99,7 @@ class Command(BaseCommand): # Create parks parks = [] for park_data in parks_data: - park, created = Park.objects.get_or_create( - name=park_data["name"], defaults=park_data - ) + park, created = Park.objects.get_or_create(name=park_data["name"], defaults=park_data) parks.append(park) if created and self.verbose: self.stdout.write(f" Created park: {park.name}") @@ -151,9 +147,7 @@ class Command(BaseCommand): # Create rides rides = [] for ride_data in rides_data: - ride, created = Ride.objects.get_or_create( - name=ride_data["name"], defaults=ride_data - ) + ride, created = Ride.objects.get_or_create(name=ride_data["name"], defaults=ride_data) rides.append(ride) if created and self.verbose: self.stdout.write(f" Created ride: {ride.name}") @@ -169,48 +163,34 @@ class Command(BaseCommand): # Pattern 1: Recently trending item (Steel Vengeance) steel_vengeance = next(r for r in rides if r.name == "Steel Vengeance") - self.create_views_for_content( - steel_vengeance, recent_views=50, older_views=10, base_time=now - ) + self.create_views_for_content(steel_vengeance, recent_views=50, older_views=10, base_time=now) # Pattern 2: Consistently popular item (Space Mountain) space_mountain = next(r for r in rides if r.name == "Space Mountain") - self.create_views_for_content( - space_mountain, recent_views=30, older_views=25, base_time=now - ) + self.create_views_for_content(space_mountain, recent_views=30, older_views=25, base_time=now) # Pattern 3: Declining popularity (Kingda Ka) kingda_ka = next(r for r in rides if r.name == "Kingda Ka") - self.create_views_for_content( - kingda_ka, recent_views=5, older_views=40, base_time=now - ) + self.create_views_for_content(kingda_ka, recent_views=5, older_views=40, base_time=now) # Pattern 4: New but growing (Millennium Force) millennium_force = next(r for r in rides if r.name == "Millennium Force") - self.create_views_for_content( - millennium_force, recent_views=25, older_views=5, base_time=now - ) + self.create_views_for_content(millennium_force, recent_views=25, older_views=5, base_time=now) # Create some park views too cedar_point = next(p for p in parks if p.name == "Cedar Point") - self.create_views_for_content( - cedar_point, recent_views=35, older_views=20, base_time=now - ) + self.create_views_for_content(cedar_point, recent_views=35, older_views=20, base_time=now) if self.verbose: self.stdout.write(" Created PageView data for trending analysis") - def create_views_for_content( - self, content_object, recent_views, older_views, base_time - ): + def create_views_for_content(self, content_object, recent_views, older_views, base_time): """Create PageViews for a content object with specified patterns.""" content_type = ContentType.objects.get_for_model(type(content_object)) # Create recent views (last 2 hours) for _i in range(recent_views): - view_time = base_time - timedelta( - minutes=random.randint(0, 120) # Last 2 hours - ) + view_time = base_time - timedelta(minutes=random.randint(0, 120)) # Last 2 hours PageView.objects.create( content_type=content_type, object_id=content_object.id, @@ -235,15 +215,9 @@ class Command(BaseCommand): self.stdout.write("Testing trending algorithm...") # Test trending content for different content types - trending_parks = trending_service.get_trending_content( - content_type="parks", limit=3 - ) - trending_rides = trending_service.get_trending_content( - content_type="rides", limit=3 - ) - trending_all = trending_service.get_trending_content( - content_type="all", limit=5 - ) + trending_parks = trending_service.get_trending_content(content_type="parks", limit=3) + trending_rides = trending_service.get_trending_content(content_type="rides", limit=3) + trending_all = trending_service.get_trending_content(content_type="all", limit=5) # Test new content new_parks = trending_service.get_new_content(content_type="parks", limit=3) @@ -265,12 +239,8 @@ class Command(BaseCommand): self.stdout.write("Testing API response format...") # Test trending content format - trending_parks = trending_service.get_trending_content( - content_type="parks", limit=3 - ) - trending_service.get_trending_content( - content_type="rides", limit=3 - ) + trending_parks = trending_service.get_trending_content(content_type="parks", limit=3) + trending_service.get_trending_content(content_type="rides", limit=3) # Test new content format new_parks = trending_service.get_new_content(content_type="parks", limit=3) diff --git a/backend/apps/core/management/commands/warm_cache.py b/backend/apps/core/management/commands/warm_cache.py index ba3dc207..3e5c85db 100644 --- a/backend/apps/core/management/commands/warm_cache.py +++ b/backend/apps/core/management/commands/warm_cache.py @@ -94,13 +94,21 @@ class Command(BaseCommand): try: parks_list = list( Park.objects.select_related("location", "operator") - .only("id", "name", "slug", "status", "location__city", "location__state_province", "location__country") + .only( + "id", + "name", + "slug", + "status", + "location__city", + "location__state_province", + "location__country", + ) .order_by("name")[:500] ) cache_service.default_cache.set( "warm:park_list", [{"id": p.id, "name": p.name, "slug": p.slug} for p in parks_list], - timeout=3600 + timeout=3600, ) warmed_count += 1 if verbose: @@ -116,11 +124,7 @@ class Command(BaseCommand): if not dry_run: try: status_counts = Park.objects.values("status").annotate(count=Count("id")) - cache_service.default_cache.set( - "warm:park_status_counts", - list(status_counts), - timeout=3600 - ) + cache_service.default_cache.set("warm:park_status_counts", list(status_counts), timeout=3600) warmed_count += 1 if verbose: self.stdout.write(" Cached park status counts") @@ -141,8 +145,11 @@ class Command(BaseCommand): ) cache_service.default_cache.set( "warm:popular_parks", - [{"id": p.id, "name": p.name, "slug": p.slug, "ride_count": p.ride_count} for p in popular_parks], - timeout=3600 + [ + {"id": p.id, "name": p.name, "slug": p.slug, "ride_count": p.ride_count} + for p in popular_parks + ], + timeout=3600, ) warmed_count += 1 if verbose: @@ -168,8 +175,11 @@ class Command(BaseCommand): ) cache_service.default_cache.set( "warm:ride_list", - [{"id": r.id, "name": r.name, "slug": r.slug, "park": r.park.name if r.park else None} for r in rides_list], - timeout=3600 + [ + {"id": r.id, "name": r.name, "slug": r.slug, "park": r.park.name if r.park else None} + for r in rides_list + ], + timeout=3600, ) warmed_count += 1 if verbose: @@ -185,11 +195,7 @@ class Command(BaseCommand): if not dry_run: try: category_counts = Ride.objects.values("category").annotate(count=Count("id")) - cache_service.default_cache.set( - "warm:ride_category_counts", - list(category_counts), - timeout=3600 - ) + cache_service.default_cache.set("warm:ride_category_counts", list(category_counts), timeout=3600) warmed_count += 1 if verbose: self.stdout.write(" Cached ride category counts") @@ -210,8 +216,16 @@ class Command(BaseCommand): ) cache_service.default_cache.set( "warm:top_rated_rides", - [{"id": r.id, "name": r.name, "slug": r.slug, "rating": float(r.average_rating) if r.average_rating else None} for r in top_rides], - timeout=3600 + [ + { + "id": r.id, + "name": r.name, + "slug": r.slug, + "rating": float(r.average_rating) if r.average_rating else None, + } + for r in top_rides + ], + timeout=3600, ) warmed_count += 1 if verbose: @@ -231,12 +245,9 @@ class Command(BaseCommand): try: # Park filter metadata from apps.parks.services.hybrid_loader import smart_park_loader + metadata = smart_park_loader.get_filter_metadata() - cache_service.default_cache.set( - "warm:park_filter_metadata", - metadata, - timeout=1800 - ) + cache_service.default_cache.set("warm:park_filter_metadata", metadata, timeout=1800) warmed_count += 1 if verbose: self.stdout.write(" Cached park filter metadata") @@ -251,13 +262,10 @@ class Command(BaseCommand): try: # Ride filter metadata from apps.rides.services.hybrid_loader import SmartRideLoader + ride_loader = SmartRideLoader() metadata = ride_loader.get_filter_metadata() - cache_service.default_cache.set( - "warm:ride_filter_metadata", - metadata, - timeout=1800 - ) + cache_service.default_cache.set("warm:ride_filter_metadata", metadata, timeout=1800) warmed_count += 1 if verbose: self.stdout.write(" Cached ride filter metadata") diff --git a/backend/apps/core/managers.py b/backend/apps/core/managers.py index 0b277831..18fd297a 100644 --- a/backend/apps/core/managers.py +++ b/backend/apps/core/managers.py @@ -92,9 +92,7 @@ class LocationQuerySet(BaseQuerySet): """Filter locations near a geographic point.""" if hasattr(self.model, "point"): return ( - self.filter(point__distance_lte=(point, Distance(km=distance_km))) - .distance(point) - .order_by("distance") + self.filter(point__distance_lte=(point, Distance(km=distance_km))).distance(point).order_by("distance") ) return self @@ -138,9 +136,7 @@ class LocationManager(BaseManager): return self.get_queryset().near_point(point=point, distance_km=distance_km) def within_bounds(self, *, north: float, south: float, east: float, west: float): - return self.get_queryset().within_bounds( - north=north, south=south, east=east, west=west - ) + return self.get_queryset().within_bounds(north=north, south=south, east=east, west=west) class ReviewableQuerySet(BaseQuerySet): @@ -151,9 +147,7 @@ class ReviewableQuerySet(BaseQuerySet): return self.annotate( review_count=Count("reviews", filter=Q(reviews__is_published=True)), average_rating=Avg("reviews__rating", filter=Q(reviews__is_published=True)), - latest_review_date=Max( - "reviews__created_at", filter=Q(reviews__is_published=True) - ), + latest_review_date=Max("reviews__created_at", filter=Q(reviews__is_published=True)), ) def highly_rated(self, *, min_rating: float = 8.0): @@ -163,9 +157,7 @@ class ReviewableQuerySet(BaseQuerySet): def recently_reviewed(self, *, days: int = 30): """Filter for items with recent reviews.""" cutoff_date = timezone.now() - timedelta(days=days) - return self.filter( - reviews__created_at__gte=cutoff_date, reviews__is_published=True - ).distinct() + return self.filter(reviews__created_at__gte=cutoff_date, reviews__is_published=True).distinct() class ReviewableManager(BaseManager): @@ -237,9 +229,7 @@ class TimestampedManager(BaseManager): return TimestampedQuerySet(self.model, using=self._db) def created_between(self, *, start_date, end_date): - return self.get_queryset().created_between( - start_date=start_date, end_date=end_date - ) + return self.get_queryset().created_between(start_date=start_date, end_date=end_date) class StatusQuerySet(BaseQuerySet): diff --git a/backend/apps/core/middleware/analytics.py b/backend/apps/core/middleware/analytics.py index e9be360f..c740d410 100644 --- a/backend/apps/core/middleware/analytics.py +++ b/backend/apps/core/middleware/analytics.py @@ -16,16 +16,10 @@ class RequestContextProvider(pghistory.context): def __call__(self, request: WSGIRequest) -> dict: return { - "user": ( - str(request.user) - if request.user and not isinstance(request.user, AnonymousUser) - else None - ), + "user": (str(request.user) if request.user and not isinstance(request.user, AnonymousUser) else None), "ip": request.META.get("REMOTE_ADDR"), "user_agent": request.META.get("HTTP_USER_AGENT"), - "session_key": ( - request.session.session_key if hasattr(request, "session") else None - ), + "session_key": (request.session.session_key if hasattr(request, "session") else None), } diff --git a/backend/apps/core/middleware/htmx_error_middleware.py b/backend/apps/core/middleware/htmx_error_middleware.py index 6b169a47..5a02c887 100644 --- a/backend/apps/core/middleware/htmx_error_middleware.py +++ b/backend/apps/core/middleware/htmx_error_middleware.py @@ -1,6 +1,7 @@ """ Middleware for handling errors in HTMX requests. """ + import logging from django.http import HttpResponseServerError diff --git a/backend/apps/core/middleware/nextjs.py b/backend/apps/core/middleware/nextjs.py index 8e025c66..f9b8f1a4 100644 --- a/backend/apps/core/middleware/nextjs.py +++ b/backend/apps/core/middleware/nextjs.py @@ -38,12 +38,8 @@ class APIResponseMiddleware(MiddlewareMixin): response["Vary"] = "Origin" # Helpful dev CORS headers (adjust for your frontend requests) - response["Access-Control-Allow-Methods"] = ( - "GET, POST, PUT, PATCH, DELETE, OPTIONS" - ) - response["Access-Control-Allow-Headers"] = ( - "Authorization, Content-Type, X-Requested-With" - ) + response["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" + response["Access-Control-Allow-Headers"] = "Authorization, Content-Type, X-Requested-With" # Uncomment if your dev frontend needs to send cookies/auth credentials # response['Access-Control-Allow-Credentials'] = 'true' logger.debug(f"Added CORS headers for origin: {origin}") diff --git a/backend/apps/core/middleware/performance_middleware.py b/backend/apps/core/middleware/performance_middleware.py index 76943981..ece5c9e3 100644 --- a/backend/apps/core/middleware/performance_middleware.py +++ b/backend/apps/core/middleware/performance_middleware.py @@ -19,9 +19,7 @@ class PerformanceMiddleware(MiddlewareMixin): def process_request(self, request): """Initialize performance tracking for the request""" request._performance_start_time = time.time() - request._performance_initial_queries = ( - len(connection.queries) if hasattr(connection, "queries") else 0 - ) + request._performance_initial_queries = len(connection.queries) if hasattr(connection, "queries") else 0 def process_response(self, request, response): """Log performance metrics after response is ready""" @@ -42,11 +40,7 @@ class PerformanceMiddleware(MiddlewareMixin): duration = end_time - start_time initial_queries = getattr(request, "_performance_initial_queries", 0) - total_queries = ( - len(connection.queries) - initial_queries - if hasattr(connection, "queries") - else 0 - ) + total_queries = len(connection.queries) - initial_queries if hasattr(connection, "queries") else 0 # Get content length content_length = 0 @@ -70,9 +64,7 @@ class PerformanceMiddleware(MiddlewareMixin): if hasattr(request, "user") and request.user.is_authenticated else None ), - "user_agent": request.META.get("HTTP_USER_AGENT", "")[ - :100 - ], # Truncate user agent + "user_agent": request.META.get("HTTP_USER_AGENT", "")[:100], # Truncate user agent "remote_addr": self._get_client_ip(request), } @@ -81,11 +73,7 @@ class PerformanceMiddleware(MiddlewareMixin): recent_queries = connection.queries[-total_queries:] performance_data["queries"] = [ { - "sql": ( - query["sql"][:200] + "..." - if len(query["sql"]) > 200 - else query["sql"] - ), + "sql": (query["sql"][:200] + "..." if len(query["sql"]) > 200 else query["sql"]), "time": float(query["time"]), } for query in recent_queries[-10:] # Last 10 queries only @@ -95,9 +83,7 @@ class PerformanceMiddleware(MiddlewareMixin): slow_queries = [q for q in recent_queries if float(q["time"]) > 0.1] if slow_queries: performance_data["slow_query_count"] = len(slow_queries) - performance_data["slowest_query_time"] = max( - float(q["time"]) for q in slow_queries - ) + performance_data["slowest_query_time"] = max(float(q["time"]) for q in slow_queries) # Determine log level based on performance log_level = self._get_log_level(duration, total_queries, response.status_code) @@ -115,9 +101,7 @@ class PerformanceMiddleware(MiddlewareMixin): response["X-Response-Time"] = f"{duration * 1000:.2f}ms" response["X-Query-Count"] = str(total_queries) if total_queries > 0 and hasattr(connection, "queries"): - total_query_time = sum( - float(q["time"]) for q in connection.queries[-total_queries:] - ) + total_query_time = sum(float(q["time"]) for q in connection.queries[-total_queries:]) response["X-Query-Time"] = f"{total_query_time * 1000:.2f}ms" return response @@ -129,11 +113,7 @@ class PerformanceMiddleware(MiddlewareMixin): duration = end_time - start_time initial_queries = getattr(request, "_performance_initial_queries", 0) - total_queries = ( - len(connection.queries) - initial_queries - if hasattr(connection, "queries") - else 0 - ) + total_queries = len(connection.queries) - initial_queries if hasattr(connection, "queries") else 0 performance_data = { "path": request.path, @@ -195,9 +175,7 @@ class QueryCountMiddleware(MiddlewareMixin): def process_request(self, request): """Initialize query tracking""" - request._query_count_start = ( - len(connection.queries) if hasattr(connection, "queries") else 0 - ) + request._query_count_start = len(connection.queries) if hasattr(connection, "queries") else 0 def process_response(self, request, response): """Check query count and warn if excessive""" @@ -267,9 +245,7 @@ class CachePerformanceMiddleware(MiddlewareMixin): def process_response(self, request, response): """Log cache performance metrics""" - cache_duration = time.time() - getattr( - request, "_cache_start_time", time.time() - ) + cache_duration = time.time() - getattr(request, "_cache_start_time", time.time()) cache_hits = getattr(request, "_cache_hits", 0) cache_misses = getattr(request, "_cache_misses", 0) diff --git a/backend/apps/core/middleware/rate_limiting.py b/backend/apps/core/middleware/rate_limiting.py index 6bc4d1ed..e12aef2e 100644 --- a/backend/apps/core/middleware/rate_limiting.py +++ b/backend/apps/core/middleware/rate_limiting.py @@ -35,20 +35,17 @@ class AuthRateLimitMiddleware: # Endpoints to rate limit RATE_LIMITED_PATHS = { # Login endpoints - '/api/v1/auth/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100}, - '/accounts/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100}, - + "/api/v1/auth/login/": {"per_minute": 5, "per_hour": 30, "per_day": 100}, + "/accounts/login/": {"per_minute": 5, "per_hour": 30, "per_day": 100}, # Signup endpoints - '/api/v1/auth/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20}, - '/accounts/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20}, - + "/api/v1/auth/signup/": {"per_minute": 3, "per_hour": 10, "per_day": 20}, + "/accounts/signup/": {"per_minute": 3, "per_hour": 10, "per_day": 20}, # Password reset endpoints - '/api/v1/auth/password-reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10}, - '/accounts/password/reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10}, - + "/api/v1/auth/password-reset/": {"per_minute": 2, "per_hour": 5, "per_day": 10}, + "/accounts/password/reset/": {"per_minute": 2, "per_hour": 5, "per_day": 10}, # Token endpoints - '/api/v1/auth/token/': {'per_minute': 10, 'per_hour': 60, 'per_day': 200}, - '/api/v1/auth/token/refresh/': {'per_minute': 20, 'per_hour': 120, 'per_day': 500}, + "/api/v1/auth/token/": {"per_minute": 10, "per_hour": 60, "per_day": 200}, + "/api/v1/auth/token/refresh/": {"per_minute": 20, "per_hour": 120, "per_day": 500}, } def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): @@ -56,7 +53,7 @@ class AuthRateLimitMiddleware: def __call__(self, request: HttpRequest) -> HttpResponse: # Only rate limit POST requests to auth endpoints - if request.method != 'POST': + if request.method != "POST": return self.get_response(request) # Check if this path should be rate limited @@ -68,14 +65,10 @@ class AuthRateLimitMiddleware: client_ip = self._get_client_ip(request) # Check rate limits - is_allowed, message = self._check_rate_limits( - client_ip, request.path, limits - ) + is_allowed, message = self._check_rate_limits(client_ip, request.path, limits) if not is_allowed: - logger.warning( - f"Rate limit exceeded for {client_ip} on {request.path}" - ) + logger.warning(f"Rate limit exceeded for {client_ip} on {request.path}") return self._rate_limit_response(message) # Process request @@ -94,9 +87,9 @@ class AuthRateLimitMiddleware: return self.RATE_LIMITED_PATHS[path] # Prefix match (for paths with trailing slashes) - path_without_slash = path.rstrip('/') + path_without_slash = path.rstrip("/") for limited_path, limits in self.RATE_LIMITED_PATHS.items(): - if path_without_slash == limited_path.rstrip('/'): + if path_without_slash == limited_path.rstrip("/"): return limits return None @@ -108,23 +101,18 @@ class AuthRateLimitMiddleware: Handles common proxy headers (X-Forwarded-For, X-Real-IP). """ # Check for forwarded headers (set by reverse proxies) - x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") if x_forwarded_for: # Take the first IP in the chain (client IP) - return x_forwarded_for.split(',')[0].strip() + return x_forwarded_for.split(",")[0].strip() - x_real_ip = request.META.get('HTTP_X_REAL_IP') + x_real_ip = request.META.get("HTTP_X_REAL_IP") if x_real_ip: return x_real_ip - return request.META.get('REMOTE_ADDR', 'unknown') + return request.META.get("REMOTE_ADDR", "unknown") - def _check_rate_limits( - self, - client_ip: str, - path: str, - limits: dict - ) -> tuple[bool, str]: + def _check_rate_limits(self, client_ip: str, path: str, limits: dict) -> tuple[bool, str]: """ Check if the client has exceeded rate limits. @@ -132,31 +120,31 @@ class AuthRateLimitMiddleware: Tuple of (is_allowed, reason_if_blocked) """ # Create a safe cache key from path - path_key = path.replace('/', '_').strip('_') + path_key = path.replace("/", "_").strip("_") # Check per-minute limit minute_key = f"auth_rate:{client_ip}:{path_key}:minute" minute_count = cache.get(minute_key, 0) - if minute_count >= limits.get('per_minute', 10): + if minute_count >= limits.get("per_minute", 10): return False, "Too many requests. Please wait a minute before trying again." # Check per-hour limit hour_key = f"auth_rate:{client_ip}:{path_key}:hour" hour_count = cache.get(hour_key, 0) - if hour_count >= limits.get('per_hour', 60): + if hour_count >= limits.get("per_hour", 60): return False, "Too many requests. Please try again later." # Check per-day limit day_key = f"auth_rate:{client_ip}:{path_key}:day" day_count = cache.get(day_key, 0) - if day_count >= limits.get('per_day', 200): + if day_count >= limits.get("per_day", 200): return False, "Daily limit exceeded. Please try again tomorrow." return True, "" def _increment_counters(self, client_ip: str, path: str) -> None: """Increment rate limit counters.""" - path_key = path.replace('/', '_').strip('_') + path_key = path.replace("/", "_").strip("_") # Increment per-minute counter minute_key = f"auth_rate:{client_ip}:{path_key}:minute" @@ -183,8 +171,8 @@ class AuthRateLimitMiddleware: """Generate a rate limit exceeded response.""" return JsonResponse( { - 'error': message, - 'code': 'RATE_LIMIT_EXCEEDED', + "error": message, + "code": "RATE_LIMIT_EXCEEDED", }, status=429, # Too Many Requests ) @@ -201,50 +189,31 @@ class SecurityEventLogger: """ @staticmethod - def log_failed_login( - request: HttpRequest, - username: str, - reason: str = "Invalid credentials" - ) -> None: + def log_failed_login(request: HttpRequest, username: str, reason: str = "Invalid credentials") -> None: """Log a failed login attempt.""" - client_ip = AuthRateLimitMiddleware._get_client_ip( - AuthRateLimitMiddleware, request - ) + client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request) logger.warning( f"Failed login attempt - IP: {client_ip}, Username: {username}, " f"Reason: {reason}, User-Agent: {request.META.get('HTTP_USER_AGENT', 'unknown')}" ) @staticmethod - def log_permission_denied( - request: HttpRequest, - resource: str, - action: str = "access" - ) -> None: + def log_permission_denied(request: HttpRequest, resource: str, action: str = "access") -> None: """Log a permission denied event.""" - client_ip = AuthRateLimitMiddleware._get_client_ip( - AuthRateLimitMiddleware, request - ) - user = getattr(request, 'user', None) - username = user.username if user and user.is_authenticated else 'anonymous' + client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request) + user = getattr(request, "user", None) + username = user.username if user and user.is_authenticated else "anonymous" logger.warning( - f"Permission denied - IP: {client_ip}, User: {username}, " - f"Resource: {resource}, Action: {action}" + f"Permission denied - IP: {client_ip}, User: {username}, " f"Resource: {resource}, Action: {action}" ) @staticmethod - def log_suspicious_activity( - request: HttpRequest, - activity_type: str, - details: str = "" - ) -> None: + def log_suspicious_activity(request: HttpRequest, activity_type: str, details: str = "") -> None: """Log suspicious activity.""" - client_ip = AuthRateLimitMiddleware._get_client_ip( - AuthRateLimitMiddleware, request - ) - user = getattr(request, 'user', None) - username = user.username if user and user.is_authenticated else 'anonymous' + client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request) + user = getattr(request, "user", None) + username = user.username if user and user.is_authenticated else "anonymous" logger.error( f"Suspicious activity detected - Type: {activity_type}, " diff --git a/backend/apps/core/middleware/request_logging.py b/backend/apps/core/middleware/request_logging.py index 504dacd1..8fd02ea6 100644 --- a/backend/apps/core/middleware/request_logging.py +++ b/backend/apps/core/middleware/request_logging.py @@ -9,7 +9,7 @@ import time from django.utils.deprecation import MiddlewareMixin -logger = logging.getLogger('request_logging') +logger = logging.getLogger("request_logging") class RequestLoggingMiddleware(MiddlewareMixin): @@ -20,17 +20,16 @@ class RequestLoggingMiddleware(MiddlewareMixin): # Paths to exclude from detailed logging (e.g., static files, health checks) EXCLUDE_DETAILED_LOGGING_PATHS = [ - '/static/', - '/media/', - '/favicon.ico', - '/health/', - '/admin/jsi18n/', + "/static/", + "/media/", + "/favicon.ico", + "/health/", + "/admin/jsi18n/", ] def _should_log_detailed(self, request): """Determine if detailed logging should be enabled for this request.""" - return not any( - path in request.path for path in self.EXCLUDE_DETAILED_LOGGING_PATHS) + return not any(path in request.path for path in self.EXCLUDE_DETAILED_LOGGING_PATHS) def process_request(self, request): """Store request start time and capture request data for detailed logging.""" @@ -44,14 +43,17 @@ class RequestLoggingMiddleware(MiddlewareMixin): try: # Log request data request_data = {} - if hasattr(request, 'data') and request.data: + if hasattr(request, "data") and request.data: request_data = dict(request.data) elif request.body: try: - request_data = json.loads(request.body.decode('utf-8')) + request_data = json.loads(request.body.decode("utf-8")) except (json.JSONDecodeError, UnicodeDecodeError): - request_data = {'body': str(request.body)[ - :200] + '...' if len(str(request.body)) > 200 else str(request.body)} + request_data = { + "body": ( + str(request.body)[:200] + "..." if len(str(request.body)) > 200 else str(request.body) + ) + } # Log query parameters query_params = dict(request.GET) if request.GET else {} @@ -61,9 +63,8 @@ class RequestLoggingMiddleware(MiddlewareMixin): logger.info(f" Body: {self._safe_log_data(request_data)}") if query_params: logger.info(f" Query: {query_params}") - if hasattr(request, 'user') and request.user.is_authenticated: - logger.info( - f" User: {request.user.username} (ID: {request.user.id})") + if hasattr(request, "user") and request.user.is_authenticated: + logger.info(f" User: {request.user.username} (ID: {request.user.id})") except Exception as e: logger.warning(f"Failed to log request data: {e}") @@ -75,34 +76,28 @@ class RequestLoggingMiddleware(MiddlewareMixin): try: # Calculate request duration duration = 0 - if hasattr(request, '_start_time'): + if hasattr(request, "_start_time"): duration = time.time() - request._start_time # Basic request logging - logger.info( - f"{request.method} {request.get_full_path()} -> {response.status_code} " - f"({duration:.3f}s)" - ) + logger.info(f"{request.method} {request.get_full_path()} -> {response.status_code} " f"({duration:.3f}s)") # Detailed response logging for specific endpoints - if getattr(request, '_log_request_data', False): + if getattr(request, "_log_request_data", False): try: # Log response data - if hasattr(response, 'data'): - logger.info( - f"RESPONSE DATA for {request.method} {request.path}:") + if hasattr(response, "data"): + logger.info(f"RESPONSE DATA for {request.method} {request.path}:") logger.info(f" Status: {response.status_code}") logger.info(f" Data: {self._safe_log_data(response.data)}") - elif hasattr(response, 'content'): + elif hasattr(response, "content"): try: - content = json.loads(response.content.decode('utf-8')) - logger.info( - f"RESPONSE DATA for {request.method} {request.path}:") + content = json.loads(response.content.decode("utf-8")) + logger.info(f"RESPONSE DATA for {request.method} {request.path}:") logger.info(f" Status: {response.status_code}") logger.info(f" Content: {self._safe_log_data(content)}") except (json.JSONDecodeError, UnicodeDecodeError): - logger.info( - f"RESPONSE DATA for {request.method} {request.path}:") + logger.info(f"RESPONSE DATA for {request.method} {request.path}:") logger.info(f" Status: {response.status_code}") logger.info(f" Content: {str(response.content)[:200]}...") @@ -118,31 +113,31 @@ class RequestLoggingMiddleware(MiddlewareMixin): # Sensitive field patterns that should be masked in logs # Security: Comprehensive list of sensitive data patterns SENSITIVE_PATTERNS = [ - 'password', - 'passwd', - 'pwd', - 'token', - 'secret', - 'key', - 'api_key', - 'apikey', - 'auth', - 'authorization', - 'credential', - 'ssn', - 'social_security', - 'credit_card', - 'creditcard', - 'card_number', - 'cvv', - 'cvc', - 'pin', - 'access_token', - 'refresh_token', - 'jwt', - 'session', - 'cookie', - 'private', + "password", + "passwd", + "pwd", + "token", + "secret", + "key", + "api_key", + "apikey", + "auth", + "authorization", + "credential", + "ssn", + "social_security", + "credit_card", + "creditcard", + "card_number", + "cvv", + "cvc", + "pin", + "access_token", + "refresh_token", + "jwt", + "session", + "cookie", + "private", ] def _safe_log_data(self, data): @@ -167,15 +162,15 @@ class RequestLoggingMiddleware(MiddlewareMixin): # Truncate if too long if len(data_str) > 1000: - return data_str[:1000] + '...[TRUNCATED]' + return data_str[:1000] + "...[TRUNCATED]" return data_str except Exception: - return str(data)[:500] + '...[ERROR_LOGGING]' + return str(data)[:500] + "...[ERROR_LOGGING]" def _mask_sensitive_dict(self, data, depth=0): """Recursively mask sensitive fields in a dictionary.""" if depth > 5: # Prevent infinite recursion - return '***DEPTH_LIMIT***' + return "***DEPTH_LIMIT***" safe_data = {} for key, value in data.items(): @@ -183,7 +178,7 @@ class RequestLoggingMiddleware(MiddlewareMixin): # Check if key contains any sensitive pattern if any(pattern in key_lower for pattern in self.SENSITIVE_PATTERNS): - safe_data[key] = '***MASKED***' + safe_data[key] = "***MASKED***" else: safe_data[key] = self._mask_sensitive_value(value, depth) @@ -197,11 +192,11 @@ class RequestLoggingMiddleware(MiddlewareMixin): return [self._mask_sensitive_value(item, depth + 1) for item in value[:10]] # Limit list items elif isinstance(value, str): # Mask email addresses (show only domain) - if '@' in value and '.' in value.split('@')[-1]: - parts = value.split('@') + if "@" in value and "." in value.split("@")[-1]: + parts = value.split("@") if len(parts) == 2: return f"***@{parts[1]}" # Truncate long strings if len(value) > 200: - return value[:200] + '...[TRUNCATED]' + return value[:200] + "...[TRUNCATED]" return value diff --git a/backend/apps/core/middleware/security_headers.py b/backend/apps/core/middleware/security_headers.py index 6793360d..70ff9eaf 100644 --- a/backend/apps/core/middleware/security_headers.py +++ b/backend/apps/core/middleware/security_headers.py @@ -49,9 +49,7 @@ class SecurityHeadersMiddleware: if not response.get("Content-Security-Policy"): response["Content-Security-Policy"] = self._csp_header else: - logger.warning( - f"CSP header already present for {request.path}, skipping" - ) + logger.warning(f"CSP header already present for {request.path}, skipping") # Permissions-Policy (successor to Feature-Policy) if not response.get("Permissions-Policy"): @@ -144,11 +142,13 @@ class SecurityHeadersMiddleware: # Add debug-specific relaxations if debug: # Allow webpack dev server connections in development - directives["connect-src"].extend([ - "ws://localhost:*", - "http://localhost:*", - "http://127.0.0.1:*", - ]) + directives["connect-src"].extend( + [ + "ws://localhost:*", + "http://localhost:*", + "http://127.0.0.1:*", + ] + ) # Build header string parts = [] @@ -168,30 +168,34 @@ class SecurityHeadersMiddleware: This header controls which browser features the page can use. """ # Get permissions policy from settings or use defaults - policy = getattr(settings, "PERMISSIONS_POLICY", { - "accelerometer": [], - "ambient-light-sensor": [], - "autoplay": [], - "camera": [], - "display-capture": [], - "document-domain": [], - "encrypted-media": [], - "fullscreen": ["self"], - "geolocation": ["self"], - "gyroscope": [], - "interest-cohort": [], - "magnetometer": [], - "microphone": [], - "midi": [], - "payment": [], - "picture-in-picture": [], - "publickey-credentials-get": [], - "screen-wake-lock": [], - "sync-xhr": [], - "usb": [], - "web-share": ["self"], - "xr-spatial-tracking": [], - }) + policy = getattr( + settings, + "PERMISSIONS_POLICY", + { + "accelerometer": [], + "ambient-light-sensor": [], + "autoplay": [], + "camera": [], + "display-capture": [], + "document-domain": [], + "encrypted-media": [], + "fullscreen": ["self"], + "geolocation": ["self"], + "gyroscope": [], + "interest-cohort": [], + "magnetometer": [], + "microphone": [], + "midi": [], + "payment": [], + "picture-in-picture": [], + "publickey-credentials-get": [], + "screen-wake-lock": [], + "sync-xhr": [], + "usb": [], + "web-share": ["self"], + "xr-spatial-tracking": [], + }, + ) parts = [] for feature, allowlist in policy.items(): diff --git a/backend/apps/core/middleware/view_tracking.py b/backend/apps/core/middleware/view_tracking.py index 6583a160..d5430049 100644 --- a/backend/apps/core/middleware/view_tracking.py +++ b/backend/apps/core/middleware/view_tracking.py @@ -9,7 +9,6 @@ analytics for the trending algorithm. import logging import re from datetime import timedelta -from typing import Union from django.conf import settings from django.contrib.contenttypes.models import ContentType @@ -22,7 +21,7 @@ from apps.parks.models import Park from apps.rides.models import Ride # Type alias for content objects -ContentObject = Union[Park, Ride] +ContentObject = Park | Ride logger = logging.getLogger(__name__) @@ -50,8 +49,7 @@ class ViewTrackingMiddleware: # Compile patterns for performance self.compiled_patterns = [ - (re.compile(pattern), content_type) - for pattern, content_type in self.tracked_patterns + (re.compile(pattern), content_type) for pattern, content_type in self.tracked_patterns ] # Cache configuration @@ -63,11 +61,7 @@ class ViewTrackingMiddleware: response = self.get_response(request) # Only track successful GET requests - if ( - request.method == "GET" - and 200 <= response.status_code < 300 - and not self._should_skip_tracking(request) - ): + if request.method == "GET" and 200 <= response.status_code < 300 and not self._should_skip_tracking(request): try: self._track_view_if_applicable(request) except Exception as e: @@ -119,9 +113,7 @@ class ViewTrackingMiddleware: self._record_page_view(request, content_type, slug) break - def _record_page_view( - self, request: HttpRequest, content_type: str, slug: str - ) -> None: + def _record_page_view(self, request: HttpRequest, content_type: str, slug: str) -> None: """Record a page view for the specified content.""" client_ip = self._get_client_ip(request) if not client_ip: @@ -131,33 +123,23 @@ class ViewTrackingMiddleware: # Get the content object content_obj = self._get_content_object(content_type, slug) if not content_obj: - self.logger.warning( - f"Content not found: {content_type} with slug '{slug}'" - ) + self.logger.warning(f"Content not found: {content_type} with slug '{slug}'") return # Check deduplication if self._is_duplicate_view(content_obj, client_ip): - self.logger.debug( - f"Duplicate view skipped for {content_type} {slug} from {client_ip}" - ) + self.logger.debug(f"Duplicate view skipped for {content_type} {slug} from {client_ip}") return # Create PageView record self._create_page_view(content_obj, client_ip, request) - self.logger.debug( - f"Recorded view for {content_type} {slug} from {client_ip}" - ) + self.logger.debug(f"Recorded view for {content_type} {slug} from {client_ip}") except Exception as e: - self.logger.error( - f"Failed to record page view for {content_type} {slug}: {e}" - ) + self.logger.error(f"Failed to record page view for {content_type} {slug}: {e}") - def _get_content_object( - self, content_type: str, slug: str - ) -> ContentObject | None: + def _get_content_object(self, content_type: str, slug: str) -> ContentObject | None: """Get the content object by type and slug.""" try: if content_type == "park": @@ -202,16 +184,12 @@ class ViewTrackingMiddleware: return existing_view - def _create_page_view( - self, content_obj: ContentObject, client_ip: str, request: HttpRequest - ) -> None: + def _create_page_view(self, content_obj: ContentObject, client_ip: str, request: HttpRequest) -> None: """Create a new PageView record.""" content_type = ContentType.objects.get_for_model(content_obj) # Extract additional metadata - user_agent = request.META.get("HTTP_USER_AGENT", "")[ - :500 - ] # Truncate long user agents + user_agent = request.META.get("HTTP_USER_AGENT", "")[:500] # Truncate long user agents referer = request.META.get("HTTP_REFERER", "")[:500] PageView.objects.create( @@ -267,11 +245,9 @@ class ViewTrackingMiddleware: return False # Skip localhost and private IPs in production - if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG): + if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG): # noqa: SIM102 if (ip.startswith(("127.", "192.168.", "10.")) or ip.startswith("172.")) and any( - 16 <= int(ip.split(".")[1]) <= 31 - for _ in [ip] - if ip.startswith("172.") + 16 <= int(ip.split(".")[1]) <= 31 for _ in [ip] if ip.startswith("172.") ): return False diff --git a/backend/apps/core/migrations/0004_alter_slughistory_options_and_more.py b/backend/apps/core/migrations/0004_alter_slughistory_options_and_more.py index 8bc33fc9..7c01f0f8 100644 --- a/backend/apps/core/migrations/0004_alter_slughistory_options_and_more.py +++ b/backend/apps/core/migrations/0004_alter_slughistory_options_and_more.py @@ -32,9 +32,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="slughistory", name="object_id", - field=models.CharField( - help_text="ID of the object this slug belongs to", max_length=50 - ), + field=models.CharField(help_text="ID of the object this slug belongs to", max_length=50), ), migrations.AlterField( model_name="slughistory", @@ -56,15 +54,11 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="slughistoryevent", name="object_id", - field=models.CharField( - help_text="ID of the object this slug belongs to", max_length=50 - ), + field=models.CharField(help_text="ID of the object this slug belongs to", max_length=50), ), migrations.AlterField( model_name="slughistoryevent", name="old_slug", - field=models.SlugField( - db_index=False, help_text="Previous slug value", max_length=200 - ), + field=models.SlugField(db_index=False, help_text="Previous slug value", max_length=200), ), ] diff --git a/backend/apps/core/mixins/__init__.py b/backend/apps/core/mixins/__init__.py index 2d2f7f10..b078b8ea 100644 --- a/backend/apps/core/mixins/__init__.py +++ b/backend/apps/core/mixins/__init__.py @@ -58,10 +58,7 @@ class HTMXFormMixin(FormMixin): def form_valid(self, form): """Add HX-Trigger header on successful form submission via HTMX.""" res = super().form_valid(form) - if ( - self.request.headers.get("HX-Request") == "true" - and self.htmx_success_trigger - ): + if self.request.headers.get("HX-Request") == "true" and self.htmx_success_trigger: res["HX-Trigger"] = self.htmx_success_trigger return res diff --git a/backend/apps/core/models.py b/backend/apps/core/models.py index d69c4719..5b720401 100644 --- a/backend/apps/core/models.py +++ b/backend/apps/core/models.py @@ -78,9 +78,7 @@ class SluggedModel(TrackedModel): Returns the name of the read-only ID field for this model. Should be overridden by subclasses. """ - raise NotImplementedError( - "Subclasses of SluggedModel must implement get_id_field_name()" - ) + raise NotImplementedError("Subclasses of SluggedModel must implement get_id_field_name()") @classmethod def get_by_slug(cls, slug): @@ -123,4 +121,4 @@ class SluggedModel(TrackedModel): True, ) - raise cls.DoesNotExist(f"{cls.__name__} with slug '{slug}' does not exist") + raise cls.DoesNotExist(f"{cls.__name__} with slug '{slug}' does not exist") from None diff --git a/backend/apps/core/permissions.py b/backend/apps/core/permissions.py index cc40da90..457e3d02 100644 --- a/backend/apps/core/permissions.py +++ b/backend/apps/core/permissions.py @@ -14,10 +14,11 @@ class IsOwnerOrReadOnly(permissions.BasePermission): # Write permissions are only allowed to the owner of the object. # Assumes the model instance has an `user` attribute. - if hasattr(obj, 'user'): + if hasattr(obj, "user"): return obj.user == request.user return False + class IsStaffOrReadOnly(permissions.BasePermission): """ Custom permission to only allow staff to edit it. diff --git a/backend/apps/core/selectors.py b/backend/apps/core/selectors.py index 34962ec6..306fd1e8 100644 --- a/backend/apps/core/selectors.py +++ b/backend/apps/core/selectors.py @@ -61,23 +61,20 @@ def unified_locations_for_map( # Rides if "ride" in location_types: - ride_queryset = Ride.objects.select_related( - "park", "manufacturer" - ).prefetch_related("park__location", "location") + ride_queryset = Ride.objects.select_related("park", "manufacturer").prefetch_related( + "park__location", "location" + ) if bounds: ride_queryset = ride_queryset.filter( - Q(location__coordinates__within=bounds) - | Q(park__location__coordinates__within=bounds) + Q(location__coordinates__within=bounds) | Q(park__location__coordinates__within=bounds) ) if filters: if "category" in filters: ride_queryset = ride_queryset.filter(category=filters["category"]) if "manufacturer" in filters: - ride_queryset = ride_queryset.filter( - manufacturer=filters["manufacturer"] - ) + ride_queryset = ride_queryset.filter(manufacturer=filters["manufacturer"]) if "park" in filters: ride_queryset = ride_queryset.filter(park=filters["park"]) @@ -248,12 +245,7 @@ def popular_pages_summary(*, days: int = 30) -> dict[str, Any]: total_views = PageView.objects.filter(timestamp__gte=cutoff_date).count() # Unique visitors (based on IP) - unique_visitors = ( - PageView.objects.filter(timestamp__gte=cutoff_date) - .values("ip_address") - .distinct() - .count() - ) + unique_visitors = PageView.objects.filter(timestamp__gte=cutoff_date).values("ip_address").distinct().count() return { "popular_pages": list(popular_pages), @@ -311,14 +303,10 @@ def system_health_metrics() -> dict[str, Any]: "page_views_7d": PageView.objects.filter(timestamp__gte=last_7d).count(), "data_freshness": { "latest_park_update": ( - Park.objects.order_by("-updated_at").first().updated_at - if Park.objects.exists() - else None + Park.objects.order_by("-updated_at").first().updated_at if Park.objects.exists() else None ), "latest_ride_update": ( - Ride.objects.order_by("-updated_at").first().updated_at - if Ride.objects.exists() - else None + Ride.objects.order_by("-updated_at").first().updated_at if Ride.objects.exists() else None ), }, } diff --git a/backend/apps/core/services/clustering_service.py b/backend/apps/core/services/clustering_service.py index 9a65373a..b3a1fc9e 100644 --- a/backend/apps/core/services/clustering_service.py +++ b/backend/apps/core/services/clustering_service.py @@ -63,9 +63,7 @@ class ClusteringService: if zoom_level < self.MIN_ZOOM_FOR_CLUSTERING: return True - config = self.ZOOM_CONFIGS.get( - zoom_level, {"min_points": self.MIN_POINTS_TO_CLUSTER} - ) + config = self.ZOOM_CONFIGS.get(zoom_level, {"min_points": self.MIN_POINTS_TO_CLUSTER}) return point_count >= config["min_points"] def cluster_locations( @@ -94,9 +92,7 @@ class ClusteringService: ) # Perform clustering - clustered_groups = self._cluster_points( - cluster_points, config["radius"], config["min_points"] - ) + clustered_groups = self._cluster_points(cluster_points, config["radius"], config["min_points"]) # Separate individual locations from clusters unclustered_locations = [] @@ -135,9 +131,7 @@ class ClusteringService: # Simple equirectangular projection (good enough for clustering) center_lat = (bounds.north + bounds.south) / 2 lat_scale = 111320 # meters per degree latitude - lng_scale = 111320 * math.cos( - math.radians(center_lat) - ) # meters per degree longitude + lng_scale = 111320 * math.cos(math.radians(center_lat)) # meters per degree longitude for location in locations: # Convert to meters relative to bounds center @@ -200,9 +194,7 @@ class ClusteringService: # Calculate cluster bounds lats = [loc.latitude for loc in locations] lngs = [loc.longitude for loc in locations] - cluster_bounds = GeoBounds( - north=max(lats), south=min(lats), east=max(lngs), west=min(lngs) - ) + cluster_bounds = GeoBounds(north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)) # Collect location types in cluster types = {loc.type for loc in locations} @@ -223,9 +215,7 @@ class ClusteringService: representative_location=representative, ) - def _select_representative_location( - self, locations: list[UnifiedLocation] - ) -> UnifiedLocation | None: + def _select_representative_location(self, locations: list[UnifiedLocation]) -> UnifiedLocation | None: """Select the most representative location for a cluster.""" if not locations: return None @@ -291,9 +281,7 @@ class ClusteringService: "category_distribution": dict(category_counts), } - def expand_cluster( - self, cluster: ClusterData, zoom_level: int - ) -> list[UnifiedLocation]: + def expand_cluster(self, cluster: ClusterData, zoom_level: int) -> list[UnifiedLocation]: """ Expand a cluster to show individual locations (for drill-down functionality). This would typically require re-querying the database with the cluster bounds. @@ -320,14 +308,11 @@ class SmartClusteringRules: return True # Major parks should resist clustering unless very close - if ( - loc1.cluster_category == "major_park" - or loc2.cluster_category == "major_park" - ): + if loc1.cluster_category == "major_park" or loc2.cluster_category == "major_park": return False # Similar types cluster more readily - if loc1.type == loc2.type: + if loc1.type == loc2.type: # noqa: SIM103 return True # Different types can cluster but with higher threshold diff --git a/backend/apps/core/services/data_structures.py b/backend/apps/core/services/data_structures.py index 8505941b..1bf9e556 100644 --- a/backend/apps/core/services/data_structures.py +++ b/backend/apps/core/services/data_structures.py @@ -89,9 +89,7 @@ class MapFilters: def to_dict(self) -> dict[str, Any]: """Convert to dictionary for caching and serialization.""" return { - "location_types": ( - [t.value for t in self.location_types] if self.location_types else None - ), + "location_types": ([t.value for t in self.location_types] if self.location_types else None), "park_status": (list(self.park_status) if self.park_status else None), "ride_types": list(self.ride_types) if self.ride_types else None, "company_roles": (list(self.company_roles) if self.company_roles else None), @@ -183,11 +181,7 @@ class ClusterData: "count": self.count, "types": [t.value for t in self.types], "bounds": self.bounds.to_dict(), - "representative": ( - self.representative_location.to_dict() - if self.representative_location - else None - ), + "representative": (self.representative_location.to_dict() if self.representative_location else None), } diff --git a/backend/apps/core/services/enhanced_cache_service.py b/backend/apps/core/services/enhanced_cache_service.py index 0b426305..0589b5bb 100644 --- a/backend/apps/core/services/enhanced_cache_service.py +++ b/backend/apps/core/services/enhanced_cache_service.py @@ -103,9 +103,7 @@ class EnhancedCacheService: self.default_cache.set(cache_key, data, timeout) logger.debug(f"Cached geographic data for bounds {bounds}") - def get_cached_geographic_data( - self, bounds: "GeoBounds", zoom_level: int - ) -> Any | None: + def get_cached_geographic_data(self, bounds: "GeoBounds", zoom_level: int) -> Any | None: """Retrieve cached geographic data""" cache_key = f"geo:{bounds.min_lat}:{bounds.min_lng}:{bounds.max_lat}:{ bounds.max_lng @@ -119,20 +117,14 @@ class EnhancedCacheService: # For Redis cache backends if hasattr(self.default_cache, "delete_pattern"): deleted_count = self.default_cache.delete_pattern(pattern) - logger.info( - f"Invalidated {deleted_count} cache keys matching pattern '{pattern}'" - ) + logger.info(f"Invalidated {deleted_count} cache keys matching pattern '{pattern}'") return deleted_count else: - logger.warning( - f"Cache backend does not support pattern deletion for pattern '{pattern}'" - ) + logger.warning(f"Cache backend does not support pattern deletion for pattern '{pattern}'") except Exception as e: logger.error(f"Error invalidating cache pattern '{pattern}': {e}") - def invalidate_model_cache( - self, model_name: str, instance_id: int | None = None - ): + def invalidate_model_cache(self, model_name: str, instance_id: int | None = None): """Invalidate cache keys related to a specific model""" pattern = f"*{model_name}:{instance_id}*" if instance_id else f"*{model_name}*" @@ -175,11 +167,7 @@ def cache_api_response(timeout=1800, vary_on=None, key_prefix=""): # Generate cache key based on view, user, and parameters cache_key_parts = [ key_prefix or view_func.__name__, - ( - str(request.user.id) - if request.user.is_authenticated - else "anonymous" - ), + (str(request.user.id) if request.user.is_authenticated else "anonymous"), str(hash(frozenset(request.GET.items()))), ] @@ -219,9 +207,7 @@ def cache_queryset_result(cache_key_template: str, timeout: int = 3600): cache_key = cache_key_template.format(*args, **kwargs) cache_service = EnhancedCacheService() - return cache_service.cache_queryset( - cache_key, func, timeout, *args, **kwargs - ) + return cache_service.cache_queryset(cache_key, func, timeout, *args, **kwargs) return wrapper diff --git a/backend/apps/core/services/entity_fuzzy_matching.py b/backend/apps/core/services/entity_fuzzy_matching.py index 607963fa..54ae6290 100644 --- a/backend/apps/core/services/entity_fuzzy_matching.py +++ b/backend/apps/core/services/entity_fuzzy_matching.py @@ -218,17 +218,14 @@ class EntityFuzzyMatcher: return matches[: self.MAX_RESULTS], suggestion - def _get_candidates( - self, query: str, entity_type: EntityType - ) -> list[dict[str, Any]]: + def _get_candidates(self, query: str, entity_type: EntityType) -> list[dict[str, Any]]: """Get potential matching candidates for an entity type.""" candidates = [] if entity_type == EntityType.PARK: - parks = Park.objects.filter( - Q(name__icontains=query) - | Q(slug__icontains=query.lower().replace(" ", "-")) - )[: self.MAX_CANDIDATES] + parks = Park.objects.filter(Q(name__icontains=query) | Q(slug__icontains=query.lower().replace(" ", "-")))[ + : self.MAX_CANDIDATES + ] for park in parks: candidates.append( @@ -265,8 +262,7 @@ class EntityFuzzyMatcher: elif entity_type == EntityType.COMPANY: companies = Company.objects.filter( - Q(name__icontains=query) - | Q(slug__icontains=query.lower().replace(" ", "-")) + Q(name__icontains=query) | Q(slug__icontains=query.lower().replace(" ", "-")) )[: self.MAX_CANDIDATES] for company in companies: @@ -284,9 +280,7 @@ class EntityFuzzyMatcher: return candidates - def _score_and_rank_candidates( - self, query: str, candidates: list[dict[str, Any]] - ) -> list[FuzzyMatchResult]: + def _score_and_rank_candidates(self, query: str, candidates: list[dict[str, Any]]) -> list[FuzzyMatchResult]: """Score and rank all candidates using multiple algorithms.""" scored_matches = [] @@ -354,9 +348,7 @@ class EntityFuzzyMatcher: # Sort by score (highest first) and return return sorted(scored_matches, key=lambda x: x.score, reverse=True) - def _generate_entity_suggestion( - self, query: str, entity_types: list[EntityType], user - ) -> EntitySuggestion: + def _generate_entity_suggestion(self, query: str, entity_types: list[EntityType], user) -> EntitySuggestion: """Generate suggestion for creating new entity when no matches found.""" # Determine most likely entity type based on query characteristics @@ -364,14 +356,9 @@ class EntityFuzzyMatcher: # Simple heuristics for entity type detection query_lower = query.lower() - if any( - word in query_lower - for word in ["roller coaster", "ride", "coaster", "attraction"] - ): + if any(word in query_lower for word in ["roller coaster", "ride", "coaster", "attraction"]): suggested_type = EntityType.RIDE - elif any( - word in query_lower for word in ["inc", "corp", "company", "manufacturer"] - ): + elif any(word in query_lower for word in ["inc", "corp", "company", "manufacturer"]): suggested_type = EntityType.COMPANY elif EntityType.PARK in entity_types: suggested_type = EntityType.PARK @@ -382,21 +369,13 @@ class EntityFuzzyMatcher: suggested_name = " ".join(word.capitalize() for word in query.split()) # Check if user is authenticated - is_authenticated = ( - user and hasattr(user, "is_authenticated") and user.is_authenticated - ) + is_authenticated = user and hasattr(user, "is_authenticated") and user.is_authenticated # Generate appropriate prompts entity_name = suggested_type.value - login_prompt = ( - f"Log in to suggest adding '{suggested_name}' as a new {entity_name}" - ) - signup_prompt = ( - f"Sign up to contribute and add '{suggested_name}' to ThrillWiki" - ) - creation_hint = ( - f"Help expand ThrillWiki by adding information about '{suggested_name}'" - ) + login_prompt = f"Log in to suggest adding '{suggested_name}' as a new {entity_name}" + signup_prompt = f"Sign up to contribute and add '{suggested_name}' to ThrillWiki" + creation_hint = f"Help expand ThrillWiki by adding information about '{suggested_name}'" return EntitySuggestion( suggested_name=suggested_name, diff --git a/backend/apps/core/services/location_adapters.py b/backend/apps/core/services/location_adapters.py index 9ce6b433..02e2faee 100644 --- a/backend/apps/core/services/location_adapters.py +++ b/backend/apps/core/services/location_adapters.py @@ -2,7 +2,6 @@ Location adapters for converting between domain-specific models and UnifiedLocation. """ - from django.db.models import QuerySet from django.urls import reverse @@ -45,15 +44,9 @@ class BaseLocationAdapter: class ParkLocationAdapter(BaseLocationAdapter): """Converts Park/ParkLocation to UnifiedLocation.""" - def to_unified_location( - self, location_obj: ParkLocation - ) -> UnifiedLocation | None: + def to_unified_location(self, location_obj: ParkLocation) -> UnifiedLocation | None: """Convert ParkLocation to UnifiedLocation.""" - if ( - not location_obj.point - or location_obj.latitude is None - or location_obj.longitude is None - ): + if not location_obj.point or location_obj.latitude is None or location_obj.longitude is None: return None park = location_obj.park @@ -67,17 +60,11 @@ class ParkLocationAdapter(BaseLocationAdapter): metadata={ "status": getattr(park, "status", "UNKNOWN"), "rating": ( - float(park.average_rating) - if hasattr(park, "average_rating") and park.average_rating - else None + float(park.average_rating) if hasattr(park, "average_rating") and park.average_rating else None ), "ride_count": getattr(park, "ride_count", 0), "coaster_count": getattr(park, "coaster_count", 0), - "operator": ( - park.operator.name - if hasattr(park, "operator") and park.operator - else None - ), + "operator": (park.operator.name if hasattr(park, "operator") and park.operator else None), "city": location_obj.city, "state": location_obj.state, "country": location_obj.country, @@ -85,18 +72,14 @@ class ParkLocationAdapter(BaseLocationAdapter): type_data={ "slug": park.slug, "opening_date": ( - park.opening_date.isoformat() - if hasattr(park, "opening_date") and park.opening_date - else None + park.opening_date.isoformat() if hasattr(park, "opening_date") and park.opening_date else None ), "website": getattr(park, "website", ""), "operating_season": getattr(park, "operating_season", ""), "highway_exit": location_obj.highway_exit, "parking_notes": location_obj.parking_notes, "best_arrival_time": ( - location_obj.best_arrival_time.strftime("%H:%M") - if location_obj.best_arrival_time - else None + location_obj.best_arrival_time.strftime("%H:%M") if location_obj.best_arrival_time else None ), "seasonal_notes": location_obj.seasonal_notes, "url": self._get_park_url(park), @@ -111,9 +94,7 @@ class ParkLocationAdapter(BaseLocationAdapter): filters: MapFilters | None = None, ) -> QuerySet: """Get optimized queryset for park locations.""" - queryset = ParkLocation.objects.select_related("park", "park__operator").filter( - point__isnull=False - ) + queryset = ParkLocation.objects.select_related("park", "park__operator").filter(point__isnull=False) # Spatial filtering if bounds: @@ -139,17 +120,9 @@ class ParkLocationAdapter(BaseLocationAdapter): weight = 1 if hasattr(park, "ride_count") and park.ride_count and park.ride_count > 20: weight += 2 - if ( - hasattr(park, "coaster_count") - and park.coaster_count - and park.coaster_count > 5 - ): + if hasattr(park, "coaster_count") and park.coaster_count and park.coaster_count > 5: weight += 1 - if ( - hasattr(park, "average_rating") - and park.average_rating - and park.average_rating > 4.0 - ): + if hasattr(park, "average_rating") and park.average_rating and park.average_rating > 4.0: weight += 1 return min(weight, 5) # Cap at 5 @@ -176,15 +149,9 @@ class ParkLocationAdapter(BaseLocationAdapter): class RideLocationAdapter(BaseLocationAdapter): """Converts Ride/RideLocation to UnifiedLocation.""" - def to_unified_location( - self, location_obj: RideLocation - ) -> UnifiedLocation | None: + def to_unified_location(self, location_obj: RideLocation) -> UnifiedLocation | None: """Convert RideLocation to UnifiedLocation.""" - if ( - not location_obj.point - or location_obj.latitude is None - or location_obj.longitude is None - ): + if not location_obj.point or location_obj.latitude is None or location_obj.longitude is None: return None ride = location_obj.ride @@ -194,11 +161,7 @@ class RideLocationAdapter(BaseLocationAdapter): type=LocationType.RIDE, name=ride.name, coordinates=[float(location_obj.latitude), float(location_obj.longitude)], - address=( - f"{location_obj.park_area}, {ride.park.name}" - if location_obj.park_area - else ride.park.name - ), + address=(f"{location_obj.park_area}, {ride.park.name}" if location_obj.park_area else ride.park.name), metadata={ "park_id": ride.park.id, "park_name": ride.park.name, @@ -206,22 +169,16 @@ class RideLocationAdapter(BaseLocationAdapter): "ride_type": getattr(ride, "ride_type", "Unknown"), "status": getattr(ride, "status", "UNKNOWN"), "rating": ( - float(ride.average_rating) - if hasattr(ride, "average_rating") and ride.average_rating - else None + float(ride.average_rating) if hasattr(ride, "average_rating") and ride.average_rating else None ), "manufacturer": ( - getattr(ride, "manufacturer", {}).get("name") - if hasattr(ride, "manufacturer") - else None + getattr(ride, "manufacturer", {}).get("name") if hasattr(ride, "manufacturer") else None ), }, type_data={ "slug": ride.slug, "opening_date": ( - ride.opening_date.isoformat() - if hasattr(ride, "opening_date") and ride.opening_date - else None + ride.opening_date.isoformat() if hasattr(ride, "opening_date") and ride.opening_date else None ), "height_requirement": getattr(ride, "height_requirement", ""), "duration_minutes": getattr(ride, "duration_minutes", None), @@ -240,9 +197,9 @@ class RideLocationAdapter(BaseLocationAdapter): filters: MapFilters | None = None, ) -> QuerySet: """Get optimized queryset for ride locations.""" - queryset = RideLocation.objects.select_related( - "ride", "ride__park", "ride__park__operator" - ).filter(point__isnull=False) + queryset = RideLocation.objects.select_related("ride", "ride__park", "ride__park__operator").filter( + point__isnull=False + ) # Spatial filtering if bounds: @@ -263,11 +220,7 @@ class RideLocationAdapter(BaseLocationAdapter): ride_type = getattr(ride, "ride_type", "").lower() if "coaster" in ride_type or "roller" in ride_type: weight += 1 - if ( - hasattr(ride, "average_rating") - and ride.average_rating - and ride.average_rating > 4.0 - ): + if hasattr(ride, "average_rating") and ride.average_rating and ride.average_rating > 4.0: weight += 1 return min(weight, 3) # Cap at 3 for rides @@ -292,9 +245,7 @@ class RideLocationAdapter(BaseLocationAdapter): class CompanyLocationAdapter(BaseLocationAdapter): """Converts Company/CompanyHeadquarters to UnifiedLocation.""" - def to_unified_location( - self, location_obj: CompanyHeadquarters - ) -> UnifiedLocation | None: + def to_unified_location(self, location_obj: CompanyHeadquarters) -> UnifiedLocation | None: """Convert CompanyHeadquarters to UnifiedLocation.""" # Note: CompanyHeadquarters doesn't have coordinates, so we need to geocode # For now, we'll skip companies without coordinates @@ -312,13 +263,9 @@ class CompanyLocationAdapter(BaseLocationAdapter): # Company-specific filters if filters: if filters.company_roles: - queryset = queryset.filter( - company__roles__overlap=filters.company_roles - ) + queryset = queryset.filter(company__roles__overlap=filters.company_roles) if filters.search_query: - queryset = queryset.filter( - company__name__icontains=filters.search_query - ) + queryset = queryset.filter(company__name__icontains=filters.search_query) if filters.country: queryset = queryset.filter(country=filters.country) if filters.city: @@ -354,11 +301,7 @@ class LocationAbstractionLayer: all_locations = [] # Determine which location types to include - location_types = ( - filters.location_types - if filters and filters.location_types - else set(LocationType) - ) + location_types = filters.location_types if filters and filters.location_types else set(LocationType) for location_type in location_types: adapter = self.adapters[location_type] @@ -379,25 +322,17 @@ class LocationAbstractionLayer: queryset = adapter.get_queryset(bounds, filters) return adapter.bulk_convert(queryset) - def get_location_by_id( - self, location_type: LocationType, location_id: int - ) -> UnifiedLocation | None: + def get_location_by_id(self, location_type: LocationType, location_id: int) -> UnifiedLocation | None: """Get single location with full details.""" adapter = self.adapters[location_type] try: if location_type == LocationType.PARK: - obj = ParkLocation.objects.select_related("park", "park__operator").get( - park_id=location_id - ) + obj = ParkLocation.objects.select_related("park", "park__operator").get(park_id=location_id) elif location_type == LocationType.RIDE: - obj = RideLocation.objects.select_related("ride", "ride__park").get( - ride_id=location_id - ) + obj = RideLocation.objects.select_related("ride", "ride__park").get(ride_id=location_id) elif location_type == LocationType.COMPANY: - obj = CompanyHeadquarters.objects.select_related("company").get( - company_id=location_id - ) + obj = CompanyHeadquarters.objects.select_related("company").get(company_id=location_id) # LocationType.GENERIC removed - generic location app deprecated else: return None diff --git a/backend/apps/core/services/location_search.py b/backend/apps/core/services/location_search.py index a9f05b29..6c831e39 100644 --- a/backend/apps/core/services/location_search.py +++ b/backend/apps/core/services/location_search.py @@ -128,9 +128,7 @@ class LocationSearchService: # Apply max results limit return results[: filters.max_results] - def _search_parks( - self, filters: LocationSearchFilters - ) -> list[LocationSearchResult]: + def _search_parks(self, filters: LocationSearchFilters) -> list[LocationSearchResult]: """Search parks with location data.""" queryset = Park.objects.select_related("location", "operator").all() @@ -154,9 +152,9 @@ class LocationSearchService: # Add distance annotation if proximity search if filters.location_point and filters.include_distance: - queryset = queryset.annotate( - distance=Distance("location__point", filters.location_point) - ).order_by("distance") + queryset = queryset.annotate(distance=Distance("location__point", filters.location_point)).order_by( + "distance" + ) # Convert to search results results = [] @@ -166,11 +164,7 @@ class LocationSearchService: object_id=park.id, name=park.name, description=park.description, - url=( - park.get_absolute_url() - if hasattr(park, "get_absolute_url") - else None - ), + url=(park.get_absolute_url() if hasattr(park, "get_absolute_url") else None), status=park.get_status_display(), rating=(float(park.average_rating) if park.average_rating else None), tags=["park", park.status.lower()], @@ -187,20 +181,14 @@ class LocationSearchService: result.country = location.country # Add distance if proximity search - if ( - filters.location_point - and filters.include_distance - and hasattr(park, "distance") - ): + if filters.location_point and filters.include_distance and hasattr(park, "distance"): result.distance_km = float(park.distance.km) results.append(result) return results - def _search_rides( - self, filters: LocationSearchFilters - ) -> list[LocationSearchResult]: + def _search_rides(self, filters: LocationSearchFilters) -> list[LocationSearchResult]: """Search rides with location data.""" queryset = Ride.objects.select_related("park", "location").all() @@ -223,9 +211,9 @@ class LocationSearchService: # Add distance annotation if proximity search if filters.location_point and filters.include_distance: - queryset = queryset.annotate( - distance=Distance("location__point", filters.location_point) - ).order_by("distance") + queryset = queryset.annotate(distance=Distance("location__point", filters.location_point)).order_by( + "distance" + ) # Convert to search results results = [] @@ -235,11 +223,7 @@ class LocationSearchService: object_id=ride.id, name=ride.name, description=ride.description, - url=( - ride.get_absolute_url() - if hasattr(ride, "get_absolute_url") - else None - ), + url=(ride.get_absolute_url() if hasattr(ride, "get_absolute_url") else None), status=ride.status, tags=[ "ride", @@ -253,18 +237,10 @@ class LocationSearchService: location = ride.location result.latitude = location.latitude result.longitude = location.longitude - result.address = ( - f"{ride.park.name} - {location.park_area}" - if location.park_area - else ride.park.name - ) + result.address = f"{ride.park.name} - {location.park_area}" if location.park_area else ride.park.name # Add distance if proximity search - if ( - filters.location_point - and filters.include_distance - and hasattr(ride, "distance") - ): + if filters.location_point and filters.include_distance and hasattr(ride, "distance"): result.distance_km = float(ride.distance.km) # Fall back to park location if no specific ride location @@ -281,16 +257,12 @@ class LocationSearchService: return results - def _search_companies( - self, filters: LocationSearchFilters - ) -> list[LocationSearchResult]: + def _search_companies(self, filters: LocationSearchFilters) -> list[LocationSearchResult]: """Search companies with headquarters location data.""" queryset = Company.objects.select_related("headquarters").all() # Apply location filters - queryset = self._apply_location_filters( - queryset, filters, "headquarters__point" - ) + queryset = self._apply_location_filters(queryset, filters, "headquarters__point") # Apply text search if filters.search_query: @@ -309,9 +281,9 @@ class LocationSearchService: # Add distance annotation if proximity search if filters.location_point and filters.include_distance: - queryset = queryset.annotate( - distance=Distance("headquarters__point", filters.location_point) - ).order_by("distance") + queryset = queryset.annotate(distance=Distance("headquarters__point", filters.location_point)).order_by( + "distance" + ) # Convert to search results results = [] @@ -321,11 +293,7 @@ class LocationSearchService: object_id=company.id, name=company.name, description=company.description, - url=( - company.get_absolute_url() - if hasattr(company, "get_absolute_url") - else None - ), + url=(company.get_absolute_url() if hasattr(company, "get_absolute_url") else None), tags=["company"] + (company.roles or []), ) @@ -340,20 +308,14 @@ class LocationSearchService: result.country = hq.country # Add distance if proximity search - if ( - filters.location_point - and filters.include_distance - and hasattr(company, "distance") - ): + if filters.location_point and filters.include_distance and hasattr(company, "distance"): result.distance_km = float(company.distance.km) results.append(result) return results - def _apply_location_filters( - self, queryset, filters: LocationSearchFilters, point_field: str - ): + def _apply_location_filters(self, queryset, filters: LocationSearchFilters, point_field: str): """Apply common location filters to a queryset.""" # Proximity filter @@ -371,31 +333,21 @@ class LocationSearchService: # Geographic filters - adjust field names based on model if filters.country: if "headquarters" in point_field: - queryset = queryset.filter( - headquarters__country__icontains=filters.country - ) + queryset = queryset.filter(headquarters__country__icontains=filters.country) else: location_field = point_field.split("__")[0] - queryset = queryset.filter( - **{f"{location_field}__country__icontains": filters.country} - ) + queryset = queryset.filter(**{f"{location_field}__country__icontains": filters.country}) if filters.state: if "headquarters" in point_field: - queryset = queryset.filter( - headquarters__state_province__icontains=filters.state - ) + queryset = queryset.filter(headquarters__state_province__icontains=filters.state) else: location_field = point_field.split("__")[0] - queryset = queryset.filter( - **{f"{location_field}__state__icontains": filters.state} - ) + queryset = queryset.filter(**{f"{location_field}__state__icontains": filters.state}) if filters.city: location_field = point_field.split("__")[0] - queryset = queryset.filter( - **{f"{location_field}__city__icontains": filters.city} - ) + queryset = queryset.filter(**{f"{location_field}__city__icontains": filters.city}) return queryset @@ -417,9 +369,7 @@ class LocationSearchService: # Get park location suggestions park_locations = ParkLocation.objects.filter( - Q(park__name__icontains=query) - | Q(city__icontains=query) - | Q(state__icontains=query) + Q(park__name__icontains=query) | Q(city__icontains=query) | Q(state__icontains=query) ).select_related("park")[: limit // 3] for location in park_locations: @@ -429,11 +379,7 @@ class LocationSearchService: "name": location.park.name, "address": location.formatted_address, "coordinates": location.coordinates, - "url": ( - location.park.get_absolute_url() - if hasattr(location.park, "get_absolute_url") - else None - ), + "url": (location.park.get_absolute_url() if hasattr(location.park, "get_absolute_url") else None), } ) diff --git a/backend/apps/core/services/map_cache_service.py b/backend/apps/core/services/map_cache_service.py index 61e57e8f..08f687ef 100644 --- a/backend/apps/core/services/map_cache_service.py +++ b/backend/apps/core/services/map_cache_service.py @@ -93,9 +93,7 @@ class MapCacheService: return ":".join(key_parts) - def get_location_detail_cache_key( - self, location_type: str, location_id: int - ) -> str: + def get_location_detail_cache_key(self, location_type: str, location_id: int) -> str: """Generate cache key for individual location details.""" return f"{self.DETAIL_PREFIX}:{location_type}:{location_id}" @@ -137,9 +135,7 @@ class MapCacheService: except Exception as e: print(f"Cache write error for clusters {cache_key}: {e}") - def cache_map_response( - self, cache_key: str, response: MapResponse, ttl: int | None = None - ) -> None: + def cache_map_response(self, cache_key: str, response: MapResponse, ttl: int | None = None) -> None: """Cache complete map response.""" try: cache_data = response.to_dict() @@ -212,24 +208,18 @@ class MapCacheService: self.cache_stats["misses"] += 1 return None - def invalidate_location_cache( - self, location_type: str, location_id: int | None = None - ) -> None: + def invalidate_location_cache(self, location_type: str, location_id: int | None = None) -> None: """Invalidate cache for specific location or all locations of a type.""" try: if location_id: # Invalidate specific location detail - detail_key = self.get_location_detail_cache_key( - location_type, location_id - ) + detail_key = self.get_location_detail_cache_key(location_type, location_id) cache.delete(detail_key) # Invalidate related location and cluster caches # In a production system, you'd want more sophisticated cache # tagging - cache.delete_many( - [f"{self.LOCATIONS_PREFIX}:*", f"{self.CLUSTERS_PREFIX}:*"] - ) + cache.delete_many([f"{self.LOCATIONS_PREFIX}:*", f"{self.CLUSTERS_PREFIX}:*"]) self.cache_stats["invalidations"] += 1 @@ -271,11 +261,7 @@ class MapCacheService: def get_cache_stats(self) -> dict[str, Any]: """Get cache performance statistics.""" total_requests = self.cache_stats["hits"] + self.cache_stats["misses"] - hit_rate = ( - (self.cache_stats["hits"] / total_requests * 100) - if total_requests > 0 - else 0 - ) + hit_rate = (self.cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0 return { "hits": self.cache_stats["hits"], @@ -408,12 +394,8 @@ class MapCacheService: def _dict_to_map_response(self, data: dict[str, Any]) -> MapResponse: """Convert dictionary back to MapResponse object.""" - locations = [ - self._dict_to_unified_location(loc) for loc in data.get("locations", []) - ] - clusters = [ - self._dict_to_cluster_data(cluster) for cluster in data.get("clusters", []) - ] + locations = [self._dict_to_unified_location(loc) for loc in data.get("locations", [])] + clusters = [self._dict_to_cluster_data(cluster) for cluster in data.get("clusters", [])] bounds = None if data.get("bounds"): diff --git a/backend/apps/core/services/map_service.py b/backend/apps/core/services/map_service.py index 50cbf473..5de767e3 100644 --- a/backend/apps/core/services/map_service.py +++ b/backend/apps/core/services/map_service.py @@ -67,17 +67,13 @@ class UnifiedMapService: # Generate cache key cache_key = None if use_cache: - cache_key = self._generate_cache_key( - bounds, filters, zoom_level, cluster - ) + cache_key = self._generate_cache_key(bounds, filters, zoom_level, cluster) # Try to get from cache first cached_response = self.cache_service.get_cached_map_response(cache_key) if cached_response: cached_response.cache_hit = True - cached_response.query_time_ms = int( - (time.time() - start_time) * 1000 - ) + cached_response.query_time_ms = int((time.time() - start_time) * 1000) return cached_response # Get locations from database @@ -87,21 +83,15 @@ class UnifiedMapService: locations = self._apply_smart_limiting(locations, bounds, zoom_level) # Determine if clustering should be applied - should_cluster = cluster and self.clustering_service.should_cluster( - zoom_level, len(locations) - ) + should_cluster = cluster and self.clustering_service.should_cluster(zoom_level, len(locations)) # Apply clustering if needed clusters = [] if should_cluster: - locations, clusters = self.clustering_service.cluster_locations( - locations, zoom_level, bounds - ) + locations, clusters = self.clustering_service.cluster_locations(locations, zoom_level, bounds) # Calculate response bounds - response_bounds = self._calculate_response_bounds( - locations, clusters, bounds - ) + response_bounds = self._calculate_response_bounds(locations, clusters, bounds) # Create response response = MapResponse( @@ -144,9 +134,7 @@ class UnifiedMapService: cache_hit=False, ) - def get_location_details( - self, location_type: str, location_id: int - ) -> UnifiedLocation | None: + def get_location_details(self, location_type: str, location_id: int) -> UnifiedLocation | None: """ Get detailed information for a specific location. @@ -159,18 +147,14 @@ class UnifiedMapService: """ try: # Check cache first - cache_key = self.cache_service.get_location_detail_cache_key( - location_type, location_id - ) + cache_key = self.cache_service.get_location_detail_cache_key(location_type, location_id) cached_locations = self.cache_service.get_cached_locations(cache_key) if cached_locations: return cached_locations[0] if cached_locations else None # Get from database location_type_enum = LocationType(location_type.lower()) - location = self.location_layer.get_location_by_id( - location_type_enum, location_id - ) + location = self.location_layer.get_location_by_id(location_type_enum, location_id) # Cache the result if location: @@ -245,19 +229,13 @@ class UnifiedMapService: """ try: bounds = GeoBounds(north=north, south=south, east=east, west=west) - filters = ( - MapFilters(location_types=location_types) if location_types else None - ) + filters = MapFilters(location_types=location_types) if location_types else None - return self.get_map_data( - bounds=bounds, filters=filters, zoom_level=zoom_level - ) + return self.get_map_data(bounds=bounds, filters=filters, zoom_level=zoom_level) except ValueError: # Invalid bounds - return MapResponse( - locations=[], clusters=[], total_count=0, filtered_count=0 - ) + return MapResponse(locations=[], clusters=[], total_count=0, filtered_count=0) def get_clustered_locations( self, @@ -276,9 +254,7 @@ class UnifiedMapService: Returns: MapResponse with clustered data """ - return self.get_map_data( - bounds=bounds, filters=filters, zoom_level=zoom_level, cluster=True - ) + return self.get_map_data(bounds=bounds, filters=filters, zoom_level=zoom_level, cluster=True) def get_locations_by_type( self, @@ -299,9 +275,7 @@ class UnifiedMapService: """ try: filters = MapFilters(location_types={location_type}) - locations = self.location_layer.get_locations_by_type( - location_type, bounds, filters - ) + locations = self.location_layer.get_locations_by_type(location_type, bounds, filters) if limit: locations = locations[:limit] @@ -346,9 +320,7 @@ class UnifiedMapService: "service_version": "1.0.0", } - def _get_locations_from_db( - self, bounds: GeoBounds | None, filters: MapFilters | None - ) -> list[UnifiedLocation]: + def _get_locations_from_db(self, bounds: GeoBounds | None, filters: MapFilters | None) -> list[UnifiedLocation]: """Get locations from database using the abstraction layer.""" return self.location_layer.get_all_locations(bounds, filters) @@ -363,10 +335,7 @@ class UnifiedMapService: major_parks = [ loc for loc in locations - if ( - loc.type == LocationType.PARK - and loc.cluster_category in ["major_park", "theme_park"] - ) + if (loc.type == LocationType.PARK and loc.cluster_category in ["major_park", "theme_park"]) ] return major_parks[:200] elif zoom_level < 10: # Regional level @@ -398,9 +367,7 @@ class UnifiedMapService: return None lats, lngs = zip(*all_coords, strict=False) - return GeoBounds( - north=max(lats), south=min(lats), east=max(lngs), west=min(lngs) - ) + return GeoBounds(north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)) def _get_applied_filters_list(self, filters: MapFilters | None) -> list[str]: """Get list of applied filter types for metadata.""" @@ -438,13 +405,9 @@ class UnifiedMapService: ) -> str: """Generate cache key for the request.""" if cluster: - return self.cache_service.get_clusters_cache_key( - bounds, filters, zoom_level - ) + return self.cache_service.get_clusters_cache_key(bounds, filters, zoom_level) else: - return self.cache_service.get_locations_cache_key( - bounds, filters, zoom_level - ) + return self.cache_service.get_locations_cache_key(bounds, filters, zoom_level) def _record_performance_metrics( self, diff --git a/backend/apps/core/services/media_service.py b/backend/apps/core/services/media_service.py index 8b198468..73ed37b1 100644 --- a/backend/apps/core/services/media_service.py +++ b/backend/apps/core/services/media_service.py @@ -21,9 +21,7 @@ class MediaService: """Shared service for media upload and processing operations.""" @staticmethod - def generate_upload_path( - domain: str, identifier: str, filename: str, subdirectory: str | None = None - ) -> str: + def generate_upload_path(domain: str, identifier: str, filename: str, subdirectory: str | None = None) -> str: """ Generate standardized upload path for media files. @@ -83,9 +81,7 @@ class MediaService: """ try: # Check file size - max_size = getattr( - settings, "MAX_PHOTO_SIZE", 10 * 1024 * 1024 - ) # 10MB default + max_size = getattr(settings, "MAX_PHOTO_SIZE", 10 * 1024 * 1024) # 10MB default if image_file.size > max_size: return ( False, diff --git a/backend/apps/core/services/media_url_service.py b/backend/apps/core/services/media_url_service.py index 6c8b0c75..9e2a32f4 100644 --- a/backend/apps/core/services/media_url_service.py +++ b/backend/apps/core/services/media_url_service.py @@ -32,7 +32,7 @@ class MediaURLService: slug = slugify(caption) # Limit length to avoid overly long URLs if len(slug) > 50: - slug = slug[:50].rsplit('-', 1)[0] # Cut at word boundary + slug = slug[:50].rsplit("-", 1)[0] # Cut at word boundary return f"{slug}-{photo_id}.{extension}" else: return f"photo-{photo_id}.{extension}" @@ -55,13 +55,15 @@ class MediaURLService: # Add variant to filename if not public if variant != "public": - name, ext = filename.rsplit('.', 1) + name, ext = filename.rsplit(".", 1) filename = f"{name}-{variant}.{ext}" return f"/parks/{park_slug}/photos/{filename}" @staticmethod - def generate_ride_photo_url(park_slug: str, ride_slug: str, caption: str, photo_id: int, variant: str = "public") -> str: + def generate_ride_photo_url( + park_slug: str, ride_slug: str, caption: str, photo_id: int, variant: str = "public" + ) -> str: """ Generate a friendly URL for a ride photo. @@ -78,7 +80,7 @@ class MediaURLService: filename = MediaURLService.generate_friendly_filename(caption, photo_id) if variant != "public": - name, ext = filename.rsplit('.', 1) + name, ext = filename.rsplit(".", 1) filename = f"{name}-{variant}.{ext}" return f"/parks/{park_slug}/rides/{ride_slug}/photos/{filename}" @@ -95,7 +97,7 @@ class MediaURLService: Dict with photo_id and variant, or None if parsing fails """ # Remove extension - name = filename.rsplit('.', 1)[0] + name = filename.rsplit(".", 1)[0] # Check for variant suffix variant = "public" @@ -104,17 +106,14 @@ class MediaURLService: for v in variant_patterns: if name.endswith(f"-{v}"): variant = v - name = name[:-len(f"-{v}")] + name = name[: -len(f"-{v}")] break # Extract photo ID (should be the last number) - match = re.search(r'-(\d+)$', name) + match = re.search(r"-(\d+)$", name) if match: photo_id = int(match.group(1)) - return { - "photo_id": photo_id, - "variant": variant - } + return {"photo_id": photo_id, "variant": variant} return None diff --git a/backend/apps/core/services/performance_monitoring.py b/backend/apps/core/services/performance_monitoring.py index 0a390b92..60379872 100644 --- a/backend/apps/core/services/performance_monitoring.py +++ b/backend/apps/core/services/performance_monitoring.py @@ -53,9 +53,7 @@ def monitor_performance(operation_name: str, **tags): ) # Log performance data - log_level = ( - logging.WARNING if duration > 2.0 or total_queries > 10 else logging.INFO - ) + log_level = logging.WARNING if duration > 2.0 or total_queries > 10 else logging.INFO logger.log( log_level, f"Performance: {operation_name} completed in {duration:.3f}s with { @@ -108,11 +106,7 @@ def track_queries(operation_name: str, warn_threshold: int = 10): recent_queries = connection.queries[-total_queries:] query_details = [ { - "sql": ( - query["sql"][:200] + "..." - if len(query["sql"]) > 200 - else query["sql"] - ), + "sql": (query["sql"][:200] + "..." if len(query["sql"]) > 200 else query["sql"]), "time": float(query["time"]), } for query in recent_queries @@ -127,14 +121,12 @@ def track_queries(operation_name: str, warn_threshold: int = 10): if total_queries > warn_threshold or execution_time > 1.0: logger.warning( - f"Performance concern in {operation_name}: " - f"{total_queries} queries, {execution_time:.2f}s", + f"Performance concern in {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s", extra=performance_data, ) else: logger.debug( - f"Query tracking for {operation_name}: " - f"{total_queries} queries, {execution_time:.2f}s", + f"Query tracking for {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s", extra=performance_data, ) @@ -221,9 +213,7 @@ class PerformanceProfiler: "total_queries": total_queries, "checkpoints": self.checkpoints, "memory_usage": self.memory_usage, - "queries_per_second": ( - total_queries / total_duration if total_duration > 0 else 0 - ), + "queries_per_second": (total_queries / total_duration if total_duration > 0 else 0), } # Calculate checkpoint intervals @@ -237,8 +227,7 @@ class PerformanceProfiler: "from": prev["name"], "to": curr["name"], "duration": curr["elapsed_seconds"] - prev["elapsed_seconds"], - "queries": curr["queries_since_start"] - - prev["queries_since_start"], + "queries": curr["queries_since_start"] - prev["queries_since_start"], } ) report["checkpoint_intervals"] = intervals @@ -288,9 +277,7 @@ class DatabaseQueryAnalyzer: query_types[query_type] = query_types.get(query_type, 0) + 1 # Find slow queries (top 10% by time) - sorted_queries = sorted( - queries, key=lambda q: float(q.get("time", 0)), reverse=True - ) + sorted_queries = sorted(queries, key=lambda q: float(q.get("time", 0)), reverse=True) slow_query_count = max(1, query_count // 10) slow_queries = sorted_queries[:slow_query_count] @@ -302,9 +289,7 @@ class DatabaseQueryAnalyzer: signature = " ".join(sql.split()) # Normalize whitespace query_signatures[signature] = query_signatures.get(signature, 0) + 1 - duplicates = { - sig: count for sig, count in query_signatures.items() if count > 1 - } + duplicates = {sig: count for sig, count in query_signatures.items() if count > 1} analysis = { "total_queries": query_count, @@ -313,21 +298,13 @@ class DatabaseQueryAnalyzer: "query_types": query_types, "slow_queries": [ { - "sql": ( - q.get("sql", "")[:200] + "..." - if len(q.get("sql", "")) > 200 - else q.get("sql", "") - ), + "sql": (q.get("sql", "")[:200] + "..." if len(q.get("sql", "")) > 200 else q.get("sql", "")), "time": float(q.get("time", 0)), } for q in slow_queries ], "duplicate_query_count": len(duplicates), - "duplicate_queries": ( - duplicates - if len(duplicates) <= 10 - else dict(list(duplicates.items())[:10]) - ), + "duplicate_queries": (duplicates if len(duplicates) <= 10 else dict(list(duplicates.items())[:10])), } return analysis @@ -348,9 +325,7 @@ def monitor_function_performance(operation_name: str | None = None): @wraps(func) def wrapper(*args, **kwargs): name = operation_name or f"{func.__module__}.{func.__name__}" - with monitor_performance( - name, function=func.__name__, module=func.__module__ - ): + with monitor_performance(name, function=func.__name__, module=func.__module__): return func(*args, **kwargs) return wrapper diff --git a/backend/apps/core/services/trending_service.py b/backend/apps/core/services/trending_service.py index 6607a2bd..efc5789d 100644 --- a/backend/apps/core/services/trending_service.py +++ b/backend/apps/core/services/trending_service.py @@ -74,9 +74,7 @@ class TrendingService: if not force_refresh: cached_result = cache.get(cache_key) if cached_result is not None: - self.logger.debug( - f"Returning cached trending results for {content_type}" - ) + self.logger.debug(f"Returning cached trending results for {content_type}") return cached_result self.logger.info(f"Getting trending content for {content_type}") @@ -86,15 +84,11 @@ class TrendingService: trending_items = [] if content_type in ["all", "parks"]: - park_items = self._calculate_trending_parks( - limit * 2 if content_type == "all" else limit - ) + park_items = self._calculate_trending_parks(limit * 2 if content_type == "all" else limit) trending_items.extend(park_items) if content_type in ["all", "rides"]: - ride_items = self._calculate_trending_rides( - limit * 2 if content_type == "all" else limit - ) + ride_items = self._calculate_trending_rides(limit * 2 if content_type == "all" else limit) trending_items.extend(ride_items) # Sort by trending score and apply limit @@ -107,9 +101,7 @@ class TrendingService: # Cache results cache.set(cache_key, formatted_results, self.CACHE_TTL) - self.logger.info( - f"Calculated {len(formatted_results)} trending items for {content_type}" - ) + self.logger.info(f"Calculated {len(formatted_results)} trending items for {content_type}") return formatted_results except Exception as e: @@ -140,9 +132,7 @@ class TrendingService: if not force_refresh: cached_result = cache.get(cache_key) if cached_result is not None: - self.logger.debug( - f"Returning cached new content results for {content_type}" - ) + self.logger.debug(f"Returning cached new content results for {content_type}") return cached_result self.logger.info(f"Getting new content for {content_type}") @@ -153,15 +143,11 @@ class TrendingService: new_items = [] if content_type in ["all", "parks"]: - parks = self._get_new_parks( - cutoff_date, limit * 2 if content_type == "all" else limit - ) + parks = self._get_new_parks(cutoff_date, limit * 2 if content_type == "all" else limit) new_items.extend(parks) if content_type in ["all", "rides"]: - rides = self._get_new_rides( - cutoff_date, limit * 2 if content_type == "all" else limit - ) + rides = self._get_new_rides(cutoff_date, limit * 2 if content_type == "all" else limit) new_items.extend(rides) # Sort by date added (most recent first) and apply limit @@ -174,9 +160,7 @@ class TrendingService: # Cache results cache.set(cache_key, formatted_results, 1800) # Cache for 30 minutes - self.logger.info( - f"Calculated {len(formatted_results)} new items for {content_type}" - ) + self.logger.info(f"Calculated {len(formatted_results)} new items for {content_type}") return formatted_results except Exception as e: @@ -185,9 +169,7 @@ class TrendingService: def _calculate_trending_parks(self, limit: int) -> list[dict[str, Any]]: """Calculate trending scores for parks.""" - parks = Park.objects.filter(status="OPERATING").select_related( - "location", "operator", "card_image" - ) + parks = Park.objects.filter(status="OPERATING").select_related("location", "operator", "card_image") trending_parks = [] @@ -216,9 +198,7 @@ class TrendingService: # Get card image URL card_image_url = "" if park.card_image and hasattr(park.card_image, "image"): - card_image_url = ( - park.card_image.image.url if park.card_image.image else "" - ) + card_image_url = park.card_image.image.url if park.card_image.image else "" # Get primary company (operator) primary_company = park.operator.name if park.operator else "" @@ -233,14 +213,8 @@ class TrendingService: "slug": park.slug, "park": park.name, # For parks, park field is the park name itself "category": "park", - "rating": ( - float(park.average_rating) - if park.average_rating - else 0.0 - ), - "date_opened": ( - opening_date.isoformat() if opening_date else "" - ), + "rating": (float(park.average_rating) if park.average_rating else 0.0), + "date_opened": (opening_date.isoformat() if opening_date else ""), "url": park.url, "card_image": card_image_url, "city": city, @@ -256,9 +230,7 @@ class TrendingService: def _calculate_trending_rides(self, limit: int) -> list[dict[str, Any]]: """Calculate trending scores for rides.""" - rides = Ride.objects.filter(status="OPERATING").select_related( - "park", "park__location", "card_image" - ) + rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location", "card_image") trending_rides = [] @@ -274,9 +246,7 @@ class TrendingService: # Get card image URL card_image_url = "" if ride.card_image and hasattr(ride.card_image, "image"): - card_image_url = ( - ride.card_image.image.url if ride.card_image.image else "" - ) + card_image_url = ride.card_image.image.url if ride.card_image.image else "" trending_rides.append( { @@ -288,14 +258,8 @@ class TrendingService: "slug": ride.slug, "park": ride.park.name if ride.park else "", "category": "ride", - "rating": ( - float(ride.average_rating) - if ride.average_rating - else 0.0 - ), - "date_opened": ( - opening_date.isoformat() if opening_date else "" - ), + "rating": (float(ride.average_rating) if ride.average_rating else 0.0), + "date_opened": (opening_date.isoformat() if opening_date else ""), "url": ride.url, "park_url": ride.park.url if ride.park else "", "card_image": card_image_url, @@ -347,23 +311,17 @@ class TrendingService: return final_score except Exception as e: - self.logger.error( - f"Error calculating score for {content_type} {content_obj.id}: {e}" - ) + self.logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}") return 0.0 - def _calculate_view_growth_score( - self, content_type: ContentType, object_id: int - ) -> float: + def _calculate_view_growth_score(self, content_type: ContentType, object_id: int) -> float: """Calculate normalized view growth score.""" try: - current_views, previous_views, growth_percentage = ( - PageView.get_views_growth( - content_type, - object_id, - self.CURRENT_PERIOD_HOURS, - self.PREVIOUS_PERIOD_HOURS, - ) + current_views, previous_views, growth_percentage = PageView.get_views_growth( + content_type, + object_id, + self.CURRENT_PERIOD_HOURS, + self.PREVIOUS_PERIOD_HOURS, ) if previous_views == 0: @@ -372,9 +330,7 @@ class TrendingService: # Normalize growth percentage to 0-1 scale # 100% growth = 0.5, 500% growth = 1.0 - normalized_growth = ( - min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0 - ) + normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0 return max(normalized_growth, 0.0) except Exception as e: @@ -421,11 +377,7 @@ class TrendingService: elif days_since_added <= 30: return 1.0 - (days_since_added / 30.0) * 0.2 # 1.0 to 0.8 elif days_since_added <= self.RECENCY_BASELINE_DAYS: - return ( - 0.8 - - ((days_since_added - 30) / (self.RECENCY_BASELINE_DAYS - 30)) - * 0.7 - ) # 0.8 to 0.1 + return 0.8 - ((days_since_added - 30) / (self.RECENCY_BASELINE_DAYS - 30)) * 0.7 # 0.8 to 0.1 else: return 0.0 @@ -433,9 +385,7 @@ class TrendingService: self.logger.warning(f"Error calculating recency score: {e}") return 0.5 - def _calculate_popularity_score( - self, content_type: ContentType, object_id: int - ) -> float: + def _calculate_popularity_score(self, content_type: ContentType, object_id: int) -> float: """Calculate popularity score based on total view count.""" try: total_views = PageView.get_total_views_count( @@ -461,8 +411,7 @@ class TrendingService: """Get recently added parks.""" new_parks = ( Park.objects.filter( - Q(created_at__gte=cutoff_date) - | Q(opening_date__gte=cutoff_date.date()), + Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()), status="OPERATING", ) .select_related("location", "operator", "card_image") @@ -473,7 +422,7 @@ class TrendingService: for park in new_parks: date_added = park.opening_date or park.created_at # Handle datetime to date conversion - if date_added: + if date_added: # noqa: SIM102 # If it's a datetime, convert to date if isinstance(date_added, datetime): date_added = date_added.date() @@ -500,9 +449,7 @@ class TrendingService: # Get card image URL card_image_url = "" if park.card_image and hasattr(park.card_image, "image"): - card_image_url = ( - park.card_image.image.url if park.card_image.image else "" - ) + card_image_url = park.card_image.image.url if park.card_image.image else "" # Get primary company (operator) primary_company = park.operator.name if park.operator else "" @@ -533,8 +480,7 @@ class TrendingService: """Get recently added rides.""" new_rides = ( Ride.objects.filter( - Q(created_at__gte=cutoff_date) - | Q(opening_date__gte=cutoff_date.date()), + Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()), status="OPERATING", ) .select_related("park", "park__location", "card_image") @@ -543,11 +489,9 @@ class TrendingService: results = [] for ride in new_rides: - date_added = getattr(ride, "opening_date", None) or getattr( - ride, "created_at", None - ) + date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None) # Handle datetime to date conversion - if date_added: + if date_added: # noqa: SIM102 # If it's a datetime, convert to date if isinstance(date_added, datetime): date_added = date_added.date() @@ -561,9 +505,7 @@ class TrendingService: # Get card image URL card_image_url = "" if ride.card_image and hasattr(ride.card_image, "image"): - card_image_url = ( - ride.card_image.image.url if ride.card_image.image else "" - ) + card_image_url = ride.card_image.image.url if ride.card_image.image else "" results.append( { @@ -584,9 +526,7 @@ class TrendingService: return results - def _format_trending_results( - self, trending_items: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + def _format_trending_results(self, trending_items: list[dict[str, Any]]) -> list[dict[str, Any]]: """Format trending results for frontend consumption.""" formatted_results = [] @@ -595,13 +535,11 @@ class TrendingService: # Get view change for display content_obj = item["content_object"] ct = ContentType.objects.get_for_model(content_obj) - current_views, previous_views, growth_percentage = ( - PageView.get_views_growth( - ct, - content_obj.id, - self.CURRENT_PERIOD_HOURS, - self.PREVIOUS_PERIOD_HOURS, - ) + current_views, previous_views, growth_percentage = PageView.get_views_growth( + ct, + content_obj.id, + self.CURRENT_PERIOD_HOURS, + self.PREVIOUS_PERIOD_HOURS, ) # Format exactly as frontend expects @@ -614,9 +552,7 @@ class TrendingService: "rank": rank, "views": current_views, "views_change": ( - f"+{growth_percentage:.1f}%" - if growth_percentage > 0 - else f"{growth_percentage:.1f}%" + f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%" ), "slug": item["slug"], "date_opened": item["date_opened"], @@ -649,9 +585,7 @@ class TrendingService: return formatted_results - def _format_new_content_results( - self, new_items: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + def _format_new_content_results(self, new_items: list[dict[str, Any]]) -> list[dict[str, Any]]: """Format new content results for frontend consumption.""" formatted_results = [] diff --git a/backend/apps/core/state_machine/__init__.py b/backend/apps/core/state_machine/__init__.py index e11bc327..a02488a1 100644 --- a/backend/apps/core/state_machine/__init__.py +++ b/backend/apps/core/state_machine/__init__.py @@ -1,4 +1,5 @@ """State machine utilities for core app.""" + from .builder import ( StateTransitionBuilder, determine_method_name_for_transition, diff --git a/backend/apps/core/state_machine/builder.py b/backend/apps/core/state_machine/builder.py index 0fc8a714..a337af15 100644 --- a/backend/apps/core/state_machine/builder.py +++ b/backend/apps/core/state_machine/builder.py @@ -60,6 +60,7 @@ See Also: - apps.core.choices.registry: Central choice registry - apps.core.state_machine.guards: Guard extraction from metadata """ + from typing import Any from django.core.exceptions import ImproperlyConfigured @@ -129,9 +130,7 @@ class StateTransitionBuilder: # Validate choice group exists group = registry.get(choice_group, domain) if group is None: - raise ImproperlyConfigured( - f"Choice group '{choice_group}' not found in domain '{domain}'" - ) + raise ImproperlyConfigured(f"Choice group '{choice_group}' not found in domain '{domain}'") self.choices = registry.get_choices(choice_group, domain) @@ -172,20 +171,15 @@ class StateTransitionBuilder: # Validate all target states exist for target in transitions: - target_choice = registry.get_choice( - self.choice_group, target, self.domain - ) + target_choice = registry.get_choice(self.choice_group, target, self.domain) if target_choice is None: raise ImproperlyConfigured( - f"State '{state_value}' references non-existent " - f"transition target '{target}'" + f"State '{state_value}' references non-existent " f"transition target '{target}'" ) return transitions - def extract_permission_requirements( - self, state_value: str - ) -> dict[str, bool]: + def extract_permission_requirements(self, state_value: str) -> dict[str, bool]: """ Extract permission requirements from metadata. @@ -198,9 +192,7 @@ class StateTransitionBuilder: metadata = self.get_choice_metadata(state_value) return { "requires_moderator": metadata.get("requires_moderator", False), - "requires_admin_approval": metadata.get( - "requires_admin_approval", False - ), + "requires_admin_approval": metadata.get("requires_admin_approval", False), } def is_terminal_state(self, state_value: str) -> bool: diff --git a/backend/apps/core/state_machine/callback_base.py b/backend/apps/core/state_machine/callback_base.py index 6f5db649..6db5c90a 100644 --- a/backend/apps/core/state_machine/callback_base.py +++ b/backend/apps/core/state_machine/callback_base.py @@ -181,10 +181,7 @@ class TransitionContext: return self.model_class.__name__ def __str__(self) -> str: - return ( - f"TransitionContext({self.model_name}.{self.field_name}: " - f"{self.source_state} → {self.target_state})" - ) + return f"TransitionContext({self.model_name}.{self.field_name}: " f"{self.source_state} → {self.target_state})" class BaseTransitionCallback(ABC): @@ -324,9 +321,9 @@ class CallbackRegistration: return False if self.field_name != field_name: return False - if self.source != '*' and self.source != source: + if self.source != "*" and self.source != source: return False - return not (self.target != '*' and self.target != target) + return not (self.target != "*" and self.target != target) class TransitionCallbackRegistry: @@ -337,10 +334,10 @@ class TransitionCallbackRegistry: for specific transitions. """ - _instance: Optional['TransitionCallbackRegistry'] = None + _instance: Optional["TransitionCallbackRegistry"] = None _initialized: bool = False - def __new__(cls) -> 'TransitionCallbackRegistry': + def __new__(cls) -> "TransitionCallbackRegistry": if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @@ -483,10 +480,7 @@ class TransitionCallbackRegistry: try: # Check if callback should execute if not callback.should_execute(context): - logger.debug( - f"Skipping callback {callback.name} - " - f"should_execute returned False" - ) + logger.debug(f"Skipping callback {callback.name} - " f"should_execute returned False") continue # Execute callback @@ -498,30 +492,24 @@ class TransitionCallbackRegistry: result = callback.execute(context) if not result: - logger.warning( - f"Callback {callback.name} returned False for {context}" - ) + logger.warning(f"Callback {callback.name} returned False for {context}") failures.append((callback, None)) overall_success = False if not callback.continue_on_error: logger.error( - f"Aborting callback chain - {callback.name} failed " - f"and continue_on_error=False" + f"Aborting callback chain - {callback.name} failed " f"and continue_on_error=False" ) break except Exception as e: - logger.exception( - f"Callback {callback.name} raised exception for {context}: {e}" - ) + logger.exception(f"Callback {callback.name} raised exception for {context}: {e}") failures.append((callback, e)) overall_success = False if not callback.continue_on_error: logger.error( - f"Aborting callback chain - {callback.name} raised exception " - f"and continue_on_error=False" + f"Aborting callback chain - {callback.name} raised exception " f"and continue_on_error=False" ) break @@ -540,10 +528,7 @@ class TransitionCallbackRegistry: self._callbacks[stage] = [] else: for stage in CallbackStage: - self._callbacks[stage] = [ - r for r in self._callbacks[stage] - if r.model_class != model_class - ] + self._callbacks[stage] = [r for r in self._callbacks[stage] if r.model_class != model_class] def get_all_registrations( self, @@ -563,10 +548,7 @@ class TransitionCallbackRegistry: filtered = {} for stage, registrations in self._callbacks.items(): - filtered[stage] = [ - r for r in registrations - if r.model_class == model_class - ] + filtered[stage] = [r for r in registrations if r.model_class == model_class] return filtered @classmethod @@ -601,9 +583,7 @@ def register_pre_callback( callback: PreTransitionCallback, ) -> None: """Convenience function to register a pre-transition callback.""" - callback_registry.register( - model_class, field_name, source, target, callback, CallbackStage.PRE - ) + callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.PRE) def register_post_callback( @@ -614,9 +594,7 @@ def register_post_callback( callback: PostTransitionCallback, ) -> None: """Convenience function to register a post-transition callback.""" - callback_registry.register( - model_class, field_name, source, target, callback, CallbackStage.POST - ) + callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.POST) def register_error_callback( @@ -627,6 +605,4 @@ def register_error_callback( callback: ErrorTransitionCallback, ) -> None: """Convenience function to register an error callback.""" - callback_registry.register( - model_class, field_name, source, target, callback, CallbackStage.ERROR - ) + callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.ERROR) diff --git a/backend/apps/core/state_machine/callbacks/cache.py b/backend/apps/core/state_machine/callbacks/cache.py index 72958803..15ed85a3 100644 --- a/backend/apps/core/state_machine/callbacks/cache.py +++ b/backend/apps/core/state_machine/callbacks/cache.py @@ -44,8 +44,8 @@ class CacheInvalidationCallback(PostTransitionCallback): def should_execute(self, context: TransitionContext) -> bool: """Check if cache invalidation is enabled.""" - callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {}) - if not callback_settings.get('cache_invalidation_enabled', True): + callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {}) + if not callback_settings.get("cache_invalidation_enabled", True): logger.debug("Cache invalidation disabled via settings") return False return True @@ -54,6 +54,7 @@ class CacheInvalidationCallback(PostTransitionCallback): """Get the EnhancedCacheService instance.""" try: from apps.core.services.enhanced_cache_service import EnhancedCacheService + return EnhancedCacheService() except ImportError: logger.warning("EnhancedCacheService not available") @@ -85,11 +86,7 @@ class CacheInvalidationCallback(PostTransitionCallback): substituted = set() for pattern in all_patterns: - substituted.add( - pattern - .replace('{id}', instance_id) - .replace('{model}', model_name) - ) + substituted.add(pattern.replace("{id}", instance_id).replace("{model}", model_name)) return substituted @@ -108,20 +105,13 @@ class CacheInvalidationCallback(PostTransitionCallback): cache_service.invalidate_pattern(pattern) logger.debug(f"Invalidated cache pattern: {pattern}") except Exception as e: - logger.warning( - f"Failed to invalidate cache pattern {pattern}: {e}" - ) + logger.warning(f"Failed to invalidate cache pattern {pattern}: {e}") - logger.info( - f"Cache invalidation completed for {context}: " - f"{len(patterns)} patterns" - ) + logger.info(f"Cache invalidation completed for {context}: " f"{len(patterns)} patterns") return True except Exception as e: - logger.exception( - f"Failed to invalidate cache for {context}: {e}" - ) + logger.exception(f"Failed to invalidate cache for {context}: {e}") return False def _fallback_invalidation(self, context: TransitionContext) -> bool: @@ -133,8 +123,7 @@ class CacheInvalidationCallback(PostTransitionCallback): # Django's default cache doesn't support pattern deletion # Log a warning and return True (don't fail the transition) logger.warning( - f"EnhancedCacheService not available, skipping pattern " - f"invalidation for {len(patterns)} patterns" + f"EnhancedCacheService not available, skipping pattern " f"invalidation for {len(patterns)} patterns" ) return True @@ -155,13 +144,13 @@ class ModelCacheInvalidation(CacheInvalidationCallback): # Default patterns by model type MODEL_PATTERNS = { - 'Park': ['*park:{id}*', '*parks*', 'geo:*'], - 'Ride': ['*ride:{id}*', '*rides*', '*park:*', 'geo:*'], - 'EditSubmission': ['*submission:{id}*', '*moderation*'], - 'PhotoSubmission': ['*photo:{id}*', '*moderation*'], - 'ModerationReport': ['*report:{id}*', '*moderation*'], - 'ModerationQueue': ['*queue*', '*moderation*'], - 'BulkOperation': ['*operation:{id}*', '*moderation*'], + "Park": ["*park:{id}*", "*parks*", "geo:*"], + "Ride": ["*ride:{id}*", "*rides*", "*park:*", "geo:*"], + "EditSubmission": ["*submission:{id}*", "*moderation*"], + "PhotoSubmission": ["*photo:{id}*", "*moderation*"], + "ModerationReport": ["*report:{id}*", "*moderation*"], + "ModerationQueue": ["*queue*", "*moderation*"], + "BulkOperation": ["*operation:{id}*", "*moderation*"], } def __init__(self, **kwargs): @@ -178,7 +167,7 @@ class ModelCacheInvalidation(CacheInvalidationCallback): # Substitute {id} placeholder instance_id = str(context.instance.pk) for pattern in model_patterns: - base_patterns.append(pattern.replace('{id}', instance_id)) + base_patterns.append(pattern.replace("{id}", instance_id)) return base_patterns @@ -217,14 +206,14 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback): continue # Handle foreign key relationships - if hasattr(related_obj, 'pk'): + if hasattr(related_obj, "pk"): related_model = type(related_obj).__name__.lower() related_id = related_obj.pk patterns.append(f"*{related_model}:{related_id}*") patterns.append(f"*{related_model}_{related_id}*") # Handle many-to-many relationships - elif hasattr(related_obj, 'all'): + elif hasattr(related_obj, "all"): for obj in related_obj.all(): related_model = type(obj).__name__.lower() related_id = obj.pk @@ -293,7 +282,7 @@ class APICacheInvalidation(CacheInvalidationCallback): **kwargs: Additional arguments. """ super().__init__(**kwargs) - self.api_prefixes = api_prefixes or ['api:*'] + self.api_prefixes = api_prefixes or ["api:*"] self.include_geo_cache = include_geo_cache def _get_all_patterns(self, context: TransitionContext) -> set[str]: @@ -306,8 +295,8 @@ class APICacheInvalidation(CacheInvalidationCallback): # Add geo cache if requested if self.include_geo_cache: - patterns.add('geo:*') - patterns.add('map:*') + patterns.add("geo:*") + patterns.add("map:*") # Add model-specific API patterns model_name = context.model_name.lower() @@ -329,10 +318,10 @@ class ParkCacheInvalidation(CacheInvalidationCallback): def __init__(self, **kwargs): super().__init__( patterns=[ - '*park:{id}*', - '*parks*', - 'api:*', - 'geo:*', + "*park:{id}*", + "*parks*", + "api:*", + "geo:*", ], **kwargs, ) @@ -346,10 +335,10 @@ class RideCacheInvalidation(CacheInvalidationCallback): def __init__(self, **kwargs): super().__init__( patterns=[ - '*ride:{id}*', - '*rides*', - 'api:*', - 'geo:*', + "*ride:{id}*", + "*rides*", + "api:*", + "geo:*", ], **kwargs, ) @@ -359,9 +348,9 @@ class RideCacheInvalidation(CacheInvalidationCallback): patterns = super()._get_instance_patterns(context) # Invalidate parent park's cache - park = getattr(context.instance, 'park', None) + park = getattr(context.instance, "park", None) if park: - park_id = park.pk if hasattr(park, 'pk') else park + park_id = park.pk if hasattr(park, "pk") else park patterns.append(f"*park:{park_id}*") patterns.append(f"*park_{park_id}*") @@ -376,9 +365,9 @@ class ModerationCacheInvalidation(CacheInvalidationCallback): def __init__(self, **kwargs): super().__init__( patterns=[ - '*submission*', - '*moderation*', - 'api:moderation*', + "*submission*", + "*moderation*", + "api:moderation*", ], **kwargs, ) diff --git a/backend/apps/core/state_machine/callbacks/notifications.py b/backend/apps/core/state_machine/callbacks/notifications.py index 0a5e47a2..ef290460 100644 --- a/backend/apps/core/state_machine/callbacks/notifications.py +++ b/backend/apps/core/state_machine/callbacks/notifications.py @@ -53,17 +53,15 @@ class NotificationCallback(PostTransitionCallback): def should_execute(self, context: TransitionContext) -> bool: """Check if notifications are enabled and recipient exists.""" # Check if notifications are disabled in settings - callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {}) - if not callback_settings.get('notifications_enabled', True): + callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {}) + if not callback_settings.get("notifications_enabled", True): logger.debug("Notifications disabled via settings") return False # Check if recipient exists recipient = self._get_recipient(context.instance) if not recipient: - logger.debug( - f"No recipient found at {self.recipient_field} for {context}" - ) + logger.debug(f"No recipient found at {self.recipient_field} for {context}") return False return True @@ -76,6 +74,7 @@ class NotificationCallback(PostTransitionCallback): """Get the NotificationService instance.""" try: from apps.accounts.services.notification_service import NotificationService + return NotificationService() except ImportError: logger.warning("NotificationService not available") @@ -86,18 +85,16 @@ class NotificationCallback(PostTransitionCallback): extra_data = {} if self.include_transition_data: - extra_data['transition'] = { - 'source_state': context.source_state, - 'target_state': context.target_state, - 'field_name': context.field_name, - 'timestamp': context.timestamp.isoformat(), + extra_data["transition"] = { + "source_state": context.source_state, + "target_state": context.target_state, + "field_name": context.field_name, + "timestamp": context.timestamp.isoformat(), } if context.user: - extra_data['transition']['by_user_id'] = context.user.id - extra_data['transition']['by_username'] = getattr( - context.user, 'username', str(context.user) - ) + extra_data["transition"]["by_user_id"] = context.user.id + extra_data["transition"]["by_username"] = getattr(context.user, "username", str(context.user)) # Include any extra data from the context extra_data.update(context.extra_data) @@ -112,10 +109,7 @@ class NotificationCallback(PostTransitionCallback): def _get_notification_message(self, context: TransitionContext) -> str: """Get the notification message based on context.""" model_name = context.model_name - return ( - f"The {model_name} has transitioned from {context.source_state} " - f"to {context.target_state}." - ) + return f"The {model_name} has transitioned from {context.source_state} " f"to {context.target_state}." def execute(self, context: TransitionContext) -> bool: """Execute the notification callback.""" @@ -140,16 +134,11 @@ class NotificationCallback(PostTransitionCallback): extra_data=extra_data, ) - logger.info( - f"Created {self.notification_type} notification for " - f"{recipient} on {context}" - ) + logger.info(f"Created {self.notification_type} notification for " f"{recipient} on {context}") return True except Exception as e: - logger.exception( - f"Failed to create notification for {context}: {e}" - ) + logger.exception(f"Failed to create notification for {context}: {e}") return False @@ -176,8 +165,8 @@ class SubmissionApprovedNotification(NotificationCallback): def _get_submission_type(self, context: TransitionContext) -> str: """Get the submission type from context or instance.""" # Try to get from extra_data first - if 'submission_type' in context.extra_data: - return context.extra_data['submission_type'] + if "submission_type" in context.extra_data: + return context.extra_data["submission_type"] # Fall back to model name return self.submission_type or context.model_name.lower() @@ -193,10 +182,10 @@ class SubmissionApprovedNotification(NotificationCallback): try: submission_type = self._get_submission_type(context) - additional_message = context.extra_data.get('comment', '') + additional_message = context.extra_data.get("comment", "") # Use the specific method if available - if hasattr(notification_service, 'create_submission_approved_notification'): + if hasattr(notification_service, "create_submission_approved_notification"): notification_service.create_submission_approved_notification( user=recipient, submission_object=context.instance, @@ -215,15 +204,11 @@ class SubmissionApprovedNotification(NotificationCallback): extra_data=extra_data, ) - logger.info( - f"Created approval notification for {recipient} on {context}" - ) + logger.info(f"Created approval notification for {recipient} on {context}") return True except Exception as e: - logger.exception( - f"Failed to create approval notification for {context}: {e}" - ) + logger.exception(f"Failed to create approval notification for {context}: {e}") return False @@ -250,8 +235,8 @@ class SubmissionRejectedNotification(NotificationCallback): def _get_submission_type(self, context: TransitionContext) -> str: """Get the submission type from context or instance.""" # Try to get from extra_data first - if 'submission_type' in context.extra_data: - return context.extra_data['submission_type'] + if "submission_type" in context.extra_data: + return context.extra_data["submission_type"] # Fall back to model name return self.submission_type or context.model_name.lower() @@ -268,11 +253,11 @@ class SubmissionRejectedNotification(NotificationCallback): try: submission_type = self._get_submission_type(context) # Extract rejection reason from extra_data - rejection_reason = context.extra_data.get('reason', 'No reason provided') - additional_message = context.extra_data.get('comment', '') + rejection_reason = context.extra_data.get("reason", "No reason provided") + additional_message = context.extra_data.get("comment", "") # Use the specific method if available - if hasattr(notification_service, 'create_submission_rejected_notification'): + if hasattr(notification_service, "create_submission_rejected_notification"): notification_service.create_submission_rejected_notification( user=recipient, submission_object=context.instance, @@ -291,15 +276,11 @@ class SubmissionRejectedNotification(NotificationCallback): extra_data=extra_data, ) - logger.info( - f"Created rejection notification for {recipient} on {context}" - ) + logger.info(f"Created rejection notification for {recipient} on {context}") return True except Exception as e: - logger.exception( - f"Failed to create rejection notification for {context}: {e}" - ) + logger.exception(f"Failed to create rejection notification for {context}: {e}") return False @@ -326,6 +307,7 @@ class SubmissionEscalatedNotification(NotificationCallback): """Get admin users to notify.""" try: from django.contrib.auth import get_user_model + user_model = get_user_model() return user_model.objects.filter(is_staff=True, is_active=True) except Exception as e: @@ -340,9 +322,9 @@ class SubmissionEscalatedNotification(NotificationCallback): try: extra_data = self._build_extra_data(context) - escalation_reason = context.extra_data.get('reason', '') + escalation_reason = context.extra_data.get("reason", "") if escalation_reason: - extra_data['escalation_reason'] = escalation_reason + extra_data["escalation_reason"] = escalation_reason title = f"{context.model_name} escalated for review" message = f"A {context.model_name} has been escalated and requires attention." @@ -361,9 +343,7 @@ class SubmissionEscalatedNotification(NotificationCallback): related_object=context.instance, extra_data=extra_data, ) - logger.info( - f"Created escalation notifications for {admins.count()} admins" - ) + logger.info(f"Created escalation notifications for {admins.count()} admins") else: # Notify the submitter recipient = self._get_recipient(context.instance) @@ -376,16 +356,12 @@ class SubmissionEscalatedNotification(NotificationCallback): related_object=context.instance, extra_data=extra_data, ) - logger.info( - f"Created escalation notification for {recipient}" - ) + logger.info(f"Created escalation notification for {recipient}") return True except Exception as e: - logger.exception( - f"Failed to create escalation notification for {context}: {e}" - ) + logger.exception(f"Failed to create escalation notification for {context}: {e}") return False @@ -415,16 +391,14 @@ class StatusChangeNotification(NotificationCallback): notification_type="status_change", **kwargs, ) - self.significant_states = significant_states or [ - 'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED' - ] + self.significant_states = significant_states or ["CLOSED_PERM", "DEMOLISHED", "RELOCATED"] self.notify_admins = notify_admins def should_execute(self, context: TransitionContext) -> bool: """Only execute for significant state changes.""" # Check if notifications are disabled - callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {}) - if not callback_settings.get('notifications_enabled', True): + callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {}) + if not callback_settings.get("notifications_enabled", True): return False # Only notify for significant status changes @@ -441,16 +415,13 @@ class StatusChangeNotification(NotificationCallback): try: extra_data = self._build_extra_data(context) - extra_data['entity_type'] = context.model_name - extra_data['entity_id'] = context.instance.pk + extra_data["entity_type"] = context.model_name + extra_data["entity_id"] = context.instance.pk # Build title and message - entity_name = getattr(context.instance, 'name', str(context.instance)) + entity_name = getattr(context.instance, "name", str(context.instance)) title = f"{context.model_name} status changed to {context.target_state}" - message = ( - f"{entity_name} has changed status from {context.source_state} " - f"to {context.target_state}." - ) + message = f"{entity_name} has changed status from {context.source_state} " f"to {context.target_state}." # Notify admin users admins = self._get_admin_users() @@ -471,15 +442,14 @@ class StatusChangeNotification(NotificationCallback): return True except Exception as e: - logger.exception( - f"Failed to create status change notification for {context}: {e}" - ) + logger.exception(f"Failed to create status change notification for {context}: {e}") return False def _get_admin_users(self): """Get admin users to notify.""" try: from django.contrib.auth import get_user_model + user_model = get_user_model() return user_model.objects.filter(is_staff=True, is_active=True) except Exception as e: @@ -499,13 +469,13 @@ class ModerationNotificationCallback(NotificationCallback): # Mapping of (model_name, target_state) to notification type NOTIFICATION_MAPPING = { - ('ModerationReport', 'UNDER_REVIEW'): 'report_under_review', - ('ModerationReport', 'RESOLVED'): 'report_resolved', - ('ModerationQueue', 'IN_PROGRESS'): 'queue_in_progress', - ('ModerationQueue', 'COMPLETED'): 'queue_completed', - ('BulkOperation', 'RUNNING'): 'bulk_operation_started', - ('BulkOperation', 'COMPLETED'): 'bulk_operation_completed', - ('BulkOperation', 'FAILED'): 'bulk_operation_failed', + ("ModerationReport", "UNDER_REVIEW"): "report_under_review", + ("ModerationReport", "RESOLVED"): "report_resolved", + ("ModerationQueue", "IN_PROGRESS"): "queue_in_progress", + ("ModerationQueue", "COMPLETED"): "queue_completed", + ("BulkOperation", "RUNNING"): "bulk_operation_started", + ("BulkOperation", "COMPLETED"): "bulk_operation_completed", + ("BulkOperation", "FAILED"): "bulk_operation_failed", } def __init__(self, **kwargs): @@ -522,7 +492,7 @@ class ModerationNotificationCallback(NotificationCallback): def _get_recipient(self, instance: models.Model) -> Any | None: """Get the appropriate recipient based on model type.""" # Try common recipient fields - for field in ['reporter', 'assigned_to', 'created_by', 'submitted_by']: + for field in ["reporter", "assigned_to", "created_by", "submitted_by"]: recipient = getattr(instance, field, None) if recipient: return recipient @@ -531,31 +501,28 @@ class ModerationNotificationCallback(NotificationCallback): def _get_notification_title(self, context: TransitionContext, notification_type: str) -> str: """Get the notification title based on notification type.""" titles = { - 'report_under_review': 'Your report is under review', - 'report_resolved': 'Your report has been resolved', - 'queue_in_progress': 'Moderation queue item in progress', - 'queue_completed': 'Moderation queue item completed', - 'bulk_operation_started': 'Bulk operation started', - 'bulk_operation_completed': 'Bulk operation completed', - 'bulk_operation_failed': 'Bulk operation failed', + "report_under_review": "Your report is under review", + "report_resolved": "Your report has been resolved", + "queue_in_progress": "Moderation queue item in progress", + "queue_completed": "Moderation queue item completed", + "bulk_operation_started": "Bulk operation started", + "bulk_operation_completed": "Bulk operation completed", + "bulk_operation_failed": "Bulk operation failed", } return titles.get(notification_type, f"{context.model_name} status updated") def _get_notification_message(self, context: TransitionContext, notification_type: str) -> str: """Get the notification message based on notification type.""" messages = { - 'report_under_review': 'Your moderation report is now being reviewed by our team.', - 'report_resolved': 'Your moderation report has been reviewed and resolved.', - 'queue_in_progress': 'A moderation queue item is now being processed.', - 'queue_completed': 'A moderation queue item has been completed.', - 'bulk_operation_started': 'Your bulk operation has started processing.', - 'bulk_operation_completed': 'Your bulk operation has completed successfully.', - 'bulk_operation_failed': 'Your bulk operation encountered an error and could not complete.', + "report_under_review": "Your moderation report is now being reviewed by our team.", + "report_resolved": "Your moderation report has been reviewed and resolved.", + "queue_in_progress": "A moderation queue item is now being processed.", + "queue_completed": "A moderation queue item has been completed.", + "bulk_operation_started": "Your bulk operation has started processing.", + "bulk_operation_completed": "Your bulk operation has completed successfully.", + "bulk_operation_failed": "Your bulk operation encountered an error and could not complete.", } - return messages.get( - notification_type, - f"The {context.model_name} has been updated to {context.target_state}." - ) + return messages.get(notification_type, f"The {context.model_name} has been updated to {context.target_state}.") def execute(self, context: TransitionContext) -> bool: """Execute the moderation notification.""" @@ -565,10 +532,7 @@ class ModerationNotificationCallback(NotificationCallback): notification_type = self._get_notification_type(context) if not notification_type: - logger.debug( - f"No notification type defined for {context.model_name} " - f"→ {context.target_state}" - ) + logger.debug(f"No notification type defined for {context.model_name} " f"→ {context.target_state}") return True # Not an error, just no notification needed recipient = self._get_recipient(context.instance) @@ -587,13 +551,9 @@ class ModerationNotificationCallback(NotificationCallback): extra_data=extra_data, ) - logger.info( - f"Created {notification_type} notification for {recipient}" - ) + logger.info(f"Created {notification_type} notification for {recipient}") return True except Exception as e: - logger.exception( - f"Failed to create moderation notification for {context}: {e}" - ) + logger.exception(f"Failed to create moderation notification for {context}: {e}") return False diff --git a/backend/apps/core/state_machine/callbacks/related_updates.py b/backend/apps/core/state_machine/callbacks/related_updates.py index 993a43f2..b07500c5 100644 --- a/backend/apps/core/state_machine/callbacks/related_updates.py +++ b/backend/apps/core/state_machine/callbacks/related_updates.py @@ -45,8 +45,8 @@ class RelatedModelUpdateCallback(PostTransitionCallback): def should_execute(self, context: TransitionContext) -> bool: """Check if related updates are enabled.""" - callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {}) - if not callback_settings.get('related_updates_enabled', True): + callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {}) + if not callback_settings.get("related_updates_enabled", True): logger.debug("Related model updates disabled via settings") return False return True @@ -77,9 +77,7 @@ class RelatedModelUpdateCallback(PostTransitionCallback): return self.perform_update(context) except Exception as e: - logger.exception( - f"Failed to update related models for {context}: {e}" - ) + logger.exception(f"Failed to update related models for {context}: {e}") return False @@ -94,10 +92,10 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback): name: str = "ParkCountUpdateCallback" # Status values that count as "active" rides - ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'} + ACTIVE_STATUSES = {"OPERATING", "SEASONAL", "UNDER_CONSTRUCTION"} # Status values that indicate a ride is no longer countable - INACTIVE_STATUSES = {'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED', 'REMOVED'} + INACTIVE_STATUSES = {"CLOSED_PERM", "DEMOLISHED", "RELOCATED", "REMOVED"} def should_execute(self, context: TransitionContext) -> bool: """Only execute when status affects ride counts.""" @@ -115,14 +113,14 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback): return source_affects or target_affects # Category value for roller coasters (from rides domain choices) - COASTER_CATEGORY = 'RC' + COASTER_CATEGORY = "RC" def perform_update(self, context: TransitionContext) -> bool: """Update park ride counts.""" instance = context.instance # Get the parent park - park = getattr(instance, 'park', None) + park = getattr(instance, "park", None) if not park: logger.debug(f"No park found for ride {instance.pk}") return True @@ -133,22 +131,17 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback): from apps.rides.models.rides import Ride # Get the park ID (handle both object and ID) - park_id = park.pk if hasattr(park, 'pk') else park + park_id = park.pk if hasattr(park, "pk") else park # Calculate new counts efficiently ride_queryset = Ride.objects.filter(park_id=park_id) # Count active rides active_statuses = list(self.ACTIVE_STATUSES) - ride_count = ride_queryset.filter( - status__in=active_statuses - ).count() + ride_count = ride_queryset.filter(status__in=active_statuses).count() # Count active coasters (category='RC' for Roller Coaster) - coaster_count = ride_queryset.filter( - status__in=active_statuses, - category=self.COASTER_CATEGORY - ).count() + coaster_count = ride_queryset.filter(status__in=active_statuses, category=self.COASTER_CATEGORY).count() # Update park counts Park.objects.filter(id=park_id).update( @@ -156,16 +149,11 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback): coaster_count=coaster_count, ) - logger.info( - f"Updated park {park_id} counts: " - f"ride_count={ride_count}, coaster_count={coaster_count}" - ) + logger.info(f"Updated park {park_id} counts: " f"ride_count={ride_count}, coaster_count={coaster_count}") return True except Exception as e: - logger.exception( - f"Failed to update park counts for {instance.pk}: {e}" - ) + logger.exception(f"Failed to update park counts for {instance.pk}: {e}") return False @@ -184,20 +172,16 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback): instance = context.instance # Check if instance has search_text field - if not hasattr(instance, 'search_text'): - logger.debug( - f"{context.model_name} has no search_text field" - ) + if not hasattr(instance, "search_text"): + logger.debug(f"{context.model_name} has no search_text field") return True try: # Call the model's update_search_text method if available - if hasattr(instance, 'update_search_text'): + if hasattr(instance, "update_search_text"): instance.update_search_text() - instance.save(update_fields=['search_text']) - logger.info( - f"Updated search_text for {context.model_name} {instance.pk}" - ) + instance.save(update_fields=["search_text"]) + logger.info(f"Updated search_text for {context.model_name} {instance.pk}") else: # Build search text manually self._build_search_text(instance, context) @@ -205,9 +189,7 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback): return True except Exception as e: - logger.exception( - f"Failed to update search_text for {instance.pk}: {e}" - ) + logger.exception(f"Failed to update search_text for {instance.pk}: {e}") return False def _build_search_text( @@ -219,7 +201,7 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback): parts = [] # Common searchable fields - for field in ['name', 'title', 'description', 'location']: + for field in ["name", "title", "description", "location"]: value = getattr(instance, field, None) if value: parts.append(str(value)) @@ -228,15 +210,15 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback): status_field = getattr(instance, context.field_name, None) if status_field: # Try to get the display label - display_method = f'get_{context.field_name}_display' + display_method = f"get_{context.field_name}_display" if hasattr(instance, display_method): parts.append(getattr(instance, display_method)()) else: parts.append(str(status_field)) # Update search_text - instance.search_text = ' '.join(parts) - instance.save(update_fields=['search_text']) + instance.search_text = " ".join(parts) + instance.save(update_fields=["search_text"]) class ComputedFieldUpdateCallback(RelatedModelUpdateCallback): @@ -280,7 +262,7 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback): # Update specific fields updated_fields = [] for field_name in self.computed_fields: - update_method_name = f'compute_{field_name}' + update_method_name = f"compute_{field_name}" if hasattr(instance, update_method_name): method = getattr(instance, update_method_name) if callable(method): @@ -291,17 +273,12 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback): # Save updated fields if updated_fields: instance.save(update_fields=updated_fields) - logger.info( - f"Updated computed fields {updated_fields} for " - f"{context.model_name} {instance.pk}" - ) + logger.info(f"Updated computed fields {updated_fields} for " f"{context.model_name} {instance.pk}") return True except Exception as e: - logger.exception( - f"Failed to update computed fields for {instance.pk}: {e}" - ) + logger.exception(f"Failed to update computed fields for {instance.pk}: {e}") return False @@ -320,7 +297,7 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback): return False # Only execute for Ride model - return context.model_name == 'Ride' + return context.model_name == "Ride" def perform_update(self, context: TransitionContext) -> bool: """Perform ride-specific status updates.""" @@ -329,22 +306,18 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback): try: # Handle CLOSING → post_closing_status transition - if context.source_state == 'CLOSING' and target != 'CLOSING': - post_closing_status = getattr(instance, 'post_closing_status', None) + if context.source_state == "CLOSING" and target != "CLOSING": + post_closing_status = getattr(instance, "post_closing_status", None) if post_closing_status and target == post_closing_status: # Clear post_closing_status after applying it instance.post_closing_status = None - instance.save(update_fields=['post_closing_status']) - logger.info( - f"Cleared post_closing_status for ride {instance.pk}" - ) + instance.save(update_fields=["post_closing_status"]) + logger.info(f"Cleared post_closing_status for ride {instance.pk}") return True except Exception as e: - logger.exception( - f"Failed to update ride status fields for {instance.pk}: {e}" - ) + logger.exception(f"Failed to update ride status fields for {instance.pk}: {e}") return False @@ -362,9 +335,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback): # Only for submission and report models model_name = context.model_name - return model_name in ( - 'EditSubmission', 'PhotoSubmission', 'ModerationReport' - ) + return model_name in ("EditSubmission", "PhotoSubmission", "ModerationReport") def perform_update(self, context: TransitionContext) -> bool: """Update moderation queue entries.""" @@ -373,15 +344,13 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback): try: # Mark related queue items as completed when submission is resolved - if target in ('APPROVED', 'REJECTED', 'RESOLVED'): + if target in ("APPROVED", "REJECTED", "RESOLVED"): self._update_queue_items(instance, context) return True except Exception as e: - logger.exception( - f"Failed to update moderation queue for {instance.pk}: {e}" - ) + logger.exception(f"Failed to update moderation queue for {instance.pk}: {e}") return False def _update_queue_items( @@ -401,20 +370,18 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback): queue_items = ModerationQueue.objects.filter( content_type_id=content_type_id, object_id=instance.pk, - status='IN_PROGRESS', + status="IN_PROGRESS", ) for item in queue_items: - if hasattr(item, 'complete'): + if hasattr(item, "complete"): item.complete(user=context.user) else: - item.status = 'COMPLETED' - item.save(update_fields=['status']) + item.status = "COMPLETED" + item.save(update_fields=["status"]) if queue_items.exists(): - logger.info( - f"Marked {queue_items.count()} queue items as completed" - ) + logger.info(f"Marked {queue_items.count()} queue items as completed") except ImportError: logger.debug("ModerationQueue model not available") @@ -425,6 +392,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback): """Get content type ID for the instance.""" try: from django.contrib.contenttypes.models import ContentType + content_type = ContentType.objects.get_for_model(type(instance)) return content_type.pk except Exception: diff --git a/backend/apps/core/state_machine/config.py b/backend/apps/core/state_machine/config.py index 2a0ff62e..ee1dba96 100644 --- a/backend/apps/core/state_machine/config.py +++ b/backend/apps/core/state_machine/config.py @@ -33,7 +33,7 @@ class ModelCallbackConfig: """Configuration for all callbacks on a model.""" model_name: str - field_name: str = 'status' + field_name: str = "status" transitions: dict[tuple, TransitionCallbackConfig] = field(default_factory=dict) default_config: TransitionCallbackConfig = field(default_factory=TransitionCallbackConfig) @@ -53,12 +53,12 @@ class CallbackConfig: # Default settings DEFAULT_SETTINGS = { - 'enabled': True, - 'notifications_enabled': True, - 'cache_invalidation_enabled': True, - 'related_updates_enabled': True, - 'debug_mode': False, - 'log_callbacks': False, + "enabled": True, + "notifications_enabled": True, + "cache_invalidation_enabled": True, + "related_updates_enabled": True, + "debug_mode": False, + "log_callbacks": False, } # Model-specific configurations @@ -70,7 +70,7 @@ class CallbackConfig: def _load_settings(self) -> dict[str, Any]: """Load settings from Django configuration.""" - django_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {}) + django_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {}) merged = dict(self.DEFAULT_SETTINGS) merged.update(django_settings) return merged @@ -78,123 +78,123 @@ class CallbackConfig: def _build_model_configs(self) -> dict[str, ModelCallbackConfig]: """Build model-specific configurations.""" return { - 'EditSubmission': ModelCallbackConfig( - model_name='EditSubmission', - field_name='status', + "EditSubmission": ModelCallbackConfig( + model_name="EditSubmission", + field_name="status", transitions={ - ('PENDING', 'APPROVED'): TransitionCallbackConfig( - notification_template='submission_approved', - cache_patterns=['*submission*', '*moderation*'], + ("PENDING", "APPROVED"): TransitionCallbackConfig( + notification_template="submission_approved", + cache_patterns=["*submission*", "*moderation*"], ), - ('PENDING', 'REJECTED'): TransitionCallbackConfig( - notification_template='submission_rejected', - cache_patterns=['*submission*', '*moderation*'], + ("PENDING", "REJECTED"): TransitionCallbackConfig( + notification_template="submission_rejected", + cache_patterns=["*submission*", "*moderation*"], ), - ('PENDING', 'ESCALATED'): TransitionCallbackConfig( - notification_template='submission_escalated', - cache_patterns=['*submission*', '*moderation*'], + ("PENDING", "ESCALATED"): TransitionCallbackConfig( + notification_template="submission_escalated", + cache_patterns=["*submission*", "*moderation*"], ), }, ), - 'PhotoSubmission': ModelCallbackConfig( - model_name='PhotoSubmission', - field_name='status', + "PhotoSubmission": ModelCallbackConfig( + model_name="PhotoSubmission", + field_name="status", transitions={ - ('PENDING', 'APPROVED'): TransitionCallbackConfig( - notification_template='photo_approved', - cache_patterns=['*photo*', '*moderation*'], + ("PENDING", "APPROVED"): TransitionCallbackConfig( + notification_template="photo_approved", + cache_patterns=["*photo*", "*moderation*"], ), - ('PENDING', 'REJECTED'): TransitionCallbackConfig( - notification_template='photo_rejected', - cache_patterns=['*photo*', '*moderation*'], + ("PENDING", "REJECTED"): TransitionCallbackConfig( + notification_template="photo_rejected", + cache_patterns=["*photo*", "*moderation*"], ), }, ), - 'ModerationReport': ModelCallbackConfig( - model_name='ModerationReport', - field_name='status', + "ModerationReport": ModelCallbackConfig( + model_name="ModerationReport", + field_name="status", transitions={ - ('PENDING', 'UNDER_REVIEW'): TransitionCallbackConfig( - notification_template='report_under_review', - cache_patterns=['*report*', '*moderation*'], + ("PENDING", "UNDER_REVIEW"): TransitionCallbackConfig( + notification_template="report_under_review", + cache_patterns=["*report*", "*moderation*"], ), - ('UNDER_REVIEW', 'RESOLVED'): TransitionCallbackConfig( - notification_template='report_resolved', - cache_patterns=['*report*', '*moderation*'], + ("UNDER_REVIEW", "RESOLVED"): TransitionCallbackConfig( + notification_template="report_resolved", + cache_patterns=["*report*", "*moderation*"], ), }, ), - 'ModerationQueue': ModelCallbackConfig( - model_name='ModerationQueue', - field_name='status', + "ModerationQueue": ModelCallbackConfig( + model_name="ModerationQueue", + field_name="status", transitions={ - ('PENDING', 'IN_PROGRESS'): TransitionCallbackConfig( - notification_template='queue_in_progress', - cache_patterns=['*queue*', '*moderation*'], + ("PENDING", "IN_PROGRESS"): TransitionCallbackConfig( + notification_template="queue_in_progress", + cache_patterns=["*queue*", "*moderation*"], ), - ('IN_PROGRESS', 'COMPLETED'): TransitionCallbackConfig( - notification_template='queue_completed', - cache_patterns=['*queue*', '*moderation*'], + ("IN_PROGRESS", "COMPLETED"): TransitionCallbackConfig( + notification_template="queue_completed", + cache_patterns=["*queue*", "*moderation*"], ), }, ), - 'BulkOperation': ModelCallbackConfig( - model_name='BulkOperation', - field_name='status', + "BulkOperation": ModelCallbackConfig( + model_name="BulkOperation", + field_name="status", transitions={ - ('PENDING', 'RUNNING'): TransitionCallbackConfig( - notification_template='bulk_operation_started', - cache_patterns=['*operation*', '*moderation*'], + ("PENDING", "RUNNING"): TransitionCallbackConfig( + notification_template="bulk_operation_started", + cache_patterns=["*operation*", "*moderation*"], ), - ('RUNNING', 'COMPLETED'): TransitionCallbackConfig( - notification_template='bulk_operation_completed', - cache_patterns=['*operation*', '*moderation*'], + ("RUNNING", "COMPLETED"): TransitionCallbackConfig( + notification_template="bulk_operation_completed", + cache_patterns=["*operation*", "*moderation*"], ), - ('RUNNING', 'FAILED'): TransitionCallbackConfig( - notification_template='bulk_operation_failed', - cache_patterns=['*operation*', '*moderation*'], + ("RUNNING", "FAILED"): TransitionCallbackConfig( + notification_template="bulk_operation_failed", + cache_patterns=["*operation*", "*moderation*"], ), }, ), - 'Park': ModelCallbackConfig( - model_name='Park', - field_name='status', + "Park": ModelCallbackConfig( + model_name="Park", + field_name="status", default_config=TransitionCallbackConfig( - cache_patterns=['*park*', 'api:*', 'geo:*'], + cache_patterns=["*park*", "api:*", "geo:*"], ), transitions={ - ('*', 'CLOSED_PERM'): TransitionCallbackConfig( + ("*", "CLOSED_PERM"): TransitionCallbackConfig( notifications_enabled=True, - notification_template='park_closed_permanently', - cache_patterns=['*park*', 'api:*', 'geo:*'], + notification_template="park_closed_permanently", + cache_patterns=["*park*", "api:*", "geo:*"], ), - ('*', 'OPERATING'): TransitionCallbackConfig( + ("*", "OPERATING"): TransitionCallbackConfig( notifications_enabled=False, - cache_patterns=['*park*', 'api:*', 'geo:*'], + cache_patterns=["*park*", "api:*", "geo:*"], ), }, ), - 'Ride': ModelCallbackConfig( - model_name='Ride', - field_name='status', + "Ride": ModelCallbackConfig( + model_name="Ride", + field_name="status", default_config=TransitionCallbackConfig( - cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'], + cache_patterns=["*ride*", "*park*", "api:*", "geo:*"], ), transitions={ - ('*', 'OPERATING'): TransitionCallbackConfig( - cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'], + ("*", "OPERATING"): TransitionCallbackConfig( + cache_patterns=["*ride*", "*park*", "api:*", "geo:*"], related_updates_enabled=True, ), - ('*', 'CLOSED_PERM'): TransitionCallbackConfig( - cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'], + ("*", "CLOSED_PERM"): TransitionCallbackConfig( + cache_patterns=["*ride*", "*park*", "api:*", "geo:*"], related_updates_enabled=True, ), - ('*', 'DEMOLISHED'): TransitionCallbackConfig( - cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'], + ("*", "DEMOLISHED"): TransitionCallbackConfig( + cache_patterns=["*ride*", "*park*", "api:*", "geo:*"], related_updates_enabled=True, ), - ('*', 'RELOCATED'): TransitionCallbackConfig( - cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'], + ("*", "RELOCATED"): TransitionCallbackConfig( + cache_patterns=["*ride*", "*park*", "api:*", "geo:*"], related_updates_enabled=True, ), }, @@ -204,32 +204,32 @@ class CallbackConfig: @property def enabled(self) -> bool: """Check if callbacks are globally enabled.""" - return self._settings.get('enabled', True) + return self._settings.get("enabled", True) @property def notifications_enabled(self) -> bool: """Check if notification callbacks are enabled.""" - return self._settings.get('notifications_enabled', True) + return self._settings.get("notifications_enabled", True) @property def cache_invalidation_enabled(self) -> bool: """Check if cache invalidation is enabled.""" - return self._settings.get('cache_invalidation_enabled', True) + return self._settings.get("cache_invalidation_enabled", True) @property def related_updates_enabled(self) -> bool: """Check if related model updates are enabled.""" - return self._settings.get('related_updates_enabled', True) + return self._settings.get("related_updates_enabled", True) @property def debug_mode(self) -> bool: """Check if debug mode is enabled.""" - return self._settings.get('debug_mode', False) + return self._settings.get("debug_mode", False) @property def log_callbacks(self) -> bool: """Check if callback logging is enabled.""" - return self._settings.get('log_callbacks', False) + return self._settings.get("log_callbacks", False) def get_config( self, @@ -258,12 +258,12 @@ class CallbackConfig: return config # Try wildcard source - config = model_config.transitions.get(('*', target)) + config = model_config.transitions.get(("*", target)) if config: return config # Try wildcard target - config = model_config.transitions.get((source, '*')) + config = model_config.transitions.get((source, "*")) if config: return config @@ -362,9 +362,7 @@ class CallbackConfig: **kwargs: Configuration values to update. """ if model_name not in self._model_configs: - self._model_configs[model_name] = ModelCallbackConfig( - model_name=model_name - ) + self._model_configs[model_name] = ModelCallbackConfig(model_name=model_name) model_config = self._model_configs[model_name] transition_key = (source, target) @@ -394,9 +392,9 @@ def get_callback_config() -> CallbackConfig: __all__ = [ - 'TransitionCallbackConfig', - 'ModelCallbackConfig', - 'CallbackConfig', - 'callback_config', - 'get_callback_config', + "TransitionCallbackConfig", + "ModelCallbackConfig", + "CallbackConfig", + "callback_config", + "get_callback_config", ] diff --git a/backend/apps/core/state_machine/decorators.py b/backend/apps/core/state_machine/decorators.py index c0afdd51..72838b6b 100644 --- a/backend/apps/core/state_machine/decorators.py +++ b/backend/apps/core/state_machine/decorators.py @@ -1,4 +1,5 @@ """Transition decorator generation for django-fsm integration.""" + import logging from collections.abc import Callable from functools import wraps @@ -51,51 +52,42 @@ def with_callbacks( @wraps(func) def wrapper(instance, *args, **kwargs): # Extract user from kwargs - user = kwargs.get('user') + user = kwargs.get("user") # Get source state before transition source_state = getattr(instance, field_name, None) # Get target state from the transition decorator # The @transition decorator sets _django_fsm_target - target_state = getattr(func, '_django_fsm', {}).get('target', None) + target_state = getattr(func, "_django_fsm", {}).get("target", None) # If we can't determine the target from decorator metadata, # we'll capture it after the transition if target_state is None: # This happens when decorators are applied in wrong order - logger.debug( - f"Could not determine target state from decorator for {func.__name__}" - ) + logger.debug(f"Could not determine target state from decorator for {func.__name__}") # Create transition context context = TransitionContext( instance=instance, field_name=field_name, - source_state=str(source_state) if source_state else '', - target_state=str(target_state) if target_state else '', + source_state=str(source_state) if source_state else "", + target_state=str(target_state) if target_state else "", user=user, extra_data=dict(kwargs), ) # Execute pre-transition callbacks - pre_success, pre_failures = callback_registry.execute_callbacks( - context, CallbackStage.PRE - ) + pre_success, pre_failures = callback_registry.execute_callbacks(context, CallbackStage.PRE) # If pre-callbacks fail with continue_on_error=False, abort if not pre_success and pre_failures: for callback, exc in pre_failures: if not callback.continue_on_error: - logger.error( - f"Pre-transition callback {callback.name} failed, " - f"aborting transition" - ) + logger.error(f"Pre-transition callback {callback.name} failed, " f"aborting transition") if exc: raise exc - raise RuntimeError( - f"Pre-transition callback {callback.name} failed" - ) + raise RuntimeError(f"Pre-transition callback {callback.name} failed") # Emit pre-transition signal if emit_signals: @@ -114,19 +106,14 @@ def with_callbacks( # Update context with actual target state after transition actual_target = getattr(instance, field_name, None) - context.target_state = str(actual_target) if actual_target else '' + context.target_state = str(actual_target) if actual_target else "" # Execute post-transition callbacks - post_success, post_failures = callback_registry.execute_callbacks( - context, CallbackStage.POST - ) + post_success, post_failures = callback_registry.execute_callbacks(context, CallbackStage.POST) if not post_success: - for callback, exc in post_failures: - logger.warning( - f"Post-transition callback {callback.name} failed " - f"for {context}" - ) + for callback, _exc in post_failures: + logger.warning(f"Post-transition callback {callback.name} failed " f"for {context}") # Emit post-transition signal if emit_signals: @@ -236,9 +223,7 @@ def create_transition_method( on_success(instance, user=user, **kwargs) transition_method.__name__ = method_name - transition_method.__doc__ = ( - f"Transition from {source} to {target} on field {field_name}" - ) + transition_method.__doc__ = f"Transition from {source} to {target} on field {field_name}" # Apply callback wrapper if enabled if enable_callbacks: @@ -249,10 +234,10 @@ def create_transition_method( # Store metadata for callback registration transition_method._fsm_metadata = { - 'source': source, - 'target': target, - 'field_name': field_name, - 'callbacks': callbacks or [], + "source": source, + "target": target, + "field_name": field_name, + "callbacks": callbacks or [], } return transition_method @@ -271,21 +256,21 @@ def register_method_callbacks( model_class: The model class containing the method. method: The transition method with _fsm_metadata. """ - metadata = getattr(method, '_fsm_metadata', None) - if not metadata or not metadata.get('callbacks'): + metadata = getattr(method, "_fsm_metadata", None) + if not metadata or not metadata.get("callbacks"): return from .callback_base import CallbackStage, PreTransitionCallback - for callback in metadata['callbacks']: + for callback in metadata["callbacks"]: # Determine stage from callback type stage = CallbackStage.PRE if isinstance(callback, PreTransitionCallback) else CallbackStage.POST callback_registry.register( model_class=model_class, - field_name=metadata['field_name'], - source=metadata['source'], - target=metadata['target'], + field_name=metadata["field_name"], + source=metadata["source"], + target=metadata["target"], callback=callback, stage=stage, ) @@ -490,9 +475,7 @@ class TransitionMethodFactory: if docstring: generic_transition.__doc__ = docstring else: - generic_transition.__doc__ = ( - f"Transition from {source} to {target}" - ) + generic_transition.__doc__ = f"Transition from {source} to {target}" # Apply callback wrapper if enabled if enable_callbacks: diff --git a/backend/apps/core/state_machine/exceptions.py b/backend/apps/core/state_machine/exceptions.py index a36b43e6..7c3e89db 100644 --- a/backend/apps/core/state_machine/exceptions.py +++ b/backend/apps/core/state_machine/exceptions.py @@ -12,6 +12,7 @@ Example usage: 'code': e.error_code }, status=403) """ + from typing import Any from django_fsm import TransitionNotAllowed @@ -214,29 +215,18 @@ ERROR_MESSAGES = { "You need {required_role} permissions to {action}. " "Please contact an administrator if you believe this is an error." ), - "PERMISSION_DENIED_OWNERSHIP": ( - "You must be the owner of this item to perform this action." - ), + "PERMISSION_DENIED_OWNERSHIP": ("You must be the owner of this item to perform this action."), "PERMISSION_DENIED_ASSIGNMENT": ( - "This item must be assigned to you before you can {action}. " - "Please assign it to yourself first." - ), - "NO_ASSIGNMENT": ( - "This item must be assigned before this action can be performed." + "This item must be assigned to you before you can {action}. " "Please assign it to yourself first." ), + "NO_ASSIGNMENT": ("This item must be assigned before this action can be performed."), "INVALID_STATE_TRANSITION": ( "This action cannot be performed from the current state. " "The item is currently '{current_state}' and cannot be modified." ), - "TRANSITION_NOT_AVAILABLE": ( - "This {item_type} has already been {state} and cannot be modified." - ), - "MISSING_REQUIRED_FIELD": ( - "{field_name} is required to complete this action." - ), - "EMPTY_REQUIRED_FIELD": ( - "{field_name} must not be empty." - ), + "TRANSITION_NOT_AVAILABLE": ("This {item_type} has already been {state} and cannot be modified."), + "MISSING_REQUIRED_FIELD": ("{field_name} is required to complete this action."), + "EMPTY_REQUIRED_FIELD": ("{field_name} must not be empty."), "ESCALATED_REQUIRES_ADMIN": ( "This submission has been escalated and requires admin review. " "Only administrators can approve or reject escalated items." diff --git a/backend/apps/core/state_machine/fields.py b/backend/apps/core/state_machine/fields.py index ddbb53a1..e3bc10c8 100644 --- a/backend/apps/core/state_machine/fields.py +++ b/backend/apps/core/state_machine/fields.py @@ -47,6 +47,7 @@ See Also: - apps.core.choices.registry: The central choice registry - apps.core.state_machine.mixins.StateMachineMixin: Convenience helpers """ + from typing import Any from django.core.exceptions import ValidationError @@ -138,14 +139,10 @@ class RichFSMField(DjangoFSMField): choice = registry.get_choice(self.choice_group, value, self.domain) if choice is None: - raise ValidationError( - f"'{value}' is not a valid state for {self.choice_group}" - ) + raise ValidationError(f"'{value}' is not a valid state for {self.choice_group}") if choice.deprecated and not self.allow_deprecated: - raise ValidationError( - f"'{value}' is deprecated and cannot be used for new entries" - ) + raise ValidationError(f"'{value}' is deprecated and cannot be used for new entries") def get_rich_choice(self, value: str) -> RichChoice | None: """Return the RichChoice object for a given state value.""" @@ -155,9 +152,7 @@ class RichFSMField(DjangoFSMField): """Return the label for the given state value.""" return registry.get_choice_display(self.choice_group, value, self.domain) - def contribute_to_class( - self, cls: Any, name: str, private_only: bool = False, **kwargs: Any - ) -> None: + def contribute_to_class(self, cls: Any, name: str, private_only: bool = False, **kwargs: Any) -> None: """Attach helpers to the model for convenience.""" super().contribute_to_class(cls, name, private_only=private_only, **kwargs) diff --git a/backend/apps/core/state_machine/guards.py b/backend/apps/core/state_machine/guards.py index 82d0b189..f39a3f7b 100644 --- a/backend/apps/core/state_machine/guards.py +++ b/backend/apps/core/state_machine/guards.py @@ -18,6 +18,7 @@ Example usage: OwnershipGuard() ], operator='OR') """ + from collections.abc import Callable from typing import Any, Optional @@ -534,7 +535,7 @@ class MetadataGuard: self._last_error_code = self.ERROR_CODE_EMPTY_FIELD self._failed_field = field_name return False - if isinstance(value, (list, dict)) and not value: + if isinstance(value, list | dict) and not value: self._last_error_code = self.ERROR_CODE_EMPTY_FIELD self._failed_field = field_name return False @@ -787,8 +788,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]: # Zero tolerance requires superuser if zero_tolerance: guard = PermissionGuard( - requires_superuser=True, - error_message="Zero tolerance violations require superuser permissions" + requires_superuser=True, error_message="Zero tolerance violations require superuser permissions" ) guards.append(guard) elif requires_moderator or requires_admin or escalation_level: @@ -801,7 +801,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]: assignment_guard = AssignmentGuard( require_assignment=True, allow_admin_override=True, - error_message="This item must be assigned to you before this action can be performed" + error_message="This item must be assigned to you before this action can be performed", ) guards.append(assignment_guard) @@ -814,7 +814,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]: perm_guard = PermissionGuard( custom_check=check_permissions, - error_message=f"Missing required permissions: {', '.join(required_permissions)}" + error_message=f"Missing required permissions: {', '.join(required_permissions)}", ) guards.append(perm_guard) @@ -1072,7 +1072,7 @@ def has_role(user: Any, required_roles: list[str]) -> bool: # Only apply if role field is not set if user_role is None: # Check for superuser (Django's is_superuser flag) - if hasattr(user, "is_superuser") and user.is_superuser: + if hasattr(user, "is_superuser") and user.is_superuser: # noqa: SIM102 if "SUPERUSER" in required_roles or "ADMIN" in required_roles: return True @@ -1248,7 +1248,7 @@ def create_guard_from_drf_permission( self._last_error_code = "PERMISSION_DENIED" return False - if hasattr(permission, "has_object_permission"): + if hasattr(permission, "has_object_permission"): # noqa: SIM102 if not permission.has_object_permission(mock_request, None, instance): self._last_error_code = "OBJECT_PERMISSION_DENIED" return False diff --git a/backend/apps/core/state_machine/integration.py b/backend/apps/core/state_machine/integration.py index e9fddf38..fecae859 100644 --- a/backend/apps/core/state_machine/integration.py +++ b/backend/apps/core/state_machine/integration.py @@ -1,4 +1,5 @@ """Model integration utilities for applying state machines to Django models.""" + from collections.abc import Callable from typing import Any @@ -46,18 +47,13 @@ def apply_state_machine( if not result.is_valid: error_messages = [str(e) for e in result.errors] - raise ValueError( - "Cannot apply state machine - validation failed:\n" - + "\n".join(error_messages) - ) + raise ValueError("Cannot apply state machine - validation failed:\n" + "\n".join(error_messages)) # Build transition registry registry_instance.build_registry_from_choices(choice_group, domain) # Generate and attach transition methods - generate_transition_methods_for_model( - model_class, field_name, choice_group, domain - ) + generate_transition_methods_for_model(model_class, field_name, choice_group, domain) def generate_transition_methods_for_model( @@ -140,15 +136,10 @@ def generate_transition_methods_for_model( setattr(model_class, method_name, method) - - - class StateMachineModelMixin: """Mixin providing state machine helper methods for models.""" - def get_available_state_transitions( - self, field_name: str = "status" - ) -> list[TransitionInfo]: + def get_available_state_transitions(self, field_name: str = "status") -> list[TransitionInfo]: """ Get available transitions from current state. @@ -167,9 +158,7 @@ class StateMachineModelMixin: domain = field.domain current_state = getattr(self, field_name) - return registry_instance.get_available_transitions( - choice_group, domain, current_state - ) + return registry_instance.get_available_transitions(choice_group, domain, current_state) def can_transition_to( self, @@ -199,9 +188,7 @@ class StateMachineModelMixin: domain = field.domain # Check if transition exists in registry - transition = registry_instance.get_transition( - choice_group, domain, current_state, target_state - ) + transition = registry_instance.get_transition(choice_group, domain, current_state, target_state) if not transition: return False @@ -216,9 +203,7 @@ class StateMachineModelMixin: # Use django-fsm's can_proceed return can_proceed(method) - def get_transition_method( - self, target_state: str, field_name: str = "status" - ) -> Callable | None: + def get_transition_method(self, target_state: str, field_name: str = "status") -> Callable | None: """ Get the transition method for moving to target state. @@ -238,9 +223,7 @@ class StateMachineModelMixin: choice_group = field.choice_group domain = field.domain - transition = registry_instance.get_transition( - choice_group, domain, current_state, target_state - ) + transition = registry_instance.get_transition(choice_group, domain, current_state, target_state) if not transition: return None @@ -270,9 +253,7 @@ class StateMachineModelMixin: ValueError: If transition is not allowed """ if not self.can_transition_to(target_state, field_name, user): - raise ValueError( - f"Cannot transition to {target_state} from current state" - ) + raise ValueError(f"Cannot transition to {target_state} from current state") method = self.get_transition_method(target_state, field_name) if method is None: @@ -283,9 +264,7 @@ class StateMachineModelMixin: return True -def state_machine_model( - field_name: str, choice_group: str, domain: str = "core" -): +def state_machine_model(field_name: str, choice_group: str, domain: str = "core"): """ Class decorator to automatically apply state machine to models. @@ -306,9 +285,7 @@ def state_machine_model( return decorator -def validate_model_state_machine( - model_class: type[models.Model], field_name: str -) -> bool: +def validate_model_state_machine(model_class: type[models.Model], field_name: str) -> bool: """ Ensure model is properly configured with state machine. @@ -326,13 +303,11 @@ def validate_model_state_machine( try: field = model_class._meta.get_field(field_name) except Exception: - raise ValueError(f"Field {field_name} not found on {model_class}") + raise ValueError(f"Field {field_name} not found on {model_class}") from None # Check if field has choice_group attribute if not hasattr(field, "choice_group"): - raise ValueError( - f"Field {field_name} is not a RichFSMField or RichChoiceField" - ) + raise ValueError(f"Field {field_name} is not a RichFSMField or RichChoiceField") # Validate metadata choice_group = field.choice_group @@ -343,9 +318,7 @@ def validate_model_state_machine( if not result.is_valid: error_messages = [str(e) for e in result.errors] - raise ValueError( - "State machine validation failed:\n" + "\n".join(error_messages) - ) + raise ValueError("State machine validation failed:\n" + "\n".join(error_messages)) return True diff --git a/backend/apps/core/state_machine/mixins.py b/backend/apps/core/state_machine/mixins.py index 7705970d..bde30a9e 100644 --- a/backend/apps/core/state_machine/mixins.py +++ b/backend/apps/core/state_machine/mixins.py @@ -38,6 +38,7 @@ See Also: - apps.core.state_machine.fields.RichFSMField: The FSM field implementation - django_fsm.can_proceed: FSM transition checking utility """ + from collections.abc import Iterable from typing import Any @@ -47,25 +48,75 @@ from django_fsm import can_proceed # Default transition metadata for styling TRANSITION_METADATA = { # Approval transitions - "approve": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"}, - "transition_to_approved": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"}, + "approve": { + "style": "green", + "icon": "check", + "requires_confirm": True, + "confirm_message": "Are you sure you want to approve this?", + }, + "transition_to_approved": { + "style": "green", + "icon": "check", + "requires_confirm": True, + "confirm_message": "Are you sure you want to approve this?", + }, # Rejection transitions - "reject": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"}, - "transition_to_rejected": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"}, + "reject": { + "style": "red", + "icon": "times", + "requires_confirm": True, + "confirm_message": "Are you sure you want to reject this?", + }, + "transition_to_rejected": { + "style": "red", + "icon": "times", + "requires_confirm": True, + "confirm_message": "Are you sure you want to reject this?", + }, # Escalation transitions - "escalate": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"}, - "transition_to_escalated": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"}, + "escalate": { + "style": "yellow", + "icon": "arrow-up", + "requires_confirm": True, + "confirm_message": "Are you sure you want to escalate this?", + }, + "transition_to_escalated": { + "style": "yellow", + "icon": "arrow-up", + "requires_confirm": True, + "confirm_message": "Are you sure you want to escalate this?", + }, # Assignment transitions "assign": {"style": "blue", "icon": "user-plus", "requires_confirm": False}, "unassign": {"style": "gray", "icon": "user-minus", "requires_confirm": False}, # Status transitions "start": {"style": "blue", "icon": "play", "requires_confirm": False}, - "complete": {"style": "green", "icon": "check-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to complete this?"}, - "cancel": {"style": "red", "icon": "ban", "requires_confirm": True, "confirm_message": "Are you sure you want to cancel this?"}, + "complete": { + "style": "green", + "icon": "check-circle", + "requires_confirm": True, + "confirm_message": "Are you sure you want to complete this?", + }, + "cancel": { + "style": "red", + "icon": "ban", + "requires_confirm": True, + "confirm_message": "Are you sure you want to cancel this?", + }, "reopen": {"style": "blue", "icon": "redo", "requires_confirm": False}, # Resolution transitions - "resolve": {"style": "green", "icon": "check-double", "requires_confirm": True, "confirm_message": "Are you sure you want to resolve this?"}, - "dismiss": {"style": "gray", "icon": "times-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to dismiss this?"}, + "resolve": { + "style": "green", + "icon": "check-double", + "requires_confirm": True, + "confirm_message": "Are you sure you want to resolve this?", + }, + "dismiss": { + "style": "gray", + "icon": "times-circle", + "requires_confirm": True, + "confirm_message": "Are you sure you want to dismiss this?", + }, # Default "default": {"style": "gray", "icon": "arrow-right", "requires_confirm": False}, } @@ -86,22 +137,22 @@ def _get_transition_metadata(transition_name: str) -> dict[str, Any]: def _format_transition_label(transition_name: str) -> str: """Format a transition method name into a human-readable label.""" label = transition_name - for prefix in ['transition_to_', 'transition_', 'do_']: + for prefix in ["transition_to_", "transition_", "do_"]: if label.startswith(prefix): - label = label[len(prefix):] + label = label[len(prefix) :] break - if label.endswith('ed') and len(label) > 3: - if label.endswith('ied'): - label = label[:-3] + 'y' + if label.endswith("ed") and len(label) > 3: + if label.endswith("ied"): + label = label[:-3] + "y" elif label[-3] == label[-4]: label = label[:-3] else: label = label[:-1] - if not label.endswith('e'): + if not label.endswith("e"): label = label[:-1] - return label.replace('_', ' ').title() + return label.replace("_", " ").title() class StateMachineMixin(models.Model): @@ -187,14 +238,10 @@ class StateMachineMixin(models.Model): """Check if a transition method can proceed for the current instance.""" method = getattr(self, transition_method_name, None) if method is None or not callable(method): - raise AttributeError( - f"Transition method '{transition_method_name}' not found" - ) + raise AttributeError(f"Transition method '{transition_method_name}' not found") return can_proceed(method) - def get_available_transitions( - self, field_name: str | None = None - ) -> Iterable[Any]: + def get_available_transitions(self, field_name: str | None = None) -> Iterable[Any]: """Return available transitions when helpers are present.""" name = field_name or self.state_field_name helper_name = f"get_available_{name}_transitions" @@ -246,14 +293,16 @@ class StateMachineMixin(models.Model): try: if can_proceed(method, user): metadata = _get_transition_metadata(transition_name) - transitions.append({ - 'name': transition_name, - 'label': _format_transition_label(transition_name), - 'icon': metadata.get('icon', 'arrow-right'), - 'style': metadata.get('style', 'gray'), - 'requires_confirm': metadata.get('requires_confirm', False), - 'confirm_message': metadata.get('confirm_message', 'Are you sure?'), - }) + transitions.append( + { + "name": transition_name, + "label": _format_transition_label(transition_name), + "icon": metadata.get("icon", "arrow-right"), + "style": metadata.get("style", "gray"), + "requires_confirm": metadata.get("requires_confirm", False), + "confirm_message": metadata.get("confirm_message", "Are you sure?"), + } + ) except Exception: # Skip transitions that raise errors pass diff --git a/backend/apps/core/state_machine/monitoring.py b/backend/apps/core/state_machine/monitoring.py index dadbcbea..65c756d9 100644 --- a/backend/apps/core/state_machine/monitoring.py +++ b/backend/apps/core/state_machine/monitoring.py @@ -47,7 +47,7 @@ class CallbackStats: successful_executions: int = 0 failed_executions: int = 0 total_duration_ms: float = 0.0 - min_duration_ms: float = float('inf') + min_duration_ms: float = float("inf") max_duration_ms: float = 0.0 last_execution: datetime | None = None last_error: str | None = None @@ -97,10 +97,10 @@ class CallbackMonitor: - Performance statistics """ - _instance: Optional['CallbackMonitor'] = None + _instance: Optional["CallbackMonitor"] = None _lock = threading.Lock() - def __new__(cls) -> 'CallbackMonitor': + def __new__(cls) -> "CallbackMonitor": if cls._instance is None: with cls._lock: if cls._instance is None: @@ -112,9 +112,7 @@ class CallbackMonitor: if self._initialized: return - self._stats: dict[str, CallbackStats] = defaultdict( - lambda: CallbackStats(callback_name="") - ) + self._stats: dict[str, CallbackStats] = defaultdict(lambda: CallbackStats(callback_name="")) self._recent_executions: list[CallbackExecutionRecord] = [] self._max_recent_records = 1000 self._enabled = self._check_enabled() @@ -123,13 +121,13 @@ class CallbackMonitor: def _check_enabled(self) -> bool: """Check if monitoring is enabled.""" - callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {}) - return callback_settings.get('monitoring_enabled', True) + callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {}) + return callback_settings.get("monitoring_enabled", True) def _check_debug_mode(self) -> bool: """Check if debug mode is enabled.""" - callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {}) - return callback_settings.get('debug_mode', settings.DEBUG) + callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {}) + return callback_settings.get("debug_mode", settings.DEBUG) def is_enabled(self) -> bool: """Check if monitoring is currently enabled.""" @@ -197,7 +195,7 @@ class CallbackMonitor: # Store recent executions (with size limit) self._recent_executions.append(record) if len(self._recent_executions) > self._max_recent_records: - self._recent_executions = self._recent_executions[-self._max_recent_records:] + self._recent_executions = self._recent_executions[-self._max_recent_records :] # Log in debug mode if self._debug_mode: @@ -277,12 +275,12 @@ class CallbackMonitor: # Build summary summary = { - 'total_failures': len(failures), - 'by_callback': { + "total_failures": len(failures), + "by_callback": { name: { - 'count': len(records), - 'last_error': records[-1].error_message if records else None, - 'last_occurrence': records[-1].timestamp if records else None, + "count": len(records), + "last_error": records[-1].error_message if records else None, + "last_occurrence": records[-1].timestamp if records else None, } for name, records in by_callback.items() }, @@ -293,12 +291,12 @@ class CallbackMonitor: def get_performance_report(self) -> dict[str, Any]: """Get a performance report for all callbacks.""" report = { - 'callbacks': {}, - 'summary': { - 'total_callbacks': len(self._stats), - 'total_executions': sum(s.total_executions for s in self._stats.values()), - 'total_failures': sum(s.failed_executions for s in self._stats.values()), - 'avg_duration_ms': 0.0, + "callbacks": {}, + "summary": { + "total_callbacks": len(self._stats), + "total_executions": sum(s.total_executions for s in self._stats.values()), + "total_failures": sum(s.failed_executions for s in self._stats.values()), + "avg_duration_ms": 0.0, }, } @@ -306,19 +304,19 @@ class CallbackMonitor: total_count = 0 for name, stats in self._stats.items(): - report['callbacks'][name] = { - 'executions': stats.total_executions, - 'success_rate': f"{stats.success_rate:.1f}%", - 'avg_duration_ms': f"{stats.avg_duration_ms:.2f}", - 'min_duration_ms': f"{stats.min_duration_ms:.2f}" if stats.min_duration_ms != float('inf') else "N/A", - 'max_duration_ms': f"{stats.max_duration_ms:.2f}", - 'last_execution': stats.last_execution.isoformat() if stats.last_execution else None, + report["callbacks"][name] = { + "executions": stats.total_executions, + "success_rate": f"{stats.success_rate:.1f}%", + "avg_duration_ms": f"{stats.avg_duration_ms:.2f}", + "min_duration_ms": f"{stats.min_duration_ms:.2f}" if stats.min_duration_ms != float("inf") else "N/A", + "max_duration_ms": f"{stats.max_duration_ms:.2f}", + "last_execution": stats.last_execution.isoformat() if stats.last_execution else None, } total_duration += stats.total_duration_ms total_count += stats.total_executions if total_count > 0: - report['summary']['avg_duration_ms'] = total_duration / total_count + report["summary"]["avg_duration_ms"] = total_duration / total_count return report @@ -361,7 +359,7 @@ class TimedCallbackExecution: self.success = True self.error_message: str | None = None - def __enter__(self) -> 'TimedCallbackExecution': + def __enter__(self) -> "TimedCallbackExecution": self.start_time = time.perf_counter() return self @@ -442,12 +440,12 @@ def get_callback_execution_order( __all__ = [ - 'CallbackExecutionRecord', - 'CallbackStats', - 'CallbackMonitor', - 'callback_monitor', - 'TimedCallbackExecution', - 'log_transition_start', - 'log_transition_end', - 'get_callback_execution_order', + "CallbackExecutionRecord", + "CallbackStats", + "CallbackMonitor", + "callback_monitor", + "TimedCallbackExecution", + "log_transition_start", + "log_transition_end", + "get_callback_execution_order", ] diff --git a/backend/apps/core/state_machine/registry.py b/backend/apps/core/state_machine/registry.py index f069b04b..6fba18bc 100644 --- a/backend/apps/core/state_machine/registry.py +++ b/backend/apps/core/state_machine/registry.py @@ -1,4 +1,5 @@ """TransitionRegistry - Centralized registry for managing FSM transitions.""" + import logging from collections.abc import Callable from dataclasses import dataclass, field @@ -86,9 +87,7 @@ class TransitionRegistry: self._transitions[key][transition_key] = transition_info return transition_info - def get_transition( - self, choice_group: str, domain: str, source: str, target: str - ) -> TransitionInfo | None: + def get_transition(self, choice_group: str, domain: str, source: str, target: str) -> TransitionInfo | None: """ Retrieve transition info. @@ -109,9 +108,7 @@ class TransitionRegistry: return self._transitions[key].get(transition_key) - def get_available_transitions( - self, choice_group: str, domain: str, current_state: str - ) -> list[TransitionInfo]: + def get_available_transitions(self, choice_group: str, domain: str, current_state: str) -> list[TransitionInfo]: """ Get all valid transitions from a state. @@ -135,9 +132,7 @@ class TransitionRegistry: return available - def get_transition_method_name( - self, choice_group: str, domain: str, source: str, target: str - ) -> str | None: + def get_transition_method_name(self, choice_group: str, domain: str, source: str, target: str) -> str | None: """ Get the method name for a transition. @@ -153,9 +148,7 @@ class TransitionRegistry: transition = self.get_transition(choice_group, domain, source, target) return transition.method_name if transition else None - def validate_transition( - self, choice_group: str, domain: str, source: str, target: str - ) -> bool: + def validate_transition(self, choice_group: str, domain: str, source: str, target: str) -> bool: """ Check if a transition is valid. @@ -168,13 +161,9 @@ class TransitionRegistry: Returns: True if transition is valid """ - return ( - self.get_transition(choice_group, domain, source, target) is not None - ) + return self.get_transition(choice_group, domain, source, target) is not None - def build_registry_from_choices( - self, choice_group: str, domain: str = "core" - ) -> None: + def build_registry_from_choices(self, choice_group: str, domain: str = "core") -> None: """ Automatically populate registry from RichChoice metadata. @@ -194,9 +183,7 @@ class TransitionRegistry: for target in targets: # Use shared method name determination - method_name = determine_method_name_for_transition( - source, target - ) + method_name = determine_method_name_for_transition(source, target) self.register_transition( choice_group=choice_group, @@ -226,9 +213,7 @@ class TransitionRegistry: else: self._transitions.clear() - def export_transition_graph( - self, choice_group: str, domain: str, format: str = "dict" - ) -> Any: + def export_transition_graph(self, choice_group: str, domain: str, format: str = "dict") -> Any: """ Export state machine graph for visualization. @@ -247,7 +232,7 @@ class TransitionRegistry: if format == "dict": graph: dict[str, list[str]] = {} - for (source, target), info in self._transitions[key].items(): + for (source, target), _info in self._transitions[key].items(): if source not in graph: graph[source] = [] graph[source].append(target) @@ -262,10 +247,7 @@ class TransitionRegistry: elif format == "dot": lines = ["digraph {"] for (source, target), info in self._transitions[key].items(): - lines.append( - f' "{source}" -> "{target}" ' - f'[label="{info.method_name}"];' - ) + lines.append(f' "{source}" -> "{target}" ' f'[label="{info.method_name}"];') lines.append("}") return "\n".join(lines) @@ -288,13 +270,14 @@ registry_instance = TransitionRegistry() # Callback registration helpers + def register_callback( model_class: type[models.Model], field_name: str, source: str, target: str, callback: Any, - stage: str = 'post', + stage: str = "post", ) -> None: """ Register a callback for a specific state transition. @@ -325,7 +308,7 @@ def register_notification_callback( source: str, target: str, notification_type: str, - recipient_field: str = 'submitted_by', + recipient_field: str = "submitted_by", ) -> None: """ Register a notification callback for a state transition. @@ -344,15 +327,15 @@ def register_notification_callback( notification_type=notification_type, recipient_field=recipient_field, ) - register_callback(model_class, field_name, source, target, callback, 'post') + register_callback(model_class, field_name, source, target, callback, "post") def register_cache_invalidation( model_class: type[models.Model], field_name: str, cache_patterns: list[str] | None = None, - source: str = '*', - target: str = '*', + source: str = "*", + target: str = "*", ) -> None: """ Register cache invalidation for state transitions. @@ -367,15 +350,15 @@ def register_cache_invalidation( from .callbacks.cache import CacheInvalidationCallback callback = CacheInvalidationCallback(patterns=cache_patterns or []) - register_callback(model_class, field_name, source, target, callback, 'post') + register_callback(model_class, field_name, source, target, callback, "post") def register_related_update( model_class: type[models.Model], field_name: str, update_func: Callable, - source: str = '*', - target: str = '*', + source: str = "*", + target: str = "*", ) -> None: """ Register a related model update callback. @@ -390,7 +373,7 @@ def register_related_update( from .callbacks.related_updates import RelatedModelUpdateCallback callback = RelatedModelUpdateCallback(update_function=update_func) - register_callback(model_class, field_name, source, target, callback, 'post') + register_callback(model_class, field_name, source, target, callback, "post") def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model]: @@ -414,20 +397,20 @@ def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model] Returns: The decorated model class. """ - meta = getattr(cls, 'Meta', None) + meta = getattr(cls, "Meta", None) if not meta: return cls - transition_callbacks = getattr(meta, 'transition_callbacks', None) + transition_callbacks = getattr(meta, "transition_callbacks", None) if not transition_callbacks: return cls # Get the FSM field name - field_name = getattr(meta, 'fsm_field', 'status') + field_name = getattr(meta, "fsm_field", "status") # Register each callback for (source, target), callbacks in transition_callbacks.items(): - if not isinstance(callbacks, (list, tuple)): + if not isinstance(callbacks, list | tuple): callbacks = [callbacks] for callback in callbacks: @@ -455,23 +438,23 @@ def discover_and_register_callbacks() -> None: for model in apps.get_models(): # Check if model has StateMachineMixin - if not hasattr(model, '_fsm_metadata') and not hasattr(model, 'Meta'): + if not hasattr(model, "_fsm_metadata") and not hasattr(model, "Meta"): continue - meta = getattr(model, 'Meta', None) + meta = getattr(model, "Meta", None) if not meta: continue - transition_callbacks = getattr(meta, 'transition_callbacks', None) + transition_callbacks = getattr(meta, "transition_callbacks", None) if not transition_callbacks: continue # Get the FSM field name - field_name = getattr(meta, 'fsm_field', 'status') + field_name = getattr(meta, "fsm_field", "status") # Register callbacks for (source, target), callbacks in transition_callbacks.items(): - if not isinstance(callbacks, (list, tuple)): + if not isinstance(callbacks, list | tuple): callbacks = [callbacks] for callback in callbacks: diff --git a/backend/apps/core/state_machine/signals.py b/backend/apps/core/state_machine/signals.py index 45f19176..da86e745 100644 --- a/backend/apps/core/state_machine/signals.py +++ b/backend/apps/core/state_machine/signals.py @@ -74,7 +74,7 @@ class TransitionSignalHandler: source: str, target: str, handler: Callable, - stage: str = 'post', + stage: str = "post", ) -> None: """ Register a handler for a specific transition. @@ -95,10 +95,7 @@ class TransitionSignalHandler: signal = self._get_signal(stage) self._connect_signal(signal, model_class, source, target, handler) - logger.debug( - f"Registered {stage} transition handler for " - f"{model_class.__name__}: {source} → {target}" - ) + logger.debug(f"Registered {stage} transition handler for " f"{model_class.__name__}: {source} → {target}") def unregister( self, @@ -106,7 +103,7 @@ class TransitionSignalHandler: source: str, target: str, handler: Callable, - stage: str = 'post', + stage: str = "post", ) -> None: """Unregister a previously registered handler.""" key = self._make_key(model_class, source, target, stage) @@ -128,9 +125,9 @@ class TransitionSignalHandler: def _get_signal(self, stage: str) -> Signal: """Get the signal for a given stage.""" - if stage == 'pre': + if stage == "pre": return pre_state_transition - elif stage == 'error': + elif stage == "error": return state_transition_failed return post_state_transition @@ -150,13 +147,13 @@ class TransitionSignalHandler: return # Check source state - signal_source = kwargs.get('source', '') - if source != '*' and str(signal_source) != source: + signal_source = kwargs.get("source", "") + if source != "*" and str(signal_source) != source: return # Check target state - signal_target = kwargs.get('target', '') - if target != '*' and str(signal_target) != target: + signal_target = kwargs.get("target", "") + if target != "*" and str(signal_target) != target: return # Call the handler @@ -174,7 +171,7 @@ def register_transition_handler( source: str, target: str, handler: Callable, - stage: str = 'post', + stage: str = "post", ) -> None: """ Convenience function to register a transition signal handler. @@ -186,9 +183,7 @@ def register_transition_handler( handler: The handler function to call. stage: 'pre', 'post', or 'error'. """ - transition_signal_handler.register( - model_class, source, target, handler, stage - ) + transition_signal_handler.register(model_class, source, target, handler, stage) def connect_fsm_log_signals() -> None: @@ -231,9 +226,9 @@ class TransitionHandlerDecorator: def __init__( self, model_class: type[models.Model], - source: str = '*', - target: str = '*', - stage: str = 'post', + source: str = "*", + target: str = "*", + stage: str = "post", ): """ Initialize the decorator. @@ -263,9 +258,9 @@ class TransitionHandlerDecorator: def on_transition( model_class: type[models.Model], - source: str = '*', - target: str = '*', - stage: str = 'post', + source: str = "*", + target: str = "*", + stage: str = "post", ) -> TransitionHandlerDecorator: """ Decorator factory for registering transition handlers. @@ -289,44 +284,44 @@ def on_transition( def on_pre_transition( model_class: type[models.Model], - source: str = '*', - target: str = '*', + source: str = "*", + target: str = "*", ) -> TransitionHandlerDecorator: """Decorator for pre-transition handlers.""" - return on_transition(model_class, source, target, stage='pre') + return on_transition(model_class, source, target, stage="pre") def on_post_transition( model_class: type[models.Model], - source: str = '*', - target: str = '*', + source: str = "*", + target: str = "*", ) -> TransitionHandlerDecorator: """Decorator for post-transition handlers.""" - return on_transition(model_class, source, target, stage='post') + return on_transition(model_class, source, target, stage="post") def on_transition_error( model_class: type[models.Model], - source: str = '*', - target: str = '*', + source: str = "*", + target: str = "*", ) -> TransitionHandlerDecorator: """Decorator for transition error handlers.""" - return on_transition(model_class, source, target, stage='error') + return on_transition(model_class, source, target, stage="error") __all__ = [ # Signals - 'pre_state_transition', - 'post_state_transition', - 'state_transition_failed', + "pre_state_transition", + "post_state_transition", + "state_transition_failed", # Handler registration - 'TransitionSignalHandler', - 'transition_signal_handler', - 'register_transition_handler', - 'connect_fsm_log_signals', + "TransitionSignalHandler", + "transition_signal_handler", + "register_transition_handler", + "connect_fsm_log_signals", # Decorators - 'on_transition', - 'on_pre_transition', - 'on_post_transition', - 'on_transition_error', + "on_transition", + "on_pre_transition", + "on_post_transition", + "on_transition_error", ] diff --git a/backend/apps/core/state_machine/tests/fixtures.py b/backend/apps/core/state_machine/tests/fixtures.py index 8ebe6400..cd9ddab5 100644 --- a/backend/apps/core/state_machine/tests/fixtures.py +++ b/backend/apps/core/state_machine/tests/fixtures.py @@ -29,11 +29,11 @@ class UserFactory: @classmethod def create_user( cls, - role: str = 'USER', + role: str = "USER", username: str | None = None, email: str | None = None, - password: str = 'testpass123', - **kwargs + password: str = "testpass123", + **kwargs, ) -> User: """ Create a user with specified role. @@ -54,33 +54,27 @@ class UserFactory: if email is None: email = f"{role.lower()}_{uid}@example.com" - return User.objects.create_user( - username=username, - email=email, - password=password, - role=role, - **kwargs - ) + return User.objects.create_user(username=username, email=email, password=password, role=role, **kwargs) @classmethod def create_regular_user(cls, **kwargs) -> User: """Create a regular user.""" - return cls.create_user(role='USER', **kwargs) + return cls.create_user(role="USER", **kwargs) @classmethod def create_moderator(cls, **kwargs) -> User: """Create a moderator user.""" - return cls.create_user(role='MODERATOR', **kwargs) + return cls.create_user(role="MODERATOR", **kwargs) @classmethod def create_admin(cls, **kwargs) -> User: """Create an admin user.""" - return cls.create_user(role='ADMIN', **kwargs) + return cls.create_user(role="ADMIN", **kwargs) @classmethod def create_superuser(cls, **kwargs) -> User: """Create a superuser.""" - return cls.create_user(role='SUPERUSER', **kwargs) + return cls.create_user(role="SUPERUSER", **kwargs) class CompanyFactory: @@ -102,11 +96,7 @@ class CompanyFactory: if name is None: name = f"Test Operator {uid}" - defaults = { - 'name': name, - 'description': f'Test operator company {uid}', - 'roles': ['OPERATOR'] - } + defaults = {"name": name, "description": f"Test operator company {uid}", "roles": ["OPERATOR"]} defaults.update(kwargs) return Company.objects.create(**defaults) @@ -119,11 +109,7 @@ class CompanyFactory: if name is None: name = f"Test Manufacturer {uid}" - defaults = { - 'name': name, - 'description': f'Test manufacturer company {uid}', - 'roles': ['MANUFACTURER'] - } + defaults = {"name": name, "description": f"Test manufacturer company {uid}", "roles": ["MANUFACTURER"]} defaults.update(kwargs) return Company.objects.create(**defaults) @@ -140,11 +126,7 @@ class ParkFactory: @classmethod def create_park( - cls, - name: str | None = None, - operator: Any | None = None, - status: str = 'OPERATING', - **kwargs + cls, name: str | None = None, operator: Any | None = None, status: str = "OPERATING", **kwargs ) -> Any: """ Create a park with specified status. @@ -167,12 +149,12 @@ class ParkFactory: operator = CompanyFactory.create_operator() defaults = { - 'name': name, - 'slug': f'test-park-{uid}', - 'description': f'A test park {uid}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": name, + "slug": f"test-park-{uid}", + "description": f"A test park {uid}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) @@ -194,8 +176,8 @@ class RideFactory: name: str | None = None, park: Any | None = None, manufacturer: Any | None = None, - status: str = 'OPERATING', - **kwargs + status: str = "OPERATING", + **kwargs, ) -> Any: """ Create a ride with specified status. @@ -221,12 +203,12 @@ class RideFactory: manufacturer = CompanyFactory.create_manufacturer() defaults = { - 'name': name, - 'slug': f'test-ride-{uid}', - 'description': f'A test ride {uid}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": name, + "slug": f"test-ride-{uid}", + "description": f"A test ride {uid}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -247,9 +229,9 @@ class EditSubmissionFactory: cls, user: Any | None = None, target_object: Any | None = None, - status: str = 'PENDING', + status: str = "PENDING", changes: dict[str, Any] | None = None, - **kwargs + **kwargs, ) -> Any: """ Create an edit submission. @@ -271,23 +253,20 @@ class EditSubmissionFactory: if user is None: user = UserFactory.create_regular_user() if target_object is None: - target_object = Company.objects.create( - name=f'Target Company {uid}', - description='Test company' - ) + target_object = Company.objects.create(name=f"Target Company {uid}", description="Test company") if changes is None: - changes = {'name': f'Updated Name {uid}'} + changes = {"name": f"Updated Name {uid}"} content_type = ContentType.objects.get_for_model(target_object) defaults = { - 'user': user, - 'content_type': content_type, - 'object_id': target_object.id, - 'submission_type': 'EDIT', - 'changes': changes, - 'status': status, - 'reason': f'Test reason {uid}' + "user": user, + "content_type": content_type, + "object_id": target_object.id, + "submission_type": "EDIT", + "changes": changes, + "status": status, + "reason": f"Test reason {uid}", } defaults.update(kwargs) return EditSubmission.objects.create(**defaults) @@ -305,11 +284,7 @@ class ModerationReportFactory: @classmethod def create_report( - cls, - reporter: Any | None = None, - target_object: Any | None = None, - status: str = 'PENDING', - **kwargs + cls, reporter: Any | None = None, target_object: Any | None = None, status: str = "PENDING", **kwargs ) -> Any: """ Create a moderation report. @@ -330,23 +305,20 @@ class ModerationReportFactory: if reporter is None: reporter = UserFactory.create_regular_user() if target_object is None: - target_object = Company.objects.create( - name=f'Reported Company {uid}', - description='Test company' - ) + target_object = Company.objects.create(name=f"Reported Company {uid}", description="Test company") content_type = ContentType.objects.get_for_model(target_object) defaults = { - 'report_type': 'CONTENT', - 'status': status, - 'priority': 'MEDIUM', - 'reported_entity_type': target_object._meta.model_name, - 'reported_entity_id': target_object.id, - 'content_type': content_type, - 'reason': f'Test reason {uid}', - 'description': f'Test report description {uid}', - 'reported_by': reporter + "report_type": "CONTENT", + "status": status, + "priority": "MEDIUM", + "reported_entity_type": target_object._meta.model_name, + "reported_entity_id": target_object.id, + "content_type": content_type, + "reason": f"Test reason {uid}", + "description": f"Test report description {uid}", + "reported_by": reporter, } defaults.update(kwargs) return ModerationReport.objects.create(**defaults) @@ -369,5 +341,5 @@ class MockInstance: setattr(self, key, value) def __repr__(self): - attrs = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items()) - return f'MockInstance({attrs})' + attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items()) + return f"MockInstance({attrs})" diff --git a/backend/apps/core/state_machine/tests/helpers.py b/backend/apps/core/state_machine/tests/helpers.py index 3c6d038f..9c13a4da 100644 --- a/backend/apps/core/state_machine/tests/helpers.py +++ b/backend/apps/core/state_machine/tests/helpers.py @@ -13,11 +13,7 @@ from typing import Any from django.contrib.contenttypes.models import ContentType -def assert_transition_allowed( - instance: Any, - method_name: str, - user: Any | None = None -) -> bool: +def assert_transition_allowed(instance: Any, method_name: str, user: Any | None = None) -> bool: """ Assert that a transition is allowed. @@ -43,11 +39,7 @@ def assert_transition_allowed( return True -def assert_transition_denied( - instance: Any, - method_name: str, - user: Any | None = None -) -> bool: +def assert_transition_denied(instance: Any, method_name: str, user: Any | None = None) -> bool: """ Assert that a transition is denied. @@ -73,11 +65,7 @@ def assert_transition_denied( return True -def assert_state_log_created( - instance: Any, - expected_state: str, - user: Any | None = None -) -> Any: +def assert_state_log_created(instance: Any, expected_state: str, user: Any | None = None) -> Any: """ Assert that a StateLog entry was created for a transition. @@ -98,11 +86,7 @@ def assert_state_log_created( from django_fsm_log.models import StateLog ct = ContentType.objects.get_for_model(instance) - log = StateLog.objects.filter( - content_type=ct, - object_id=instance.id, - state=expected_state - ).first() + log = StateLog.objects.filter(content_type=ct, object_id=instance.id, state=expected_state).first() assert log is not None, f"StateLog for state '{expected_state}' not found" @@ -132,22 +116,15 @@ def assert_state_log_count(instance: Any, expected_count: int) -> list[Any]: from django_fsm_log.models import StateLog ct = ContentType.objects.get_for_model(instance) - logs = list(StateLog.objects.filter( - content_type=ct, - object_id=instance.id - ).order_by('timestamp')) + logs = list(StateLog.objects.filter(content_type=ct, object_id=instance.id).order_by("timestamp")) actual_count = len(logs) - assert actual_count == expected_count, \ - f"Expected {expected_count} StateLog entries, got {actual_count}" + assert actual_count == expected_count, f"Expected {expected_count} StateLog entries, got {actual_count}" return logs -def assert_state_transition_sequence( - instance: Any, - expected_states: list[str] -) -> list[Any]: +def assert_state_transition_sequence(instance: Any, expected_states: list[str]) -> list[Any]: """ Assert that state transitions occurred in a specific sequence. @@ -167,24 +144,15 @@ def assert_state_transition_sequence( from django_fsm_log.models import StateLog ct = ContentType.objects.get_for_model(instance) - logs = list(StateLog.objects.filter( - content_type=ct, - object_id=instance.id - ).order_by('timestamp')) + logs = list(StateLog.objects.filter(content_type=ct, object_id=instance.id).order_by("timestamp")) actual_states = [log.state for log in logs] - assert actual_states == expected_states, \ - f"Expected state sequence {expected_states}, got {actual_states}" + assert actual_states == expected_states, f"Expected state sequence {expected_states}, got {actual_states}" return logs -def assert_guard_passes( - guard: Callable, - instance: Any, - user: Any | None = None, - message: str = "" -) -> bool: +def assert_guard_passes(guard: Callable, instance: Any, user: Any | None = None, message: str = "") -> bool: """ Assert that a guard function passes. @@ -210,11 +178,7 @@ def assert_guard_passes( def assert_guard_fails( - guard: Callable, - instance: Any, - user: Any | None = None, - expected_error_code: str | None = None, - message: str = "" + guard: Callable, instance: Any, user: Any | None = None, expected_error_code: str | None = None, message: str = "" ) -> bool: """ Assert that a guard function fails. @@ -239,19 +203,15 @@ def assert_guard_fails( fail_message = message or f"Guard should fail but returned {result}" assert result is False, fail_message - if expected_error_code and hasattr(guard, 'error_code'): - assert guard.error_code == expected_error_code, \ - f"Expected error code {expected_error_code}, got {guard.error_code}" + if expected_error_code and hasattr(guard, "error_code"): + assert ( + guard.error_code == expected_error_code + ), f"Expected error code {expected_error_code}, got {guard.error_code}" return True -def transition_and_save( - instance: Any, - transition_method: str, - user: Any | None = None, - **kwargs -) -> Any: +def transition_and_save(instance: Any, transition_method: str, user: Any | None = None, **kwargs) -> Any: """ Execute a transition and save the instance. @@ -290,10 +250,10 @@ def get_available_transitions(instance: Any) -> list[str]: """ # Get the state field name from the instance - state_field = getattr(instance, 'state_field_name', 'status') + state_field = getattr(instance, "state_field_name", "status") # Build the function name dynamically - func_name = f'get_available_{state_field}_transitions' + func_name = f"get_available_{state_field}_transitions" if hasattr(instance, func_name): get_transitions = getattr(instance, func_name) return [t.name for t in get_transitions()] @@ -301,19 +261,13 @@ def get_available_transitions(instance: Any) -> list[str]: # Fallback: look for transition methods transitions = [] for attr_name in dir(instance): - if attr_name.startswith('transition_to_'): + if attr_name.startswith("transition_to_"): transitions.append(attr_name) return transitions -def create_transition_context( - instance: Any, - from_state: str, - to_state: str, - user: Any | None = None, - **extra -) -> dict: +def create_transition_context(instance: Any, from_state: str, to_state: str, user: Any | None = None, **extra) -> dict: """ Create a mock transition context dictionary. @@ -331,11 +285,11 @@ def create_transition_context( context = create_transition_context(submission, 'PENDING', 'APPROVED', moderator) """ return { - 'instance': instance, - 'from_state': from_state, - 'to_state': to_state, - 'user': user, - 'model_class': type(instance), - 'transition_name': f'transition_to_{to_state.lower()}', - **extra + "instance": instance, + "from_state": from_state, + "to_state": to_state, + "user": user, + "model_class": type(instance), + "transition_name": f"transition_to_{to_state.lower()}", + **extra, } diff --git a/backend/apps/core/state_machine/tests/test_builder.py b/backend/apps/core/state_machine/tests/test_builder.py index b71e461a..b94cacce 100644 --- a/backend/apps/core/state_machine/tests/test_builder.py +++ b/backend/apps/core/state_machine/tests/test_builder.py @@ -1,4 +1,5 @@ """Tests for StateTransitionBuilder.""" + import pytest from django.core.exceptions import ImproperlyConfigured diff --git a/backend/apps/core/state_machine/tests/test_callbacks.py b/backend/apps/core/state_machine/tests/test_callbacks.py index 58d5bb05..f62495f8 100644 --- a/backend/apps/core/state_machine/tests/test_callbacks.py +++ b/backend/apps/core/state_machine/tests/test_callbacks.py @@ -19,12 +19,7 @@ class CallbackContext: """Mock context for testing callbacks.""" def __init__( - self, - instance: Any = None, - from_state: str = 'PENDING', - to_state: str = 'APPROVED', - user: Any = None, - **extra + self, instance: Any = None, from_state: str = "PENDING", to_state: str = "APPROVED", user: Any = None, **extra ): self.instance = instance or Mock() self.from_state = from_state @@ -34,18 +29,18 @@ class CallbackContext: def to_dict(self) -> dict[str, Any]: return { - 'instance': self.instance, - 'from_state': self.from_state, - 'to_state': self.to_state, - 'user': self.user, - **self.extra + "instance": self.instance, + "from_state": self.from_state, + "to_state": self.to_state, + "user": self.user, + **self.extra, } class MockCallback: """Mock callback for testing.""" - def __init__(self, name: str = 'callback', should_raise: bool = False): + def __init__(self, name: str = "callback", should_raise: bool = False): self.name = name self.calls: list[dict] = [] self.should_raise = should_raise @@ -71,40 +66,40 @@ class PreTransitionCallbackTests(TestCase): def test_pre_callback_executes_before_state_change(self): """Test that pre-transition callback executes before state changes.""" - callback = MockCallback('pre_callback') - context = CallbackContext(from_state='PENDING', to_state='APPROVED') + callback = MockCallback("pre_callback") + context = CallbackContext(from_state="PENDING", to_state="APPROVED") # Simulate pre-transition execution callback(context.to_dict()) self.assertTrue(callback.was_called()) - self.assertEqual(callback.calls[0]['from_state'], 'PENDING') - self.assertEqual(callback.calls[0]['to_state'], 'APPROVED') + self.assertEqual(callback.calls[0]["from_state"], "PENDING") + self.assertEqual(callback.calls[0]["to_state"], "APPROVED") def test_pre_callback_receives_instance(self): """Test that pre-callback receives the model instance.""" mock_instance = Mock() mock_instance.id = 123 - mock_instance.status = 'PENDING' + mock_instance.status = "PENDING" callback = MockCallback() context = CallbackContext(instance=mock_instance) callback(context.to_dict()) - self.assertEqual(callback.calls[0]['instance'], mock_instance) + self.assertEqual(callback.calls[0]["instance"], mock_instance) def test_pre_callback_receives_user(self): """Test that pre-callback receives the user performing transition.""" mock_user = Mock() - mock_user.username = 'moderator' + mock_user.username = "moderator" callback = MockCallback() context = CallbackContext(user=mock_user) callback(context.to_dict()) - self.assertEqual(callback.calls[0]['user'], mock_user) + self.assertEqual(callback.calls[0]["user"], mock_user) def test_pre_callback_can_prevent_transition(self): """Test that pre-callback can prevent transition by raising exception.""" @@ -119,13 +114,13 @@ class PreTransitionCallbackTests(TestCase): execution_order = [] def callback_1(ctx): - execution_order.append('first') + execution_order.append("first") def callback_2(ctx): - execution_order.append('second') + execution_order.append("second") def callback_3(ctx): - execution_order.append('third') + execution_order.append("third") context = CallbackContext().to_dict() @@ -134,7 +129,7 @@ class PreTransitionCallbackTests(TestCase): callback_2(context) callback_3(context) - self.assertEqual(execution_order, ['first', 'second', 'third']) + self.assertEqual(execution_order, ["first", "second", "third"]) class PostTransitionCallbackTests(TestCase): @@ -142,28 +137,24 @@ class PostTransitionCallbackTests(TestCase): def test_post_callback_executes_after_state_change(self): """Test that post-transition callback executes after state changes.""" - callback = MockCallback('post_callback') + callback = MockCallback("post_callback") # Simulate instance after transition mock_instance = Mock() - mock_instance.status = 'APPROVED' # Already changed + mock_instance.status = "APPROVED" # Already changed - context = CallbackContext( - instance=mock_instance, - from_state='PENDING', - to_state='APPROVED' - ) + context = CallbackContext(instance=mock_instance, from_state="PENDING", to_state="APPROVED") callback(context.to_dict()) self.assertTrue(callback.was_called()) - self.assertEqual(callback.calls[0]['instance'].status, 'APPROVED') + self.assertEqual(callback.calls[0]["instance"].status, "APPROVED") def test_post_callback_receives_updated_instance(self): """Test that post-callback receives instance with new state.""" mock_instance = Mock() - mock_instance.status = 'APPROVED' - mock_instance.approved_at = '2025-01-15' + mock_instance.status = "APPROVED" + mock_instance.approved_at = "2025-01-15" mock_instance.handled_by_id = 456 callback = MockCallback() @@ -171,9 +162,9 @@ class PostTransitionCallbackTests(TestCase): callback(context.to_dict()) - instance = callback.calls[0]['instance'] - self.assertEqual(instance.status, 'APPROVED') - self.assertEqual(instance.approved_at, '2025-01-15') + instance = callback.calls[0]["instance"] + self.assertEqual(instance.status, "APPROVED") + self.assertEqual(instance.approved_at, "2025-01-15") def test_post_callback_failure_does_not_rollback(self): """Test that post-callback failures don't rollback the transition.""" @@ -193,13 +184,13 @@ class PostTransitionCallbackTests(TestCase): execution_order = [] def notification_callback(ctx): - execution_order.append('notification') + execution_order.append("notification") def cache_callback(ctx): - execution_order.append('cache') + execution_order.append("cache") def analytics_callback(ctx): - execution_order.append('analytics') + execution_order.append("analytics") context = CallbackContext().to_dict() @@ -207,7 +198,7 @@ class PostTransitionCallbackTests(TestCase): cache_callback(context) analytics_callback(context) - self.assertEqual(execution_order, ['notification', 'cache', 'analytics']) + self.assertEqual(execution_order, ["notification", "cache", "analytics"]) class ErrorCallbackTests(TestCase): @@ -221,17 +212,17 @@ class ErrorCallbackTests(TestCase): raise ValueError("Transition failed") except ValueError as e: error_context = { - 'instance': Mock(), - 'from_state': 'PENDING', - 'to_state': 'APPROVED', - 'exception': e, - 'exception_type': type(e).__name__ + "instance": Mock(), + "from_state": "PENDING", + "to_state": "APPROVED", + "exception": e, + "exception_type": type(e).__name__, } error_callback(error_context) self.assertTrue(error_callback.was_called()) - self.assertIn('exception', error_callback.calls[0]) - self.assertEqual(error_callback.calls[0]['exception_type'], 'ValueError') + self.assertIn("exception", error_callback.calls[0]) + self.assertEqual(error_callback.calls[0]["exception_type"], "ValueError") def test_error_callback_for_cleanup(self): """Test that error callbacks can perform cleanup.""" @@ -244,7 +235,7 @@ class ErrorCallbackTests(TestCase): try: raise ValueError("Transition failed") except ValueError: - cleanup_callback({'exception': 'test'}) + cleanup_callback({"exception": "test"}) self.assertTrue(cleanup_performed) @@ -256,17 +247,17 @@ class ErrorCallbackTests(TestCase): error_callback = MockCallback() error_context = { - 'instance': mock_instance, - 'from_state': 'PENDING', - 'to_state': 'APPROVED', - 'user': mock_user, - 'exception': ValueError("Test error") + "instance": mock_instance, + "from_state": "PENDING", + "to_state": "APPROVED", + "user": mock_user, + "exception": ValueError("Test error"), } error_callback(error_context) - self.assertEqual(error_callback.calls[0]['instance'], mock_instance) - self.assertEqual(error_callback.calls[0]['user'], mock_user) + self.assertEqual(error_callback.calls[0]["instance"], mock_instance) + self.assertEqual(error_callback.calls[0]["user"], mock_user) class ConditionalCallbackTests(TestCase): @@ -277,15 +268,15 @@ class ConditionalCallbackTests(TestCase): execution_log = [] def approval_only_callback(ctx): - if ctx.get('to_state') == 'APPROVED': - execution_log.append('approved') + if ctx.get("to_state") == "APPROVED": + execution_log.append("approved") # Transition to APPROVED - should execute - approval_only_callback({'to_state': 'APPROVED'}) + approval_only_callback({"to_state": "APPROVED"}) self.assertEqual(len(execution_log), 1) # Transition to REJECTED - should not execute - approval_only_callback({'to_state': 'REJECTED'}) + approval_only_callback({"to_state": "REJECTED"}) self.assertEqual(len(execution_log), 1) # Still 1 def test_callback_with_transition_filter(self): @@ -293,15 +284,15 @@ class ConditionalCallbackTests(TestCase): execution_log = [] def escalation_callback(ctx): - if ctx.get('to_state') == 'ESCALATED': - execution_log.append('escalated') + if ctx.get("to_state") == "ESCALATED": + execution_log.append("escalated") # Escalation - should execute - escalation_callback({'to_state': 'ESCALATED'}) + escalation_callback({"to_state": "ESCALATED"}) self.assertEqual(len(execution_log), 1) # Other transitions - should not execute - escalation_callback({'to_state': 'APPROVED'}) + escalation_callback({"to_state": "APPROVED"}) self.assertEqual(len(execution_log), 1) def test_callback_with_user_role_filter(self): @@ -309,17 +300,17 @@ class ConditionalCallbackTests(TestCase): admin_notifications = [] def admin_only_notification(ctx): - user = ctx.get('user') - if user and getattr(user, 'role', None) == 'ADMIN': + user = ctx.get("user") + if user and getattr(user, "role", None) == "ADMIN": admin_notifications.append(ctx) - admin_user = Mock(role='ADMIN') - moderator_user = Mock(role='MODERATOR') + admin_user = Mock(role="ADMIN") + moderator_user = Mock(role="MODERATOR") - admin_only_notification({'user': admin_user}) + admin_only_notification({"user": admin_user}) self.assertEqual(len(admin_notifications), 1) - admin_only_notification({'user': moderator_user}) + admin_only_notification({"user": moderator_user}) self.assertEqual(len(admin_notifications), 1) # Still 1 @@ -331,29 +322,29 @@ class CallbackChainTests(TestCase): results = [] callbacks = [ - lambda ctx: results.append('a'), - lambda ctx: results.append('b'), - lambda ctx: results.append('c'), + lambda ctx: results.append("a"), + lambda ctx: results.append("b"), + lambda ctx: results.append("c"), ] context = {} for cb in callbacks: cb(context) - self.assertEqual(results, ['a', 'b', 'c']) + self.assertEqual(results, ["a", "b", "c"]) def test_callback_chain_stops_on_failure(self): """Test that callback chain stops when a callback fails.""" results = [] def callback_a(ctx): - results.append('a') + results.append("a") def callback_b(ctx): raise ValueError("B failed") def callback_c(ctx): - results.append('c') + results.append("c") callbacks = [callback_a, callback_b, callback_c] @@ -364,7 +355,7 @@ class CallbackChainTests(TestCase): except ValueError: break - self.assertEqual(results, ['a']) # c never executed + self.assertEqual(results, ["a"]) # c never executed def test_callback_chain_with_continue_on_error(self): """Test callback chain that continues despite errors.""" @@ -372,13 +363,13 @@ class CallbackChainTests(TestCase): errors = [] def callback_a(ctx): - results.append('a') + results.append("a") def callback_b(ctx): raise ValueError("B failed") def callback_c(ctx): - results.append('c') + results.append("c") callbacks = [callback_a, callback_b, callback_c] @@ -389,7 +380,7 @@ class CallbackChainTests(TestCase): except Exception as e: errors.append(str(e)) - self.assertEqual(results, ['a', 'c']) + self.assertEqual(results, ["a", "c"]) self.assertEqual(len(errors), 1) @@ -399,36 +390,30 @@ class CallbackContextEnrichmentTests(TestCase): def test_context_includes_model_class(self): """Test that context includes the model class.""" mock_instance = Mock() - mock_instance.__class__.__name__ = 'EditSubmission' + mock_instance.__class__.__name__ = "EditSubmission" - context = { - 'instance': mock_instance, - 'model_class': type(mock_instance) - } + context = {"instance": mock_instance, "model_class": type(mock_instance)} - self.assertIn('model_class', context) + self.assertIn("model_class", context) def test_context_includes_transition_name(self): """Test that context includes the transition method name.""" context = { - 'instance': Mock(), - 'from_state': 'PENDING', - 'to_state': 'APPROVED', - 'transition_name': 'transition_to_approved' + "instance": Mock(), + "from_state": "PENDING", + "to_state": "APPROVED", + "transition_name": "transition_to_approved", } - self.assertEqual(context['transition_name'], 'transition_to_approved') + self.assertEqual(context["transition_name"], "transition_to_approved") def test_context_includes_timestamp(self): """Test that context includes transition timestamp.""" from django.utils import timezone - context = { - 'instance': Mock(), - 'timestamp': timezone.now() - } + context = {"instance": Mock(), "timestamp": timezone.now()} - self.assertIn('timestamp', context) + self.assertIn("timestamp", context) # ============================================================================ @@ -446,9 +431,9 @@ class NotificationCallbackTests(TestCase): def _create_transition_context( self, - model_name: str = 'EditSubmission', - source_state: str = 'PENDING', - target_state: str = 'APPROVED', + model_name: str = "EditSubmission", + source_state: str = "PENDING", + target_state: str = "APPROVED", user=None, instance=None, ): @@ -465,18 +450,18 @@ class NotificationCallbackTests(TestCase): if user is None: user = Mock() user.pk = 1 - user.username = 'moderator' + user.username = "moderator" return TransitionContext( instance=instance, - field_name='status', + field_name="status", source_state=source_state, target_state=target_state, user=user, timestamp=timezone.now(), ) - @patch('apps.core.state_machine.callbacks.notifications.NotificationService') + @patch("apps.core.state_machine.callbacks.notifications.NotificationService") def test_notification_callback_approval_title(self, mock_service_class): """Test NotificationCallback generates correct title for approvals.""" from ..callbacks.notifications import NotificationCallback @@ -487,8 +472,8 @@ class NotificationCallbackTests(TestCase): callback = NotificationCallback() context = self._create_transition_context( - source_state='PENDING', - target_state='APPROVED', + source_state="PENDING", + target_state="APPROVED", ) callback.execute(context) @@ -496,9 +481,9 @@ class NotificationCallbackTests(TestCase): # Check that notification was sent with correct title if mock_service.send_notification.called: call_args = mock_service.send_notification.call_args - self.assertIn('approved', call_args[1].get('title', '').lower()) + self.assertIn("approved", call_args[1].get("title", "").lower()) - @patch('apps.core.state_machine.callbacks.notifications.NotificationService') + @patch("apps.core.state_machine.callbacks.notifications.NotificationService") def test_notification_callback_rejection_title(self, mock_service_class): """Test NotificationCallback generates correct title for rejections.""" from ..callbacks.notifications import NotificationCallback @@ -509,17 +494,17 @@ class NotificationCallbackTests(TestCase): callback = NotificationCallback() context = self._create_transition_context( - source_state='PENDING', - target_state='REJECTED', + source_state="PENDING", + target_state="REJECTED", ) callback.execute(context) if mock_service.send_notification.called: call_args = mock_service.send_notification.call_args - self.assertIn('rejected', call_args[1].get('title', '').lower()) + self.assertIn("rejected", call_args[1].get("title", "").lower()) - @patch('apps.core.state_machine.callbacks.notifications.NotificationService') + @patch("apps.core.state_machine.callbacks.notifications.NotificationService") def test_moderation_notification_recipient_selection(self, mock_service_class): """Test ModerationNotificationCallback sends to correct recipient.""" from ..callbacks.notifications import ModerationNotificationCallback @@ -530,16 +515,16 @@ class NotificationCallbackTests(TestCase): submitter = Mock() submitter.pk = 999 - submitter.username = 'submitter' + submitter.username = "submitter" instance = Mock() instance.pk = 123 - instance.__class__.__name__ = 'EditSubmission' + instance.__class__.__name__ = "EditSubmission" instance.user = submitter # The submitter who should receive notification callback = ModerationNotificationCallback() context = self._create_transition_context( - target_state='APPROVED', + target_state="APPROVED", instance=instance, ) @@ -548,10 +533,10 @@ class NotificationCallbackTests(TestCase): if mock_service.send_notification.called: call_args = mock_service.send_notification.call_args # Should notify the submitter about their submission - recipient = call_args[1].get('user') or call_args[0][0] if call_args[0] else None + recipient = call_args[1].get("user") or call_args[0][0] if call_args[0] else None self.assertIsNotNone(recipient) - @patch('apps.core.state_machine.callbacks.notifications.NotificationService') + @patch("apps.core.state_machine.callbacks.notifications.NotificationService") def test_notification_callback_handles_service_error(self, mock_service_class): """Test NotificationCallback handles service errors gracefully.""" from ..callbacks.notifications import NotificationCallback @@ -568,7 +553,7 @@ class NotificationCallbackTests(TestCase): # Callback may return False on error but should not raise self.assertIsNotNone(result) - @patch('apps.core.state_machine.callbacks.notifications.NotificationService') + @patch("apps.core.state_machine.callbacks.notifications.NotificationService") def test_notification_callback_message_includes_model_info(self, mock_service_class): """Test notification message includes model information.""" from ..callbacks.notifications import NotificationCallback @@ -578,13 +563,13 @@ class NotificationCallbackTests(TestCase): mock_service_class.return_value = mock_service callback = NotificationCallback() - context = self._create_transition_context(model_name='PhotoSubmission') + context = self._create_transition_context(model_name="PhotoSubmission") callback.execute(context) if mock_service.send_notification.called: call_args = mock_service.send_notification.call_args - message = call_args[1].get('message', '') + message = call_args[1].get("message", "") # Should reference the submission type or model self.assertIsInstance(message, str) @@ -599,10 +584,10 @@ class CacheCallbackTests(TestCase): def _create_transition_context( self, - model_name: str = 'Park', + model_name: str = "Park", instance_id: int = 123, - source_state: str = 'OPERATING', - target_state: str = 'CLOSED_TEMP', + source_state: str = "OPERATING", + target_state: str = "CLOSED_TEMP", ): """Helper to create a TransitionContext.""" from django.utils import timezone @@ -615,14 +600,14 @@ class CacheCallbackTests(TestCase): return TransitionContext( instance=instance, - field_name='status', + field_name="status", source_state=source_state, target_state=target_state, user=Mock(), timestamp=timezone.now(), ) - @patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service') + @patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service") def test_cache_callback_invalidates_model_patterns(self, mock_get_service): """Test CacheInvalidationCallback invalidates correct patterns.""" from ..callbacks.cache import CacheInvalidationCallback @@ -631,9 +616,7 @@ class CacheCallbackTests(TestCase): mock_cache.invalidate_pattern = Mock() mock_get_service.return_value = mock_cache - callback = CacheInvalidationCallback( - patterns=['*park:123*', '*parks*'] - ) + callback = CacheInvalidationCallback(patterns=["*park:123*", "*parks*"]) context = self._create_transition_context() callback.execute(context) @@ -641,7 +624,7 @@ class CacheCallbackTests(TestCase): # Should have called invalidate_pattern for each pattern self.assertTrue(mock_cache.invalidate_pattern.called) - @patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service') + @patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service") def test_cache_callback_generates_instance_patterns(self, mock_get_service): """Test CacheInvalidationCallback generates instance-specific patterns.""" from ..callbacks.cache import CacheInvalidationCallback @@ -651,30 +634,25 @@ class CacheCallbackTests(TestCase): mock_get_service.return_value = mock_cache callback = CacheInvalidationCallback(include_instance_patterns=True) - context = self._create_transition_context( - model_name='Park', - instance_id=456 - ) + context = self._create_transition_context(model_name="Park", instance_id=456) callback.execute(context) # Should have called invalidate_pattern with instance-specific patterns self.assertTrue(mock_cache.invalidate_pattern.called) - patterns_called = [ - call[0][0] for call in mock_cache.invalidate_pattern.call_args_list - ] + patterns_called = [call[0][0] for call in mock_cache.invalidate_pattern.call_args_list] # Should include patterns containing the instance ID - has_instance_pattern = any('456' in p for p in patterns_called) + has_instance_pattern = any("456" in p for p in patterns_called) self.assertTrue(has_instance_pattern, f"No pattern with instance ID in {patterns_called}") - @patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service') + @patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service") def test_cache_callback_handles_service_unavailable(self, mock_get_service): """Test CacheInvalidationCallback handles unavailable cache service.""" from ..callbacks.cache import CacheInvalidationCallback mock_get_service.return_value = None - callback = CacheInvalidationCallback(patterns=['*test*']) + callback = CacheInvalidationCallback(patterns=["*test*"]) context = self._create_transition_context() # Should not raise, uses fallback @@ -682,7 +660,7 @@ class CacheCallbackTests(TestCase): # Should return True (fallback succeeds) self.assertTrue(result) - @patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service') + @patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service") def test_cache_callback_continues_on_pattern_error(self, mock_get_service): """Test CacheInvalidationCallback continues if individual pattern fails.""" from ..callbacks.cache import CacheInvalidationCallback @@ -693,16 +671,13 @@ class CacheCallbackTests(TestCase): def invalidate_side_effect(pattern): nonlocal call_count call_count += 1 - if 'bad' in pattern: + if "bad" in pattern: raise Exception("Pattern invalid") mock_cache.invalidate_pattern = Mock(side_effect=invalidate_side_effect) mock_get_service.return_value = mock_cache - callback = CacheInvalidationCallback( - patterns=['good:*', 'bad:*', 'another:*'], - include_instance_patterns=False - ) + callback = CacheInvalidationCallback(patterns=["good:*", "bad:*", "another:*"], include_instance_patterns=False) context = self._create_transition_context() # Should not raise overall @@ -716,7 +691,7 @@ class ModelCacheInvalidationTests(TestCase): def _create_transition_context( self, - model_name: str = 'Ride', + model_name: str = "Ride", instance_id: int = 789, ): from django.utils import timezone @@ -728,20 +703,20 @@ class ModelCacheInvalidationTests(TestCase): instance.__class__.__name__ = model_name # Add park reference for rides - if model_name == 'Ride': + if model_name == "Ride": instance.park = Mock() instance.park.pk = 111 return TransitionContext( instance=instance, - field_name='status', - source_state='OPERATING', - target_state='CLOSED_TEMP', + field_name="status", + source_state="OPERATING", + target_state="CLOSED_TEMP", user=Mock(), timestamp=timezone.now(), ) - @patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service') + @patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service") def test_ride_cache_includes_park_patterns(self, mock_get_service): """Test RideCacheInvalidation includes parent park patterns.""" from ..callbacks.cache import RideCacheInvalidation @@ -755,12 +730,10 @@ class ModelCacheInvalidationTests(TestCase): callback.execute(context) - patterns_called = [ - call[0][0] for call in mock_cache.invalidate_pattern.call_args_list - ] + patterns_called = [call[0][0] for call in mock_cache.invalidate_pattern.call_args_list] # Should include park patterns (parent park ID is 111) - has_park_pattern = any('park' in p.lower() for p in patterns_called) + has_park_pattern = any("park" in p.lower() for p in patterns_called) self.assertTrue(has_park_pattern, f"No park pattern in {patterns_called}") @@ -775,17 +748,18 @@ class RelatedUpdateCallbackTests(TestCase): def setUp(self): """Set up test fixtures.""" from django.contrib.auth import get_user_model + get_user_model() self.user = Mock() self.user.pk = 1 - self.user.username = 'testuser' + self.user.username = "testuser" def _create_transition_context( self, - model_name: str = 'Ride', + model_name: str = "Ride", instance=None, - target_state: str = 'OPERATING', + target_state: str = "OPERATING", ): from django.utils import timezone @@ -798,8 +772,8 @@ class RelatedUpdateCallbackTests(TestCase): return TransitionContext( instance=instance, - field_name='status', - source_state='UNDER_CONSTRUCTION', + field_name="status", + source_state="UNDER_CONSTRUCTION", target_state=target_state, user=self.user, timestamp=timezone.now(), @@ -819,7 +793,7 @@ class RelatedUpdateCallbackTests(TestCase): # Create mock ride that belongs to park mock_ride = Mock() mock_ride.pk = 200 - mock_ride.__class__.__name__ = 'Ride' + mock_ride.__class__.__name__ = "Ride" mock_ride.park = mock_park mock_ride.is_coaster = True @@ -830,9 +804,9 @@ class RelatedUpdateCallbackTests(TestCase): callback = ParkCountUpdateCallback() context = self._create_transition_context( - model_name='Ride', + model_name="Ride", instance=mock_ride, - target_state='OPERATING', + target_state="OPERATING", ) # Execute callback @@ -847,12 +821,12 @@ class RelatedUpdateCallbackTests(TestCase): mock_ride = Mock() mock_ride.pk = 200 - mock_ride.__class__.__name__ = 'Ride' + mock_ride.__class__.__name__ = "Ride" mock_ride.park = None # No park callback = ParkCountUpdateCallback() context = self._create_transition_context( - model_name='Ride', + model_name="Ride", instance=mock_ride, ) @@ -872,15 +846,15 @@ class RelatedUpdateCallbackTests(TestCase): mock_ride = Mock() mock_ride.pk = 200 - mock_ride.__class__.__name__ = 'Ride' + mock_ride.__class__.__name__ = "Ride" mock_ride.park = mock_park mock_ride.is_coaster = False callback = ParkCountUpdateCallback() context = self._create_transition_context( - model_name='Ride', + model_name="Ride", instance=mock_ride, - target_state='OPERATING', + target_state="OPERATING", ) callback.execute(context) @@ -901,15 +875,15 @@ class RelatedUpdateCallbackTests(TestCase): mock_ride = Mock() mock_ride.pk = 200 - mock_ride.__class__.__name__ = 'Ride' + mock_ride.__class__.__name__ = "Ride" mock_ride.park = mock_park mock_ride.is_coaster = True callback = ParkCountUpdateCallback() context = self._create_transition_context( - model_name='Ride', + model_name="Ride", instance=mock_ride, - target_state='CLOSED_PERM', + target_state="CLOSED_PERM", ) result = callback.execute(context) @@ -931,18 +905,18 @@ class CallbackErrorHandlingTests(TestCase): instance = Mock() instance.pk = 1 - instance.__class__.__name__ = 'EditSubmission' + instance.__class__.__name__ = "EditSubmission" return TransitionContext( instance=instance, - field_name='status', - source_state='PENDING', - target_state='APPROVED', + field_name="status", + source_state="PENDING", + target_state="APPROVED", user=Mock(), timestamp=timezone.now(), ) - @patch('apps.core.state_machine.callbacks.notifications.NotificationService') + @patch("apps.core.state_machine.callbacks.notifications.NotificationService") def test_notification_callback_logs_error_on_failure(self, mock_service_class): """Test NotificationCallback logs errors when service fails.""" import logging @@ -957,7 +931,7 @@ class CallbackErrorHandlingTests(TestCase): context = self._create_transition_context() with self.assertLogs(level=logging.WARNING): - try: + try: # noqa: SIM105 callback.execute(context) except Exception: pass # May or may not raise depending on implementation @@ -965,7 +939,7 @@ class CallbackErrorHandlingTests(TestCase): # Should have logged something about the error # (Logging behavior depends on implementation) - @patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service') + @patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service") def test_cache_callback_returns_false_on_total_failure(self, mock_get_service): """Test CacheInvalidationCallback returns False on complete failure.""" from ..callbacks.cache import CacheInvalidationCallback @@ -974,10 +948,7 @@ class CallbackErrorHandlingTests(TestCase): mock_cache.invalidate_pattern = Mock(side_effect=Exception("Cache error")) mock_get_service.return_value = mock_cache - callback = CacheInvalidationCallback( - patterns=['*test*'], - include_instance_patterns=False - ) + callback = CacheInvalidationCallback(patterns=["*test*"], include_instance_patterns=False) context = self._create_transition_context() result = callback.execute(context) @@ -993,18 +964,18 @@ class CallbackErrorHandlingTests(TestCase): instance = Mock() instance.pk = 1 - instance.__class__.__name__ = 'EditSubmission' + instance.__class__.__name__ = "EditSubmission" context = TransitionContext( instance=instance, - field_name='status', - source_state='PENDING', - target_state='APPROVED', + field_name="status", + source_state="PENDING", + target_state="APPROVED", user=None, # No user timestamp=timezone.now(), ) - with patch('apps.core.state_machine.callbacks.notifications.NotificationService'): + with patch("apps.core.state_machine.callbacks.notifications.NotificationService"): callback = NotificationCallback() # Should not raise with None user try: diff --git a/backend/apps/core/state_machine/tests/test_decorators.py b/backend/apps/core/state_machine/tests/test_decorators.py index fc4f7c6d..9cc69d89 100644 --- a/backend/apps/core/state_machine/tests/test_decorators.py +++ b/backend/apps/core/state_machine/tests/test_decorators.py @@ -1,4 +1,5 @@ """Tests for transition decorator generation.""" + from unittest.mock import Mock from apps.core.state_machine.decorators import ( @@ -11,9 +12,7 @@ from apps.core.state_machine.decorators import ( def test_generate_transition_decorator(): """Test basic transition decorator generation.""" - decorator = generate_transition_decorator( - source="pending", target="approved", field_name="status" - ) + decorator = generate_transition_decorator(source="pending", target="approved", field_name="status") assert callable(decorator) @@ -72,9 +71,7 @@ def test_create_transition_method_with_callbacks(): def test_factory_create_approve_method(): """Test approval method creation.""" factory = TransitionMethodFactory() - method = factory.create_approve_method( - source="pending", target="approved", field_name="status" - ) + method = factory.create_approve_method(source="pending", target="approved", field_name="status") assert callable(method) assert method.__name__ == "approve" @@ -82,9 +79,7 @@ def test_factory_create_approve_method(): def test_factory_create_reject_method(): """Test rejection method creation.""" factory = TransitionMethodFactory() - method = factory.create_reject_method( - source="pending", target="rejected", field_name="status" - ) + method = factory.create_reject_method(source="pending", target="rejected", field_name="status") assert callable(method) assert method.__name__ == "reject" @@ -92,9 +87,7 @@ def test_factory_create_reject_method(): def test_factory_create_escalate_method(): """Test escalation method creation.""" factory = TransitionMethodFactory() - method = factory.create_escalate_method( - source="pending", target="escalated", field_name="status" - ) + method = factory.create_escalate_method(source="pending", target="escalated", field_name="status") assert callable(method) assert method.__name__ == "escalate" @@ -145,16 +138,14 @@ def test_with_transition_logging(): def test_method_signature_generation(): """Test that generated methods have proper signatures.""" factory = TransitionMethodFactory() - method = factory.create_approve_method( - source="pending", target="approved" - ) + method = factory.create_approve_method(source="pending", target="approved") # Check method accepts expected parameters mock_instance = Mock() mock_user = Mock() # Should not raise - try: + try: # noqa: SIM105 method(mock_instance, user=mock_user, comment="test") except Exception: # May fail due to django-fsm not being fully configured diff --git a/backend/apps/core/state_machine/tests/test_guards.py b/backend/apps/core/state_machine/tests/test_guards.py index 8fd75565..bc13484d 100644 --- a/backend/apps/core/state_machine/tests/test_guards.py +++ b/backend/apps/core/state_machine/tests/test_guards.py @@ -58,28 +58,16 @@ class PermissionGuardTests(TestCase): def setUp(self): """Set up test fixtures.""" self.regular_user = User.objects.create_user( - username='user', - email='user@example.com', - password='testpass123', - role='USER' + username="user", email="user@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) self.superuser = User.objects.create_user( - username='superuser', - email='superuser@example.com', - password='testpass123', - role='SUPERUSER' + username="superuser", email="superuser@example.com", password="testpass123", role="SUPERUSER" ) self.instance = MockInstance() @@ -168,7 +156,7 @@ class PermissionGuardTests(TestCase): def test_required_roles_explicit_list(self): """Test using explicit required_roles list.""" - guard = PermissionGuard(required_roles=['ADMIN', 'SUPERUSER']) + guard = PermissionGuard(required_roles=["ADMIN", "SUPERUSER"]) self.assertTrue(guard(self.instance, user=self.admin)) self.assertTrue(guard(self.instance, user=self.superuser)) @@ -177,8 +165,9 @@ class PermissionGuardTests(TestCase): def test_custom_check_passes(self): """Test custom check function that passes.""" + def custom_check(instance, user): - return hasattr(instance, 'allow_access') and instance.allow_access + return hasattr(instance, "allow_access") and instance.allow_access guard = PermissionGuard(custom_check=custom_check) instance = MockInstance(allow_access=True) @@ -189,8 +178,9 @@ class PermissionGuardTests(TestCase): def test_custom_check_fails(self): """Test custom check function that fails.""" + def custom_check(instance, user): - return hasattr(instance, 'allow_access') and instance.allow_access + return hasattr(instance, "allow_access") and instance.allow_access guard = PermissionGuard(custom_check=custom_check) instance = MockInstance(allow_access=False) @@ -237,28 +227,16 @@ class OwnershipGuardTests(TestCase): def setUp(self): """Set up test fixtures.""" self.owner = User.objects.create_user( - username='owner', - email='owner@example.com', - password='testpass123', - role='USER' + username="owner", email="owner@example.com", password="testpass123", role="USER" ) self.other_user = User.objects.create_user( - username='other', - email='other@example.com', - password='testpass123', - role='USER' + username="other", email="other@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) def test_no_user_fails(self): @@ -329,7 +307,7 @@ class OwnershipGuardTests(TestCase): def test_custom_owner_fields(self): """Test custom owner field names.""" instance = MockInstance(author=self.owner) - guard = OwnershipGuard(owner_fields=['author']) + guard = OwnershipGuard(owner_fields=["author"]) result = guard(instance, user=self.owner) @@ -357,22 +335,13 @@ class AssignmentGuardTests(TestCase): def setUp(self): """Set up test fixtures.""" self.assigned_user = User.objects.create_user( - username='assigned', - email='assigned@example.com', - password='testpass123', - role='MODERATOR' + username="assigned", email="assigned@example.com", password="testpass123", role="MODERATOR" ) self.other_user = User.objects.create_user( - username='other', - email='other@example.com', - password='testpass123', - role='MODERATOR' + username="other", email="other@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) def test_no_user_fails(self): @@ -426,7 +395,7 @@ class AssignmentGuardTests(TestCase): def test_custom_assignment_fields(self): """Test custom assignment field names.""" instance = MockInstance(reviewer=self.assigned_user) - guard = AssignmentGuard(assignment_fields=['reviewer']) + guard = AssignmentGuard(assignment_fields=["reviewer"]) result = guard(instance, user=self.assigned_user) @@ -439,7 +408,7 @@ class AssignmentGuardTests(TestCase): guard(instance, user=self.assigned_user) - self.assertIn('assigned', guard.get_error_message().lower()) + self.assertIn("assigned", guard.get_error_message().lower()) # ============================================================================ @@ -453,16 +422,13 @@ class StateGuardTests(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='user', - email='user@example.com', - password='testpass123', - role='USER' + username="user", email="user@example.com", password="testpass123", role="USER" ) def test_allowed_states_passes(self): """Test that guard passes when in allowed state.""" - instance = MockInstance(status='PENDING') - guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW']) + instance = MockInstance(status="PENDING") + guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"]) result = guard(instance, user=self.user) @@ -470,8 +436,8 @@ class StateGuardTests(TestCase): def test_allowed_states_fails(self): """Test that guard fails when not in allowed state.""" - instance = MockInstance(status='COMPLETED') - guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW']) + instance = MockInstance(status="COMPLETED") + guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"]) result = guard(instance, user=self.user) @@ -480,8 +446,8 @@ class StateGuardTests(TestCase): def test_blocked_states_passes(self): """Test that guard passes when not in blocked state.""" - instance = MockInstance(status='PENDING') - guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED']) + instance = MockInstance(status="PENDING") + guard = StateGuard(blocked_states=["COMPLETED", "CANCELLED"]) result = guard(instance, user=self.user) @@ -489,8 +455,8 @@ class StateGuardTests(TestCase): def test_blocked_states_fails(self): """Test that guard fails when in blocked state.""" - instance = MockInstance(status='COMPLETED') - guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED']) + instance = MockInstance(status="COMPLETED") + guard = StateGuard(blocked_states=["COMPLETED", "CANCELLED"]) result = guard(instance, user=self.user) @@ -499,8 +465,8 @@ class StateGuardTests(TestCase): def test_custom_state_field(self): """Test using custom state field name.""" - instance = MockInstance(workflow_status='ACTIVE') - guard = StateGuard(allowed_states=['ACTIVE'], state_field='workflow_status') + instance = MockInstance(workflow_status="ACTIVE") + guard = StateGuard(allowed_states=["ACTIVE"], state_field="workflow_status") result = guard(instance, user=self.user) @@ -508,14 +474,14 @@ class StateGuardTests(TestCase): def test_error_message_includes_states(self): """Test that error message includes allowed states.""" - instance = MockInstance(status='COMPLETED') - guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW']) + instance = MockInstance(status="COMPLETED") + guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"]) guard(instance, user=self.user) message = guard.get_error_message() - self.assertIn('PENDING', message) - self.assertIn('UNDER_REVIEW', message) + self.assertIn("PENDING", message) + self.assertIn("UNDER_REVIEW", message) # ============================================================================ @@ -529,16 +495,13 @@ class MetadataGuardTests(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='user', - email='user@example.com', - password='testpass123', - role='USER' + username="user", email="user@example.com", password="testpass123", role="USER" ) def test_required_fields_present(self): """Test that guard passes when required fields are present.""" - instance = MockInstance(resolution_notes='Fixed', assigned_to='user') - guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to']) + instance = MockInstance(resolution_notes="Fixed", assigned_to="user") + guard = MetadataGuard(required_fields=["resolution_notes", "assigned_to"]) result = guard(instance, user=self.user) @@ -546,8 +509,8 @@ class MetadataGuardTests(TestCase): def test_required_field_missing(self): """Test that guard fails when required field is missing.""" - instance = MockInstance(resolution_notes='Fixed') - guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to']) + instance = MockInstance(resolution_notes="Fixed") + guard = MetadataGuard(required_fields=["resolution_notes", "assigned_to"]) result = guard(instance, user=self.user) @@ -557,7 +520,7 @@ class MetadataGuardTests(TestCase): def test_required_field_none(self): """Test that guard fails when required field is None.""" instance = MockInstance(resolution_notes=None) - guard = MetadataGuard(required_fields=['resolution_notes']) + guard = MetadataGuard(required_fields=["resolution_notes"]) result = guard(instance, user=self.user) @@ -566,8 +529,8 @@ class MetadataGuardTests(TestCase): def test_empty_string_fails_check_not_empty(self): """Test that empty string fails when check_not_empty is True.""" - instance = MockInstance(resolution_notes=' ') - guard = MetadataGuard(required_fields=['resolution_notes'], check_not_empty=True) + instance = MockInstance(resolution_notes=" ") + guard = MetadataGuard(required_fields=["resolution_notes"], check_not_empty=True) result = guard(instance, user=self.user) @@ -577,7 +540,7 @@ class MetadataGuardTests(TestCase): def test_empty_list_fails_check_not_empty(self): """Test that empty list fails when check_not_empty is True.""" instance = MockInstance(tags=[]) - guard = MetadataGuard(required_fields=['tags'], check_not_empty=True) + guard = MetadataGuard(required_fields=["tags"], check_not_empty=True) result = guard(instance, user=self.user) @@ -587,7 +550,7 @@ class MetadataGuardTests(TestCase): def test_empty_dict_fails_check_not_empty(self): """Test that empty dict fails when check_not_empty is True.""" instance = MockInstance(metadata={}) - guard = MetadataGuard(required_fields=['metadata'], check_not_empty=True) + guard = MetadataGuard(required_fields=["metadata"], check_not_empty=True) result = guard(instance, user=self.user) @@ -597,12 +560,12 @@ class MetadataGuardTests(TestCase): def test_error_message_includes_field_name(self): """Test that error message includes the field name.""" instance = MockInstance(resolution_notes=None) - guard = MetadataGuard(required_fields=['resolution_notes']) + guard = MetadataGuard(required_fields=["resolution_notes"]) guard(instance, user=self.user) message = guard.get_error_message() - self.assertIn('Resolution Notes', message) + self.assertIn("Resolution Notes", message) # ============================================================================ @@ -616,32 +579,23 @@ class CompositeGuardTests(TestCase): def setUp(self): """Set up test fixtures.""" self.owner = User.objects.create_user( - username='owner', - email='owner@example.com', - password='testpass123', - role='USER' + username="owner", email="owner@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.non_owner_moderator = User.objects.create_user( - username='non_owner_moderator', - email='non_owner_moderator@example.com', - password='testpass123', - role='MODERATOR' + username="non_owner_moderator", + email="non_owner_moderator@example.com", + password="testpass123", + role="MODERATOR", ) def test_and_operator_all_pass(self): """Test AND operator when all guards pass.""" instance = MockInstance(created_by=self.moderator) - guards = [ - PermissionGuard(requires_moderator=True), - OwnershipGuard() - ] - composite = CompositeGuard(guards, operator='AND') + guards = [PermissionGuard(requires_moderator=True), OwnershipGuard()] + composite = CompositeGuard(guards, operator="AND") result = composite(instance, user=self.moderator) @@ -652,9 +606,9 @@ class CompositeGuardTests(TestCase): instance = MockInstance(created_by=self.owner) guards = [ PermissionGuard(requires_moderator=True), # Will pass for moderator - OwnershipGuard() # Will fail - moderator is not owner + OwnershipGuard(), # Will fail - moderator is not owner ] - composite = CompositeGuard(guards, operator='AND') + composite = CompositeGuard(guards, operator="AND") result = composite(instance, user=self.non_owner_moderator) @@ -666,9 +620,9 @@ class CompositeGuardTests(TestCase): instance = MockInstance(created_by=self.owner) guards = [ PermissionGuard(requires_moderator=True), # Will fail for owner - OwnershipGuard() # Will pass - user is owner + OwnershipGuard(), # Will pass - user is owner ] - composite = CompositeGuard(guards, operator='OR') + composite = CompositeGuard(guards, operator="OR") result = composite(instance, user=self.owner) @@ -677,11 +631,8 @@ class CompositeGuardTests(TestCase): def test_or_operator_all_fail(self): """Test OR operator when all guards fail.""" instance = MockInstance(created_by=self.moderator) - guards = [ - PermissionGuard(requires_admin=True), # Regular user fails - OwnershipGuard() # Not the owner fails - ] - composite = CompositeGuard(guards, operator='OR') + guards = [PermissionGuard(requires_admin=True), OwnershipGuard()] # Regular user fails # Not the owner fails + composite = CompositeGuard(guards, operator="OR") result = composite(instance, user=self.owner) @@ -690,19 +641,13 @@ class CompositeGuardTests(TestCase): def test_nested_composite_guards(self): """Test nested composite guards.""" - instance = MockInstance(created_by=self.moderator, status='PENDING') + instance = MockInstance(created_by=self.moderator, status="PENDING") # Inner composite: moderator OR owner - inner = CompositeGuard([ - PermissionGuard(requires_moderator=True), - OwnershipGuard() - ], operator='OR') + inner = CompositeGuard([PermissionGuard(requires_moderator=True), OwnershipGuard()], operator="OR") # Outer composite: (moderator OR owner) AND valid state - outer = CompositeGuard([ - inner, - StateGuard(allowed_states=['PENDING']) - ], operator='AND') + outer = CompositeGuard([inner, StateGuard(allowed_states=["PENDING"])], operator="AND") result = outer(instance, user=self.moderator) @@ -713,12 +658,12 @@ class CompositeGuardTests(TestCase): instance = MockInstance(created_by=self.owner) perm_guard = PermissionGuard(requires_admin=True) guards = [perm_guard] - composite = CompositeGuard(guards, operator='AND') + composite = CompositeGuard(guards, operator="AND") composite(instance, user=self.owner) message = composite.get_error_message() - self.assertIn('admin', message.lower()) + self.assertIn("admin", message.lower()) # ============================================================================ @@ -732,15 +677,12 @@ class GuardFactoryTests(TestCase): def setUp(self): """Set up test fixtures.""" self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) def test_create_permission_guard_moderator(self): """Test create_permission_guard with moderator requirement.""" - metadata = {'requires_moderator': True} + metadata = {"requires_moderator": True} guard = create_permission_guard(metadata) instance = MockInstance() @@ -750,14 +692,14 @@ class GuardFactoryTests(TestCase): def test_create_permission_guard_admin(self): """Test create_permission_guard with admin requirement.""" - metadata = {'requires_admin_approval': True} + metadata = {"requires_admin_approval": True} guard = create_permission_guard(metadata) self.assertTrue(guard.requires_admin) def test_create_permission_guard_escalation_level(self): """Test create_permission_guard with escalation level.""" - metadata = {'escalation_level': 'admin'} + metadata = {"escalation_level": "admin"} guard = create_permission_guard(metadata) self.assertTrue(guard.requires_admin) @@ -777,9 +719,9 @@ class GuardFactoryTests(TestCase): def test_create_composite_guard(self): """Test create_composite_guard factory.""" guards = [PermissionGuard(), OwnershipGuard()] - composite = create_composite_guard(guards, operator='OR') + composite = create_composite_guard(guards, operator="OR") - self.assertEqual(composite.operator, 'OR') + self.assertEqual(composite.operator, "OR") self.assertEqual(len(composite.guards), 2) @@ -793,7 +735,7 @@ class MetadataExtractionTests(TestCase): def test_extract_moderator_guard(self): """Test extracting guard for moderator requirement.""" - metadata = {'requires_moderator': True} + metadata = {"requires_moderator": True} guards = extract_guards_from_metadata(metadata) self.assertEqual(len(guards), 1) @@ -801,7 +743,7 @@ class MetadataExtractionTests(TestCase): def test_extract_admin_guard(self): """Test extracting guard for admin requirement.""" - metadata = {'requires_admin_approval': True} + metadata = {"requires_admin_approval": True} guards = extract_guards_from_metadata(metadata) self.assertEqual(len(guards), 1) @@ -809,7 +751,7 @@ class MetadataExtractionTests(TestCase): def test_extract_assignment_guard(self): """Test extracting assignment guard.""" - metadata = {'requires_assignment': True} + metadata = {"requires_assignment": True} guards = extract_guards_from_metadata(metadata) self.assertEqual(len(guards), 1) @@ -817,17 +759,14 @@ class MetadataExtractionTests(TestCase): def test_extract_multiple_guards(self): """Test extracting multiple guards.""" - metadata = { - 'requires_moderator': True, - 'requires_assignment': True - } + metadata = {"requires_moderator": True, "requires_assignment": True} guards = extract_guards_from_metadata(metadata) self.assertEqual(len(guards), 2) def test_extract_zero_tolerance_guard(self): """Test extracting guard for zero tolerance (superuser required).""" - metadata = {'zero_tolerance': True} + metadata = {"zero_tolerance": True} guards = extract_guards_from_metadata(metadata) self.assertEqual(len(guards), 1) @@ -835,7 +774,7 @@ class MetadataExtractionTests(TestCase): def test_invalid_escalation_level_raises(self): """Test that invalid escalation level raises ValueError.""" - metadata = {'escalation_level': 'invalid'} + metadata = {"escalation_level": "invalid"} with self.assertRaises(ValueError): extract_guards_from_metadata(metadata) @@ -851,11 +790,7 @@ class MetadataValidationTests(TestCase): def test_valid_metadata(self): """Test that valid metadata passes validation.""" - metadata = { - 'requires_moderator': True, - 'escalation_level': 'admin', - 'requires_assignment': False - } + metadata = {"requires_moderator": True, "escalation_level": "admin", "requires_assignment": False} is_valid, errors = validate_guard_metadata(metadata) @@ -864,30 +799,30 @@ class MetadataValidationTests(TestCase): def test_invalid_escalation_level(self): """Test that invalid escalation level fails validation.""" - metadata = {'escalation_level': 'invalid_level'} + metadata = {"escalation_level": "invalid_level"} is_valid, errors = validate_guard_metadata(metadata) self.assertFalse(is_valid) - self.assertTrue(any('escalation_level' in e for e in errors)) + self.assertTrue(any("escalation_level" in e for e in errors)) def test_invalid_boolean_field(self): """Test that non-boolean value for boolean field fails validation.""" - metadata = {'requires_moderator': 'yes'} + metadata = {"requires_moderator": "yes"} is_valid, errors = validate_guard_metadata(metadata) self.assertFalse(is_valid) - self.assertTrue(any('requires_moderator' in e for e in errors)) + self.assertTrue(any("requires_moderator" in e for e in errors)) def test_required_permissions_not_list(self): """Test that non-list required_permissions fails validation.""" - metadata = {'required_permissions': 'app.permission'} + metadata = {"required_permissions": "app.permission"} is_valid, errors = validate_guard_metadata(metadata) self.assertFalse(is_valid) - self.assertTrue(any('required_permissions' in e for e in errors)) + self.assertTrue(any("required_permissions" in e for e in errors)) # ============================================================================ @@ -901,42 +836,30 @@ class RoleHelperTests(TestCase): def setUp(self): """Set up test fixtures.""" self.regular_user = User.objects.create_user( - username='user', - email='user@example.com', - password='testpass123', - role='USER' + username="user", email="user@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) self.superuser = User.objects.create_user( - username='superuser', - email='superuser@example.com', - password='testpass123', - role='SUPERUSER' + username="superuser", email="superuser@example.com", password="testpass123", role="SUPERUSER" ) def test_get_user_role(self): """Test get_user_role returns correct role.""" - self.assertEqual(get_user_role(self.regular_user), 'USER') - self.assertEqual(get_user_role(self.moderator), 'MODERATOR') - self.assertEqual(get_user_role(self.admin), 'ADMIN') - self.assertEqual(get_user_role(self.superuser), 'SUPERUSER') + self.assertEqual(get_user_role(self.regular_user), "USER") + self.assertEqual(get_user_role(self.moderator), "MODERATOR") + self.assertEqual(get_user_role(self.admin), "ADMIN") + self.assertEqual(get_user_role(self.superuser), "SUPERUSER") self.assertIsNone(get_user_role(None)) def test_has_role(self): """Test has_role function.""" - self.assertTrue(has_role(self.moderator, ['MODERATOR', 'ADMIN'])) - self.assertFalse(has_role(self.regular_user, ['MODERATOR', 'ADMIN'])) + self.assertTrue(has_role(self.moderator, ["MODERATOR", "ADMIN"])) + self.assertFalse(has_role(self.regular_user, ["MODERATOR", "ADMIN"])) def test_is_moderator_or_above(self): """Test is_moderator_or_above function.""" @@ -963,7 +886,7 @@ class RoleHelperTests(TestCase): """Test that anonymous user has no role.""" anonymous = AnonymousUser() - self.assertFalse(has_role(anonymous, ['USER'])) + self.assertFalse(has_role(anonymous, ["USER"])) self.assertFalse(is_moderator_or_above(anonymous)) self.assertFalse(is_admin_or_above(anonymous)) self.assertFalse(is_superuser_role(anonymous)) diff --git a/backend/apps/core/state_machine/tests/test_integration.py b/backend/apps/core/state_machine/tests/test_integration.py index 14031fa1..c16e1c94 100644 --- a/backend/apps/core/state_machine/tests/test_integration.py +++ b/backend/apps/core/state_machine/tests/test_integration.py @@ -1,4 +1,5 @@ """Integration tests for state machine model integration.""" + from unittest.mock import Mock, patch import pytest @@ -74,31 +75,23 @@ def test_generate_transition_methods(sample_choices): """Test generating transition methods on model.""" mock_model = type("MockModel", (), {}) - generate_transition_methods_for_model( - mock_model, "status", "test_states", "test" - ) + generate_transition_methods_for_model(mock_model, "status", "test_states", "test") # Check that transition methods were added # Method names may vary based on implementation - assert hasattr(mock_model, "approve") or hasattr( - mock_model, "transition_to_approved" - ) + assert hasattr(mock_model, "approve") or hasattr(mock_model, "transition_to_approved") def test_state_machine_model_decorator(sample_choices): """Test state_machine_model decorator.""" - @state_machine_model( - field_name="status", choice_group="test_states", domain="test" - ) + @state_machine_model(field_name="status", choice_group="test_states", domain="test") class TestModel: pass # Decorator should apply state machine # Check for transition methods - assert hasattr(TestModel, "approve") or hasattr( - TestModel, "transition_to_approved" - ) + assert hasattr(TestModel, "approve") or hasattr(TestModel, "transition_to_approved") def test_state_machine_mixin_get_available_transitions(): diff --git a/backend/apps/core/state_machine/tests/test_registry.py b/backend/apps/core/state_machine/tests/test_registry.py index b622b54a..469e85ff 100644 --- a/backend/apps/core/state_machine/tests/test_registry.py +++ b/backend/apps/core/state_machine/tests/test_registry.py @@ -1,4 +1,5 @@ """Tests for TransitionRegistry.""" + import pytest from apps.core.choices.base import RichChoice @@ -55,12 +56,8 @@ def test_transition_info_creation(): def test_transition_info_hashable(): """Test TransitionInfo is hashable.""" - info1 = TransitionInfo( - source="pending", target="approved", method_name="approve" - ) - info2 = TransitionInfo( - source="pending", target="approved", method_name="approve" - ) + info1 = TransitionInfo(source="pending", target="approved", method_name="approve") + info2 = TransitionInfo(source="pending", target="approved", method_name="approve") assert hash(info1) == hash(info2) @@ -82,9 +79,7 @@ def test_register_transition(): metadata={"requires_moderator": True}, ) - transition = registry_instance.get_transition( - "test_states", "test", "pending", "approved" - ) + transition = registry_instance.get_transition("test_states", "test", "pending", "approved") assert transition is not None assert transition.method_name == "approve" assert transition.requires_moderator is True @@ -92,9 +87,7 @@ def test_register_transition(): def test_get_transition_not_found(): """Test getting non-existent transition.""" - transition = registry_instance.get_transition( - "nonexistent", "test", "pending", "approved" - ) + transition = registry_instance.get_transition("nonexistent", "test", "pending", "approved") assert transition is None @@ -102,9 +95,7 @@ def test_get_available_transitions(sample_choices): """Test getting available transitions from a state.""" registry_instance.build_registry_from_choices("test_states", "test") - available = registry_instance.get_available_transitions( - "test_states", "test", "pending" - ) + available = registry_instance.get_available_transitions("test_states", "test", "pending") assert len(available) == 2 targets = [t.target for t in available] assert "approved" in targets @@ -121,9 +112,7 @@ def test_get_transition_method_name(): method_name="approve", ) - method_name = registry_instance.get_transition_method_name( - "test_states", "test", "pending", "approved" - ) + method_name = registry_instance.get_transition_method_name("test_states", "test", "pending", "approved") assert method_name == "approve" @@ -137,12 +126,8 @@ def test_validate_transition(): method_name="approve", ) - assert registry_instance.validate_transition( - "test_states", "test", "pending", "approved" - ) - assert not registry_instance.validate_transition( - "test_states", "test", "pending", "nonexistent" - ) + assert registry_instance.validate_transition("test_states", "test", "pending", "approved") + assert not registry_instance.validate_transition("test_states", "test", "pending", "nonexistent") def test_build_registry_from_choices(sample_choices): @@ -150,9 +135,7 @@ def test_build_registry_from_choices(sample_choices): registry_instance.build_registry_from_choices("test_states", "test") # Check transitions were registered - transition = registry_instance.get_transition( - "test_states", "test", "pending", "approved" - ) + transition = registry_instance.get_transition("test_states", "test", "pending", "approved") assert transition is not None @@ -168,9 +151,7 @@ def test_clear_registry_specific(): registry_instance.clear_registry(choice_group="test_states", domain="test") - transition = registry_instance.get_transition( - "test_states", "test", "pending", "approved" - ) + transition = registry_instance.get_transition("test_states", "test", "pending", "approved") assert transition is None @@ -186,9 +167,7 @@ def test_clear_registry_all(): registry_instance.clear_registry() - transition = registry_instance.get_transition( - "test_states", "test", "pending", "approved" - ) + transition = registry_instance.get_transition("test_states", "test", "pending", "approved") assert transition is None @@ -196,9 +175,7 @@ def test_export_transition_graph_dict(sample_choices): """Test exporting transition graph as dict.""" registry_instance.build_registry_from_choices("test_states", "test") - graph = registry_instance.export_transition_graph( - "test_states", "test", format="dict" - ) + graph = registry_instance.export_transition_graph("test_states", "test", format="dict") assert isinstance(graph, dict) assert "pending" in graph assert set(graph["pending"]) == {"approved", "rejected"} @@ -208,9 +185,7 @@ def test_export_transition_graph_mermaid(sample_choices): """Test exporting transition graph as mermaid.""" registry_instance.build_registry_from_choices("test_states", "test") - graph = registry_instance.export_transition_graph( - "test_states", "test", format="mermaid" - ) + graph = registry_instance.export_transition_graph("test_states", "test", format="mermaid") assert isinstance(graph, str) assert "stateDiagram-v2" in graph assert "pending" in graph @@ -220,9 +195,7 @@ def test_export_transition_graph_dot(sample_choices): """Test exporting transition graph as DOT.""" registry_instance.build_registry_from_choices("test_states", "test") - graph = registry_instance.export_transition_graph( - "test_states", "test", format="dot" - ) + graph = registry_instance.export_transition_graph("test_states", "test", format="dot") assert isinstance(graph, str) assert "digraph" in graph assert "pending" in graph @@ -233,9 +206,7 @@ def test_export_invalid_format(sample_choices): registry_instance.build_registry_from_choices("test_states", "test") with pytest.raises(ValueError): - registry_instance.export_transition_graph( - "test_states", "test", format="invalid" - ) + registry_instance.export_transition_graph("test_states", "test", format="invalid") def test_get_all_registered_groups(): diff --git a/backend/apps/core/state_machine/tests/test_validators.py b/backend/apps/core/state_machine/tests/test_validators.py index 8efc3a95..c00b163c 100644 --- a/backend/apps/core/state_machine/tests/test_validators.py +++ b/backend/apps/core/state_machine/tests/test_validators.py @@ -1,4 +1,5 @@ """Tests for metadata validators.""" + import pytest from apps.core.choices.base import RichChoice @@ -70,9 +71,7 @@ def terminal_with_transitions(): def test_validation_error_creation(): """Test ValidationError creation.""" - error = ValidationError( - code="TEST_ERROR", message="Test message", state="pending" - ) + error = ValidationError(code="TEST_ERROR", message="Test message", state="pending") assert error.code == "TEST_ERROR" assert error.message == "Test message" assert error.state == "pending" @@ -81,9 +80,7 @@ def test_validation_error_creation(): def test_validation_warning_creation(): """Test ValidationWarning creation.""" - warning = ValidationWarning( - code="TEST_WARNING", message="Test warning", state="pending" - ) + warning = ValidationWarning(code="TEST_WARNING", message="Test warning", state="pending") assert warning.code == "TEST_WARNING" assert warning.message == "Test warning" @@ -166,15 +163,9 @@ def test_validate_no_cycles(valid_choices): def test_validate_no_cycles_with_cycle(): """Test cycle detection finds cycles.""" choices = [ - RichChoice( - value="a", label="A", metadata={"can_transition_to": ["b"]} - ), - RichChoice( - value="b", label="B", metadata={"can_transition_to": ["c"]} - ), - RichChoice( - value="c", label="C", metadata={"can_transition_to": ["a"]} - ), + RichChoice(value="a", label="A", metadata={"can_transition_to": ["b"]}), + RichChoice(value="b", label="B", metadata={"can_transition_to": ["c"]}), + RichChoice(value="c", label="C", metadata={"can_transition_to": ["a"]}), ] registry.register("cycle_states", choices, domain="test") @@ -202,9 +193,7 @@ def test_validate_reachability_unreachable(): label="Pending", metadata={"can_transition_to": ["approved"]}, ), - RichChoice( - value="approved", label="Approved", metadata={"is_final": True} - ), + RichChoice(value="approved", label="Approved", metadata={"is_final": True}), RichChoice( value="orphan", label="Orphan", diff --git a/backend/apps/core/state_machine/validators.py b/backend/apps/core/state_machine/validators.py index 705c8e00..048cc1d5 100644 --- a/backend/apps/core/state_machine/validators.py +++ b/backend/apps/core/state_machine/validators.py @@ -1,4 +1,5 @@ """Metadata validators for ensuring RichChoice metadata meets FSM requirements.""" + from dataclasses import dataclass, field from typing import Any @@ -110,8 +111,7 @@ class MetadataValidator: ValidationError( code="MISSING_CAN_TRANSITION_TO", message=( - "State metadata must explicitly define " - "'can_transition_to' (use [] for terminal states)" + "State metadata must explicitly define " "'can_transition_to' (use [] for terminal states)" ), state=state, ) @@ -138,9 +138,7 @@ class MetadataValidator: errors.append( ValidationError( code="INVALID_TRANSITION_TARGET", - message=( - f"Transition target '{target}' does not exist" - ), + message=(f"Transition target '{target}' does not exist"), state=state, ) ) @@ -188,17 +186,11 @@ class MetadataValidator: perms = self.builder.extract_permission_requirements(state) # Check for contradictory permissions - if ( - perms.get("requires_admin_approval") - and not perms.get("requires_moderator") - ): + if perms.get("requires_admin_approval") and not perms.get("requires_moderator"): errors.append( ValidationError( code="PERMISSION_INCONSISTENCY", - message=( - "State requires admin approval but not moderator " - "(admin should imply moderator)" - ), + message=("State requires admin approval but not moderator " "(admin should imply moderator)"), state=state, ) ) @@ -251,9 +243,7 @@ class MetadataValidator: errors.append( ValidationError( code="STATE_CYCLE_DETECTED", - message=( - f"Cycle detected: {' -> '.join(cycle)}" - ), + message=(f"Cycle detected: {' -> '.join(cycle)}"), state=cycle[0], ) ) @@ -278,9 +268,7 @@ class MetadataValidator: for target in targets: incoming[target].append(source) - initial_states = [ - state for state in all_states if not incoming[state] - ] + initial_states = [state for state in all_states if not incoming[state]] if not initial_states: errors.append( @@ -327,9 +315,7 @@ class MetadataValidator: result = self.validate_choice_group() lines = [] - lines.append( - f"Validation Report for {self.domain}.{self.choice_group}" - ) + lines.append(f"Validation Report for {self.domain}.{self.choice_group}") lines.append("=" * 60) lines.append(f"Status: {'VALID' if result.is_valid else 'INVALID'}") lines.append(f"Errors: {len(result.errors)}") @@ -372,10 +358,7 @@ def validate_on_registration(choice_group: str, domain: str = "core") -> bool: if not result.is_valid: error_messages = [str(e) for e in result.errors] - raise ValueError( - f"Validation failed for {domain}.{choice_group}:\n" - + "\n".join(error_messages) - ) + raise ValueError(f"Validation failed for {domain}.{choice_group}:\n" + "\n".join(error_messages)) return True diff --git a/backend/apps/core/tasks/trending.py b/backend/apps/core/tasks/trending.py index 4354f033..f2340b29 100644 --- a/backend/apps/core/tasks/trending.py +++ b/backend/apps/core/tasks/trending.py @@ -23,9 +23,7 @@ logger = logging.getLogger(__name__) @shared_task(bind=True, max_retries=3, default_retry_delay=60) -def calculate_trending_content( - self, content_type: str = "all", limit: int = 50 -) -> dict[str, Any]: +def calculate_trending_content(self, content_type: str = "all", limit: int = 50) -> dict[str, Any]: """ Calculate trending content using real analytics data. @@ -72,17 +70,13 @@ def calculate_trending_content( trending_items = trending_items[:limit] # Format results for API consumption - formatted_results = _format_trending_results( - trending_items, current_period_hours, previous_period_hours - ) + formatted_results = _format_trending_results(trending_items, current_period_hours, previous_period_hours) # Cache results cache_key = f"trending:calculated:{content_type}:{limit}" cache.set(cache_key, formatted_results, 3600) # Cache for 1 hour - logger.info( - f"Calculated {len(formatted_results)} trending items for {content_type}" - ) + logger.info(f"Calculated {len(formatted_results)} trending items for {content_type}") return { "success": True, @@ -95,13 +89,11 @@ def calculate_trending_content( except Exception as e: logger.error(f"Error calculating trending content: {e}", exc_info=True) # Retry the task - raise self.retry(exc=e) + raise self.retry(exc=e) from None @shared_task(bind=True, max_retries=3, default_retry_delay=30) -def calculate_new_content( - self, content_type: str = "all", days_back: int = 30, limit: int = 50 -) -> dict[str, Any]: +def calculate_new_content(self, content_type: str = "all", days_back: int = 30, limit: int = 50) -> dict[str, Any]: """ Calculate new content based on opening dates and creation dates. @@ -120,15 +112,11 @@ def calculate_new_content( new_items = [] if content_type in ["all", "parks"]: - parks = _get_new_parks( - cutoff_date, limit if content_type == "parks" else limit * 2 - ) + parks = _get_new_parks(cutoff_date, limit if content_type == "parks" else limit * 2) new_items.extend(parks) if content_type in ["all", "rides"]: - rides = _get_new_rides( - cutoff_date, limit if content_type == "rides" else limit * 2 - ) + rides = _get_new_rides(cutoff_date, limit if content_type == "rides" else limit * 2) new_items.extend(rides) # Sort by date added (most recent first) and apply limit @@ -154,7 +142,7 @@ def calculate_new_content( except Exception as e: logger.error(f"Error calculating new content: {e}", exc_info=True) - raise self.retry(exc=e) + raise self.retry(exc=e) from None @shared_task(bind=True) @@ -185,9 +173,7 @@ def warm_trending_cache(self) -> dict[str, Any]: calculate_new_content.delay(**query) results[f"trending_{query['content_type']}_{query['limit']}"] = "scheduled" - results[f"new_content_{query['content_type']}_{query['limit']}"] = ( - "scheduled" - ) + results[f"new_content_{query['content_type']}_{query['limit']}"] = "scheduled" logger.info("Trending cache warming completed") @@ -211,17 +197,13 @@ def _calculate_trending_parks( current_period_hours: int, previous_period_hours: int, limit: int ) -> list[dict[str, Any]]: """Calculate trending scores for parks using real data.""" - parks = Park.objects.filter(status="OPERATING").select_related( - "location", "operator" - ) + parks = Park.objects.filter(status="OPERATING").select_related("location", "operator") trending_parks = [] for park in parks: try: - score = _calculate_content_score( - park, "park", current_period_hours, previous_period_hours - ) + score = _calculate_content_score(park, "park", current_period_hours, previous_period_hours) if score > 0: # Only include items with positive trending scores trending_parks.append( { @@ -231,13 +213,9 @@ def _calculate_trending_parks( "id": park.id, "name": park.name, "slug": park.slug, - "location": ( - park.formatted_location if hasattr(park, "location") else "" - ), + "location": (park.formatted_location if hasattr(park, "location") else ""), "category": "park", - "rating": ( - float(park.average_rating) if park.average_rating else 0.0 - ), + "rating": (float(park.average_rating) if park.average_rating else 0.0), } ) except Exception as e: @@ -250,17 +228,13 @@ def _calculate_trending_rides( current_period_hours: int, previous_period_hours: int, limit: int ) -> list[dict[str, Any]]: """Calculate trending scores for rides using real data.""" - rides = Ride.objects.filter(status="OPERATING").select_related( - "park", "park__location" - ) + rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location") trending_rides = [] for ride in rides: try: - score = _calculate_content_score( - ride, "ride", current_period_hours, previous_period_hours - ) + score = _calculate_content_score(ride, "ride", current_period_hours, previous_period_hours) if score > 0: # Only include items with positive trending scores # Get location from park location = "" @@ -277,9 +251,7 @@ def _calculate_trending_rides( "slug": ride.slug, "location": location, "category": "ride", - "rating": ( - float(ride.average_rating) if ride.average_rating else 0.0 - ), + "rating": (float(ride.average_rating) if ride.average_rating else 0.0), } ) except Exception as e: @@ -322,17 +294,10 @@ def _calculate_content_score( recency_score = _calculate_recency_score(content_obj) # 4. Popularity Score (10% weight) - popularity_score = _calculate_popularity_score( - ct, content_obj.id, current_period_hours - ) + popularity_score = _calculate_popularity_score(ct, content_obj.id, current_period_hours) # Calculate weighted final score - final_score = ( - view_growth_score * 0.4 - + rating_score * 0.3 - + recency_score * 0.2 - + popularity_score * 0.1 - ) + final_score = view_growth_score * 0.4 + rating_score * 0.3 + recency_score * 0.2 + popularity_score * 0.1 logger.debug( f"{content_type} {content_obj.id}: " @@ -344,9 +309,7 @@ def _calculate_content_score( return final_score except Exception as e: - logger.error( - f"Error calculating score for {content_type} {content_obj.id}: {e}" - ) + logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}") return 0.0 @@ -371,9 +334,7 @@ def _calculate_view_growth_score( # Normalize growth percentage to 0-1 scale # 100% growth = 0.5, 500% growth = 1.0 - normalized_growth = ( - min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0 - ) + normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0 return max(normalized_growth, 0.0) except Exception as e: @@ -431,14 +392,10 @@ def _calculate_recency_score(content_obj: Any) -> float: return 0.5 -def _calculate_popularity_score( - content_type: ContentType, object_id: int, hours: int -) -> float: +def _calculate_popularity_score(content_type: ContentType, object_id: int, hours: int) -> float: """Calculate popularity score based on total view count.""" try: - total_views = PageView.get_total_views_count( - content_type, object_id, hours=hours - ) + total_views = PageView.get_total_views_count(content_type, object_id, hours=hours) # Normalize views to 0-1 scale # 0 views = 0.0, 100 views = 0.5, 1000+ views = 1.0 @@ -505,9 +462,7 @@ def _get_new_rides(cutoff_date: datetime, limit: int) -> list[dict[str, Any]]: results = [] for ride in new_rides: - date_added = getattr(ride, "opening_date", None) or getattr( - ride, "created_at", None - ) + date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None) if date_added and isinstance(date_added, datetime): date_added = date_added.date() @@ -545,13 +500,11 @@ def _format_trending_results( # Get view change for display content_obj = item["content_object"] ct = ContentType.objects.get_for_model(content_obj) - current_views, previous_views, growth_percentage = ( - PageView.get_views_growth( - ct, - content_obj.id, - current_period_hours, - previous_period_hours, - ) + current_views, previous_views, growth_percentage = PageView.get_views_growth( + ct, + content_obj.id, + current_period_hours, + previous_period_hours, ) # Format exactly as frontend expects @@ -564,9 +517,7 @@ def _format_trending_results( "rank": rank, "views": current_views, "views_change": ( - f"+{growth_percentage:.1f}%" - if growth_percentage > 0 - else f"{growth_percentage:.1f}%" + f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%" ), "slug": item["slug"], } diff --git a/backend/apps/core/templatetags/common_filters.py b/backend/apps/core/templatetags/common_filters.py index 14c2f0a4..53eb1bbc 100644 --- a/backend/apps/core/templatetags/common_filters.py +++ b/backend/apps/core/templatetags/common_filters.py @@ -26,6 +26,7 @@ register = template.Library() # Time and Date Filters # ============================================================================= + @register.filter def humanize_timedelta(value): """ @@ -42,26 +43,26 @@ def humanize_timedelta(value): Human-readable string like "2 hours ago" """ if value is None: - return '' + return "" # Convert datetime to timedelta from now - if hasattr(value, 'tzinfo'): # It's a datetime + if hasattr(value, "tzinfo"): # It's a datetime now = timezone.now() if value > now: - return 'in the future' + return "in the future" value = now - value # Convert seconds to timedelta - if isinstance(value, (int, float)): + if isinstance(value, int | float): value = timedelta(seconds=value) if not isinstance(value, timedelta): - return '' + return "" seconds = int(value.total_seconds()) if seconds < 60: - return 'just now' + return "just now" elif seconds < 3600: minutes = seconds // 60 return f'{minutes} minute{"s" if minutes != 1 else ""} ago' @@ -92,22 +93,23 @@ def time_until(value): Output: "in 2 days", "in 3 hours" """ if value is None: - return '' + return "" - if hasattr(value, 'tzinfo'): + if hasattr(value, "tzinfo"): now = timezone.now() if value <= now: - return 'now' + return "now" diff = value - now - return humanize_timedelta(diff).replace(' ago', '') + return humanize_timedelta(diff).replace(" ago", "") - return '' + return "" # ============================================================================= # Text Manipulation Filters # ============================================================================= + @register.filter @stringfilter def truncate_smart(value, max_length=50): @@ -130,12 +132,12 @@ def truncate_smart(value, max_length=50): # Find the last space before max_length truncated = value[:max_length] - last_space = truncated.rfind(' ') + last_space = truncated.rfind(" ") if last_space > max_length * 0.5: # Only use word boundary if reasonable truncated = truncated[:last_space] - return truncated.rstrip('.,!?;:') + '...' + return truncated.rstrip(".,!?;:") + "..." @register.filter @@ -153,7 +155,7 @@ def truncate_middle(value, max_length=50): return value keep_chars = (max_length - 3) // 2 - return f'{value[:keep_chars]}...{value[-keep_chars:]}' + return f"{value[:keep_chars]}...{value[-keep_chars:]}" @register.filter @@ -167,13 +169,14 @@ def initials(value, max_initials=2): Output: "JD" for "John Doe" """ words = value.split() - return ''.join(word[0].upper() for word in words[:max_initials] if word) + return "".join(word[0].upper() for word in words[:max_initials] if word) # ============================================================================= # Number Formatting Filters # ============================================================================= + @register.filter def format_number(value, decimals=0): """ @@ -187,14 +190,14 @@ def format_number(value, decimals=0): Output: "1,234.56" """ if value is None: - return '' + return "" try: value = float(value) decimals = int(decimals) if decimals > 0: - return f'{value:,.{decimals}f}' - return f'{int(value):,}' + return f"{value:,.{decimals}f}" + return f"{int(value):,}" except (ValueError, TypeError): return value @@ -209,16 +212,16 @@ def format_compact(value): Output: "1.2K", "3.4M", "2.1B" """ if value is None: - return '' + return "" try: value = float(value) if value >= 1_000_000_000: - return f'{value / 1_000_000_000:.1f}B' + return f"{value / 1_000_000_000:.1f}B" elif value >= 1_000_000: - return f'{value / 1_000_000:.1f}M' + return f"{value / 1_000_000:.1f}M" elif value >= 1_000: - return f'{value / 1_000:.1f}K' + return f"{value / 1_000:.1f}K" return str(int(value)) except (ValueError, TypeError): return value @@ -237,16 +240,17 @@ def percentage(value, total): value = float(value) total = float(total) if total == 0: - return '0%' - return f'{(value / total * 100):.0f}%' + return "0%" + return f"{(value / total * 100):.0f}%" except (ValueError, TypeError, ZeroDivisionError): - return '0%' + return "0%" # ============================================================================= # Dictionary/List Filters # ============================================================================= + @register.filter def get_item(dictionary, key): """ @@ -278,7 +282,7 @@ def getlist(querydict, key): """ if querydict is None: return [] - if hasattr(querydict, 'getlist'): + if hasattr(querydict, "getlist"): return querydict.getlist(key) return [] @@ -314,6 +318,7 @@ def index(sequence, i): # Pluralization Filters # ============================================================================= + @register.filter def pluralize_custom(count, forms): """ @@ -330,7 +335,7 @@ def pluralize_custom(count, forms): """ try: count = int(count) - singular, plural = forms.split(',') + singular, plural = forms.split(",") return singular if count == 1 else plural except (ValueError, AttributeError): return forms @@ -347,9 +352,9 @@ def count_with_label(count, forms): """ try: count = int(count) - singular, plural = forms.split(',') + singular, plural = forms.split(",") label = singular if count == 1 else plural - return f'{count} {label}' + return f"{count} {label}" except (ValueError, AttributeError): return str(count) @@ -358,6 +363,7 @@ def count_with_label(count, forms): # CSS Class Manipulation # ============================================================================= + @register.filter def add_class(field, css_class): """ @@ -366,10 +372,10 @@ def add_class(field, css_class): Usage: {{ form.email|add_class:"form-control" }} """ - if hasattr(field, 'as_widget'): - existing = field.field.widget.attrs.get('class', '') - new_classes = f'{existing} {css_class}'.strip() - return field.as_widget(attrs={'class': new_classes}) + if hasattr(field, "as_widget"): + existing = field.field.widget.attrs.get("class", "") + new_classes = f"{existing} {css_class}".strip() + return field.as_widget(attrs={"class": new_classes}) return field @@ -381,8 +387,8 @@ def set_attr(field, attr_value): Usage: {{ form.email|set_attr:"placeholder:Enter email" }} """ - if hasattr(field, 'as_widget'): - attr, value = attr_value.split(':') + if hasattr(field, "as_widget"): + attr, value = attr_value.split(":") return field.as_widget(attrs={attr: value}) return field @@ -391,6 +397,7 @@ def set_attr(field, attr_value): # Conditional Filters # ============================================================================= + @register.filter def default_if_none(value, default): """ @@ -413,5 +420,5 @@ def yesno_icon(value, icons="fa-check,fa-times"): {{ has_feature|yesno_icon:"fa-star,fa-star-o" }} """ - true_icon, false_icon = icons.split(',') + true_icon, false_icon = icons.split(",") return true_icon if value else false_icon diff --git a/backend/apps/core/templatetags/fsm_tags.py b/backend/apps/core/templatetags/fsm_tags.py index 9ceac39e..8a28da33 100644 --- a/backend/apps/core/templatetags/fsm_tags.py +++ b/backend/apps/core/templatetags/fsm_tags.py @@ -23,6 +23,7 @@ Usage: {# Render a transition button #} {% transition_button submission 'approve' request.user %} """ + from typing import Any from django import template @@ -53,12 +54,12 @@ def get_state_value(obj) -> str | None: Returns: The current state value or None """ - if hasattr(obj, 'get_state_value'): + if hasattr(obj, "get_state_value"): return obj.get_state_value() - if hasattr(obj, 'state_field_name'): + if hasattr(obj, "state_field_name"): return getattr(obj, obj.state_field_name, None) # Try common field names - for field in ['status', 'state']: + for field in ["status", "state"]: if hasattr(obj, field): return getattr(obj, field, None) return None @@ -78,19 +79,19 @@ def get_state_display(obj) -> str: Returns: The human-readable state display value """ - if hasattr(obj, 'get_state_display_value'): + if hasattr(obj, "get_state_display_value"): return obj.get_state_display_value() - if hasattr(obj, 'state_field_name'): + if hasattr(obj, "state_field_name"): field_name = obj.state_field_name - getter = getattr(obj, f'get_{field_name}_display', None) + getter = getattr(obj, f"get_{field_name}_display", None) if callable(getter): return getter() # Try common field names - for field in ['status', 'state']: - getter = getattr(obj, f'get_{field}_display', None) + for field in ["status", "state"]: + getter = getattr(obj, f"get_{field}_display", None) if callable(getter): return getter() - return str(get_state_value(obj) or '') + return str(get_state_value(obj) or "") @register.filter @@ -109,7 +110,7 @@ def get_state_choice(obj): Returns: The RichChoice object or None """ - if hasattr(obj, 'get_state_choice'): + if hasattr(obj, "get_state_choice"): return obj.get_state_choice() return None @@ -205,18 +206,24 @@ def get_available_transitions(obj, user) -> list[dict[str, Any]]: # Get list of available transitions available_transition_names = [] - if hasattr(obj, 'get_available_user_transitions'): + if hasattr(obj, "get_available_user_transitions"): # Use the helper method if available return obj.get_available_user_transitions(user) - if hasattr(obj, 'get_available_transitions'): + if hasattr(obj, "get_available_transitions"): available_transition_names = list(obj.get_available_transitions()) else: # Fallback: look for transition methods by convention for attr_name in dir(obj): - if attr_name.startswith('transition_to_') or attr_name in ['approve', 'reject', 'escalate', 'complete', 'cancel']: + if attr_name.startswith("transition_to_") or attr_name in [ + "approve", + "reject", + "escalate", + "complete", + "cancel", + ]: method = getattr(obj, attr_name, None) - if callable(method) and hasattr(method, '_django_fsm'): + if callable(method) and hasattr(method, "_django_fsm"): available_transition_names.append(attr_name) # Filter transitions by user permission @@ -226,14 +233,16 @@ def get_available_transitions(obj, user) -> list[dict[str, Any]]: try: if can_proceed(method, user): metadata = get_transition_metadata(transition_name) - transitions.append({ - 'name': transition_name, - 'label': _format_transition_label(transition_name), - 'icon': metadata.get('icon', 'arrow-right'), - 'style': metadata.get('style', 'gray'), - 'requires_confirm': metadata.get('requires_confirm', False), - 'confirm_message': metadata.get('confirm_message', 'Are you sure?'), - }) + transitions.append( + { + "name": transition_name, + "label": _format_transition_label(transition_name), + "icon": metadata.get("icon", "arrow-right"), + "style": metadata.get("style", "gray"), + "requires_confirm": metadata.get("requires_confirm", False), + "confirm_message": metadata.get("confirm_message", "Are you sure?"), + } + ) except Exception: # Skip transitions that raise errors during can_proceed check pass @@ -289,14 +298,17 @@ def get_transition_url(obj, transition_name: str) -> str: The URL string for the transition endpoint """ try: - return reverse('core:fsm_transition', kwargs={ - 'app_label': obj._meta.app_label, - 'model_name': obj._meta.model_name, - 'pk': obj.pk, - 'transition_name': transition_name, - }) + return reverse( + "core:fsm_transition", + kwargs={ + "app_label": obj._meta.app_label, + "model_name": obj._meta.model_name, + "pk": obj.pk, + "transition_name": transition_name, + }, + ) except NoReverseMatch: - return '' + return "" # ============================================================================= @@ -304,7 +316,7 @@ def get_transition_url(obj, transition_name: str) -> str: # ============================================================================= -@register.inclusion_tag('htmx/state_actions.html', takes_context=True) +@register.inclusion_tag("htmx/state_actions.html", takes_context=True) def render_state_actions(context, obj, user=None, **kwargs): """ Render the state action buttons for an FSM-enabled object. @@ -323,17 +335,17 @@ def render_state_actions(context, obj, user=None, **kwargs): Context for the state_actions.html template """ if user is None: - user = context.get('request', {}).user if 'request' in context else None + user = context.get("request", {}).user if "request" in context else None return { - 'object': obj, - 'user': user, - 'request': context.get('request'), + "object": obj, + "user": user, + "request": context.get("request"), **kwargs, } -@register.inclusion_tag('htmx/status_with_actions.html', takes_context=True) +@register.inclusion_tag("htmx/status_with_actions.html", takes_context=True) def render_status_with_actions(context, obj, user=None, **kwargs): """ Render the status badge with action buttons for an FSM-enabled object. @@ -352,12 +364,12 @@ def render_status_with_actions(context, obj, user=None, **kwargs): Context for the status_with_actions.html template """ if user is None: - user = context.get('request', {}).user if 'request' in context else None + user = context.get("request", {}).user if "request" in context else None return { - 'object': obj, - 'user': user, - 'request': context.get('request'), + "object": obj, + "user": user, + "request": context.get("request"), **kwargs, } @@ -384,28 +396,28 @@ def _format_transition_label(transition_name: str) -> str: """ # Remove common prefixes label = transition_name - for prefix in ['transition_to_', 'transition_', 'do_']: + for prefix in ["transition_to_", "transition_", "do_"]: if label.startswith(prefix): - label = label[len(prefix):] + label = label[len(prefix) :] break # Remove past tense suffix and capitalize # e.g., 'approved' -> 'Approve' - if label.endswith('ed') and len(label) > 3: + if label.endswith("ed") and len(label) > 3: # Handle special cases - if label.endswith('ied'): - label = label[:-3] + 'y' + if label.endswith("ied"): + label = label[:-3] + "y" elif label[-3] == label[-4]: # doubled consonant (e.g., 'submitted') label = label[:-3] else: label = label[:-1] # Remove 'd' - if label.endswith('e'): + if label.endswith("e"): pass # Keep the 'e' for words like 'approve' else: label = label[:-1] # Remove 'e' for words like 'rejected' -> 'reject' # Replace underscores with spaces and title case - label = label.replace('_', ' ').title() + label = label.replace("_", " ").title() return label @@ -418,17 +430,17 @@ def _format_transition_label(transition_name: str) -> str: # Ensure all tags and filters are registered __all__ = [ # Filters - 'get_state_value', - 'get_state_display', - 'get_state_choice', - 'app_label', - 'model_name', - 'default_target_id', + "get_state_value", + "get_state_display", + "get_state_choice", + "app_label", + "model_name", + "default_target_id", # Tags - 'get_available_transitions', - 'can_transition', - 'get_transition_url', + "get_available_transitions", + "can_transition", + "get_transition_url", # Inclusion tags - 'render_state_actions', - 'render_status_with_actions', + "render_state_actions", + "render_status_with_actions", ] diff --git a/backend/apps/core/templatetags/safe_html.py b/backend/apps/core/templatetags/safe_html.py index f21786a7..16004e10 100644 --- a/backend/apps/core/templatetags/safe_html.py +++ b/backend/apps/core/templatetags/safe_html.py @@ -57,7 +57,8 @@ register = template.Library() # HTML Sanitization Filters # ============================================================================= -@register.filter(name='sanitize', is_safe=True) + +@register.filter(name="sanitize", is_safe=True) def sanitize_filter(value): """ Sanitize HTML content to prevent XSS attacks. @@ -68,11 +69,11 @@ def sanitize_filter(value): {{ user_content|sanitize }} """ if not value: - return '' + return "" return mark_safe(sanitize_html(str(value))) -@register.filter(name='sanitize_minimal', is_safe=True) +@register.filter(name="sanitize_minimal", is_safe=True) def sanitize_minimal_filter(value): """ Sanitize HTML with minimal allowed tags. @@ -83,11 +84,11 @@ def sanitize_minimal_filter(value): {{ comment|sanitize_minimal }} """ if not value: - return '' + return "" return mark_safe(_sanitize_minimal(str(value))) -@register.filter(name='sanitize_svg', is_safe=True) +@register.filter(name="sanitize_svg", is_safe=True) def sanitize_svg_filter(value): """ Sanitize SVG content for safe inline rendering. @@ -96,11 +97,11 @@ def sanitize_svg_filter(value): {{ icon_svg|sanitize_svg }} """ if not value: - return '' + return "" return mark_safe(sanitize_svg(str(value))) -@register.filter(name='strip_html') +@register.filter(name="strip_html") def strip_html_filter(value): """ Remove all HTML tags from content. @@ -109,7 +110,7 @@ def strip_html_filter(value): {{ html_content|strip_html }} """ if not value: - return '' + return "" return _strip_html(str(value)) @@ -117,7 +118,8 @@ def strip_html_filter(value): # JavaScript/JSON Context Filters # ============================================================================= -@register.filter(name='json_safe', is_safe=True) + +@register.filter(name="json_safe", is_safe=True) def json_safe_filter(value): """ Safely serialize data for embedding in JavaScript. @@ -131,11 +133,11 @@ def json_safe_filter(value): """ if value is None: - return 'null' + return "null" return mark_safe(sanitize_for_json(value)) -@register.filter(name='escapejs_safe') +@register.filter(name="escapejs_safe") def escapejs_safe_filter(value): """ Escape a string for safe use in JavaScript string literals. @@ -146,7 +148,7 @@ def escapejs_safe_filter(value): """ if not value: - return '' + return "" return _escape_js_string(str(value)) @@ -154,7 +156,8 @@ def escapejs_safe_filter(value): # URL and Attribute Filters # ============================================================================= -@register.filter(name='sanitize_url') + +@register.filter(name="sanitize_url") def sanitize_url_filter(value): """ Sanitize a URL to prevent javascript: and other dangerous protocols. @@ -163,11 +166,11 @@ def sanitize_url_filter(value): Link """ if not value: - return '' + return "" return _sanitize_url(str(value)) -@register.filter(name='attr_safe') +@register.filter(name="attr_safe") def attr_safe_filter(value): """ Escape a value for safe use in HTML attributes. @@ -176,7 +179,7 @@ def attr_safe_filter(value):
""" if not value: - return '' + return "" return sanitize_attribute_value(str(value)) @@ -187,36 +190,36 @@ def attr_safe_filter(value): # Predefined safe SVG icons # These are trusted and can be rendered without sanitization BUILTIN_ICONS = { - 'check': '''''', - 'x': '''''', - 'plus': '''''', - 'minus': '''''', - 'chevron-down': '''''', - 'chevron-up': '''''', - 'chevron-left': '''''', - 'chevron-right': '''''', - 'search': '''''', - 'menu': '''''', - 'user': '''''', - 'cog': '''''', - 'trash': '''''', - 'pencil': '''''', - 'eye': '''''', - 'eye-slash': '''''', - 'arrow-left': '''''', - 'arrow-right': '''''', - 'info': '''''', - 'warning': '''''', - 'error': '''''', - 'success': '''''', - 'loading': '''''', - 'external-link': '''''', - 'download': '''''', - 'upload': '''''', - 'star': '''''', - 'star-filled': '''''', - 'heart': '''''', - 'heart-filled': '''''', + "check": """""", + "x": """""", + "plus": """""", + "minus": """""", + "chevron-down": """""", + "chevron-up": """""", + "chevron-left": """""", + "chevron-right": """""", + "search": """""", + "menu": """""", + "user": """""", + "cog": """""", + "trash": """""", + "pencil": """""", + "eye": """""", + "eye-slash": """""", + "arrow-left": """""", + "arrow-right": """""", + "info": """""", + "warning": """""", + "error": """""", + "success": """""", + "loading": """""", + "external-link": """""", + "download": """""", + "upload": """""", + "star": """""", + "star-filled": """""", + "heart": """""", + "heart-filled": """""", } @@ -243,18 +246,18 @@ def icon(name, **kwargs): if not svg_template: # Return empty string for unknown icons (fail silently) - return '' + return "" # Build attributes string attrs_list = [] for key, value in kwargs.items(): # Convert underscore to hyphen for HTML attributes (e.g., aria_hidden -> aria-hidden) - attr_name = key.replace('_', '-') + attr_name = key.replace("_", "-") # Escape attribute values to prevent XSS safe_value = sanitize_attribute_value(str(value)) attrs_list.append(f'{attr_name}="{safe_value}"') - attrs_str = ' '.join(attrs_list) + attrs_str = " ".join(attrs_list) # Substitute attributes into template svg = svg_template.format(attrs=attrs_str) @@ -263,7 +266,7 @@ def icon(name, **kwargs): @register.simple_tag -def icon_class(name, size='w-5 h-5', extra_class=''): +def icon_class(name, size="w-5 h-5", extra_class=""): """ Render a trusted SVG icon with common class presets. @@ -278,5 +281,5 @@ def icon_class(name, size='w-5 h-5', extra_class=''): Returns: Safe HTML for the icon SVG """ - classes = f'{size} {extra_class}'.strip() - return icon(name, **{'class': classes}) + classes = f"{size} {extra_class}".strip() + return icon(name, **{"class": classes}) diff --git a/backend/apps/core/tests/test_history.py b/backend/apps/core/tests/test_history.py index 79a71f7e..4811d33a 100644 --- a/backend/apps/core/tests/test_history.py +++ b/backend/apps/core/tests/test_history.py @@ -1,4 +1,3 @@ - import pghistory import pytest from django.contrib.auth import get_user_model @@ -7,6 +6,7 @@ from apps.parks.models import Company, Park User = get_user_model() + @pytest.mark.django_db class TestTrackedModel: """ @@ -20,10 +20,7 @@ class TestTrackedModel: with pghistory.context(user=user.id): park = Park.objects.create( - name="History Test Park", - description="Testing history", - operating_season="Summer", - operator=company + name="History Test Park", description="Testing history", operating_season="Summer", operator=company ) # Verify history using the helper method from TrackedModel @@ -50,6 +47,5 @@ class TestTrackedModel: park.save() assert park.get_history().count() == 2 - latest = park.get_history().first() # Ordered by -pgh_created_at + latest = park.get_history().first() # Ordered by -pgh_created_at assert latest.name == "Updated" - diff --git a/backend/apps/core/urls/__init__.py b/backend/apps/core/urls/__init__.py index f70fe166..4b413933 100644 --- a/backend/apps/core/urls/__init__.py +++ b/backend/apps/core/urls/__init__.py @@ -17,9 +17,7 @@ app_name = "core" entity_patterns = [ path("search/", EntityFuzzySearchView.as_view(), name="entity_fuzzy_search"), path("not-found/", EntityNotFoundView.as_view(), name="entity_not_found"), - path( - "suggestions/", QuickEntitySuggestionView.as_view(), name="entity_suggestions" - ), + path("suggestions/", QuickEntitySuggestionView.as_view(), name="entity_suggestions"), ] # FSM transition endpoints diff --git a/backend/apps/core/utils/cloudflare.py b/backend/apps/core/utils/cloudflare.py index 54136ddc..187898ef 100644 --- a/backend/apps/core/utils/cloudflare.py +++ b/backend/apps/core/utils/cloudflare.py @@ -6,6 +6,7 @@ from django.core.exceptions import ImproperlyConfigured logger = logging.getLogger(__name__) + def get_direct_upload_url(user_id=None): """ Generates a direct upload URL for Cloudflare Images. @@ -20,13 +21,11 @@ def get_direct_upload_url(user_id=None): ImproperlyConfigured: If Cloudflare settings are missing. requests.RequestException: If the Cloudflare API request fails. """ - account_id = getattr(settings, 'CLOUDFLARE_IMAGES_ACCOUNT_ID', None) - api_token = getattr(settings, 'CLOUDFLARE_IMAGES_API_TOKEN', None) + account_id = getattr(settings, "CLOUDFLARE_IMAGES_ACCOUNT_ID", None) + api_token = getattr(settings, "CLOUDFLARE_IMAGES_API_TOKEN", None) if not account_id or not api_token: - raise ImproperlyConfigured( - "CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set." - ) + raise ImproperlyConfigured("CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set.") url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/images/v2/direct_upload" diff --git a/backend/apps/core/utils/file_scanner.py b/backend/apps/core/utils/file_scanner.py index 0954a651..414dcd4c 100644 --- a/backend/apps/core/utils/file_scanner.py +++ b/backend/apps/core/utils/file_scanner.py @@ -37,6 +37,7 @@ from django.core.files.uploadedfile import UploadedFile class FileValidationError(ValidationError): """Custom exception for file validation errors.""" + pass @@ -47,41 +48,49 @@ class FileValidationError(ValidationError): # Magic number signatures for common image formats # Format: (magic_bytes, offset, description) IMAGE_SIGNATURES = { - 'jpeg': [ - (b'\xFF\xD8\xFF\xE0', 0, 'JPEG (JFIF)'), - (b'\xFF\xD8\xFF\xE1', 0, 'JPEG (EXIF)'), - (b'\xFF\xD8\xFF\xE2', 0, 'JPEG (ICC)'), - (b'\xFF\xD8\xFF\xE3', 0, 'JPEG (Samsung)'), - (b'\xFF\xD8\xFF\xE8', 0, 'JPEG (SPIFF)'), - (b'\xFF\xD8\xFF\xDB', 0, 'JPEG (Raw)'), + "jpeg": [ + (b"\xff\xd8\xff\xe0", 0, "JPEG (JFIF)"), + (b"\xff\xd8\xff\xe1", 0, "JPEG (EXIF)"), + (b"\xff\xd8\xff\xe2", 0, "JPEG (ICC)"), + (b"\xff\xd8\xff\xe3", 0, "JPEG (Samsung)"), + (b"\xff\xd8\xff\xe8", 0, "JPEG (SPIFF)"), + (b"\xff\xd8\xff\xdb", 0, "JPEG (Raw)"), ], - 'png': [ - (b'\x89PNG\r\n\x1a\n', 0, 'PNG'), + "png": [ + (b"\x89PNG\r\n\x1a\n", 0, "PNG"), ], - 'gif': [ - (b'GIF87a', 0, 'GIF87a'), - (b'GIF89a', 0, 'GIF89a'), + "gif": [ + (b"GIF87a", 0, "GIF87a"), + (b"GIF89a", 0, "GIF89a"), ], - 'webp': [ - (b'RIFF', 0, 'RIFF'), # WebP starts with RIFF header + "webp": [ + (b"RIFF", 0, "RIFF"), # WebP starts with RIFF header ], - 'bmp': [ - (b'BM', 0, 'BMP'), + "bmp": [ + (b"BM", 0, "BMP"), ], } # All allowed MIME types -ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset({ - 'image/jpeg', - 'image/png', - 'image/gif', - 'image/webp', -}) +ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset( + { + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + } +) # Allowed file extensions -ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset({ - '.jpg', '.jpeg', '.png', '.gif', '.webp', -}) +ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset( + { + ".jpg", + ".jpeg", + ".png", + ".gif", + ".webp", + } +) # Maximum file size (10MB) MAX_FILE_SIZE = 10 * 1024 * 1024 @@ -94,6 +103,7 @@ MIN_FILE_SIZE = 100 # 100 bytes # File Validation Functions # ============================================================================= + def validate_image_upload( file: UploadedFile, max_size: int = MAX_FILE_SIZE, @@ -133,39 +143,29 @@ def validate_image_upload( # 2. Check file size if file.size > max_size: - raise FileValidationError( - f"File too large. Maximum size is {max_size // (1024 * 1024)}MB" - ) + raise FileValidationError(f"File too large. Maximum size is {max_size // (1024 * 1024)}MB") if file.size < MIN_FILE_SIZE: raise FileValidationError("File too small or empty") # 3. Check file extension - filename = file.name or '' + filename = file.name or "" ext = os.path.splitext(filename)[1].lower() if ext not in allowed_extensions: - raise FileValidationError( - f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}" - ) + raise FileValidationError(f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}") # 4. Check Content-Type header - content_type = getattr(file, 'content_type', '') + content_type = getattr(file, "content_type", "") if content_type and content_type not in allowed_types: - raise FileValidationError( - f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}" - ) + raise FileValidationError(f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}") # 5. Validate magic numbers (actual file content) if not _validate_magic_number(file): - raise FileValidationError( - "File content doesn't match file extension. File may be corrupted or malicious." - ) + raise FileValidationError("File content doesn't match file extension. File may be corrupted or malicious.") # 6. Validate image integrity using PIL if not _validate_image_integrity(file): - raise FileValidationError( - "Invalid or corrupted image file" - ) + raise FileValidationError("Invalid or corrupted image file") return True @@ -191,10 +191,10 @@ def _validate_magic_number(file: UploadedFile) -> bool: # Check against known signatures for format_name, signatures in IMAGE_SIGNATURES.items(): for magic, offset, _description in signatures: - if len(header) >= offset + len(magic) and header[offset:offset + len(magic)] == magic: + if len(header) >= offset + len(magic) and header[offset : offset + len(magic)] == magic: # Special handling for WebP (must also have WEBP marker) - if format_name == 'webp': - if len(header) >= 12 and header[8:12] == b'WEBP': + if format_name == "webp": + if len(header) >= 12 and header[8:12] == b"WEBP": return True else: return True @@ -233,9 +233,7 @@ def _validate_image_integrity(file: UploadedFile) -> bool: # Prevent decompression bombs max_dimension = 10000 if img2.width > max_dimension or img2.height > max_dimension: - raise FileValidationError( - f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}" - ) + raise FileValidationError(f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}") # Check for very small dimensions (might be suspicious) if img2.width < 1 or img2.height < 1: @@ -253,6 +251,7 @@ def _validate_image_integrity(file: UploadedFile) -> bool: # Filename Sanitization # ============================================================================= + def sanitize_filename(filename: str, max_length: int = 100) -> str: """ Sanitize a filename to prevent directory traversal and other attacks. @@ -281,13 +280,13 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str: # Remove or replace dangerous characters from name # Allow alphanumeric, hyphens, underscores, dots - name = re.sub(r'[^\w\-.]', '_', name) + name = re.sub(r"[^\w\-.]", "_", name) # Remove leading dots and underscores (hidden file prevention) - name = name.lstrip('._') + name = name.lstrip("._") # Collapse multiple underscores - name = re.sub(r'_+', '_', name) + name = re.sub(r"_+", "_", name) # Ensure name is not empty if not name: @@ -295,7 +294,7 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str: # Sanitize extension ext = ext.lower() - ext = re.sub(r'[^\w.]', '', ext) + ext = re.sub(r"[^\w.]", "", ext) # Combine and truncate result = f"{name[:max_length - len(ext)]}{ext}" @@ -303,7 +302,7 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str: return result -def generate_unique_filename(original_filename: str, prefix: str = '') -> str: +def generate_unique_filename(original_filename: str, prefix: str = "") -> str: """ Generate a unique filename using UUID while preserving extension. @@ -317,7 +316,7 @@ def generate_unique_filename(original_filename: str, prefix: str = '') -> str: ext = os.path.splitext(original_filename)[1].lower() # Sanitize extension - ext = re.sub(r'[^\w.]', '', ext) + ext = re.sub(r"[^\w.]", "", ext) # Generate unique filename unique_id = uuid.uuid4().hex[:12] @@ -332,9 +331,9 @@ def generate_unique_filename(original_filename: str, prefix: str = '') -> str: # Rate limiting configuration UPLOAD_RATE_LIMITS = { - 'per_minute': 10, - 'per_hour': 100, - 'per_day': 500, + "per_minute": 10, + "per_hour": 100, + "per_day": 500, } @@ -351,24 +350,25 @@ def check_upload_rate_limit(user_id: int, cache_backend=None) -> tuple[bool, str """ if cache_backend is None: from django.core.cache import cache + cache_backend = cache # Check per-minute limit minute_key = f"upload_rate:{user_id}:minute" minute_count = cache_backend.get(minute_key, 0) - if minute_count >= UPLOAD_RATE_LIMITS['per_minute']: + if minute_count >= UPLOAD_RATE_LIMITS["per_minute"]: return False, "Upload rate limit exceeded. Please wait a minute." # Check per-hour limit hour_key = f"upload_rate:{user_id}:hour" hour_count = cache_backend.get(hour_key, 0) - if hour_count >= UPLOAD_RATE_LIMITS['per_hour']: + if hour_count >= UPLOAD_RATE_LIMITS["per_hour"]: return False, "Hourly upload limit exceeded. Please try again later." # Check per-day limit day_key = f"upload_rate:{user_id}:day" day_count = cache_backend.get(day_key, 0) - if day_count >= UPLOAD_RATE_LIMITS['per_day']: + if day_count >= UPLOAD_RATE_LIMITS["per_day"]: return False, "Daily upload limit exceeded. Please try again tomorrow." return True, "" @@ -384,6 +384,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None: """ if cache_backend is None: from django.core.cache import cache + cache_backend = cache # Increment per-minute counter (expires in 60 seconds) @@ -412,6 +413,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None: # Antivirus Integration Point # ============================================================================= + def scan_file_for_malware(file: UploadedFile) -> tuple[bool, str]: """ Placeholder for antivirus/malware scanning integration. diff --git a/backend/apps/core/utils/html_sanitizer.py b/backend/apps/core/utils/html_sanitizer.py index c37c510f..5c992458 100644 --- a/backend/apps/core/utils/html_sanitizer.py +++ b/backend/apps/core/utils/html_sanitizer.py @@ -26,6 +26,7 @@ from typing import Any try: import bleach + BLEACH_AVAILABLE = True except ImportError: BLEACH_AVAILABLE = False @@ -36,71 +37,135 @@ except ImportError: # ============================================================================= # Default allowed HTML tags for user-generated content -ALLOWED_TAGS = frozenset([ - # Text formatting - 'p', 'br', 'hr', - 'strong', 'b', 'em', 'i', 'u', 's', 'strike', - 'sub', 'sup', 'small', 'mark', - - # Headers - 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', - - # Lists - 'ul', 'ol', 'li', - - # Links (with restrictions on attributes) - 'a', - - # Block elements - 'blockquote', 'pre', 'code', - 'div', 'span', - - # Tables - 'table', 'thead', 'tbody', 'tfoot', 'tr', 'th', 'td', -]) +ALLOWED_TAGS = frozenset( + [ + # Text formatting + "p", + "br", + "hr", + "strong", + "b", + "em", + "i", + "u", + "s", + "strike", + "sub", + "sup", + "small", + "mark", + # Headers + "h1", + "h2", + "h3", + "h4", + "h5", + "h6", + # Lists + "ul", + "ol", + "li", + # Links (with restrictions on attributes) + "a", + # Block elements + "blockquote", + "pre", + "code", + "div", + "span", + # Tables + "table", + "thead", + "tbody", + "tfoot", + "tr", + "th", + "td", + ] +) # Allowed attributes for each tag ALLOWED_ATTRIBUTES = { - 'a': ['href', 'title', 'rel', 'target'], - 'img': ['src', 'alt', 'title', 'width', 'height'], - 'div': ['class'], - 'span': ['class'], - 'p': ['class'], - 'table': ['class'], - 'th': ['class', 'colspan', 'rowspan'], - 'td': ['class', 'colspan', 'rowspan'], - '*': ['class'], # Allow class on all elements + "a": ["href", "title", "rel", "target"], + "img": ["src", "alt", "title", "width", "height"], + "div": ["class"], + "span": ["class"], + "p": ["class"], + "table": ["class"], + "th": ["class", "colspan", "rowspan"], + "td": ["class", "colspan", "rowspan"], + "*": ["class"], # Allow class on all elements } # Allowed URL protocols -ALLOWED_PROTOCOLS = frozenset([ - 'http', 'https', 'mailto', 'tel', -]) +ALLOWED_PROTOCOLS = frozenset( + [ + "http", + "https", + "mailto", + "tel", + ] +) # Minimal tags for comments and short text -MINIMAL_TAGS = frozenset([ - 'p', 'br', 'strong', 'b', 'em', 'i', 'a', -]) +MINIMAL_TAGS = frozenset( + [ + "p", + "br", + "strong", + "b", + "em", + "i", + "a", + ] +) # Tags allowed in icon SVGs (for icon template rendering) -SVG_TAGS = frozenset([ - 'svg', 'path', 'g', 'circle', 'rect', 'line', 'polyline', 'polygon', - 'ellipse', 'text', 'tspan', 'defs', 'use', 'symbol', 'clipPath', - 'mask', 'linearGradient', 'radialGradient', 'stop', 'title', -]) +SVG_TAGS = frozenset( + [ + "svg", + "path", + "g", + "circle", + "rect", + "line", + "polyline", + "polygon", + "ellipse", + "text", + "tspan", + "defs", + "use", + "symbol", + "clipPath", + "mask", + "linearGradient", + "radialGradient", + "stop", + "title", + ] +) SVG_ATTRIBUTES = { - 'svg': ['viewBox', 'width', 'height', 'fill', 'stroke', 'class', - 'xmlns', 'aria-hidden', 'role'], - 'path': ['d', 'fill', 'stroke', 'stroke-width', 'stroke-linecap', - 'stroke-linejoin', 'class', 'fill-rule', 'clip-rule'], - 'g': ['fill', 'stroke', 'transform', 'class'], - 'circle': ['cx', 'cy', 'r', 'fill', 'stroke', 'class'], - 'rect': ['x', 'y', 'width', 'height', 'rx', 'ry', 'fill', 'stroke', 'class'], - 'line': ['x1', 'y1', 'x2', 'y2', 'stroke', 'stroke-width', 'class'], - 'polyline': ['points', 'fill', 'stroke', 'class'], - 'polygon': ['points', 'fill', 'stroke', 'class'], - '*': ['class', 'fill', 'stroke'], + "svg": ["viewBox", "width", "height", "fill", "stroke", "class", "xmlns", "aria-hidden", "role"], + "path": [ + "d", + "fill", + "stroke", + "stroke-width", + "stroke-linecap", + "stroke-linejoin", + "class", + "fill-rule", + "clip-rule", + ], + "g": ["fill", "stroke", "transform", "class"], + "circle": ["cx", "cy", "r", "fill", "stroke", "class"], + "rect": ["x", "y", "width", "height", "rx", "ry", "fill", "stroke", "class"], + "line": ["x1", "y1", "x2", "y2", "stroke", "stroke-width", "class"], + "polyline": ["points", "fill", "stroke", "class"], + "polygon": ["points", "fill", "stroke", "class"], + "*": ["class", "fill", "stroke"], } @@ -108,6 +173,7 @@ SVG_ATTRIBUTES = { # Sanitization Functions # ============================================================================= + def sanitize_html( html: str | None, allowed_tags: frozenset | None = None, @@ -133,7 +199,7 @@ def sanitize_html( '

Hello

' """ if not html: - return '' + return "" if not isinstance(html, str): html = str(html) @@ -170,7 +236,7 @@ def sanitize_minimal(html: str | None) -> str: return sanitize_html( html, allowed_tags=MINIMAL_TAGS, - allowed_attributes={'a': ['href', 'title']}, + allowed_attributes={"a": ["href", "title"]}, ) @@ -188,7 +254,7 @@ def sanitize_svg(svg: str | None) -> str: Sanitized SVG string safe for inline rendering """ if not svg: - return '' + return "" if not isinstance(svg, str): svg = str(svg) @@ -218,7 +284,7 @@ def strip_html(html: str | None) -> str: Plain text with all HTML tags removed """ if not html: - return '' + return "" if not isinstance(html, str): html = str(html) @@ -227,13 +293,14 @@ def strip_html(html: str | None) -> str: return bleach.clean(html, tags=[], strip=True) else: # Fallback: use regex to strip tags - return re.sub(r'<[^>]+>', '', html) + return re.sub(r"<[^>]+>", "", html) # ============================================================================= # JSON/JavaScript Context Sanitization # ============================================================================= + def sanitize_for_json(data: Any) -> str: """ Safely serialize data for embedding in JavaScript/JSON contexts. @@ -251,14 +318,12 @@ def sanitize_for_json(data: Any) -> str: '{"name": "\\u003c/script\\u003e\\u003cscript\\u003ealert(\\"xss\\")"}' """ # JSON encode with safe characters escaped - return json.dumps(data, ensure_ascii=False).replace( - '<', '\\u003c' - ).replace( - '>', '\\u003e' - ).replace( - '&', '\\u0026' - ).replace( - "'", '\\u0027' + return ( + json.dumps(data, ensure_ascii=False) + .replace("<", "\\u003c") + .replace(">", "\\u003e") + .replace("&", "\\u0026") + .replace("'", "\\u0027") ) @@ -273,26 +338,21 @@ def escape_js_string(s: str | None) -> str: Escaped string safe for JavaScript contexts """ if not s: - return '' + return "" if not isinstance(s, str): s = str(s) # Escape backslashes first, then other special characters - return s.replace('\\', '\\\\').replace( - "'", "\\'" - ).replace( - '"', '\\"' - ).replace( - '\n', '\\n' - ).replace( - '\r', '\\r' - ).replace( - '<', '\\u003c' - ).replace( - '>', '\\u003e' - ).replace( - '&', '\\u0026' + return ( + s.replace("\\", "\\\\") + .replace("'", "\\'") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("<", "\\u003c") + .replace(">", "\\u003e") + .replace("&", "\\u0026") ) @@ -300,6 +360,7 @@ def escape_js_string(s: str | None) -> str: # URL Sanitization # ============================================================================= + def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> str: """ Sanitize a URL to prevent javascript: and other dangerous protocols. @@ -312,7 +373,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> Sanitized URL or empty string if unsafe """ if not url: - return '' + return "" if not isinstance(url, str): url = str(url) @@ -320,7 +381,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> url = url.strip() if not url: - return '' + return "" protocols = allowed_protocols if allowed_protocols is not None else ALLOWED_PROTOCOLS @@ -328,12 +389,12 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> url_lower = url.lower() # Check for javascript:, data:, vbscript:, etc. - if ':' in url_lower: - protocol = url_lower.split(':')[0] - if protocol not in protocols: + if ":" in url_lower: + protocol = url_lower.split(":")[0] + if protocol not in protocols: # noqa: SIM102 # Allow relative URLs and anchor links - if not (url.startswith('/') or url.startswith('#') or url.startswith('?')): - return '' + if not (url.startswith("/") or url.startswith("#") or url.startswith("?")): + return "" return url @@ -342,6 +403,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> # Attribute Sanitization # ============================================================================= + def sanitize_attribute_value(value: str | None) -> str: """ Sanitize a value for use in HTML attributes. @@ -353,7 +415,7 @@ def sanitize_attribute_value(value: str | None) -> str: Sanitized value safe for HTML attribute contexts """ if not value: - return '' + return "" if not isinstance(value, str): value = str(value) @@ -373,10 +435,10 @@ def sanitize_class_name(name: str | None) -> str: Sanitized class name containing only safe characters """ if not name: - return '' + return "" if not isinstance(name, str): name = str(name) # Only allow alphanumeric, hyphens, and underscores - return re.sub(r'[^a-zA-Z0-9_-]', '', name) + return re.sub(r"[^a-zA-Z0-9_-]", "", name) diff --git a/backend/apps/core/utils/query_optimization.py b/backend/apps/core/utils/query_optimization.py index c021e194..63b132ac 100644 --- a/backend/apps/core/utils/query_optimization.py +++ b/backend/apps/core/utils/query_optimization.py @@ -16,9 +16,7 @@ logger = logging.getLogger("query_optimization") @contextmanager -def track_queries( - operation_name: str, warn_threshold: int = 10, time_threshold: float = 1.0 -): +def track_queries(operation_name: str, warn_threshold: int = 10, time_threshold: float = 1.0): """ Context manager to track database queries for specific operations @@ -47,15 +45,9 @@ def track_queries( recent_queries = connection.queries[-total_queries:] query_details = [ { - "sql": ( - query["sql"][:500] + "..." - if len(query["sql"]) > 500 - else query["sql"] - ), + "sql": (query["sql"][:500] + "..." if len(query["sql"]) > 500 else query["sql"]), "time": float(query["time"]), - "duplicate_count": sum( - 1 for q in recent_queries if q["sql"] == query["sql"] - ), + "duplicate_count": sum(1 for q in recent_queries if q["sql"] == query["sql"]), } for query in recent_queries ] @@ -65,22 +57,18 @@ def track_queries( "query_count": total_queries, "execution_time": execution_time, "queries": query_details if settings.DEBUG else [], - "slow_queries": [ - q for q in query_details if q["time"] > 0.1 - ], # Queries slower than 100ms + "slow_queries": [q for q in query_details if q["time"] > 0.1], # Queries slower than 100ms } # Log warnings for performance issues if total_queries > warn_threshold or execution_time > time_threshold: logger.warning( - f"Performance concern in {operation_name}: " - f"{total_queries} queries, {execution_time:.2f}s", + f"Performance concern in {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s", extra=performance_data, ) else: logger.debug( - f"Query tracking for {operation_name}: " - f"{total_queries} queries, {execution_time:.2f}s", + f"Query tracking for {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s", extra=performance_data, ) @@ -109,9 +97,7 @@ class QueryOptimizer: Optimize Ride queryset with proper relationships """ return ( - queryset.select_related( - "park", "park__location", "manufacturer", "created_by" - ) + queryset.select_related("park", "park__location", "manufacturer", "created_by") .prefetch_related("reviews__user", "media_items") .annotate( review_count=Count("reviews"), @@ -158,9 +144,7 @@ class QueryCache: """Caching utilities for expensive queries""" @staticmethod - def cache_queryset_result( - cache_key: str, queryset_func, timeout: int = 3600, **kwargs - ): + def cache_queryset_result(cache_key: str, queryset_func, timeout: int = 3600, **kwargs): """ Cache the result of an expensive queryset operation @@ -202,13 +186,9 @@ class QueryCache: # For Redis cache backends that support pattern deletion if hasattr(cache, "delete_pattern"): deleted_count = cache.delete_pattern(pattern) - logger.info( - f"Invalidated {deleted_count} cache keys for pattern: {pattern}" - ) + logger.info(f"Invalidated {deleted_count} cache keys for pattern: {pattern}") else: - logger.warning( - f"Cache backend does not support pattern deletion: {pattern}" - ) + logger.warning(f"Cache backend does not support pattern deletion: {pattern}") except Exception as e: logger.error(f"Error invalidating cache pattern {pattern}: {e}") @@ -249,10 +229,7 @@ class IndexAnalyzer: sql_upper = sql.upper() analysis = { "has_where_clause": "WHERE" in sql_upper, - "has_join": any( - join in sql_upper - for join in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"] - ), + "has_join": any(join in sql_upper for join in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"]), "has_order_by": "ORDER BY" in sql_upper, "has_group_by": "GROUP BY" in sql_upper, "has_like": "LIKE" in sql_upper, @@ -266,19 +243,13 @@ class IndexAnalyzer: # Suggest indexes based on patterns if analysis["has_where_clause"] and not analysis["has_join"]: - analysis["suggestions"].append( - "Consider adding indexes on WHERE clause columns" - ) + analysis["suggestions"].append("Consider adding indexes on WHERE clause columns") if analysis["has_order_by"]: - analysis["suggestions"].append( - "Consider adding indexes on ORDER BY columns" - ) + analysis["suggestions"].append("Consider adding indexes on ORDER BY columns") if analysis["has_like"] and "%" not in sql[: sql.find("LIKE") + 10]: - analysis["suggestions"].append( - "LIKE queries with leading wildcards cannot use indexes efficiently" - ) + analysis["suggestions"].append("LIKE queries with leading wildcards cannot use indexes efficiently") return analysis @@ -294,28 +265,16 @@ class IndexAnalyzer: # automatically) for field in opts.fields: if isinstance(field, models.ForeignKey): - suggestions.append( - f"Index on {field.name} (automatically created by Django)" - ) + suggestions.append(f"Index on {field.name} (automatically created by Django)") # Suggest composite indexes for common query patterns - date_fields = [ - f.name - for f in opts.fields - if isinstance(f, (models.DateField, models.DateTimeField)) - ] - status_fields = [ - f.name - for f in opts.fields - if f.name in ["status", "is_active", "is_published"] - ] + date_fields = [f.name for f in opts.fields if isinstance(f, models.DateField | models.DateTimeField)] + status_fields = [f.name for f in opts.fields if f.name in ["status", "is_active", "is_published"]] if date_fields and status_fields: for date_field in date_fields: for status_field in status_fields: - suggestions.append( - f"Composite index on ({status_field}, {date_field}) for filtered date queries" - ) + suggestions.append(f"Composite index on ({status_field}, {date_field}) for filtered date queries") # Suggest indexes for fields commonly used in WHERE clauses common_filter_fields = ["slug", "name", "created_at", "updated_at"] @@ -340,9 +299,7 @@ def log_query_performance(): return decorator -def optimize_queryset_for_serialization( - queryset: QuerySet, fields: list[str] -) -> QuerySet: +def optimize_queryset_for_serialization(queryset: QuerySet, fields: list[str]) -> QuerySet: """ Optimize a queryset for API serialization by only selecting needed fields @@ -362,9 +319,7 @@ def optimize_queryset_for_serialization( field = opts.get_field(field_name) if isinstance(field, models.ForeignKey): select_related_fields.append(field_name) - elif isinstance( - field, (models.ManyToManyField, models.reverse.ManyToManyRel) - ): + elif isinstance(field, models.ManyToManyField | models.reverse.ManyToManyRel): prefetch_related_fields.append(field_name) except models.FieldDoesNotExist: # Field might be a property or method, skip optimization @@ -421,7 +376,6 @@ def monitor_db_performance(operation_name: str): ) else: logger.debug( - f"DB performance for {operation_name}: " - f"{duration:.3f}s, {total_queries} queries", + f"DB performance for {operation_name}: " f"{duration:.3f}s, {total_queries} queries", extra=performance_data, ) diff --git a/backend/apps/core/utils/turnstile.py b/backend/apps/core/utils/turnstile.py index 5e7b071b..196d8d6e 100644 --- a/backend/apps/core/utils/turnstile.py +++ b/backend/apps/core/utils/turnstile.py @@ -4,6 +4,7 @@ Cloudflare Turnstile validation utilities. This module provides a function to validate Turnstile tokens on the server side before processing form submissions. """ + import requests from django.conf import settings @@ -20,45 +21,41 @@ def validate_turnstile_token(token: str, ip: str = None) -> dict: dict with 'success' boolean and optional 'error' message """ # Skip validation if configured (dev mode) - if getattr(settings, 'TURNSTILE_SKIP_VALIDATION', False): - return {'success': True} + if getattr(settings, "TURNSTILE_SKIP_VALIDATION", False): + return {"success": True} - secret = getattr(settings, 'TURNSTILE_SECRET', '') + secret = getattr(settings, "TURNSTILE_SECRET", "") if not secret: - return {'success': True} # Skip if no secret configured + return {"success": True} # Skip if no secret configured if not token: - return {'success': False, 'error': 'Captcha verification required'} + return {"success": False, "error": "Captcha verification required"} try: response = requests.post( - 'https://challenges.cloudflare.com/turnstile/v0/siteverify', + "https://challenges.cloudflare.com/turnstile/v0/siteverify", data={ - 'secret': secret, - 'response': token, - 'remoteip': ip, + "secret": secret, + "response": token, + "remoteip": ip, }, - timeout=10 + timeout=10, ) result = response.json() - if result.get('success'): - return {'success': True} + if result.get("success"): + return {"success": True} else: - error_codes = result.get('error-codes', []) - return { - 'success': False, - 'error': 'Captcha verification failed', - 'error_codes': error_codes - } + error_codes = result.get("error-codes", []) + return {"success": False, "error": "Captcha verification failed", "error_codes": error_codes} except requests.RequestException: # Log error but don't block user on network issues - return {'success': True} # Fail open to avoid blocking legitimate users + return {"success": True} # Fail open to avoid blocking legitimate users def get_client_ip(request): """Extract client IP from request, handling proxies.""" - x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") if x_forwarded_for: - return x_forwarded_for.split(',')[0].strip() - return request.META.get('REMOTE_ADDR') + return x_forwarded_for.split(",")[0].strip() + return request.META.get("REMOTE_ADDR") diff --git a/backend/apps/core/views/base.py b/backend/apps/core/views/base.py index 4e4cb4bc..8f1a3c00 100644 --- a/backend/apps/core/views/base.py +++ b/backend/apps/core/views/base.py @@ -5,7 +5,6 @@ This module provides base view classes that implement common patterns such as automatic query optimization with select_related and prefetch_related. """ - from django.db.models import QuerySet from django.views.generic import DetailView, ListView diff --git a/backend/apps/core/views/entity_search.py b/backend/apps/core/views/entity_search.py index bed79b66..4bb3e7fe 100644 --- a/backend/apps/core/views/entity_search.py +++ b/backend/apps/core/views/entity_search.py @@ -2,7 +2,6 @@ Entity search views with fuzzy matching and authentication prompts. """ - import contextlib from rest_framework import status @@ -67,9 +66,7 @@ class EntityFuzzySearchView(APIView): try: # Parse request data query = request.data.get("query", "").strip() - entity_types_raw = request.data.get( - "entity_types", ["park", "ride", "company"] - ) + entity_types_raw = request.data.get("entity_types", ["park", "ride", "company"]) include_suggestions = request.data.get("include_suggestions", True) # Validate query @@ -105,9 +102,7 @@ class EntityFuzzySearchView(APIView): "query": query, "matches": [match.to_dict() for match in matches], "user_authenticated": ( - request.user.is_authenticated - if hasattr(request.user, "is_authenticated") - else False + request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False ), } @@ -211,9 +206,7 @@ class EntityNotFoundView(APIView): "context": context, "matches": [match.to_dict() for match in matches], "user_authenticated": ( - request.user.is_authenticated - if hasattr(request.user, "is_authenticated") - else False + request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False ), "has_matches": len(matches) > 0, } @@ -267,9 +260,7 @@ class QuickEntitySuggestionView(APIView): limit = min(int(request.GET.get("limit", 5)), 10) # Cap at 10 if not query or len(query) < 2: - return Response( - {"suggestions": [], "query": query}, status=status.HTTP_200_OK - ) + return Response({"suggestions": [], "query": query}, status=status.HTTP_200_OK) # Parse entity types entity_types = [] @@ -282,9 +273,7 @@ class QuickEntitySuggestionView(APIView): entity_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY] # Get fuzzy matches - matches, _ = entity_fuzzy_matcher.find_entity( - query=query, entity_types=entity_types, user=request.user - ) + matches, _ = entity_fuzzy_matcher.find_entity(query=query, entity_types=entity_types, user=request.user) # Format as simple suggestions suggestions = [] @@ -313,9 +302,7 @@ class QuickEntitySuggestionView(APIView): # Utility function for other views to use -def get_entity_suggestions( - query: str, entity_types: list[str] | None = None, user=None -): +def get_entity_suggestions(query: str, entity_types: list[str] | None = None, user=None): """ Utility function for other Django views to get entity suggestions. @@ -340,8 +327,6 @@ def get_entity_suggestions( if not parsed_types: parsed_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY] - return entity_fuzzy_matcher.find_entity( - query=query, entity_types=parsed_types, user=user - ) + return entity_fuzzy_matcher.find_entity(query=query, entity_types=parsed_types, user=user) except Exception: return [], None diff --git a/backend/apps/core/views/map_views.py b/backend/apps/core/views/map_views.py index 4318f98b..2ec2a16e 100644 --- a/backend/apps/core/views/map_views.py +++ b/backend/apps/core/views/map_views.py @@ -86,7 +86,7 @@ class MapAPIView(View): return bounds return None except (ValueError, TypeError) as e: - raise ValidationError(f"Invalid bounds parameters: {e}") + raise ValidationError(f"Invalid bounds parameters: {e}") from None def _parse_pagination(self, request: HttpRequest) -> dict[str, int]: """Parse pagination parameters from request.""" @@ -125,11 +125,7 @@ class MapAPIView(View): if location_types_param: type_strings = location_types_param.split(",") valid_types = {lt.value for lt in LocationType} - filters.location_types = { - LocationType(t.strip()) - for t in type_strings - if t.strip() in valid_types - } + filters.location_types = {LocationType(t.strip()) for t in type_strings if t.strip() in valid_types} # Park status park_status_param = request.GET.get("park_status") @@ -199,7 +195,7 @@ class MapAPIView(View): ) except (ValueError, TypeError) as e: - raise ValidationError(f"Invalid filter parameters: {e}") + raise ValidationError(f"Invalid filter parameters: {e}") from None def _parse_zoom_level(self, request: HttpRequest) -> int: """Parse zoom level from request with default.""" @@ -218,9 +214,7 @@ class MapAPIView(View): request: HttpRequest, ) -> dict[str, Any]: """Create paginated response with metadata.""" - total_pages = (total_count + pagination["page_size"] - 1) // pagination[ - "page_size" - ] + total_pages = (total_count + pagination["page_size"] - 1) // pagination["page_size"] # Build pagination URLs base_url = request.build_absolute_uri(request.path) @@ -278,9 +272,7 @@ class MapAPIView(View): return JsonResponse(response_data, status=status) - def _success_response( - self, data: Any, message: str = None, metadata: dict[str, Any] = None - ) -> JsonResponse: + def _success_response(self, data: Any, message: str = None, metadata: dict[str, Any] = None) -> JsonResponse: """Return standardized success response.""" response_data = { "status": "success", @@ -397,9 +389,7 @@ class MapLocationDetailView(MapAPIView): """ @method_decorator(cache_page(600)) # Cache for 10 minutes - def get( - self, request: HttpRequest, location_type: str, location_id: int - ) -> JsonResponse: + def get(self, request: HttpRequest, location_type: str, location_id: int) -> JsonResponse: """Get detailed information for a specific location.""" try: # Validate location type @@ -422,9 +412,7 @@ class MapLocationDetailView(MapAPIView): ) # Get location details - location = unified_map_service.get_location_details( - location_type, location_id - ) + location = unified_map_service.get_location_details(location_type, location_id) if not location: return self._error_response( @@ -499,9 +487,7 @@ class MapSearchView(MapAPIView): try: valid_types = {lt.value for lt in LocationType} location_types = { - LocationType(t.strip()) - for t in types_param.split(",") - if t.strip() in valid_types + LocationType(t.strip()) for t in types_param.split(",") if t.strip() in valid_types } except ValueError: return self._error_response( @@ -569,9 +555,7 @@ class MapBoundsView(MapAPIView): # Parse required bounds bounds = self._parse_bounds(request) if not bounds: - return self._error_response( - "Bounds parameters required: north, south, east, west", 400 - ) + return self._error_response("Bounds parameters required: north, south, east, west", 400) # Parse optional filters location_types = None diff --git a/backend/apps/core/views/maps.py b/backend/apps/core/views/maps.py index bd30125a..1001881b 100644 --- a/backend/apps/core/views/maps.py +++ b/backend/apps/core/views/maps.py @@ -74,9 +74,7 @@ class UniversalMapView(MapViewMixin, TemplateView): ) # Handle initial bounds from query parameters - if all( - param in self.request.GET for param in ["north", "south", "east", "west"] - ): + if all(param in self.request.GET for param in ["north", "south", "east", "west"]): with contextlib.suppress(ValueError, TypeError): context["initial_bounds"] = { "north": float(self.request.GET["north"]), @@ -243,9 +241,7 @@ class LocationSearchView(MapViewMixin, View): limit = min(20, max(5, int(request.GET.get("limit", "10")))) # Perform search - results = unified_map_service.search_locations( - query=query, location_types=location_types, limit=limit - ) + results = unified_map_service.search_locations(query=query, location_types=location_types, limit=limit) return render( request, @@ -285,11 +281,7 @@ class MapBoundsUpdateView(MapViewMixin, View): zoom_level = int(data.get("zoom", 10)) location_types = None if "types" in data: - location_types = { - LocationType(t) - for t in data["types"] - if t in [lt.value for lt in LocationType] - } + location_types = {LocationType(t) for t in data["types"] if t in [lt.value for lt in LocationType]} # Location types are used directly in the service call @@ -324,9 +316,7 @@ class LocationDetailModalView(MapViewMixin, View): URL: /maps/htmx/location/// """ - def get( - self, request: HttpRequest, location_type: str, location_id: int - ) -> HttpResponse: + def get(self, request: HttpRequest, location_type: str, location_id: int) -> HttpResponse: """Return location detail modal content.""" try: # Validate location type @@ -338,9 +328,7 @@ class LocationDetailModalView(MapViewMixin, View): ) # Get location details - location = unified_map_service.get_location_details( - location_type, location_id - ) + location = unified_map_service.get_location_details(location_type, location_id) if not location: return render( @@ -356,9 +344,7 @@ class LocationDetailModalView(MapViewMixin, View): ) except Exception as e: - return render( - request, "maps/partials/location_modal.html", {"error": str(e)} - ) + return render(request, "maps/partials/location_modal.html", {"error": str(e)}) class LocationListView(MapViewMixin, TemplateView): @@ -392,9 +378,7 @@ class LocationListView(MapViewMixin, TemplateView): ) # Get locations without clustering - map_response = unified_map_service.get_map_data( - filters=filters, cluster=False, use_cache=True - ) + map_response = unified_map_service.get_map_data(filters=filters, cluster=False, use_cache=True) # Paginate results paginator = Paginator(map_response.locations, self.paginate_by) diff --git a/backend/apps/core/views/performance_dashboard.py b/backend/apps/core/views/performance_dashboard.py index f629b29c..31c95fda 100644 --- a/backend/apps/core/views/performance_dashboard.py +++ b/backend/apps/core/views/performance_dashboard.py @@ -83,13 +83,15 @@ class PerformanceDashboardView(TemplateView): try: client = cache._cache.get_client() info = client.info() - cache_stats.update({ - "connected_clients": info.get("connected_clients"), - "used_memory_human": info.get("used_memory_human"), - "keyspace_hits": info.get("keyspace_hits", 0), - "keyspace_misses": info.get("keyspace_misses", 0), - "total_commands": info.get("total_commands_processed"), - }) + cache_stats.update( + { + "connected_clients": info.get("connected_clients"), + "used_memory_human": info.get("used_memory_human"), + "keyspace_hits": info.get("keyspace_hits", 0), + "keyspace_misses": info.get("keyspace_misses", 0), + "total_commands": info.get("total_commands_processed"), + } + ) # Calculate hit rate hits = info.get("keyspace_hits", 0) @@ -127,8 +129,7 @@ class PerformanceDashboardView(TemplateView): # Get connection count (PostgreSQL specific) try: cursor.execute( - "SELECT count(*) FROM pg_stat_activity WHERE datname = %s;", - [db_settings.get("NAME")] + "SELECT count(*) FROM pg_stat_activity WHERE datname = %s;", [db_settings.get("NAME")] ) stats["active_connections"] = cursor.fetchone()[0] except Exception: @@ -244,16 +245,18 @@ class CacheStatsAPIView(View): client = cache._cache.get_client() info = client.info() - cache_info.update({ - "used_memory": info.get("used_memory_human"), - "connected_clients": info.get("connected_clients"), - "keyspace_hits": info.get("keyspace_hits", 0), - "keyspace_misses": info.get("keyspace_misses", 0), - "expired_keys": info.get("expired_keys", 0), - "evicted_keys": info.get("evicted_keys", 0), - "total_connections_received": info.get("total_connections_received"), - "total_commands_processed": info.get("total_commands_processed"), - }) + cache_info.update( + { + "used_memory": info.get("used_memory_human"), + "connected_clients": info.get("connected_clients"), + "keyspace_hits": info.get("keyspace_hits", 0), + "keyspace_misses": info.get("keyspace_misses", 0), + "expired_keys": info.get("expired_keys", 0), + "evicted_keys": info.get("evicted_keys", 0), + "total_connections_received": info.get("total_connections_received"), + "total_commands_processed": info.get("total_commands_processed"), + } + ) # Calculate metrics hits = info.get("keyspace_hits", 0) diff --git a/backend/apps/core/views/search.py b/backend/apps/core/views/search.py index ad59e123..6da3b57a 100644 --- a/backend/apps/core/views/search.py +++ b/backend/apps/core/views/search.py @@ -18,11 +18,7 @@ class AdaptiveSearchView(TemplateView): """ Get the base queryset, optimized with select_related and prefetch_related """ - return ( - Park.objects.select_related("operator", "property_owner") - .prefetch_related("location", "photos") - .all() - ) + return Park.objects.select_related("operator", "property_owner").prefetch_related("location", "photos").all() def get_filterset(self): """ @@ -46,9 +42,7 @@ class AdaptiveSearchView(TemplateView): { "results": filterset.qs, "filters": filterset, - "applied_filters": bool( - self.request.GET - ), # Check if any filters are applied + "applied_filters": bool(self.request.GET), # Check if any filters are applied "is_location_search": bool(location_search or near_location), "location_search_query": location_search or near_location, } diff --git a/backend/apps/core/views/views.py b/backend/apps/core/views/views.py index 8b204933..93db646f 100644 --- a/backend/apps/core/views/views.py +++ b/backend/apps/core/views/views.py @@ -54,9 +54,7 @@ class SlugRedirectMixin(View): # Build kwargs for reverse() reverse_kwargs = self.get_redirect_url_kwargs() # Redirect to the current slug URL - return redirect( - reverse(url_pattern, kwargs=reverse_kwargs), permanent=True - ) + return redirect(reverse(url_pattern, kwargs=reverse_kwargs), permanent=True) return super().dispatch(request, *args, **kwargs) except Exception: # pylint: disable=broad-exception-caught # Fallback to default dispatch on any error (e.g. object not found) @@ -67,9 +65,7 @@ class SlugRedirectMixin(View): Get the URL pattern name for redirects. Should be overridden by subclasses. """ - raise NotImplementedError( - "Subclasses must implement get_redirect_url_pattern()" - ) + raise NotImplementedError("Subclasses must implement get_redirect_url_pattern()") def get_redirect_url_kwargs(self) -> dict[str, Any]: """ @@ -202,9 +198,7 @@ def get_transition_metadata(transition_name: str) -> dict[str, Any]: return TRANSITION_METADATA["default"].copy() -def add_toast_trigger( - response: HttpResponse, message: str, toast_type: str = "success" -) -> HttpResponse: +def add_toast_trigger(response: HttpResponse, message: str, toast_type: str = "success") -> HttpResponse: """ Add HX-Trigger header to trigger Alpine.js toast. @@ -256,16 +250,12 @@ class FSMTransitionView(View): The model class or None if not found """ try: - content_type = ContentType.objects.get( - app_label=app_label, model=model_name - ) + content_type = ContentType.objects.get(app_label=app_label, model=model_name) return content_type.model_class() except ContentType.DoesNotExist: return None - def get_object( - self, model_class: type[Model], pk: Any, slug: str | None = None - ) -> Model: + def get_object(self, model_class: type[Model], pk: Any, slug: str | None = None) -> Model: """ Get the model instance. @@ -297,9 +287,7 @@ class FSMTransitionView(View): """ return getattr(obj, transition_name, None) - def validate_transition( - self, obj: Model, transition_name: str, user - ) -> tuple[bool, str | None]: + def validate_transition(self, obj: Model, transition_name: str, user) -> tuple[bool, str | None]: """ Validate that the transition can proceed. @@ -331,9 +319,7 @@ class FSMTransitionView(View): return True, None - def execute_transition( - self, obj: Model, transition_name: str, user, **kwargs - ) -> None: + def execute_transition(self, obj: Model, transition_name: str, user, **kwargs) -> None: """ Execute the transition on the object. @@ -355,9 +341,7 @@ class FSMTransitionView(View): def get_success_message(self, obj: Model, transition_name: str) -> str: """Generate a success message for the transition.""" # Clean up transition name for display - display_name = ( - transition_name.replace("transition_to_", "").replace("_", " ").title() - ) + display_name = transition_name.replace("transition_to_", "").replace("_", " ").title() model_name = obj._meta.verbose_name.title() return f"{model_name} has been {display_name.lower()}d successfully." @@ -404,9 +388,7 @@ class FSMTransitionView(View): except TemplateDoesNotExist: return "htmx/updated_row.html" - def format_success_response( - self, request: HttpRequest, obj: Model, transition_name: str - ) -> HttpResponse: + def format_success_response(self, request: HttpRequest, obj: Model, transition_name: str) -> HttpResponse: """ Format a successful transition response. @@ -443,17 +425,11 @@ class FSMTransitionView(View): { "success": True, "message": message, - "new_state": ( - getattr(obj, obj.state_field_name, None) - if hasattr(obj, "state_field_name") - else None - ), + "new_state": (getattr(obj, obj.state_field_name, None) if hasattr(obj, "state_field_name") else None), } ) - def format_error_response( - self, request: HttpRequest, error: Exception, status_code: int = 400 - ) -> HttpResponse: + def format_error_response(self, request: HttpRequest, error: Exception, status_code: int = 400) -> HttpResponse: """ Format an error response. @@ -489,36 +465,26 @@ class FSMTransitionView(View): if not all([app_label, model_name, transition_name]): return self.format_error_response( request, - ValueError( - "Missing required parameters: app_label, model_name, and transition_name" - ), + ValueError("Missing required parameters: app_label, model_name, and transition_name"), 400, ) if not pk and not slug: - return self.format_error_response( - request, ValueError("Missing required parameter: pk or slug"), 400 - ) + return self.format_error_response(request, ValueError("Missing required parameter: pk or slug"), 400) # Get the model class model_class = self.get_model_class(app_label, model_name) if model_class is None: - return self.format_error_response( - request, ValueError(f"Model '{app_label}.{model_name}' not found"), 404 - ) + return self.format_error_response(request, ValueError(f"Model '{app_label}.{model_name}' not found"), 404) # Get the object try: obj = self.get_object(model_class, pk, slug) except ObjectDoesNotExist: - return self.format_error_response( - request, ValueError(f"Object not found: {model_name} with pk={pk}"), 404 - ) + return self.format_error_response(request, ValueError(f"Object not found: {model_name} with pk={pk}"), 404) # Validate the transition - can_execute, error_msg = self.validate_transition( - obj, transition_name, request.user - ) + can_execute, error_msg = self.validate_transition(obj, transition_name, request.user) if not can_execute: return self.format_error_response( request, @@ -561,15 +527,11 @@ class FSMTransitionView(View): return self.format_error_response(request, e, 400) except TransitionNotAllowed as e: - logger.warning( - f"Transition not allowed: '{transition_name}' on {model_class.__name__}(pk={obj.pk}): {e}" - ) + logger.warning(f"Transition not allowed: '{transition_name}' on {model_class.__name__}(pk={obj.pk}): {e}") return self.format_error_response(request, e, 400) except Exception as e: logger.exception( f"Unexpected error during transition '{transition_name}' on {model_class.__name__}(pk={obj.pk})" ) - return self.format_error_response( - request, ValueError(f"An unexpected error occurred: {str(e)}"), 500 - ) + return self.format_error_response(request, ValueError(f"An unexpected error occurred: {str(e)}"), 500) diff --git a/backend/apps/lists/admin.py b/backend/apps/lists/admin.py index 28964510..19d4cf3d 100644 --- a/backend/apps/lists/admin.py +++ b/backend/apps/lists/admin.py @@ -14,6 +14,7 @@ from .models import ListItem, UserList class ListItemInline(admin.TabularInline): """Inline admin for ListItem within UserList admin.""" + model = ListItem extra = 1 fields = ("content_type", "object_id", "rank", "notes") @@ -24,6 +25,7 @@ class ListItemInline(admin.TabularInline): @admin.register(UserList) class UserListAdmin(QueryOptimizationMixin, ExportActionMixin, TimestampFieldsMixin, BaseModelAdmin): """Admin interface for UserList.""" + list_display = ( "title", "user_link", @@ -65,6 +67,7 @@ class UserListAdmin(QueryOptimizationMixin, ExportActionMixin, TimestampFieldsMi def user_link(self, obj): if obj.user: from django.urls import reverse + url = reverse("admin:accounts_customuser_change", args=[obj.user.pk]) return format_html('{}', url, obj.user.username) return "-" @@ -82,6 +85,7 @@ class UserListAdmin(QueryOptimizationMixin, ExportActionMixin, TimestampFieldsMi @admin.register(ListItem) class ListItemAdmin(QueryOptimizationMixin, BaseModelAdmin): """Admin interface for ListItem.""" + list_display = ( "user_list", "content_type", diff --git a/backend/apps/lists/views.py b/backend/apps/lists/views.py index 01a9a883..b6fd682f 100644 --- a/backend/apps/lists/views.py +++ b/backend/apps/lists/views.py @@ -28,4 +28,6 @@ class ListItemViewSet(viewsets.ModelViewSet): lookup_field = "id" def get_queryset(self): - return ListItem.objects.filter(user_list__is_public=True) | ListItem.objects.filter(user_list__user=self.request.user) + return ListItem.objects.filter(user_list__is_public=True) | ListItem.objects.filter( + user_list__user=self.request.user + ) diff --git a/backend/apps/media/commands/download_photos.py b/backend/apps/media/commands/download_photos.py index 09d9df86..a6853d2f 100644 --- a/backend/apps/media/commands/download_photos.py +++ b/backend/apps/media/commands/download_photos.py @@ -52,9 +52,7 @@ class Command(BaseCommand): park.name}: { photo.image.name}" ) - self.stdout.write( - f"Database record created with ID: {photo.id}" - ) + self.stdout.write(f"Database record created with ID: {photo.id}") else: self.stdout.write( f"Error downloading image. Status code: { @@ -112,9 +110,7 @@ class Command(BaseCommand): ) except Exception as e: - self.stdout.write( - f"Error downloading ride photo: {str(e)}" - ) + self.stdout.write(f"Error downloading ride photo: {str(e)}") except Ride.DoesNotExist: self.stdout.write( diff --git a/backend/apps/media/commands/fix_photo_paths.py b/backend/apps/media/commands/fix_photo_paths.py index 138d0042..891fbb50 100644 --- a/backend/apps/media/commands/fix_photo_paths.py +++ b/backend/apps/media/commands/fix_photo_paths.py @@ -49,9 +49,7 @@ class Command(BaseCommand): if files: # Get the first file and update the database # record - file_path = os.path.join( - content_type, identifier, files[0] - ) + file_path = os.path.join(content_type, identifier, files[0]) if os.path.exists(os.path.join("media", file_path)): photo.image.name = file_path photo.save() @@ -111,9 +109,7 @@ class Command(BaseCommand): if files: # Get the first file and update the database # record - file_path = os.path.join( - content_type, identifier, files[0] - ) + file_path = os.path.join(content_type, identifier, files[0]) if os.path.exists(os.path.join("media", file_path)): photo.image.name = file_path photo.save() diff --git a/backend/apps/media/commands/move_photos.py b/backend/apps/media/commands/move_photos.py index f99edf17..e840c75b 100644 --- a/backend/apps/media/commands/move_photos.py +++ b/backend/apps/media/commands/move_photos.py @@ -37,9 +37,7 @@ class Command(BaseCommand): identifier = photo.park.slug # Look for any files in that directory - old_dir = os.path.join( - settings.MEDIA_ROOT, content_type, identifier - ) + old_dir = os.path.join(settings.MEDIA_ROOT, content_type, identifier) if os.path.exists(old_dir): files = [ f @@ -83,9 +81,7 @@ class Command(BaseCommand): # Move the file if current_path != new_full_path: - shutil.copy2( - current_path, new_full_path - ) # Use copy2 to preserve metadata + shutil.copy2(current_path, new_full_path) # Use copy2 to preserve metadata processed_files.add(current_path) else: processed_files.add(current_path) @@ -116,9 +112,7 @@ class Command(BaseCommand): identifier = parts[1] # e.g., 'alton-towers' # Look for any files in that directory - old_dir = os.path.join( - settings.MEDIA_ROOT, content_type, identifier - ) + old_dir = os.path.join(settings.MEDIA_ROOT, content_type, identifier) if os.path.exists(old_dir): files = [ f @@ -162,9 +156,7 @@ class Command(BaseCommand): # Move the file if current_path != new_full_path: - shutil.copy2( - current_path, new_full_path - ) # Use copy2 to preserve metadata + shutil.copy2(current_path, new_full_path) # Use copy2 to preserve metadata processed_files.add(current_path) else: processed_files.add(current_path) @@ -192,8 +184,6 @@ class Command(BaseCommand): os.remove(file_path) self.stdout.write(f"Removed old file: {file_path}") except Exception as e: - self.stdout.write( - f"Error removing {file_path}: {str(e)}" - ) + self.stdout.write(f"Error removing {file_path}: {str(e)}") self.stdout.write("Finished moving photo files and cleaning up") diff --git a/backend/apps/media/models.py b/backend/apps/media/models.py index 8ebb387a..d4af84cb 100644 --- a/backend/apps/media/models.py +++ b/backend/apps/media/models.py @@ -23,10 +23,7 @@ class Photo(TrackedModel): # The actual image image = models.ForeignKey( - CloudflareImage, - on_delete=models.CASCADE, - related_name="photos_usage", - help_text="Cloudflare Image reference" + CloudflareImage, on_delete=models.CASCADE, related_name="photos_usage", help_text="Cloudflare Image reference" ) # Generic relation to target object (Park, Ride, etc.) @@ -40,10 +37,7 @@ class Photo(TrackedModel): # Metadata caption = models.CharField(max_length=255, blank=True, help_text="Photo caption") - is_public = models.BooleanField( - default=True, - help_text="Whether this photo is visible to others" - ) + is_public = models.BooleanField(default=True, help_text="Whether this photo is visible to others") # We might want credit/source info if not taken by user source = models.CharField(max_length=100, blank=True, help_text="Source/Credit if applicable") diff --git a/backend/apps/media/serializers.py b/backend/apps/media/serializers.py index e0bfdf9c..c6691d69 100644 --- a/backend/apps/media/serializers.py +++ b/backend/apps/media/serializers.py @@ -14,6 +14,7 @@ class CloudflareImageSerializer(serializers.ModelSerializer): model = CloudflareImage fields = ["id", "cloudflare_id", "variants"] + class PhotoSerializer(serializers.ModelSerializer): user = UserSerializer(read_only=True) image = CloudflareImageSerializer(read_only=True) @@ -56,10 +57,10 @@ class PhotoSerializer(serializers.ModelSerializer): # Return public variant or default if obj.image: # Check if get_url method exists or we construct strictly - return getattr(obj.image, 'get_url', lambda x: None)('public') + return getattr(obj.image, "get_url", lambda x: None)("public") return None def get_thumbnail(self, obj): if obj.image: - return getattr(obj.image, 'get_url', lambda x: None)('thumbnail') + return getattr(obj.image, "get_url", lambda x: None)("thumbnail") return None diff --git a/backend/apps/moderation/admin.py b/backend/apps/moderation/admin.py index 1ca3ba2a..d0577989 100644 --- a/backend/apps/moderation/admin.py +++ b/backend/apps/moderation/admin.py @@ -51,20 +51,16 @@ class ModerationAdminSite(AdminSite): extra_context = extra_context or {} # Get pending counts - extra_context["pending_edits"] = EditSubmission.objects.filter( - status="PENDING" - ).count() - extra_context["pending_photos"] = PhotoSubmission.objects.filter( - status="PENDING" - ).count() + extra_context["pending_edits"] = EditSubmission.objects.filter(status="PENDING").count() + extra_context["pending_photos"] = PhotoSubmission.objects.filter(status="PENDING").count() # Get recent activity - extra_context["recent_edits"] = EditSubmission.objects.select_related( - "user", "handled_by" - ).order_by("-created_at")[:5] - extra_context["recent_photos"] = PhotoSubmission.objects.select_related( - "user", "handled_by" - ).order_by("-created_at")[:5] + extra_context["recent_edits"] = EditSubmission.objects.select_related("user", "handled_by").order_by( + "-created_at" + )[:5] + extra_context["recent_photos"] = PhotoSubmission.objects.select_related("user", "handled_by").order_by( + "-created_at" + )[:5] return super().index(request, extra_context) @@ -639,9 +635,7 @@ class StateLogAdmin(admin.ModelAdmin): output = StringIO() writer = csv.writer(output) - writer.writerow( - ["ID", "Timestamp", "Model", "Object ID", "State", "Transition", "User"] - ) + writer.writerow(["ID", "Timestamp", "Model", "Object ID", "State", "Transition", "User"]) for log in queryset: writer.writerow( diff --git a/backend/apps/moderation/apps.py b/backend/apps/moderation/apps.py index b4317f14..ad4511b3 100644 --- a/backend/apps/moderation/apps.py +++ b/backend/apps/moderation/apps.py @@ -82,82 +82,31 @@ class ModerationConfig(AppConfig): ) # EditSubmission callbacks (transitions from CLAIMED state) - register_callback( - EditSubmission, 'status', 'CLAIMED', 'APPROVED', - SubmissionApprovedNotification() - ) - register_callback( - EditSubmission, 'status', 'CLAIMED', 'APPROVED', - ModerationCacheInvalidation() - ) - register_callback( - EditSubmission, 'status', 'CLAIMED', 'REJECTED', - SubmissionRejectedNotification() - ) - register_callback( - EditSubmission, 'status', 'CLAIMED', 'REJECTED', - ModerationCacheInvalidation() - ) - register_callback( - EditSubmission, 'status', 'CLAIMED', 'ESCALATED', - SubmissionEscalatedNotification() - ) - register_callback( - EditSubmission, 'status', 'CLAIMED', 'ESCALATED', - ModerationCacheInvalidation() - ) + register_callback(EditSubmission, "status", "CLAIMED", "APPROVED", SubmissionApprovedNotification()) + register_callback(EditSubmission, "status", "CLAIMED", "APPROVED", ModerationCacheInvalidation()) + register_callback(EditSubmission, "status", "CLAIMED", "REJECTED", SubmissionRejectedNotification()) + register_callback(EditSubmission, "status", "CLAIMED", "REJECTED", ModerationCacheInvalidation()) + register_callback(EditSubmission, "status", "CLAIMED", "ESCALATED", SubmissionEscalatedNotification()) + register_callback(EditSubmission, "status", "CLAIMED", "ESCALATED", ModerationCacheInvalidation()) # PhotoSubmission callbacks (transitions from CLAIMED state) - register_callback( - PhotoSubmission, 'status', 'CLAIMED', 'APPROVED', - SubmissionApprovedNotification() - ) - register_callback( - PhotoSubmission, 'status', 'CLAIMED', 'APPROVED', - ModerationCacheInvalidation() - ) - register_callback( - PhotoSubmission, 'status', 'CLAIMED', 'REJECTED', - SubmissionRejectedNotification() - ) - register_callback( - PhotoSubmission, 'status', 'CLAIMED', 'REJECTED', - ModerationCacheInvalidation() - ) - register_callback( - PhotoSubmission, 'status', 'CLAIMED', 'ESCALATED', - SubmissionEscalatedNotification() - ) + register_callback(PhotoSubmission, "status", "CLAIMED", "APPROVED", SubmissionApprovedNotification()) + register_callback(PhotoSubmission, "status", "CLAIMED", "APPROVED", ModerationCacheInvalidation()) + register_callback(PhotoSubmission, "status", "CLAIMED", "REJECTED", SubmissionRejectedNotification()) + register_callback(PhotoSubmission, "status", "CLAIMED", "REJECTED", ModerationCacheInvalidation()) + register_callback(PhotoSubmission, "status", "CLAIMED", "ESCALATED", SubmissionEscalatedNotification()) # ModerationReport callbacks - register_callback( - ModerationReport, 'status', '*', '*', - ModerationNotificationCallback() - ) - register_callback( - ModerationReport, 'status', '*', '*', - ModerationCacheInvalidation() - ) + register_callback(ModerationReport, "status", "*", "*", ModerationNotificationCallback()) + register_callback(ModerationReport, "status", "*", "*", ModerationCacheInvalidation()) # ModerationQueue callbacks - register_callback( - ModerationQueue, 'status', '*', '*', - ModerationNotificationCallback() - ) - register_callback( - ModerationQueue, 'status', '*', '*', - ModerationCacheInvalidation() - ) + register_callback(ModerationQueue, "status", "*", "*", ModerationNotificationCallback()) + register_callback(ModerationQueue, "status", "*", "*", ModerationCacheInvalidation()) # BulkOperation callbacks - register_callback( - BulkOperation, 'status', '*', '*', - ModerationNotificationCallback() - ) - register_callback( - BulkOperation, 'status', '*', '*', - ModerationCacheInvalidation() - ) + register_callback(BulkOperation, "status", "*", "*", ModerationNotificationCallback()) + register_callback(BulkOperation, "status", "*", "*", ModerationCacheInvalidation()) logger.debug("Registered moderation transition callbacks") diff --git a/backend/apps/moderation/choices.py b/backend/apps/moderation/choices.py index 289215b6..bc09a8cc 100644 --- a/backend/apps/moderation/choices.py +++ b/backend/apps/moderation/choices.py @@ -18,80 +18,80 @@ EDIT_SUBMISSION_STATUSES = [ label="Pending", description="Submission awaiting moderator review", metadata={ - 'color': 'yellow', - 'icon': 'clock', - 'css_class': 'bg-yellow-100 text-yellow-800 border-yellow-200', - 'sort_order': 1, - 'can_transition_to': ['CLAIMED'], # Must be claimed before any action - 'requires_moderator': True, - 'is_actionable': True + "color": "yellow", + "icon": "clock", + "css_class": "bg-yellow-100 text-yellow-800 border-yellow-200", + "sort_order": 1, + "can_transition_to": ["CLAIMED"], # Must be claimed before any action + "requires_moderator": True, + "is_actionable": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CLAIMED", label="Claimed", description="Submission has been claimed by a moderator for review", metadata={ - 'color': 'blue', - 'icon': 'user-check', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 2, + "color": "blue", + "icon": "user-check", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 2, # Note: PENDING not included to avoid cycle - unclaim uses direct status update - 'can_transition_to': ['APPROVED', 'REJECTED', 'ESCALATED'], - 'requires_moderator': True, - 'is_actionable': True, - 'is_locked': True # Indicates this submission is locked for editing by others + "can_transition_to": ["APPROVED", "REJECTED", "ESCALATED"], + "requires_moderator": True, + "is_actionable": True, + "is_locked": True, # Indicates this submission is locked for editing by others }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="APPROVED", label="Approved", description="Submission has been approved and changes applied", metadata={ - 'color': 'green', - 'icon': 'check-circle', - 'css_class': 'bg-green-100 text-green-800 border-green-200', - 'sort_order': 3, - 'can_transition_to': [], - 'requires_moderator': True, - 'is_actionable': False, - 'is_final': True + "color": "green", + "icon": "check-circle", + "css_class": "bg-green-100 text-green-800 border-green-200", + "sort_order": 3, + "can_transition_to": [], + "requires_moderator": True, + "is_actionable": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="REJECTED", label="Rejected", description="Submission has been rejected and will not be applied", metadata={ - 'color': 'red', - 'icon': 'x-circle', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 4, - 'can_transition_to': [], - 'requires_moderator': True, - 'is_actionable': False, - 'is_final': True + "color": "red", + "icon": "x-circle", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 4, + "can_transition_to": [], + "requires_moderator": True, + "is_actionable": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="ESCALATED", label="Escalated", description="Submission has been escalated for higher-level review", metadata={ - 'color': 'purple', - 'icon': 'arrow-up', - 'css_class': 'bg-purple-100 text-purple-800 border-purple-200', - 'sort_order': 5, - 'can_transition_to': ['APPROVED', 'REJECTED'], - 'requires_moderator': True, - 'is_actionable': True, - 'escalation_level': 'admin' + "color": "purple", + "icon": "arrow-up", + "css_class": "bg-purple-100 text-purple-800 border-purple-200", + "sort_order": 5, + "can_transition_to": ["APPROVED", "REJECTED"], + "requires_moderator": True, + "is_actionable": True, + "escalation_level": "admin", }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -101,28 +101,28 @@ SUBMISSION_TYPES = [ label="Edit Existing", description="Modification to existing content", metadata={ - 'color': 'blue', - 'icon': 'pencil', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 1, - 'requires_existing_object': True, - 'complexity_level': 'medium' + "color": "blue", + "icon": "pencil", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 1, + "requires_existing_object": True, + "complexity_level": "medium", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="CREATE", label="Create New", description="Creation of new content", metadata={ - 'color': 'green', - 'icon': 'plus-circle', - 'css_class': 'bg-green-100 text-green-800 border-green-200', - 'sort_order': 2, - 'requires_existing_object': False, - 'complexity_level': 'high' + "color": "green", + "icon": "plus-circle", + "css_class": "bg-green-100 text-green-800 border-green-200", + "sort_order": 2, + "requires_existing_object": False, + "complexity_level": "high", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -136,62 +136,62 @@ MODERATION_REPORT_STATUSES = [ label="Pending Review", description="Report awaiting initial moderator review", metadata={ - 'color': 'yellow', - 'icon': 'clock', - 'css_class': 'bg-yellow-100 text-yellow-800 border-yellow-200', - 'sort_order': 1, - 'can_transition_to': ['UNDER_REVIEW', 'DISMISSED'], - 'requires_assignment': False, - 'is_actionable': True + "color": "yellow", + "icon": "clock", + "css_class": "bg-yellow-100 text-yellow-800 border-yellow-200", + "sort_order": 1, + "can_transition_to": ["UNDER_REVIEW", "DISMISSED"], + "requires_assignment": False, + "is_actionable": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="UNDER_REVIEW", label="Under Review", description="Report is actively being investigated by a moderator", metadata={ - 'color': 'blue', - 'icon': 'eye', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 2, - 'can_transition_to': ['RESOLVED', 'DISMISSED'], - 'requires_assignment': True, - 'is_actionable': True + "color": "blue", + "icon": "eye", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 2, + "can_transition_to": ["RESOLVED", "DISMISSED"], + "requires_assignment": True, + "is_actionable": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="RESOLVED", label="Resolved", description="Report has been resolved with appropriate action taken", metadata={ - 'color': 'green', - 'icon': 'check-circle', - 'css_class': 'bg-green-100 text-green-800 border-green-200', - 'sort_order': 3, - 'can_transition_to': [], - 'requires_assignment': True, - 'is_actionable': False, - 'is_final': True + "color": "green", + "icon": "check-circle", + "css_class": "bg-green-100 text-green-800 border-green-200", + "sort_order": 3, + "can_transition_to": [], + "requires_assignment": True, + "is_actionable": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="DISMISSED", label="Dismissed", description="Report was reviewed but no action was necessary", metadata={ - 'color': 'gray', - 'icon': 'x-circle', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 4, - 'can_transition_to': [], - 'requires_assignment': True, - 'is_actionable': False, - 'is_final': True + "color": "gray", + "icon": "x-circle", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 4, + "can_transition_to": [], + "requires_assignment": True, + "is_actionable": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -201,61 +201,61 @@ PRIORITY_LEVELS = [ label="Low", description="Low priority - can be handled in regular workflow", metadata={ - 'color': 'green', - 'icon': 'arrow-down', - 'css_class': 'bg-green-100 text-green-800 border-green-200', - 'sort_order': 1, - 'sla_hours': 168, # 7 days - 'escalation_threshold': 240, # 10 days - 'urgency_level': 1 + "color": "green", + "icon": "arrow-down", + "css_class": "bg-green-100 text-green-800 border-green-200", + "sort_order": 1, + "sla_hours": 168, # 7 days + "escalation_threshold": 240, # 10 days + "urgency_level": 1, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="MEDIUM", label="Medium", description="Medium priority - standard response time expected", metadata={ - 'color': 'yellow', - 'icon': 'minus', - 'css_class': 'bg-yellow-100 text-yellow-800 border-yellow-200', - 'sort_order': 2, - 'sla_hours': 72, # 3 days - 'escalation_threshold': 120, # 5 days - 'urgency_level': 2 + "color": "yellow", + "icon": "minus", + "css_class": "bg-yellow-100 text-yellow-800 border-yellow-200", + "sort_order": 2, + "sla_hours": 72, # 3 days + "escalation_threshold": 120, # 5 days + "urgency_level": 2, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="HIGH", label="High", description="High priority - requires prompt attention", metadata={ - 'color': 'orange', - 'icon': 'arrow-up', - 'css_class': 'bg-orange-100 text-orange-800 border-orange-200', - 'sort_order': 3, - 'sla_hours': 24, # 1 day - 'escalation_threshold': 48, # 2 days - 'urgency_level': 3 + "color": "orange", + "icon": "arrow-up", + "css_class": "bg-orange-100 text-orange-800 border-orange-200", + "sort_order": 3, + "sla_hours": 24, # 1 day + "escalation_threshold": 48, # 2 days + "urgency_level": 3, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="URGENT", label="Urgent", description="Urgent priority - immediate attention required", metadata={ - 'color': 'red', - 'icon': 'exclamation', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 4, - 'sla_hours': 4, # 4 hours - 'escalation_threshold': 8, # 8 hours - 'urgency_level': 4, - 'requires_immediate_notification': True + "color": "red", + "icon": "exclamation", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 4, + "sla_hours": 4, # 4 hours + "escalation_threshold": 8, # 8 hours + "urgency_level": 4, + "requires_immediate_notification": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -265,145 +265,145 @@ REPORT_TYPES = [ label="Spam", description="Unwanted or repetitive content", metadata={ - 'color': 'yellow', - 'icon': 'ban', - 'css_class': 'bg-yellow-100 text-yellow-800 border-yellow-200', - 'sort_order': 1, - 'default_priority': 'MEDIUM', - 'auto_actions': ['content_review'], - 'severity_level': 2 + "color": "yellow", + "icon": "ban", + "css_class": "bg-yellow-100 text-yellow-800 border-yellow-200", + "sort_order": 1, + "default_priority": "MEDIUM", + "auto_actions": ["content_review"], + "severity_level": 2, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="HARASSMENT", label="Harassment", description="Targeted harassment or bullying behavior", metadata={ - 'color': 'red', - 'icon': 'shield-exclamation', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 2, - 'default_priority': 'HIGH', - 'auto_actions': ['user_review', 'content_review'], - 'severity_level': 4, - 'requires_user_action': True + "color": "red", + "icon": "shield-exclamation", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 2, + "default_priority": "HIGH", + "auto_actions": ["user_review", "content_review"], + "severity_level": 4, + "requires_user_action": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="INAPPROPRIATE_CONTENT", label="Inappropriate Content", description="Content that violates community guidelines", metadata={ - 'color': 'orange', - 'icon': 'exclamation-triangle', - 'css_class': 'bg-orange-100 text-orange-800 border-orange-200', - 'sort_order': 3, - 'default_priority': 'HIGH', - 'auto_actions': ['content_review'], - 'severity_level': 3 + "color": "orange", + "icon": "exclamation-triangle", + "css_class": "bg-orange-100 text-orange-800 border-orange-200", + "sort_order": 3, + "default_priority": "HIGH", + "auto_actions": ["content_review"], + "severity_level": 3, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="MISINFORMATION", label="Misinformation", description="False or misleading information", metadata={ - 'color': 'purple', - 'icon': 'information-circle', - 'css_class': 'bg-purple-100 text-purple-800 border-purple-200', - 'sort_order': 4, - 'default_priority': 'HIGH', - 'auto_actions': ['content_review', 'fact_check'], - 'severity_level': 3, - 'requires_expert_review': True + "color": "purple", + "icon": "information-circle", + "css_class": "bg-purple-100 text-purple-800 border-purple-200", + "sort_order": 4, + "default_priority": "HIGH", + "auto_actions": ["content_review", "fact_check"], + "severity_level": 3, + "requires_expert_review": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="COPYRIGHT", label="Copyright Violation", description="Unauthorized use of copyrighted material", metadata={ - 'color': 'indigo', - 'icon': 'document-duplicate', - 'css_class': 'bg-indigo-100 text-indigo-800 border-indigo-200', - 'sort_order': 5, - 'default_priority': 'HIGH', - 'auto_actions': ['content_review', 'legal_review'], - 'severity_level': 4, - 'requires_legal_review': True + "color": "indigo", + "icon": "document-duplicate", + "css_class": "bg-indigo-100 text-indigo-800 border-indigo-200", + "sort_order": 5, + "default_priority": "HIGH", + "auto_actions": ["content_review", "legal_review"], + "severity_level": 4, + "requires_legal_review": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="PRIVACY", label="Privacy Violation", description="Unauthorized sharing of private information", metadata={ - 'color': 'pink', - 'icon': 'lock-closed', - 'css_class': 'bg-pink-100 text-pink-800 border-pink-200', - 'sort_order': 6, - 'default_priority': 'URGENT', - 'auto_actions': ['content_removal', 'user_review'], - 'severity_level': 5, - 'requires_immediate_action': True + "color": "pink", + "icon": "lock-closed", + "css_class": "bg-pink-100 text-pink-800 border-pink-200", + "sort_order": 6, + "default_priority": "URGENT", + "auto_actions": ["content_removal", "user_review"], + "severity_level": 5, + "requires_immediate_action": True, }, - category=ChoiceCategory.SECURITY + category=ChoiceCategory.SECURITY, ), RichChoice( value="HATE_SPEECH", label="Hate Speech", description="Content promoting hatred or discrimination", metadata={ - 'color': 'red', - 'icon': 'fire', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 7, - 'default_priority': 'URGENT', - 'auto_actions': ['content_removal', 'user_suspension'], - 'severity_level': 5, - 'requires_immediate_action': True, - 'zero_tolerance': True + "color": "red", + "icon": "fire", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 7, + "default_priority": "URGENT", + "auto_actions": ["content_removal", "user_suspension"], + "severity_level": 5, + "requires_immediate_action": True, + "zero_tolerance": True, }, - category=ChoiceCategory.SECURITY + category=ChoiceCategory.SECURITY, ), RichChoice( value="VIOLENCE", label="Violence or Threats", description="Content containing violence or threatening behavior", metadata={ - 'color': 'red', - 'icon': 'exclamation', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 8, - 'default_priority': 'URGENT', - 'auto_actions': ['content_removal', 'user_ban', 'law_enforcement_notification'], - 'severity_level': 5, - 'requires_immediate_action': True, - 'zero_tolerance': True, - 'requires_law_enforcement': True + "color": "red", + "icon": "exclamation", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 8, + "default_priority": "URGENT", + "auto_actions": ["content_removal", "user_ban", "law_enforcement_notification"], + "severity_level": 5, + "requires_immediate_action": True, + "zero_tolerance": True, + "requires_law_enforcement": True, }, - category=ChoiceCategory.SECURITY + category=ChoiceCategory.SECURITY, ), RichChoice( value="OTHER", label="Other", description="Other issues not covered by specific categories", metadata={ - 'color': 'gray', - 'icon': 'dots-horizontal', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 9, - 'default_priority': 'MEDIUM', - 'auto_actions': ['manual_review'], - 'severity_level': 1, - 'requires_manual_categorization': True + "color": "gray", + "icon": "dots-horizontal", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 9, + "default_priority": "MEDIUM", + "auto_actions": ["manual_review"], + "severity_level": 1, + "requires_manual_categorization": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -417,62 +417,62 @@ MODERATION_QUEUE_STATUSES = [ label="Pending", description="Queue item awaiting assignment or action", metadata={ - 'color': 'yellow', - 'icon': 'clock', - 'css_class': 'bg-yellow-100 text-yellow-800 border-yellow-200', - 'sort_order': 1, - 'can_transition_to': ['IN_PROGRESS', 'CANCELLED'], - 'requires_assignment': False, - 'is_actionable': True + "color": "yellow", + "icon": "clock", + "css_class": "bg-yellow-100 text-yellow-800 border-yellow-200", + "sort_order": 1, + "can_transition_to": ["IN_PROGRESS", "CANCELLED"], + "requires_assignment": False, + "is_actionable": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="IN_PROGRESS", label="In Progress", description="Queue item is actively being worked on", metadata={ - 'color': 'blue', - 'icon': 'play', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 2, - 'can_transition_to': ['COMPLETED', 'CANCELLED'], - 'requires_assignment': True, - 'is_actionable': True + "color": "blue", + "icon": "play", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 2, + "can_transition_to": ["COMPLETED", "CANCELLED"], + "requires_assignment": True, + "is_actionable": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="COMPLETED", label="Completed", description="Queue item has been successfully completed", metadata={ - 'color': 'green', - 'icon': 'check-circle', - 'css_class': 'bg-green-100 text-green-800 border-green-200', - 'sort_order': 3, - 'can_transition_to': [], - 'requires_assignment': True, - 'is_actionable': False, - 'is_final': True + "color": "green", + "icon": "check-circle", + "css_class": "bg-green-100 text-green-800 border-green-200", + "sort_order": 3, + "can_transition_to": [], + "requires_assignment": True, + "is_actionable": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CANCELLED", label="Cancelled", description="Queue item was cancelled and will not be completed", metadata={ - 'color': 'gray', - 'icon': 'x-circle', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 4, - 'can_transition_to': [], - 'requires_assignment': False, - 'is_actionable': False, - 'is_final': True + "color": "gray", + "icon": "x-circle", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 4, + "can_transition_to": [], + "requires_assignment": False, + "is_actionable": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -482,90 +482,90 @@ QUEUE_ITEM_TYPES = [ label="Content Review", description="Review of user-submitted content for policy compliance", metadata={ - 'color': 'blue', - 'icon': 'document-text', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 1, - 'estimated_time_minutes': 15, - 'required_permissions': ['content_moderation'], - 'complexity_level': 'medium' + "color": "blue", + "icon": "document-text", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 1, + "estimated_time_minutes": 15, + "required_permissions": ["content_moderation"], + "complexity_level": "medium", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="USER_REVIEW", label="User Review", description="Review of user account or behavior", metadata={ - 'color': 'purple', - 'icon': 'user', - 'css_class': 'bg-purple-100 text-purple-800 border-purple-200', - 'sort_order': 2, - 'estimated_time_minutes': 30, - 'required_permissions': ['user_moderation'], - 'complexity_level': 'high' + "color": "purple", + "icon": "user", + "css_class": "bg-purple-100 text-purple-800 border-purple-200", + "sort_order": 2, + "estimated_time_minutes": 30, + "required_permissions": ["user_moderation"], + "complexity_level": "high", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="BULK_ACTION", label="Bulk Action", description="Large-scale administrative operation", metadata={ - 'color': 'indigo', - 'icon': 'collection', - 'css_class': 'bg-indigo-100 text-indigo-800 border-indigo-200', - 'sort_order': 3, - 'estimated_time_minutes': 60, - 'required_permissions': ['bulk_operations'], - 'complexity_level': 'high' + "color": "indigo", + "icon": "collection", + "css_class": "bg-indigo-100 text-indigo-800 border-indigo-200", + "sort_order": 3, + "estimated_time_minutes": 60, + "required_permissions": ["bulk_operations"], + "complexity_level": "high", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="POLICY_VIOLATION", label="Policy Violation", description="Investigation of potential policy violations", metadata={ - 'color': 'red', - 'icon': 'shield-exclamation', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 4, - 'estimated_time_minutes': 45, - 'required_permissions': ['policy_enforcement'], - 'complexity_level': 'high' + "color": "red", + "icon": "shield-exclamation", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 4, + "estimated_time_minutes": 45, + "required_permissions": ["policy_enforcement"], + "complexity_level": "high", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="APPEAL", label="Appeal", description="Review of user appeal against moderation action", metadata={ - 'color': 'orange', - 'icon': 'scale', - 'css_class': 'bg-orange-100 text-orange-800 border-orange-200', - 'sort_order': 5, - 'estimated_time_minutes': 30, - 'required_permissions': ['appeal_review'], - 'complexity_level': 'high' + "color": "orange", + "icon": "scale", + "css_class": "bg-orange-100 text-orange-800 border-orange-200", + "sort_order": 5, + "estimated_time_minutes": 30, + "required_permissions": ["appeal_review"], + "complexity_level": "high", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="OTHER", label="Other", description="Other moderation tasks not covered by specific types", metadata={ - 'color': 'gray', - 'icon': 'dots-horizontal', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 6, - 'estimated_time_minutes': 20, - 'required_permissions': ['general_moderation'], - 'complexity_level': 'medium' + "color": "gray", + "icon": "dots-horizontal", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 6, + "estimated_time_minutes": 20, + "required_permissions": ["general_moderation"], + "complexity_level": "medium", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -579,133 +579,133 @@ MODERATION_ACTION_TYPES = [ label="Warning", description="Formal warning issued to user", metadata={ - 'color': 'yellow', - 'icon': 'exclamation-triangle', - 'css_class': 'bg-yellow-100 text-yellow-800 border-yellow-200', - 'sort_order': 1, - 'severity_level': 1, - 'is_temporary': False, - 'affects_privileges': False, - 'escalation_path': ['USER_SUSPENSION'] + "color": "yellow", + "icon": "exclamation-triangle", + "css_class": "bg-yellow-100 text-yellow-800 border-yellow-200", + "sort_order": 1, + "severity_level": 1, + "is_temporary": False, + "affects_privileges": False, + "escalation_path": ["USER_SUSPENSION"], }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="USER_SUSPENSION", label="User Suspension", description="Temporary suspension of user account", metadata={ - 'color': 'orange', - 'icon': 'pause', - 'css_class': 'bg-orange-100 text-orange-800 border-orange-200', - 'sort_order': 2, - 'severity_level': 3, - 'is_temporary': True, - 'affects_privileges': True, - 'requires_duration': True, - 'escalation_path': ['USER_BAN'] + "color": "orange", + "icon": "pause", + "css_class": "bg-orange-100 text-orange-800 border-orange-200", + "sort_order": 2, + "severity_level": 3, + "is_temporary": True, + "affects_privileges": True, + "requires_duration": True, + "escalation_path": ["USER_BAN"], }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="USER_BAN", label="User Ban", description="Permanent ban of user account", metadata={ - 'color': 'red', - 'icon': 'ban', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 3, - 'severity_level': 5, - 'is_temporary': False, - 'affects_privileges': True, - 'is_permanent': True, - 'requires_admin_approval': True + "color": "red", + "icon": "ban", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 3, + "severity_level": 5, + "is_temporary": False, + "affects_privileges": True, + "is_permanent": True, + "requires_admin_approval": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="CONTENT_REMOVAL", label="Content Removal", description="Removal of specific content", metadata={ - 'color': 'red', - 'icon': 'trash', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 4, - 'severity_level': 2, - 'is_temporary': False, - 'affects_privileges': False, - 'is_content_action': True + "color": "red", + "icon": "trash", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 4, + "severity_level": 2, + "is_temporary": False, + "affects_privileges": False, + "is_content_action": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="CONTENT_EDIT", label="Content Edit", description="Modification of content to comply with policies", metadata={ - 'color': 'blue', - 'icon': 'pencil', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 5, - 'severity_level': 1, - 'is_temporary': False, - 'affects_privileges': False, - 'is_content_action': True, - 'preserves_content': True + "color": "blue", + "icon": "pencil", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 5, + "severity_level": 1, + "is_temporary": False, + "affects_privileges": False, + "is_content_action": True, + "preserves_content": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="CONTENT_RESTRICTION", label="Content Restriction", description="Restriction of content visibility or access", metadata={ - 'color': 'purple', - 'icon': 'eye-off', - 'css_class': 'bg-purple-100 text-purple-800 border-purple-200', - 'sort_order': 6, - 'severity_level': 2, - 'is_temporary': True, - 'affects_privileges': False, - 'is_content_action': True, - 'requires_duration': True + "color": "purple", + "icon": "eye-off", + "css_class": "bg-purple-100 text-purple-800 border-purple-200", + "sort_order": 6, + "severity_level": 2, + "is_temporary": True, + "affects_privileges": False, + "is_content_action": True, + "requires_duration": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="ACCOUNT_RESTRICTION", label="Account Restriction", description="Restriction of specific account privileges", metadata={ - 'color': 'indigo', - 'icon': 'lock-closed', - 'css_class': 'bg-indigo-100 text-indigo-800 border-indigo-200', - 'sort_order': 7, - 'severity_level': 3, - 'is_temporary': True, - 'affects_privileges': True, - 'requires_duration': True, - 'escalation_path': ['USER_SUSPENSION'] + "color": "indigo", + "icon": "lock-closed", + "css_class": "bg-indigo-100 text-indigo-800 border-indigo-200", + "sort_order": 7, + "severity_level": 3, + "is_temporary": True, + "affects_privileges": True, + "requires_duration": True, + "escalation_path": ["USER_SUSPENSION"], }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="OTHER", label="Other", description="Other moderation actions not covered by specific types", metadata={ - 'color': 'gray', - 'icon': 'dots-horizontal', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 8, - 'severity_level': 1, - 'is_temporary': False, - 'affects_privileges': False, - 'requires_manual_review': True + "color": "gray", + "icon": "dots-horizontal", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 8, + "severity_level": 1, + "is_temporary": False, + "affects_privileges": False, + "requires_manual_review": True, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -719,80 +719,80 @@ BULK_OPERATION_STATUSES = [ label="Pending", description="Operation is queued and waiting to start", metadata={ - 'color': 'yellow', - 'icon': 'clock', - 'css_class': 'bg-yellow-100 text-yellow-800 border-yellow-200', - 'sort_order': 1, - 'can_transition_to': ['RUNNING', 'CANCELLED'], - 'is_actionable': True, - 'can_cancel': True + "color": "yellow", + "icon": "clock", + "css_class": "bg-yellow-100 text-yellow-800 border-yellow-200", + "sort_order": 1, + "can_transition_to": ["RUNNING", "CANCELLED"], + "is_actionable": True, + "can_cancel": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="RUNNING", label="Running", description="Operation is currently executing", metadata={ - 'color': 'blue', - 'icon': 'play', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 2, - 'can_transition_to': ['COMPLETED', 'FAILED', 'CANCELLED'], - 'is_actionable': True, - 'can_cancel': True, - 'shows_progress': True + "color": "blue", + "icon": "play", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 2, + "can_transition_to": ["COMPLETED", "FAILED", "CANCELLED"], + "is_actionable": True, + "can_cancel": True, + "shows_progress": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="COMPLETED", label="Completed", description="Operation completed successfully", metadata={ - 'color': 'green', - 'icon': 'check-circle', - 'css_class': 'bg-green-100 text-green-800 border-green-200', - 'sort_order': 3, - 'can_transition_to': [], - 'is_actionable': False, - 'can_cancel': False, - 'is_final': True + "color": "green", + "icon": "check-circle", + "css_class": "bg-green-100 text-green-800 border-green-200", + "sort_order": 3, + "can_transition_to": [], + "is_actionable": False, + "can_cancel": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="FAILED", label="Failed", description="Operation failed with errors", metadata={ - 'color': 'red', - 'icon': 'x-circle', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 4, - 'can_transition_to': [], - 'is_actionable': False, - 'can_cancel': False, - 'is_final': True, - 'requires_investigation': True + "color": "red", + "icon": "x-circle", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 4, + "can_transition_to": [], + "is_actionable": False, + "can_cancel": False, + "is_final": True, + "requires_investigation": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CANCELLED", label="Cancelled", description="Operation was cancelled before completion", metadata={ - 'color': 'gray', - 'icon': 'stop', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 5, - 'can_transition_to': [], - 'is_actionable': False, - 'can_cancel': False, - 'is_final': True + "color": "gray", + "icon": "stop", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 5, + "can_transition_to": [], + "is_actionable": False, + "can_cancel": False, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -802,128 +802,128 @@ BULK_OPERATION_TYPES = [ label="Update Parks", description="Bulk update operations on park data", metadata={ - 'color': 'green', - 'icon': 'map', - 'css_class': 'bg-green-100 text-green-800 border-green-200', - 'sort_order': 1, - 'estimated_duration_minutes': 30, - 'required_permissions': ['bulk_park_operations'], - 'affects_data': ['parks'], - 'risk_level': 'medium' + "color": "green", + "icon": "map", + "css_class": "bg-green-100 text-green-800 border-green-200", + "sort_order": 1, + "estimated_duration_minutes": 30, + "required_permissions": ["bulk_park_operations"], + "affects_data": ["parks"], + "risk_level": "medium", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="UPDATE_RIDES", label="Update Rides", description="Bulk update operations on ride data", metadata={ - 'color': 'blue', - 'icon': 'cog', - 'css_class': 'bg-blue-100 text-blue-800 border-blue-200', - 'sort_order': 2, - 'estimated_duration_minutes': 45, - 'required_permissions': ['bulk_ride_operations'], - 'affects_data': ['rides'], - 'risk_level': 'medium' + "color": "blue", + "icon": "cog", + "css_class": "bg-blue-100 text-blue-800 border-blue-200", + "sort_order": 2, + "estimated_duration_minutes": 45, + "required_permissions": ["bulk_ride_operations"], + "affects_data": ["rides"], + "risk_level": "medium", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="IMPORT_DATA", label="Import Data", description="Import data from external sources", metadata={ - 'color': 'purple', - 'icon': 'download', - 'css_class': 'bg-purple-100 text-purple-800 border-purple-200', - 'sort_order': 3, - 'estimated_duration_minutes': 60, - 'required_permissions': ['data_import'], - 'affects_data': ['parks', 'rides', 'users'], - 'risk_level': 'high' + "color": "purple", + "icon": "download", + "css_class": "bg-purple-100 text-purple-800 border-purple-200", + "sort_order": 3, + "estimated_duration_minutes": 60, + "required_permissions": ["data_import"], + "affects_data": ["parks", "rides", "users"], + "risk_level": "high", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="EXPORT_DATA", label="Export Data", description="Export data for backup or analysis", metadata={ - 'color': 'indigo', - 'icon': 'upload', - 'css_class': 'bg-indigo-100 text-indigo-800 border-indigo-200', - 'sort_order': 4, - 'estimated_duration_minutes': 20, - 'required_permissions': ['data_export'], - 'affects_data': [], - 'risk_level': 'low' + "color": "indigo", + "icon": "upload", + "css_class": "bg-indigo-100 text-indigo-800 border-indigo-200", + "sort_order": 4, + "estimated_duration_minutes": 20, + "required_permissions": ["data_export"], + "affects_data": [], + "risk_level": "low", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="MODERATE_CONTENT", label="Moderate Content", description="Bulk moderation actions on content", metadata={ - 'color': 'orange', - 'icon': 'shield-check', - 'css_class': 'bg-orange-100 text-orange-800 border-orange-200', - 'sort_order': 5, - 'estimated_duration_minutes': 40, - 'required_permissions': ['bulk_moderation'], - 'affects_data': ['content', 'users'], - 'risk_level': 'high' + "color": "orange", + "icon": "shield-check", + "css_class": "bg-orange-100 text-orange-800 border-orange-200", + "sort_order": 5, + "estimated_duration_minutes": 40, + "required_permissions": ["bulk_moderation"], + "affects_data": ["content", "users"], + "risk_level": "high", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="USER_ACTIONS", label="User Actions", description="Bulk actions on user accounts", metadata={ - 'color': 'red', - 'icon': 'users', - 'css_class': 'bg-red-100 text-red-800 border-red-200', - 'sort_order': 6, - 'estimated_duration_minutes': 50, - 'required_permissions': ['bulk_user_operations'], - 'affects_data': ['users'], - 'risk_level': 'high' + "color": "red", + "icon": "users", + "css_class": "bg-red-100 text-red-800 border-red-200", + "sort_order": 6, + "estimated_duration_minutes": 50, + "required_permissions": ["bulk_user_operations"], + "affects_data": ["users"], + "risk_level": "high", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="CLEANUP", label="Cleanup", description="System cleanup and maintenance operations", metadata={ - 'color': 'gray', - 'icon': 'trash', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 7, - 'estimated_duration_minutes': 25, - 'required_permissions': ['system_maintenance'], - 'affects_data': ['system'], - 'risk_level': 'low' + "color": "gray", + "icon": "trash", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 7, + "estimated_duration_minutes": 25, + "required_permissions": ["system_maintenance"], + "affects_data": ["system"], + "risk_level": "low", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="OTHER", label="Other", description="Other bulk operations not covered by specific types", metadata={ - 'color': 'gray', - 'icon': 'dots-horizontal', - 'css_class': 'bg-gray-100 text-gray-800 border-gray-200', - 'sort_order': 8, - 'estimated_duration_minutes': 30, - 'required_permissions': ['general_operations'], - 'affects_data': [], - 'risk_level': 'medium' + "color": "gray", + "icon": "dots-horizontal", + "css_class": "bg-gray-100 text-gray-800 border-gray-200", + "sort_order": 8, + "estimated_duration_minutes": 30, + "required_permissions": ["general_operations"], + "affects_data": [], + "risk_level": "medium", }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), ] @@ -941,12 +941,20 @@ PHOTO_SUBMISSION_STATUSES = EDIT_SUBMISSION_STATUSES # Register all choice groups with the global registry register_choices("edit_submission_statuses", EDIT_SUBMISSION_STATUSES, "moderation", "Edit submission status options") register_choices("submission_types", SUBMISSION_TYPES, "moderation", "Submission type classifications") -register_choices("moderation_report_statuses", MODERATION_REPORT_STATUSES, "moderation", "Moderation report status options") +register_choices( + "moderation_report_statuses", MODERATION_REPORT_STATUSES, "moderation", "Moderation report status options" +) register_choices("priority_levels", PRIORITY_LEVELS, "moderation", "Priority level classifications") register_choices("report_types", REPORT_TYPES, "moderation", "Report type classifications") -register_choices("moderation_queue_statuses", MODERATION_QUEUE_STATUSES, "moderation", "Moderation queue status options") +register_choices( + "moderation_queue_statuses", MODERATION_QUEUE_STATUSES, "moderation", "Moderation queue status options" +) register_choices("queue_item_types", QUEUE_ITEM_TYPES, "moderation", "Queue item type classifications") -register_choices("moderation_action_types", MODERATION_ACTION_TYPES, "moderation", "Moderation action type classifications") +register_choices( + "moderation_action_types", MODERATION_ACTION_TYPES, "moderation", "Moderation action type classifications" +) register_choices("bulk_operation_statuses", BULK_OPERATION_STATUSES, "moderation", "Bulk operation status options") register_choices("bulk_operation_types", BULK_OPERATION_TYPES, "moderation", "Bulk operation type classifications") -register_choices("photo_submission_statuses", PHOTO_SUBMISSION_STATUSES, "moderation", "Photo submission status options") +register_choices( + "photo_submission_statuses", PHOTO_SUBMISSION_STATUSES, "moderation", "Photo submission status options" +) diff --git a/backend/apps/moderation/context_processors.py b/backend/apps/moderation/context_processors.py index 5d5d99a9..327caf42 100644 --- a/backend/apps/moderation/context_processors.py +++ b/backend/apps/moderation/context_processors.py @@ -11,14 +11,9 @@ def moderation_access(request): context["user_role"] = request.user.role # Check both role-based and Django's built-in superuser status context["has_moderation_access"] = ( - request.user.role in ["MODERATOR", "ADMIN", "SUPERUSER"] - or request.user.is_superuser - ) - context["has_admin_access"] = ( - request.user.role in ["ADMIN", "SUPERUSER"] or request.user.is_superuser - ) - context["has_superuser_access"] = ( - request.user.role == "SUPERUSER" or request.user.is_superuser + request.user.role in ["MODERATOR", "ADMIN", "SUPERUSER"] or request.user.is_superuser ) + context["has_admin_access"] = request.user.role in ["ADMIN", "SUPERUSER"] or request.user.is_superuser + context["has_superuser_access"] = request.user.role == "SUPERUSER" or request.user.is_superuser return context diff --git a/backend/apps/moderation/filters.py b/backend/apps/moderation/filters.py index cb297038..d60e282d 100644 --- a/backend/apps/moderation/filters.py +++ b/backend/apps/moderation/filters.py @@ -29,20 +29,22 @@ class ModerationReportFilter(django_filters.FilterSet): # Status filters status = django_filters.ChoiceFilter( - choices=lambda: [(choice.value, choice.label) for choice in get_choices("moderation_report_statuses", "moderation")], - help_text="Filter by report status" + choices=lambda: [ + (choice.value, choice.label) for choice in get_choices("moderation_report_statuses", "moderation") + ], + help_text="Filter by report status", ) # Priority filters priority = django_filters.ChoiceFilter( choices=lambda: [(choice.value, choice.label) for choice in get_choices("priority_levels", "moderation")], - help_text="Filter by report priority" + help_text="Filter by report priority", ) # Report type filters report_type = django_filters.ChoiceFilter( choices=lambda: [(choice.value, choice.label) for choice in get_choices("report_types", "moderation")], - help_text="Filter by report type" + help_text="Filter by report type", ) # User filters @@ -87,13 +89,9 @@ class ModerationReportFilter(django_filters.FilterSet): ) # Special filters - unassigned = django_filters.BooleanFilter( - method="filter_unassigned", help_text="Filter for unassigned reports" - ) + unassigned = django_filters.BooleanFilter(method="filter_unassigned", help_text="Filter for unassigned reports") - overdue = django_filters.BooleanFilter( - method="filter_overdue", help_text="Filter for overdue reports based on SLA" - ) + overdue = django_filters.BooleanFilter(method="filter_overdue", help_text="Filter for overdue reports based on SLA") has_resolution = django_filters.BooleanFilter( method="filter_has_resolution", @@ -143,12 +141,8 @@ class ModerationReportFilter(django_filters.FilterSet): def filter_has_resolution(self, queryset, name, value): """Filter reports with/without resolution.""" if value: - return queryset.exclude( - resolution_action__isnull=True, resolution_action="" - ) - return queryset.filter( - Q(resolution_action__isnull=True) | Q(resolution_action="") - ) + return queryset.exclude(resolution_action__isnull=True, resolution_action="") + return queryset.filter(Q(resolution_action__isnull=True) | Q(resolution_action="")) class ModerationQueueFilter(django_filters.FilterSet): @@ -156,8 +150,10 @@ class ModerationQueueFilter(django_filters.FilterSet): # Status filters status = django_filters.ChoiceFilter( - choices=lambda: [(choice.value, choice.label) for choice in get_choices("moderation_queue_statuses", "moderation")], - help_text="Filter by queue item status" + choices=lambda: [ + (choice.value, choice.label) for choice in get_choices("moderation_queue_statuses", "moderation") + ], + help_text="Filter by queue item status", ) # Priority filters @@ -169,7 +165,7 @@ class ModerationQueueFilter(django_filters.FilterSet): # Item type filters item_type = django_filters.ChoiceFilter( choices=lambda: [(choice.value, choice.label) for choice in get_choices("queue_item_types", "moderation")], - help_text="Filter by queue item type" + help_text="Filter by queue item type", ) # Assignment filters @@ -178,9 +174,7 @@ class ModerationQueueFilter(django_filters.FilterSet): help_text="Filter by assigned moderator", ) - unassigned = django_filters.BooleanFilter( - method="filter_unassigned", help_text="Filter for unassigned queue items" - ) + unassigned = django_filters.BooleanFilter(method="filter_unassigned", help_text="Filter for unassigned queue items") # Date filters created_after = django_filters.DateTimeFilter( @@ -208,9 +202,7 @@ class ModerationQueueFilter(django_filters.FilterSet): ) # Content type filters - content_type = django_filters.CharFilter( - field_name="content_type__model", help_text="Filter by content type" - ) + content_type = django_filters.CharFilter(field_name="content_type__model", help_text="Filter by content type") # Related report filters has_related_report = django_filters.BooleanFilter( @@ -248,8 +240,10 @@ class ModerationActionFilter(django_filters.FilterSet): # Action type filters action_type = django_filters.ChoiceFilter( - choices=lambda: [(choice.value, choice.label) for choice in get_choices("moderation_action_types", "moderation")], - help_text="Filter by action type" + choices=lambda: [ + (choice.value, choice.label) for choice in get_choices("moderation_action_types", "moderation") + ], + help_text="Filter by action type", ) # User filters @@ -258,9 +252,7 @@ class ModerationActionFilter(django_filters.FilterSet): help_text="Filter by moderator who took the action", ) - target_user = django_filters.ModelChoiceFilter( - queryset=User.objects.all(), help_text="Filter by target user" - ) + target_user = django_filters.ModelChoiceFilter(queryset=User.objects.all(), help_text="Filter by target user") # Status filters is_active = django_filters.BooleanFilter(help_text="Filter by active status") @@ -291,9 +283,7 @@ class ModerationActionFilter(django_filters.FilterSet): ) # Special filters - expired = django_filters.BooleanFilter( - method="filter_expired", help_text="Filter for expired actions" - ) + expired = django_filters.BooleanFilter(method="filter_expired", help_text="Filter for expired actions") expiring_soon = django_filters.BooleanFilter( method="filter_expiring_soon", @@ -345,8 +335,10 @@ class BulkOperationFilter(django_filters.FilterSet): # Status filters status = django_filters.ChoiceFilter( - choices=lambda: [(choice.value, choice.label) for choice in get_choices("bulk_operation_statuses", "moderation")], - help_text="Filter by operation status" + choices=lambda: [ + (choice.value, choice.label) for choice in get_choices("bulk_operation_statuses", "moderation") + ], + help_text="Filter by operation status", ) # Operation type filters @@ -358,7 +350,7 @@ class BulkOperationFilter(django_filters.FilterSet): # Priority filters priority = django_filters.ChoiceFilter( choices=lambda: [(choice.value, choice.label) for choice in get_choices("priority_levels", "moderation")], - help_text="Filter by operation priority" + help_text="Filter by operation priority", ) # User filters @@ -405,9 +397,7 @@ class BulkOperationFilter(django_filters.FilterSet): ) # Special filters - can_cancel = django_filters.BooleanFilter( - help_text="Filter by cancellation capability" - ) + can_cancel = django_filters.BooleanFilter(help_text="Filter by cancellation capability") has_failures = django_filters.BooleanFilter( method="filter_has_failures", diff --git a/backend/apps/moderation/management/commands/analyze_transitions.py b/backend/apps/moderation/management/commands/analyze_transitions.py index 53115741..78f4e3ff 100644 --- a/backend/apps/moderation/management/commands/analyze_transitions.py +++ b/backend/apps/moderation/management/commands/analyze_transitions.py @@ -16,36 +16,25 @@ from django_fsm_log.models import StateLog class Command(BaseCommand): - help = 'Analyze state transition patterns and generate statistics' + help = "Analyze state transition patterns and generate statistics" def add_arguments(self, parser): + parser.add_argument("--days", type=int, default=30, help="Number of days to analyze (default: 30)") + parser.add_argument("--model", type=str, help="Specific model to analyze (e.g., editsubmission)") parser.add_argument( - '--days', - type=int, - default=30, - help='Number of days to analyze (default: 30)' - ) - parser.add_argument( - '--model', + "--output", type=str, - help='Specific model to analyze (e.g., editsubmission)' - ) - parser.add_argument( - '--output', - type=str, - choices=['console', 'json', 'csv'], - default='console', - help='Output format (default: console)' + choices=["console", "json", "csv"], + default="console", + help="Output format (default: console)", ) def handle(self, *args, **options): - days = options['days'] - model_filter = options['model'] - output_format = options['output'] + days = options["days"] + model_filter = options["model"] + output_format = options["output"] - self.stdout.write( - self.style.SUCCESS(f'\n=== State Transition Analysis (Last {days} days) ===\n') - ) + self.stdout.write(self.style.SUCCESS(f"\n=== State Transition Analysis (Last {days} days) ===\n")) # Filter by date range start_date = timezone.now() - timedelta(days=days) @@ -56,173 +45,134 @@ class Command(BaseCommand): try: content_type = ContentType.objects.get(model=model_filter.lower()) queryset = queryset.filter(content_type=content_type) - self.stdout.write(f'Filtering for model: {model_filter}\n') + self.stdout.write(f"Filtering for model: {model_filter}\n") except ContentType.DoesNotExist: - self.stdout.write( - self.style.ERROR(f'Model "{model_filter}" not found') - ) + self.stdout.write(self.style.ERROR(f'Model "{model_filter}" not found')) return # Total transitions total_transitions = queryset.count() - self.stdout.write( - self.style.SUCCESS(f'Total Transitions: {total_transitions}\n') - ) + self.stdout.write(self.style.SUCCESS(f"Total Transitions: {total_transitions}\n")) if total_transitions == 0: - self.stdout.write( - self.style.WARNING('No transitions found in the specified period.') - ) + self.stdout.write(self.style.WARNING("No transitions found in the specified period.")) return # Most common transitions - self.stdout.write(self.style.SUCCESS('\n--- Most Common Transitions ---')) + self.stdout.write(self.style.SUCCESS("\n--- Most Common Transitions ---")) common_transitions = ( - queryset.values('transition', 'content_type__model') - .annotate(count=Count('id')) - .order_by('-count')[:10] + queryset.values("transition", "content_type__model").annotate(count=Count("id")).order_by("-count")[:10] ) for t in common_transitions: - model_name = t['content_type__model'] - transition_name = t['transition'] or 'N/A' - count = t['count'] + model_name = t["content_type__model"] + transition_name = t["transition"] or "N/A" + count = t["count"] percentage = (count / total_transitions) * 100 - self.stdout.write( - f" {model_name}.{transition_name}: {count} ({percentage:.1f}%)" - ) + self.stdout.write(f" {model_name}.{transition_name}: {count} ({percentage:.1f}%)") # Transitions by model - self.stdout.write(self.style.SUCCESS('\n--- Transitions by Model ---')) - by_model = ( - queryset.values('content_type__model') - .annotate(count=Count('id')) - .order_by('-count') - ) + self.stdout.write(self.style.SUCCESS("\n--- Transitions by Model ---")) + by_model = queryset.values("content_type__model").annotate(count=Count("id")).order_by("-count") for m in by_model: - model_name = m['content_type__model'] - count = m['count'] + model_name = m["content_type__model"] + count = m["count"] percentage = (count / total_transitions) * 100 - self.stdout.write( - f" {model_name}: {count} ({percentage:.1f}%)" - ) + self.stdout.write(f" {model_name}: {count} ({percentage:.1f}%)") # Transitions by state - self.stdout.write(self.style.SUCCESS('\n--- Final States Distribution ---')) - by_state = ( - queryset.values('state') - .annotate(count=Count('id')) - .order_by('-count') - ) + self.stdout.write(self.style.SUCCESS("\n--- Final States Distribution ---")) + by_state = queryset.values("state").annotate(count=Count("id")).order_by("-count") for s in by_state: - state_name = s['state'] - count = s['count'] + state_name = s["state"] + count = s["count"] percentage = (count / total_transitions) * 100 - self.stdout.write( - f" {state_name}: {count} ({percentage:.1f}%)" - ) + self.stdout.write(f" {state_name}: {count} ({percentage:.1f}%)") # Most active users - self.stdout.write(self.style.SUCCESS('\n--- Most Active Users ---')) + self.stdout.write(self.style.SUCCESS("\n--- Most Active Users ---")) active_users = ( queryset.exclude(by__isnull=True) - .values('by__username', 'by__id') - .annotate(count=Count('id')) - .order_by('-count')[:10] + .values("by__username", "by__id") + .annotate(count=Count("id")) + .order_by("-count")[:10] ) for u in active_users: - username = u['by__username'] - user_id = u['by__id'] - count = u['count'] - self.stdout.write( - f" {username} (ID: {user_id}): {count} transitions" - ) + username = u["by__username"] + user_id = u["by__id"] + count = u["count"] + self.stdout.write(f" {username} (ID: {user_id}): {count} transitions") # System vs User transitions system_count = queryset.filter(by__isnull=True).count() user_count = queryset.exclude(by__isnull=True).count() - self.stdout.write(self.style.SUCCESS('\n--- Transition Attribution ---')) + self.stdout.write(self.style.SUCCESS("\n--- Transition Attribution ---")) self.stdout.write(f" User-initiated: {user_count} ({(user_count/total_transitions)*100:.1f}%)") self.stdout.write(f" System-initiated: {system_count} ({(system_count/total_transitions)*100:.1f}%)") # Daily transition volume # Security: Using Django ORM functions instead of raw SQL .extra() to prevent SQL injection - self.stdout.write(self.style.SUCCESS('\n--- Daily Transition Volume ---')) + self.stdout.write(self.style.SUCCESS("\n--- Daily Transition Volume ---")) daily_stats = ( - queryset.annotate(day=TruncDate('timestamp')) - .values('day') - .annotate(count=Count('id')) - .order_by('-day')[:7] + queryset.annotate(day=TruncDate("timestamp")).values("day").annotate(count=Count("id")).order_by("-day")[:7] ) for day in daily_stats: - date = day['day'] - count = day['count'] + date = day["day"] + count = day["count"] self.stdout.write(f" {date}: {count} transitions") # Busiest hours # Security: Using Django ORM functions instead of raw SQL .extra() to prevent SQL injection - self.stdout.write(self.style.SUCCESS('\n--- Busiest Hours (UTC) ---')) + self.stdout.write(self.style.SUCCESS("\n--- Busiest Hours (UTC) ---")) hourly_stats = ( - queryset.annotate(hour=ExtractHour('timestamp')) - .values('hour') - .annotate(count=Count('id')) - .order_by('-count')[:5] + queryset.annotate(hour=ExtractHour("timestamp")) + .values("hour") + .annotate(count=Count("id")) + .order_by("-count")[:5] ) for hour in hourly_stats: - hour_val = int(hour['hour']) - count = hour['count'] + hour_val = int(hour["hour"]) + count = hour["count"] self.stdout.write(f" Hour {hour_val:02d}:00: {count} transitions") # Transition patterns (common sequences) - self.stdout.write(self.style.SUCCESS('\n--- Common Transition Patterns ---')) - self.stdout.write(' Analyzing transition sequences...') + self.stdout.write(self.style.SUCCESS("\n--- Common Transition Patterns ---")) + self.stdout.write(" Analyzing transition sequences...") # Get recent objects and their transition sequences - recent_objects = ( - queryset.values('content_type', 'object_id') - .distinct()[:100] - ) + recent_objects = queryset.values("content_type", "object_id").distinct()[:100] pattern_counts = {} for obj in recent_objects: transitions = list( - StateLog.objects.filter( - content_type=obj['content_type'], - object_id=obj['object_id'] - ) - .order_by('timestamp') - .values_list('transition', flat=True) + StateLog.objects.filter(content_type=obj["content_type"], object_id=obj["object_id"]) + .order_by("timestamp") + .values_list("transition", flat=True) ) # Create pattern from consecutive transitions if len(transitions) >= 2: - pattern = ' → '.join([t or 'N/A' for t in transitions[:3]]) + pattern = " → ".join([t or "N/A" for t in transitions[:3]]) pattern_counts[pattern] = pattern_counts.get(pattern, 0) + 1 # Display top patterns - sorted_patterns = sorted( - pattern_counts.items(), - key=lambda x: x[1], - reverse=True - )[:5] + sorted_patterns = sorted(pattern_counts.items(), key=lambda x: x[1], reverse=True)[:5] for pattern, count in sorted_patterns: self.stdout.write(f" {pattern}: {count} occurrences") - self.stdout.write( - self.style.SUCCESS('\n=== Analysis Complete ===\n') - ) + self.stdout.write(self.style.SUCCESS("\n=== Analysis Complete ===\n")) # Export options - if output_format == 'json': + if output_format == "json": self._export_json(queryset, days) - elif output_format == 'csv': + elif output_format == "csv": self._export_csv(queryset, days) def _export_json(self, queryset, days): @@ -231,24 +181,21 @@ class Command(BaseCommand): from datetime import datetime data = { - 'analysis_date': datetime.now().isoformat(), - 'period_days': days, - 'total_transitions': queryset.count(), - 'transitions': list( + "analysis_date": datetime.now().isoformat(), + "period_days": days, + "total_transitions": queryset.count(), + "transitions": list( queryset.values( - 'id', 'timestamp', 'state', 'transition', - 'content_type__model', 'object_id', 'by__username' + "id", "timestamp", "state", "transition", "content_type__model", "object_id", "by__username" ) - ) + ), } filename = f'transition_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json' - with open(filename, 'w') as f: + with open(filename, "w") as f: json.dump(data, f, indent=2, default=str) - self.stdout.write( - self.style.SUCCESS(f'Exported to {filename}') - ) + self.stdout.write(self.style.SUCCESS(f"Exported to {filename}")) def _export_csv(self, queryset, days): """Export analysis results as CSV.""" @@ -257,24 +204,21 @@ class Command(BaseCommand): filename = f'transition_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv' - with open(filename, 'w', newline='') as f: + with open(filename, "w", newline="") as f: writer = csv.writer(f) - writer.writerow([ - 'ID', 'Timestamp', 'Model', 'Object ID', - 'State', 'Transition', 'User' - ]) + writer.writerow(["ID", "Timestamp", "Model", "Object ID", "State", "Transition", "User"]) - for log in queryset.select_related('content_type', 'by'): - writer.writerow([ - log.id, - log.timestamp, - log.content_type.model, - log.object_id, - log.state, - log.transition or 'N/A', - log.by.username if log.by else 'System' - ]) + for log in queryset.select_related("content_type", "by"): + writer.writerow( + [ + log.id, + log.timestamp, + log.content_type.model, + log.object_id, + log.state, + log.transition or "N/A", + log.by.username if log.by else "System", + ] + ) - self.stdout.write( - self.style.SUCCESS(f'Exported to {filename}') - ) + self.stdout.write(self.style.SUCCESS(f"Exported to {filename}")) diff --git a/backend/apps/moderation/management/commands/seed_submissions.py b/backend/apps/moderation/management/commands/seed_submissions.py index bb1b11db..246cd3fc 100644 --- a/backend/apps/moderation/management/commands/seed_submissions.py +++ b/backend/apps/moderation/management/commands/seed_submissions.py @@ -17,9 +17,7 @@ class Command(BaseCommand): def handle(self, *args, **kwargs): # Ensure we have a test user - user, created = User.objects.get_or_create( - username="test_user", email="test@example.com" - ) + user, created = User.objects.get_or_create(username="test_user", email="test@example.com") if created: user.set_password("testpass123") user.save() @@ -215,9 +213,7 @@ class Command(BaseCommand): "audio system, and increased capacity due to improved loading efficiency." ), source=( - "Park operations manual\n" - "Maintenance records\n" - "Personal observation and timing of new ride cycle" + "Park operations manual\n" "Maintenance records\n" "Personal observation and timing of new ride cycle" ), status="PENDING", ) @@ -225,10 +221,10 @@ class Command(BaseCommand): # Create PhotoSubmissions with detailed captions # Park photo submission - image_data = b"GIF87a\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00ccc,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" - dummy_image = SimpleUploadedFile( - "park_entrance.gif", image_data, content_type="image/gif" + image_data = ( + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00ccc,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" ) + dummy_image = SimpleUploadedFile("park_entrance.gif", image_data, content_type="image/gif") PhotoSubmission.objects.create( user=user, @@ -244,9 +240,7 @@ class Command(BaseCommand): ) # Ride photo submission - dummy_image2 = SimpleUploadedFile( - "coaster_track.gif", image_data, content_type="image/gif" - ) + dummy_image2 = SimpleUploadedFile("coaster_track.gif", image_data, content_type="image/gif") PhotoSubmission.objects.create( user=user, content_type=ride_ct, diff --git a/backend/apps/moderation/management/commands/validate_state_machines.py b/backend/apps/moderation/management/commands/validate_state_machines.py index 17c0dee2..1f6188a6 100644 --- a/backend/apps/moderation/management/commands/validate_state_machines.py +++ b/backend/apps/moderation/management/commands/validate_state_machines.py @@ -1,4 +1,5 @@ """Management command to validate state machine configurations for moderation models.""" + from django.core.management import CommandError from django.core.management.base import BaseCommand @@ -76,18 +77,15 @@ class Command(BaseCommand): model_key = model_name.lower() if model_key not in models_to_validate: raise CommandError( - f"Unknown model: {model_name}. " - f"Valid options: {', '.join(models_to_validate.keys())}" + f"Unknown model: {model_name}. " f"Valid options: {', '.join(models_to_validate.keys())}" ) models_to_validate = {model_key: models_to_validate[model_key]} - self.stdout.write( - self.style.SUCCESS("\nValidating State Machine Configurations\n") - ) + self.stdout.write(self.style.SUCCESS("\nValidating State Machine Configurations\n")) self.stdout.write("=" * 60 + "\n") all_valid = True - for model_key, ( + for _model_key, ( model_class, choice_group, domain, @@ -101,61 +99,34 @@ class Command(BaseCommand): result = validator.validate_choice_group() if result.is_valid: - self.stdout.write( - self.style.SUCCESS( - f" ✓ {model_class.__name__} validation passed" - ) - ) + self.stdout.write(self.style.SUCCESS(f" ✓ {model_class.__name__} validation passed")) if verbose: self._show_transition_graph(choice_group, domain) else: all_valid = False - self.stdout.write( - self.style.ERROR( - f" ✗ {model_class.__name__} validation failed" - ) - ) + self.stdout.write(self.style.ERROR(f" ✗ {model_class.__name__} validation failed")) for error in result.errors: - self.stdout.write( - self.style.ERROR(f" - {error.message}") - ) + self.stdout.write(self.style.ERROR(f" - {error.message}")) # Check FSM field if not self._check_fsm_field(model_class): all_valid = False - self.stdout.write( - self.style.ERROR( - f" - FSM field 'status' not found on " - f"{model_class.__name__}" - ) - ) + self.stdout.write(self.style.ERROR(f" - FSM field 'status' not found on " f"{model_class.__name__}")) # Check mixin if not self._check_state_machine_mixin(model_class): all_valid = False self.stdout.write( - self.style.WARNING( - f" - StateMachineMixin not found on " - f"{model_class.__name__}" - ) + self.style.WARNING(f" - StateMachineMixin not found on " f"{model_class.__name__}") ) self.stdout.write("\n" + "=" * 60) if all_valid: - self.stdout.write( - self.style.SUCCESS( - "\n✓ All validations passed successfully!\n" - ) - ) + self.stdout.write(self.style.SUCCESS("\n✓ All validations passed successfully!\n")) else: - self.stdout.write( - self.style.ERROR( - "\n✗ Some validations failed. " - "Please review the errors above.\n" - ) - ) + self.stdout.write(self.style.ERROR("\n✗ Some validations failed. " "Please review the errors above.\n")) raise CommandError("State machine validation failed") def _check_fsm_field(self, model_class): @@ -177,9 +148,7 @@ class Command(BaseCommand): self.stdout.write("\n Transition Graph:") - graph = registry_instance.export_transition_graph( - choice_group, domain - ) + graph = registry_instance.export_transition_graph(choice_group, domain) for source, targets in sorted(graph.items()): if targets: diff --git a/backend/apps/moderation/migrations/0001_initial.py b/backend/apps/moderation/migrations/0001_initial.py index 0553be8b..3f8ceff6 100644 --- a/backend/apps/moderation/migrations/0001_initial.py +++ b/backend/apps/moderation/migrations/0001_initial.py @@ -47,9 +47,7 @@ class Migration(migrations.Migration): ), ( "changes", - models.JSONField( - help_text="JSON representation of the changes or new object data" - ), + models.JSONField(help_text="JSON representation of the changes or new object data"), ), ( "moderator_changes", @@ -150,9 +148,7 @@ class Migration(migrations.Migration): ), ( "changes", - models.JSONField( - help_text="JSON representation of the changes or new object data" - ), + models.JSONField(help_text="JSON representation of the changes or new object data"), ), ( "moderator_changes", diff --git a/backend/apps/moderation/migrations/0003_bulkoperation_bulkoperationevent_moderationaction_and_more.py b/backend/apps/moderation/migrations/0003_bulkoperation_bulkoperationevent_moderationaction_and_more.py index bb5c8a84..65aff7f1 100644 --- a/backend/apps/moderation/migrations/0003_bulkoperation_bulkoperationevent_moderationaction_and_more.py +++ b/backend/apps/moderation/migrations/0003_bulkoperation_bulkoperationevent_moderationaction_and_more.py @@ -812,21 +812,15 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="bulkoperation", - index=models.Index( - fields=["status", "priority"], name="moderation__status_f11ee8_idx" - ), + index=models.Index(fields=["status", "priority"], name="moderation__status_f11ee8_idx"), ), migrations.AddIndex( model_name="bulkoperation", - index=models.Index( - fields=["created_by"], name="moderation__created_4fe5d2_idx" - ), + index=models.Index(fields=["created_by"], name="moderation__created_4fe5d2_idx"), ), migrations.AddIndex( model_name="bulkoperation", - index=models.Index( - fields=["operation_type"], name="moderation__operati_bc84d9_idx" - ), + index=models.Index(fields=["operation_type"], name="moderation__operati_bc84d9_idx"), ), pgtrigger.migrations.AddTrigger( model_name="bulkoperation", @@ -859,9 +853,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="moderationreport", - index=models.Index( - fields=["status", "priority"], name="moderation__status_6aa18c_idx" - ), + index=models.Index(fields=["status", "priority"], name="moderation__status_6aa18c_idx"), ), migrations.AddIndex( model_name="moderationreport", @@ -872,9 +864,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="moderationreport", - index=models.Index( - fields=["assigned_moderator"], name="moderation__assigne_c43cdf_idx" - ), + index=models.Index(fields=["assigned_moderator"], name="moderation__assigne_c43cdf_idx"), ), pgtrigger.migrations.AddTrigger( model_name="moderationreport", @@ -907,9 +897,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="moderationqueue", - index=models.Index( - fields=["status", "priority"], name="moderation__status_6f2a75_idx" - ), + index=models.Index(fields=["status", "priority"], name="moderation__status_6f2a75_idx"), ), migrations.AddIndex( model_name="moderationqueue", @@ -920,15 +908,11 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="moderationqueue", - index=models.Index( - fields=["assigned_to"], name="moderation__assigne_2fc958_idx" - ), + index=models.Index(fields=["assigned_to"], name="moderation__assigne_2fc958_idx"), ), migrations.AddIndex( model_name="moderationqueue", - index=models.Index( - fields=["flagged_by"], name="moderation__flagged_169834_idx" - ), + index=models.Index(fields=["flagged_by"], name="moderation__flagged_169834_idx"), ), pgtrigger.migrations.AddTrigger( model_name="moderationqueue", @@ -975,9 +959,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="moderationaction", - index=models.Index( - fields=["expires_at"], name="moderation__expires_963efb_idx" - ), + index=models.Index(fields=["expires_at"], name="moderation__expires_963efb_idx"), ), pgtrigger.migrations.AddTrigger( model_name="moderationaction", diff --git a/backend/apps/moderation/migrations/0004_alter_moderationqueue_options_and_more.py b/backend/apps/moderation/migrations/0004_alter_moderationqueue_options_and_more.py index f4159efd..b9f8041c 100644 --- a/backend/apps/moderation/migrations/0004_alter_moderationqueue_options_and_more.py +++ b/backend/apps/moderation/migrations/0004_alter_moderationqueue_options_and_more.py @@ -55,9 +55,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperation", name="can_cancel", - field=models.BooleanField( - default=True, help_text="Whether this operation can be cancelled" - ), + field=models.BooleanField(default=True, help_text="Whether this operation can be cancelled"), ), migrations.AlterField( model_name="bulkoperation", @@ -67,23 +65,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperation", name="estimated_duration_minutes", - field=models.PositiveIntegerField( - blank=True, help_text="Estimated duration in minutes", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Estimated duration in minutes", null=True), ), migrations.AlterField( model_name="bulkoperation", name="failed_items", - field=models.PositiveIntegerField( - default=0, help_text="Number of items that failed" - ), + field=models.PositiveIntegerField(default=0, help_text="Number of items that failed"), ), migrations.AlterField( model_name="bulkoperation", name="id", - field=models.BigAutoField( - auto_created=True, primary_key=True, serialize=False, verbose_name="ID" - ), + field=models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID"), ), migrations.AlterField( model_name="bulkoperation", @@ -105,9 +97,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperation", name="parameters", - field=models.JSONField( - default=dict, help_text="Parameters for the operation" - ), + field=models.JSONField(default=dict, help_text="Parameters for the operation"), ), migrations.AlterField( model_name="bulkoperation", @@ -126,9 +116,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperation", name="processed_items", - field=models.PositiveIntegerField( - default=0, help_text="Number of items processed" - ), + field=models.PositiveIntegerField(default=0, help_text="Number of items processed"), ), migrations.AlterField( model_name="bulkoperation", @@ -142,23 +130,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperation", name="schedule_for", - field=models.DateTimeField( - blank=True, help_text="When to run this operation", null=True - ), + field=models.DateTimeField(blank=True, help_text="When to run this operation", null=True), ), migrations.AlterField( model_name="bulkoperation", name="total_items", - field=models.PositiveIntegerField( - default=0, help_text="Total number of items to process" - ), + field=models.PositiveIntegerField(default=0, help_text="Total number of items to process"), ), migrations.AlterField( model_name="bulkoperationevent", name="can_cancel", - field=models.BooleanField( - default=True, help_text="Whether this operation can be cancelled" - ), + field=models.BooleanField(default=True, help_text="Whether this operation can be cancelled"), ), migrations.AlterField( model_name="bulkoperationevent", @@ -168,16 +150,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperationevent", name="estimated_duration_minutes", - field=models.PositiveIntegerField( - blank=True, help_text="Estimated duration in minutes", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Estimated duration in minutes", null=True), ), migrations.AlterField( model_name="bulkoperationevent", name="failed_items", - field=models.PositiveIntegerField( - default=0, help_text="Number of items that failed" - ), + field=models.PositiveIntegerField(default=0, help_text="Number of items that failed"), ), migrations.AlterField( model_name="bulkoperationevent", @@ -204,9 +182,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperationevent", name="parameters", - field=models.JSONField( - default=dict, help_text="Parameters for the operation" - ), + field=models.JSONField(default=dict, help_text="Parameters for the operation"), ), migrations.AlterField( model_name="bulkoperationevent", @@ -225,9 +201,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperationevent", name="processed_items", - field=models.PositiveIntegerField( - default=0, help_text="Number of items processed" - ), + field=models.PositiveIntegerField(default=0, help_text="Number of items processed"), ), migrations.AlterField( model_name="bulkoperationevent", @@ -241,16 +215,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperationevent", name="schedule_for", - field=models.DateTimeField( - blank=True, help_text="When to run this operation", null=True - ), + field=models.DateTimeField(blank=True, help_text="When to run this operation", null=True), ), migrations.AlterField( model_name="bulkoperationevent", name="total_items", - field=models.PositiveIntegerField( - default=0, help_text="Total number of items to process" - ), + field=models.PositiveIntegerField(default=0, help_text="Total number of items to process"), ), migrations.AlterField( model_name="moderationaction", @@ -286,23 +256,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationaction", name="expires_at", - field=models.DateTimeField( - blank=True, help_text="When this action expires", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this action expires", null=True), ), migrations.AlterField( model_name="moderationaction", name="is_active", - field=models.BooleanField( - default=True, help_text="Whether this action is currently active" - ), + field=models.BooleanField(default=True, help_text="Whether this action is currently active"), ), migrations.AlterField( model_name="moderationaction", name="reason", - field=models.CharField( - help_text="Brief reason for the action", max_length=200 - ), + field=models.CharField(help_text="Brief reason for the action", max_length=200), ), migrations.AlterField( model_name="moderationactionevent", @@ -338,44 +302,32 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationactionevent", name="expires_at", - field=models.DateTimeField( - blank=True, help_text="When this action expires", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this action expires", null=True), ), migrations.AlterField( model_name="moderationactionevent", name="is_active", - field=models.BooleanField( - default=True, help_text="Whether this action is currently active" - ), + field=models.BooleanField(default=True, help_text="Whether this action is currently active"), ), migrations.AlterField( model_name="moderationactionevent", name="reason", - field=models.CharField( - help_text="Brief reason for the action", max_length=200 - ), + field=models.CharField(help_text="Brief reason for the action", max_length=200), ), migrations.AlterField( model_name="moderationqueue", name="description", - field=models.TextField( - help_text="Detailed description of what needs to be done" - ), + field=models.TextField(help_text="Detailed description of what needs to be done"), ), migrations.AlterField( model_name="moderationqueue", name="entity_id", - field=models.PositiveIntegerField( - blank=True, help_text="ID of the related entity", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="ID of the related entity", null=True), ), migrations.AlterField( model_name="moderationqueue", name="entity_preview", - field=models.JSONField( - blank=True, default=dict, help_text="Preview data for the entity" - ), + field=models.JSONField(blank=True, default=dict, help_text="Preview data for the entity"), ), migrations.AlterField( model_name="moderationqueue", @@ -389,9 +341,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueue", name="estimated_review_time", - field=models.PositiveIntegerField( - default=30, help_text="Estimated time in minutes" - ), + field=models.PositiveIntegerField(default=30, help_text="Estimated time in minutes"), ), migrations.AlterField( model_name="moderationqueue", @@ -436,37 +386,27 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueue", name="tags", - field=models.JSONField( - blank=True, default=list, help_text="Tags for categorization" - ), + field=models.JSONField(blank=True, default=list, help_text="Tags for categorization"), ), migrations.AlterField( model_name="moderationqueue", name="title", - field=models.CharField( - help_text="Brief title for the queue item", max_length=200 - ), + field=models.CharField(help_text="Brief title for the queue item", max_length=200), ), migrations.AlterField( model_name="moderationqueueevent", name="description", - field=models.TextField( - help_text="Detailed description of what needs to be done" - ), + field=models.TextField(help_text="Detailed description of what needs to be done"), ), migrations.AlterField( model_name="moderationqueueevent", name="entity_id", - field=models.PositiveIntegerField( - blank=True, help_text="ID of the related entity", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="ID of the related entity", null=True), ), migrations.AlterField( model_name="moderationqueueevent", name="entity_preview", - field=models.JSONField( - blank=True, default=dict, help_text="Preview data for the entity" - ), + field=models.JSONField(blank=True, default=dict, help_text="Preview data for the entity"), ), migrations.AlterField( model_name="moderationqueueevent", @@ -480,9 +420,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueueevent", name="estimated_review_time", - field=models.PositiveIntegerField( - default=30, help_text="Estimated time in minutes" - ), + field=models.PositiveIntegerField(default=30, help_text="Estimated time in minutes"), ), migrations.AlterField( model_name="moderationqueueevent", @@ -529,16 +467,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueueevent", name="tags", - field=models.JSONField( - blank=True, default=list, help_text="Tags for categorization" - ), + field=models.JSONField(blank=True, default=list, help_text="Tags for categorization"), ), migrations.AlterField( model_name="moderationqueueevent", name="title", - field=models.CharField( - help_text="Brief title for the queue item", max_length=200 - ), + field=models.CharField(help_text="Brief title for the queue item", max_length=200), ), migrations.AlterField( model_name="moderationreport", @@ -557,9 +491,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreport", name="reason", - field=models.CharField( - help_text="Brief reason for the report", max_length=200 - ), + field=models.CharField(help_text="Brief reason for the report", max_length=200), ), migrations.AlterField( model_name="moderationreport", @@ -582,9 +514,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreport", name="reported_entity_id", - field=models.PositiveIntegerField( - help_text="ID of the entity being reported" - ), + field=models.PositiveIntegerField(help_text="ID of the entity being reported"), ), migrations.AlterField( model_name="moderationreport", @@ -641,9 +571,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreportevent", name="reason", - field=models.CharField( - help_text="Brief reason for the report", max_length=200 - ), + field=models.CharField(help_text="Brief reason for the report", max_length=200), ), migrations.AlterField( model_name="moderationreportevent", @@ -666,9 +594,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreportevent", name="reported_entity_id", - field=models.PositiveIntegerField( - help_text="ID of the entity being reported" - ), + field=models.PositiveIntegerField(help_text="ID of the entity being reported"), ), migrations.AlterField( model_name="moderationreportevent", @@ -710,45 +636,31 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="bulkoperation", - index=models.Index( - fields=["schedule_for"], name="moderation__schedul_350704_idx" - ), + index=models.Index(fields=["schedule_for"], name="moderation__schedul_350704_idx"), ), migrations.AddIndex( model_name="bulkoperation", - index=models.Index( - fields=["created_at"], name="moderation__created_b705f4_idx" - ), + index=models.Index(fields=["created_at"], name="moderation__created_b705f4_idx"), ), migrations.AddIndex( model_name="moderationaction", - index=models.Index( - fields=["moderator"], name="moderation__moderat_1c19b0_idx" - ), + index=models.Index(fields=["moderator"], name="moderation__moderat_1c19b0_idx"), ), migrations.AddIndex( model_name="moderationaction", - index=models.Index( - fields=["created_at"], name="moderation__created_6378e6_idx" - ), + index=models.Index(fields=["created_at"], name="moderation__created_6378e6_idx"), ), migrations.AddIndex( model_name="moderationqueue", - index=models.Index( - fields=["created_at"], name="moderation__created_fe6dd0_idx" - ), + index=models.Index(fields=["created_at"], name="moderation__created_fe6dd0_idx"), ), migrations.AddIndex( model_name="moderationreport", - index=models.Index( - fields=["reported_by"], name="moderation__reporte_81af56_idx" - ), + index=models.Index(fields=["reported_by"], name="moderation__reporte_81af56_idx"), ), migrations.AddIndex( model_name="moderationreport", - index=models.Index( - fields=["created_at"], name="moderation__created_ae337c_idx" - ), + index=models.Index(fields=["created_at"], name="moderation__created_ae337c_idx"), ), pgtrigger.migrations.AddTrigger( model_name="moderationqueue", diff --git a/backend/apps/moderation/migrations/0008_alter_bulkoperation_options_and_more.py b/backend/apps/moderation/migrations/0008_alter_bulkoperation_options_and_more.py index e0893669..a7ee7398 100644 --- a/backend/apps/moderation/migrations/0008_alter_bulkoperation_options_and_more.py +++ b/backend/apps/moderation/migrations/0008_alter_bulkoperation_options_and_more.py @@ -67,9 +67,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperation", name="completed_at", - field=models.DateTimeField( - blank=True, help_text="When this operation completed", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this operation completed", null=True), ), migrations.AlterField( model_name="bulkoperation", @@ -84,23 +82,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperation", name="started_at", - field=models.DateTimeField( - blank=True, help_text="When this operation started", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this operation started", null=True), ), migrations.AlterField( model_name="bulkoperation", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this operation was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this operation was last updated"), ), migrations.AlterField( model_name="bulkoperationevent", name="completed_at", - field=models.DateTimeField( - blank=True, help_text="When this operation completed", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this operation completed", null=True), ), migrations.AlterField( model_name="bulkoperationevent", @@ -117,9 +109,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperationevent", name="started_at", - field=models.DateTimeField( - blank=True, help_text="When this operation started", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this operation started", null=True), ), migrations.AlterField( model_name="bulkoperationevent", @@ -142,9 +132,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="bulkoperationevent", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this operation was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this operation was last updated"), ), migrations.AlterField( model_name="editsubmission", @@ -158,9 +146,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="editsubmission", name="handled_at", - field=models.DateTimeField( - blank=True, help_text="When this submission was handled", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this submission was handled", null=True), ), migrations.AlterField( model_name="editsubmission", @@ -208,9 +194,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="editsubmissionevent", name="handled_at", - field=models.DateTimeField( - blank=True, help_text="When this submission was handled", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this submission was handled", null=True), ), migrations.AlterField( model_name="editsubmissionevent", @@ -267,9 +251,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationaction", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this action was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this action was created"), ), migrations.AlterField( model_name="moderationaction", @@ -306,16 +288,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationaction", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this action was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this action was last updated"), ), migrations.AlterField( model_name="moderationactionevent", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this action was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this action was created"), ), migrations.AlterField( model_name="moderationactionevent", @@ -358,16 +336,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationactionevent", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this action was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this action was last updated"), ), migrations.AlterField( model_name="moderationqueue", name="assigned_at", - field=models.DateTimeField( - blank=True, help_text="When this item was assigned", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this item was assigned", null=True), ), migrations.AlterField( model_name="moderationqueue", @@ -384,9 +358,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueue", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this item was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this item was created"), ), migrations.AlterField( model_name="moderationqueue", @@ -415,16 +387,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueue", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this item was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this item was last updated"), ), migrations.AlterField( model_name="moderationqueueevent", name="assigned_at", - field=models.DateTimeField( - blank=True, help_text="When this item was assigned", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this item was assigned", null=True), ), migrations.AlterField( model_name="moderationqueueevent", @@ -443,9 +411,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueueevent", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this item was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this item was created"), ), migrations.AlterField( model_name="moderationqueueevent", @@ -495,9 +461,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationqueueevent", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this item was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this item was last updated"), ), migrations.AlterField( model_name="moderationreport", @@ -514,9 +478,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreport", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this report was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this report was created"), ), migrations.AlterField( model_name="moderationreport", @@ -531,16 +493,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreport", name="resolved_at", - field=models.DateTimeField( - blank=True, help_text="When this report was resolved", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this report was resolved", null=True), ), migrations.AlterField( model_name="moderationreport", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this report was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this report was last updated"), ), migrations.AlterField( model_name="moderationreportevent", @@ -559,9 +517,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreportevent", name="created_at", - field=models.DateTimeField( - auto_now_add=True, help_text="When this report was created" - ), + field=models.DateTimeField(auto_now_add=True, help_text="When this report was created"), ), migrations.AlterField( model_name="moderationreportevent", @@ -578,9 +534,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreportevent", name="resolved_at", - field=models.DateTimeField( - blank=True, help_text="When this report was resolved", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this report was resolved", null=True), ), migrations.AlterField( model_name="moderationreportevent", @@ -602,16 +556,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="moderationreportevent", name="updated_at", - field=models.DateTimeField( - auto_now=True, help_text="When this report was last updated" - ), + field=models.DateTimeField(auto_now=True, help_text="When this report was last updated"), ), migrations.AlterField( model_name="photosubmission", name="caption", - field=models.CharField( - blank=True, help_text="Photo caption", max_length=255 - ), + field=models.CharField(blank=True, help_text="Photo caption", max_length=255), ), migrations.AlterField( model_name="photosubmission", @@ -625,16 +575,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="photosubmission", name="date_taken", - field=models.DateField( - blank=True, help_text="Date the photo was taken", null=True - ), + field=models.DateField(blank=True, help_text="Date the photo was taken", null=True), ), migrations.AlterField( model_name="photosubmission", name="handled_at", - field=models.DateTimeField( - blank=True, help_text="When this submission was handled", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this submission was handled", null=True), ), migrations.AlterField( model_name="photosubmission", @@ -651,9 +597,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="photosubmission", name="object_id", - field=models.PositiveIntegerField( - help_text="ID of object this photo is for" - ), + field=models.PositiveIntegerField(help_text="ID of object this photo is for"), ), migrations.AlterField( model_name="photosubmission", @@ -668,9 +612,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="photosubmissionevent", name="caption", - field=models.CharField( - blank=True, help_text="Photo caption", max_length=255 - ), + field=models.CharField(blank=True, help_text="Photo caption", max_length=255), ), migrations.AlterField( model_name="photosubmissionevent", @@ -687,16 +629,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="photosubmissionevent", name="date_taken", - field=models.DateField( - blank=True, help_text="Date the photo was taken", null=True - ), + field=models.DateField(blank=True, help_text="Date the photo was taken", null=True), ), migrations.AlterField( model_name="photosubmissionevent", name="handled_at", - field=models.DateTimeField( - blank=True, help_text="When this submission was handled", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this submission was handled", null=True), ), migrations.AlterField( model_name="photosubmissionevent", @@ -715,9 +653,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="photosubmissionevent", name="object_id", - field=models.PositiveIntegerField( - help_text="ID of object this photo is for" - ), + field=models.PositiveIntegerField(help_text="ID of object this photo is for"), ), migrations.AlterField( model_name="photosubmissionevent", diff --git a/backend/apps/moderation/mixins.py b/backend/apps/moderation/mixins.py index c9544149..80a92d79 100644 --- a/backend/apps/moderation/mixins.py +++ b/backend/apps/moderation/mixins.py @@ -132,9 +132,7 @@ class EditSubmissionMixin(DetailView): status=400, ) - return self.handle_edit_submission( - request, changes, reason, source, submission_type - ) + return self.handle_edit_submission(request, changes, reason, source, submission_type) except json.JSONDecodeError: return JsonResponse( @@ -169,9 +167,7 @@ class PhotoSubmissionMixin(DetailView): try: obj = self.get_object() except (AttributeError, self.model.DoesNotExist): - return JsonResponse( - {"status": "error", "message": "Invalid object."}, status=400 - ) + return JsonResponse({"status": "error", "message": "Invalid object."}, status=400) if not request.FILES.get("photo"): return JsonResponse( diff --git a/backend/apps/moderation/models.py b/backend/apps/moderation/models.py index ac01631d..79bab15b 100644 --- a/backend/apps/moderation/models.py +++ b/backend/apps/moderation/models.py @@ -17,7 +17,7 @@ are registered via the callback configuration defined in each model's Meta class """ from datetime import timedelta -from typing import Any, Union +from typing import Any import pghistory from django.conf import settings @@ -33,7 +33,7 @@ from apps.core.choices.fields import RichChoiceField from apps.core.history import TrackedModel from apps.core.state_machine import RichFSMField, StateMachineMixin -UserType = Union[AbstractBaseUser, AnonymousUser] +UserType = AbstractBaseUser | AnonymousUser # Lazy callback imports to avoid circular dependencies @@ -45,11 +45,12 @@ def _get_notification_callbacks(): SubmissionEscalatedNotification, SubmissionRejectedNotification, ) + return { - 'approved': SubmissionApprovedNotification, - 'rejected': SubmissionRejectedNotification, - 'escalated': SubmissionEscalatedNotification, - 'moderation': ModerationNotificationCallback, + "approved": SubmissionApprovedNotification, + "rejected": SubmissionRejectedNotification, + "escalated": SubmissionEscalatedNotification, + "moderation": ModerationNotificationCallback, } @@ -59,9 +60,10 @@ def _get_cache_callbacks(): CacheInvalidationCallback, ModerationCacheInvalidation, ) + return { - 'generic': CacheInvalidationCallback, - 'moderation': ModerationCacheInvalidation, + "generic": CacheInvalidationCallback, + "moderation": ModerationCacheInvalidation, } @@ -69,6 +71,7 @@ def _get_cache_callbacks(): # Original EditSubmission Model (Preserved) # ============================================================================ + @pghistory.track() # Track all changes by default class EditSubmission(StateMachineMixin, TrackedModel): """Edit submission model with FSM-managed status transitions.""" @@ -98,16 +101,11 @@ class EditSubmission(StateMachineMixin, TrackedModel): # Type of submission submission_type = RichChoiceField( - choice_group="submission_types", - domain="moderation", - max_length=10, - default="EDIT" + choice_group="submission_types", domain="moderation", max_length=10, default="EDIT" ) # The actual changes/data - changes = models.JSONField( - help_text="JSON representation of the changes or new object data" - ) + changes = models.JSONField(help_text="JSON representation of the changes or new object data") # Moderator's edited version of changes before approval moderator_changes = models.JSONField( @@ -118,14 +116,9 @@ class EditSubmission(StateMachineMixin, TrackedModel): # Metadata reason = models.TextField(help_text="Why this edit/addition is needed") - source = models.TextField( - blank=True, help_text="Source of information (if applicable)" - ) + source = models.TextField(blank=True, help_text="Source of information (if applicable)") status = RichFSMField( - choice_group="edit_submission_statuses", - domain="moderation", - max_length=20, - default="PENDING" + choice_group="edit_submission_statuses", domain="moderation", max_length=20, default="PENDING" ) created_at = models.DateTimeField(auto_now_add=True) @@ -138,12 +131,8 @@ class EditSubmission(StateMachineMixin, TrackedModel): related_name="handled_submissions", help_text="Moderator who handled this submission", ) - handled_at = models.DateTimeField( - null=True, blank=True, help_text="When this submission was handled" - ) - notes = models.TextField( - blank=True, help_text="Notes from the moderator about this submission" - ) + handled_at = models.DateTimeField(null=True, blank=True, help_text="When this submission was handled") + notes = models.TextField(blank=True, help_text="Notes from the moderator about this submission") # Claim tracking for concurrency control claimed_by = models.ForeignKey( @@ -154,9 +143,7 @@ class EditSubmission(StateMachineMixin, TrackedModel): related_name="claimed_edit_submissions", help_text="Moderator who has claimed this submission for review", ) - claimed_at = models.DateTimeField( - null=True, blank=True, help_text="When this submission was claimed" - ) + claimed_at = models.DateTimeField(null=True, blank=True, help_text="When this submission was claimed") class Meta(TrackedModel.Meta): verbose_name = "Edit Submission" @@ -187,12 +174,12 @@ class EditSubmission(StateMachineMixin, TrackedModel): field = model_class._meta.get_field(field_name) if isinstance(field, models.ForeignKey) and value is not None: try: - related_obj = field.related_model.objects.get(pk=value) # type: ignore + related_obj = field.related_model.objects.get(pk=value) # type: ignore resolved_data[field_name] = related_obj except ObjectDoesNotExist: raise ValueError( - f"Related object {field.related_model.__name__} with pk={value} does not exist" # type: ignore - ) + f"Related object {field.related_model.__name__} with pk={value} does not exist" # type: ignore + ) from None except FieldDoesNotExist: # Field doesn't exist on model, skip it continue @@ -217,9 +204,7 @@ class EditSubmission(StateMachineMixin, TrackedModel): from django.core.exceptions import ValidationError if self.status != "PENDING": - raise ValidationError( - f"Cannot claim submission: current status is {self.status}, expected PENDING" - ) + raise ValidationError(f"Cannot claim submission: current status is {self.status}, expected PENDING") self.transition_to_claimed(user=user) self.claimed_by = user @@ -240,9 +225,7 @@ class EditSubmission(StateMachineMixin, TrackedModel): from django.core.exceptions import ValidationError if self.status != "CLAIMED": - raise ValidationError( - f"Cannot unclaim submission: current status is {self.status}, expected CLAIMED" - ) + raise ValidationError(f"Cannot unclaim submission: current status is {self.status}, expected CLAIMED") # Set status directly (not via FSM transition to avoid cycle) # This is intentional - the unclaim action is a special "rollback" operation @@ -274,9 +257,7 @@ class EditSubmission(StateMachineMixin, TrackedModel): # Validate state - must be CLAIMED before approval if self.status != "CLAIMED": - raise ValidationError( - f"Cannot approve submission: must be CLAIMED first (current status: {self.status})" - ) + raise ValidationError(f"Cannot approve submission: must be CLAIMED first (current status: {self.status})") model_class = self.content_type.model_class() if not model_class: @@ -341,9 +322,7 @@ class EditSubmission(StateMachineMixin, TrackedModel): # Validate state - must be CLAIMED before rejection if self.status != "CLAIMED": - raise ValidationError( - f"Cannot reject submission: must be CLAIMED first (current status: {self.status})" - ) + raise ValidationError(f"Cannot reject submission: must be CLAIMED first (current status: {self.status})") # Use FSM transition to update status self.transition_to_rejected(user=rejecter) @@ -369,9 +348,7 @@ class EditSubmission(StateMachineMixin, TrackedModel): # Validate state - must be CLAIMED before escalation if self.status != "CLAIMED": - raise ValidationError( - f"Cannot escalate submission: must be CLAIMED first (current status: {self.status})" - ) + raise ValidationError(f"Cannot escalate submission: must be CLAIMED first (current status: {self.status})") # Use FSM transition to update status self.transition_to_escalated(user=escalator) @@ -395,6 +372,7 @@ class EditSubmission(StateMachineMixin, TrackedModel): # New Moderation System Models # ============================================================================ + @pghistory.track() class ModerationReport(StateMachineMixin, TrackedModel): """ @@ -407,43 +385,29 @@ class ModerationReport(StateMachineMixin, TrackedModel): state_field_name = "status" # Report details - report_type = RichChoiceField( - choice_group="report_types", - domain="moderation", - max_length=50 - ) + report_type = RichChoiceField(choice_group="report_types", domain="moderation", max_length=50) status = RichFSMField( - choice_group="moderation_report_statuses", - domain="moderation", - max_length=20, - default='PENDING' - ) - priority = RichChoiceField( - choice_group="priority_levels", - domain="moderation", - max_length=10, - default='MEDIUM' + choice_group="moderation_report_statuses", domain="moderation", max_length=20, default="PENDING" ) + priority = RichChoiceField(choice_group="priority_levels", domain="moderation", max_length=10, default="MEDIUM") # What is being reported reported_entity_type = models.CharField( - max_length=50, help_text="Type of entity being reported (park, ride, user, etc.)") - reported_entity_id = models.PositiveIntegerField( - help_text="ID of the entity being reported") - content_type = models.ForeignKey( - ContentType, on_delete=models.CASCADE, null=True, blank=True) + max_length=50, help_text="Type of entity being reported (park, ride, user, etc.)" + ) + reported_entity_id = models.PositiveIntegerField(help_text="ID of the entity being reported") + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True, blank=True) # Report content reason = models.CharField(max_length=200, help_text="Brief reason for the report") description = models.TextField(help_text="Detailed description of the issue") - evidence_urls = models.JSONField( - default=list, blank=True, help_text="URLs to evidence (screenshots, etc.)") + evidence_urls = models.JSONField(default=list, blank=True, help_text="URLs to evidence (screenshots, etc.)") # Users involved reported_by = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name='moderation_reports_made', + related_name="moderation_reports_made", help_text="User who made this report", ) assigned_moderator = models.ForeignKey( @@ -451,40 +415,32 @@ class ModerationReport(StateMachineMixin, TrackedModel): on_delete=models.SET_NULL, null=True, blank=True, - related_name='assigned_moderation_reports', + related_name="assigned_moderation_reports", help_text="Moderator assigned to handle this report", ) # Resolution - resolution_action = models.CharField( - max_length=100, blank=True, help_text="Action taken to resolve") - resolution_notes = models.TextField( - blank=True, help_text="Notes about the resolution") - resolved_at = models.DateTimeField( - null=True, blank=True, help_text="When this report was resolved" - ) + resolution_action = models.CharField(max_length=100, blank=True, help_text="Action taken to resolve") + resolution_notes = models.TextField(blank=True, help_text="Notes about the resolution") + resolved_at = models.DateTimeField(null=True, blank=True, help_text="When this report was resolved") # Timestamps - created_at = models.DateTimeField( - auto_now_add=True, help_text="When this report was created" - ) - updated_at = models.DateTimeField( - auto_now=True, help_text="When this report was last updated" - ) + created_at = models.DateTimeField(auto_now_add=True, help_text="When this report was created") + updated_at = models.DateTimeField(auto_now=True, help_text="When this report was last updated") class Meta(TrackedModel.Meta): verbose_name = "Moderation Report" verbose_name_plural = "Moderation Reports" - ordering = ['-created_at'] + ordering = ["-created_at"] indexes = [ - models.Index(fields=['status', 'priority']), - models.Index(fields=['reported_by']), - models.Index(fields=['assigned_moderator']), - models.Index(fields=['created_at']), + models.Index(fields=["status", "priority"]), + models.Index(fields=["reported_by"]), + models.Index(fields=["assigned_moderator"]), + models.Index(fields=["created_at"]), ] def __str__(self): - return f"{self.get_report_type_display()} report by {self.reported_by.username}" # type: ignore + return f"{self.get_report_type_display()} report by {self.reported_by.username}" # type: ignore @pghistory.track() @@ -499,37 +455,20 @@ class ModerationQueue(StateMachineMixin, TrackedModel): state_field_name = "status" # Queue item details - item_type = RichChoiceField( - choice_group="queue_item_types", - domain="moderation", - max_length=50 - ) + item_type = RichChoiceField(choice_group="queue_item_types", domain="moderation", max_length=50) status = RichFSMField( - choice_group="moderation_queue_statuses", - domain="moderation", - max_length=20, - default='PENDING' - ) - priority = RichChoiceField( - choice_group="priority_levels", - domain="moderation", - max_length=10, - default='MEDIUM' + choice_group="moderation_queue_statuses", domain="moderation", max_length=20, default="PENDING" ) + priority = RichChoiceField(choice_group="priority_levels", domain="moderation", max_length=10, default="MEDIUM") title = models.CharField(max_length=200, help_text="Brief title for the queue item") - description = models.TextField( - help_text="Detailed description of what needs to be done") + description = models.TextField(help_text="Detailed description of what needs to be done") # What entity this relates to - entity_type = models.CharField( - max_length=50, blank=True, help_text="Type of entity (park, ride, user, etc.)") - entity_id = models.PositiveIntegerField( - null=True, blank=True, help_text="ID of the related entity") - entity_preview = models.JSONField( - default=dict, blank=True, help_text="Preview data for the entity") - content_type = models.ForeignKey( - ContentType, on_delete=models.CASCADE, null=True, blank=True) + entity_type = models.CharField(max_length=50, blank=True, help_text="Type of entity (park, ride, user, etc.)") + entity_id = models.PositiveIntegerField(null=True, blank=True, help_text="ID of the related entity") + entity_preview = models.JSONField(default=dict, blank=True, help_text="Preview data for the entity") + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True, blank=True) # Assignment and timing assigned_to = models.ForeignKey( @@ -537,14 +476,11 @@ class ModerationQueue(StateMachineMixin, TrackedModel): on_delete=models.SET_NULL, null=True, blank=True, - related_name='assigned_queue_items', + related_name="assigned_queue_items", help_text="Moderator assigned to this item", ) - assigned_at = models.DateTimeField( - null=True, blank=True, help_text="When this item was assigned" - ) - estimated_review_time = models.PositiveIntegerField( - default=30, help_text="Estimated time in minutes") + assigned_at = models.DateTimeField(null=True, blank=True, help_text="When this item was assigned") + estimated_review_time = models.PositiveIntegerField(default=30, help_text="Estimated time in minutes") # Metadata flagged_by = models.ForeignKey( @@ -552,11 +488,10 @@ class ModerationQueue(StateMachineMixin, TrackedModel): on_delete=models.SET_NULL, null=True, blank=True, - related_name='flagged_queue_items', + related_name="flagged_queue_items", help_text="User who flagged this item", ) - tags = models.JSONField(default=list, blank=True, - help_text="Tags for categorization") + tags = models.JSONField(default=list, blank=True, help_text="Tags for categorization") # Related objects related_report = models.ForeignKey( @@ -564,30 +499,26 @@ class ModerationQueue(StateMachineMixin, TrackedModel): on_delete=models.CASCADE, null=True, blank=True, - related_name='queue_items', + related_name="queue_items", help_text="Related moderation report", ) # Timestamps - created_at = models.DateTimeField( - auto_now_add=True, help_text="When this item was created" - ) - updated_at = models.DateTimeField( - auto_now=True, help_text="When this item was last updated" - ) + created_at = models.DateTimeField(auto_now_add=True, help_text="When this item was created") + updated_at = models.DateTimeField(auto_now=True, help_text="When this item was last updated") class Meta(TrackedModel.Meta): verbose_name = "Moderation Queue Item" verbose_name_plural = "Moderation Queue Items" - ordering = ['priority', 'created_at'] + ordering = ["priority", "created_at"] indexes = [ - models.Index(fields=['status', 'priority']), - models.Index(fields=['assigned_to']), - models.Index(fields=['created_at']), + models.Index(fields=["status", "priority"]), + models.Index(fields=["assigned_to"]), + models.Index(fields=["created_at"]), ] def __str__(self): - return f"{self.get_item_type_display()}: {self.title}" # type: ignore + return f"{self.get_item_type_display()}: {self.title}" # type: ignore @pghistory.track() @@ -600,36 +531,28 @@ class ModerationAction(TrackedModel): """ # Action details - action_type = RichChoiceField( - choice_group="moderation_action_types", - domain="moderation", - max_length=50 - ) + action_type = RichChoiceField(choice_group="moderation_action_types", domain="moderation", max_length=50) reason = models.CharField(max_length=200, help_text="Brief reason for the action") details = models.TextField(help_text="Detailed explanation of the action") # Duration (for temporary actions) duration_hours = models.PositiveIntegerField( - null=True, - blank=True, - help_text="Duration in hours for temporary actions" + null=True, blank=True, help_text="Duration in hours for temporary actions" ) - expires_at = models.DateTimeField( - null=True, blank=True, help_text="When this action expires") - is_active = models.BooleanField( - default=True, help_text="Whether this action is currently active") + expires_at = models.DateTimeField(null=True, blank=True, help_text="When this action expires") + is_active = models.BooleanField(default=True, help_text="Whether this action is currently active") # Users involved moderator = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name='moderation_actions_taken', + related_name="moderation_actions_taken", help_text="Moderator who took this action", ) target_user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name='moderation_actions_received', + related_name="moderation_actions_received", help_text="User this action was taken against", ) @@ -639,31 +562,27 @@ class ModerationAction(TrackedModel): on_delete=models.SET_NULL, null=True, blank=True, - related_name='actions_taken', + related_name="actions_taken", help_text="Related moderation report", ) # Timestamps - created_at = models.DateTimeField( - auto_now_add=True, help_text="When this action was created" - ) - updated_at = models.DateTimeField( - auto_now=True, help_text="When this action was last updated" - ) + created_at = models.DateTimeField(auto_now_add=True, help_text="When this action was created") + updated_at = models.DateTimeField(auto_now=True, help_text="When this action was last updated") class Meta(TrackedModel.Meta): verbose_name = "Moderation Action" verbose_name_plural = "Moderation Actions" - ordering = ['-created_at'] + ordering = ["-created_at"] indexes = [ - models.Index(fields=['target_user', 'is_active']), - models.Index(fields=['moderator']), - models.Index(fields=['expires_at']), - models.Index(fields=['created_at']), + models.Index(fields=["target_user", "is_active"]), + models.Index(fields=["moderator"]), + models.Index(fields=["expires_at"]), + models.Index(fields=["created_at"]), ] def __str__(self): - return f"{self.get_action_type_display()} against {self.target_user.username} by {self.moderator.username}" # type: ignore + return f"{self.get_action_type_display()} against {self.target_user.username} by {self.moderator.username}" # type: ignore def save(self, *args, **kwargs): # Set expiration time if duration is provided @@ -684,85 +603,56 @@ class BulkOperation(StateMachineMixin, TrackedModel): state_field_name = "status" # Operation details - operation_type = RichChoiceField( - choice_group="bulk_operation_types", - domain="moderation", - max_length=50 - ) - status = RichFSMField( - choice_group="bulk_operation_statuses", - domain="moderation", - max_length=20, - default='PENDING' - ) - priority = RichChoiceField( - choice_group="priority_levels", - domain="moderation", - max_length=10, - default='MEDIUM' - ) + operation_type = RichChoiceField(choice_group="bulk_operation_types", domain="moderation", max_length=50) + status = RichFSMField(choice_group="bulk_operation_statuses", domain="moderation", max_length=20, default="PENDING") + priority = RichChoiceField(choice_group="priority_levels", domain="moderation", max_length=10, default="MEDIUM") description = models.TextField(help_text="Description of what this operation does") # Operation parameters and results - parameters = models.JSONField( - default=dict, help_text="Parameters for the operation") - results = models.JSONField(default=dict, blank=True, - help_text="Results and output from the operation") + parameters = models.JSONField(default=dict, help_text="Parameters for the operation") + results = models.JSONField(default=dict, blank=True, help_text="Results and output from the operation") # Progress tracking - total_items = models.PositiveIntegerField( - default=0, help_text="Total number of items to process") - processed_items = models.PositiveIntegerField( - default=0, help_text="Number of items processed") - failed_items = models.PositiveIntegerField( - default=0, help_text="Number of items that failed") + total_items = models.PositiveIntegerField(default=0, help_text="Total number of items to process") + processed_items = models.PositiveIntegerField(default=0, help_text="Number of items processed") + failed_items = models.PositiveIntegerField(default=0, help_text="Number of items that failed") # Timing estimated_duration_minutes = models.PositiveIntegerField( - null=True, - blank=True, - help_text="Estimated duration in minutes" + null=True, blank=True, help_text="Estimated duration in minutes" ) - schedule_for = models.DateTimeField( - null=True, blank=True, help_text="When to run this operation") + schedule_for = models.DateTimeField(null=True, blank=True, help_text="When to run this operation") # Control - can_cancel = models.BooleanField( - default=True, help_text="Whether this operation can be cancelled") + can_cancel = models.BooleanField(default=True, help_text="Whether this operation can be cancelled") # User who created the operation created_by = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name='bulk_operations_created', + related_name="bulk_operations_created", help_text="User who created this operation", ) # Timestamps created_at = models.DateTimeField(auto_now_add=True) - started_at = models.DateTimeField( - null=True, blank=True, help_text="When this operation started" - ) - completed_at = models.DateTimeField( - null=True, blank=True, help_text="When this operation completed" - ) - updated_at = models.DateTimeField( - auto_now=True, help_text="When this operation was last updated" - ) + started_at = models.DateTimeField(null=True, blank=True, help_text="When this operation started") + completed_at = models.DateTimeField(null=True, blank=True, help_text="When this operation completed") + updated_at = models.DateTimeField(auto_now=True, help_text="When this operation was last updated") class Meta(TrackedModel.Meta): verbose_name = "Bulk Operation" verbose_name_plural = "Bulk Operations" - ordering = ['-created_at'] + ordering = ["-created_at"] indexes = [ - models.Index(fields=['status', 'priority']), - models.Index(fields=['created_by']), - models.Index(fields=['schedule_for']), - models.Index(fields=['created_at']), + models.Index(fields=["status", "priority"]), + models.Index(fields=["created_by"]), + models.Index(fields=["schedule_for"]), + models.Index(fields=["created_at"]), ] def __str__(self): - return f"{self.get_operation_type_display()}: {self.description[:50]}" # type: ignore + return f"{self.get_operation_type_display()}: {self.description[:50]}" # type: ignore @property def progress_percentage(self): @@ -792,28 +682,21 @@ class PhotoSubmission(StateMachineMixin, TrackedModel): on_delete=models.CASCADE, help_text="Type of object this photo is for", ) - object_id = models.PositiveIntegerField( - help_text="ID of object this photo is for" - ) + object_id = models.PositiveIntegerField(help_text="ID of object this photo is for") content_object = GenericForeignKey("content_type", "object_id") # The photo itself photo = models.ForeignKey( - 'django_cloudflareimages_toolkit.CloudflareImage', + "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.CASCADE, - help_text="Photo submission stored on Cloudflare Images" + help_text="Photo submission stored on Cloudflare Images", ) caption = models.CharField(max_length=255, blank=True, help_text="Photo caption") - date_taken = models.DateField( - null=True, blank=True, help_text="Date the photo was taken" - ) + date_taken = models.DateField(null=True, blank=True, help_text="Date the photo was taken") # Metadata status = RichFSMField( - choice_group="photo_submission_statuses", - domain="moderation", - max_length=20, - default="PENDING" + choice_group="photo_submission_statuses", domain="moderation", max_length=20, default="PENDING" ) created_at = models.DateTimeField(auto_now_add=True) @@ -826,9 +709,7 @@ class PhotoSubmission(StateMachineMixin, TrackedModel): related_name="handled_photos", help_text="Moderator who handled this submission", ) - handled_at = models.DateTimeField( - null=True, blank=True, help_text="When this submission was handled" - ) + handled_at = models.DateTimeField(null=True, blank=True, help_text="When this submission was handled") notes = models.TextField( blank=True, help_text="Notes from the moderator about this photo submission", @@ -843,9 +724,7 @@ class PhotoSubmission(StateMachineMixin, TrackedModel): related_name="claimed_photo_submissions", help_text="Moderator who has claimed this submission for review", ) - claimed_at = models.DateTimeField( - null=True, blank=True, help_text="When this submission was claimed" - ) + claimed_at = models.DateTimeField(null=True, blank=True, help_text="When this submission was claimed") class Meta(TrackedModel.Meta): verbose_name = "Photo Submission" @@ -873,9 +752,7 @@ class PhotoSubmission(StateMachineMixin, TrackedModel): from django.core.exceptions import ValidationError if self.status != "PENDING": - raise ValidationError( - f"Cannot claim submission: current status is {self.status}, expected PENDING" - ) + raise ValidationError(f"Cannot claim submission: current status is {self.status}, expected PENDING") self.transition_to_claimed(user=user) self.claimed_by = user @@ -896,9 +773,7 @@ class PhotoSubmission(StateMachineMixin, TrackedModel): from django.core.exceptions import ValidationError if self.status != "CLAIMED": - raise ValidationError( - f"Cannot unclaim submission: current status is {self.status}, expected CLAIMED" - ) + raise ValidationError(f"Cannot unclaim submission: current status is {self.status}, expected CLAIMED") # Set status directly (not via FSM transition to avoid cycle) # This is intentional - the unclaim action is a special "rollback" operation diff --git a/backend/apps/moderation/permissions.py b/backend/apps/moderation/permissions.py index 0f83a6fa..23db0bb1 100644 --- a/backend/apps/moderation/permissions.py +++ b/backend/apps/moderation/permissions.py @@ -88,7 +88,7 @@ class PermissionGuardAdapter: return False # Check object permission if available - if hasattr(permission, "has_object_permission"): + if hasattr(permission, "has_object_permission"): # noqa: SIM102 if not permission.has_object_permission(mock_request, None, instance): self._last_error_code = "OBJECT_PERMISSION_DENIED" return False @@ -318,9 +318,7 @@ class CanAssignModerationTasks(GuardMixin, permissions.BasePermission): # Moderators can only assign to themselves if user_role == "MODERATOR": # Check if they're trying to assign to themselves - assignee_id = request.data.get("moderator_id") or request.data.get( - "assigned_to" - ) + assignee_id = request.data.get("moderator_id") or request.data.get("assigned_to") if assignee_id: return str(assignee_id) == str(request.user.id) return True @@ -362,7 +360,7 @@ class CanPerformBulkOperations(GuardMixin, permissions.BasePermission): # Add any admin-specific restrictions for bulk operations here # For example, admins might not be able to perform certain destructive operations operation_type = getattr(obj, "operation_type", None) - if operation_type in ["DELETE_USERS", "PURGE_DATA"]: + if operation_type in ["DELETE_USERS", "PURGE_DATA"]: # noqa: SIM103 return False # Only superusers can perform these operations return True diff --git a/backend/apps/moderation/selectors.py b/backend/apps/moderation/selectors.py index 18cb08e3..cab642b5 100644 --- a/backend/apps/moderation/selectors.py +++ b/backend/apps/moderation/selectors.py @@ -14,9 +14,7 @@ from django.utils import timezone from .models import EditSubmission -def pending_submissions_for_review( - *, content_type: str | None = None, limit: int = 50 -) -> QuerySet[EditSubmission]: +def pending_submissions_for_review(*, content_type: str | None = None, limit: int = 50) -> QuerySet[EditSubmission]: """ Get pending submissions that need moderation review. @@ -39,9 +37,7 @@ def pending_submissions_for_review( return queryset.order_by("created_at")[:limit] -def submissions_by_user( - *, user_id: int, status: str | None = None -) -> QuerySet[EditSubmission]: +def submissions_by_user(*, user_id: int, status: str | None = None) -> QuerySet[EditSubmission]: """ Get submissions created by a specific user. @@ -52,9 +48,7 @@ def submissions_by_user( Returns: QuerySet of user's submissions """ - queryset = EditSubmission.objects.filter(user_id=user_id).select_related( - "content_type", "handled_by" - ) + queryset = EditSubmission.objects.filter(user_id=user_id).select_related("content_type", "handled_by") if status: queryset = queryset.filter(status=status) @@ -62,9 +56,7 @@ def submissions_by_user( return queryset.order_by("-created_at") -def submissions_handled_by_moderator( - *, moderator_id: int, days: int = 30 -) -> QuerySet[EditSubmission]: +def submissions_handled_by_moderator(*, moderator_id: int, days: int = 30) -> QuerySet[EditSubmission]: """ Get submissions handled by a specific moderator in the last N days. @@ -78,9 +70,7 @@ def submissions_handled_by_moderator( cutoff_date = timezone.now() - timedelta(days=days) return ( - EditSubmission.objects.filter( - handled_by_id=moderator_id, handled_at__gte=cutoff_date - ) + EditSubmission.objects.filter(handled_by_id=moderator_id, handled_at__gte=cutoff_date) .select_related("user", "content_type") .order_by("-handled_at") ) @@ -105,9 +95,7 @@ def recent_submissions(*, days: int = 7) -> QuerySet[EditSubmission]: ) -def submissions_by_content_type( - *, content_type: str, status: str | None = None -) -> QuerySet[EditSubmission]: +def submissions_by_content_type(*, content_type: str, status: str | None = None) -> QuerySet[EditSubmission]: """ Get submissions for a specific content type. @@ -118,9 +106,9 @@ def submissions_by_content_type( Returns: QuerySet of submissions for the content type """ - queryset = EditSubmission.objects.filter( - content_type__model=content_type.lower() - ).select_related("user", "handled_by") + queryset = EditSubmission.objects.filter(content_type__model=content_type.lower()).select_related( + "user", "handled_by" + ) if status: queryset = queryset.filter(status=status) @@ -136,12 +124,8 @@ def moderation_queue_summary() -> dict[str, Any]: Dictionary containing queue statistics """ pending_count = EditSubmission.objects.filter(status="PENDING").count() - approved_today = EditSubmission.objects.filter( - status="APPROVED", handled_at__date=timezone.now().date() - ).count() - rejected_today = EditSubmission.objects.filter( - status="REJECTED", handled_at__date=timezone.now().date() - ).count() + approved_today = EditSubmission.objects.filter(status="APPROVED", handled_at__date=timezone.now().date()).count() + rejected_today = EditSubmission.objects.filter(status="REJECTED", handled_at__date=timezone.now().date()).count() # Submissions by content type submissions_by_type = ( @@ -159,9 +143,7 @@ def moderation_queue_summary() -> dict[str, Any]: } -def moderation_statistics_summary( - *, days: int = 30, moderator: User | None = None -) -> dict[str, Any]: +def moderation_statistics_summary(*, days: int = 30, moderator: User | None = None) -> dict[str, Any]: """ Get comprehensive moderation statistics for a time period. @@ -189,8 +171,7 @@ def moderation_statistics_summary( handled_queryset.exclude(handled_at__isnull=True) .annotate( response_hours=ExpressionWrapper( - Extract(F('handled_at') - F('created_at'), 'epoch') / 3600.0, - output_field=FloatField() + Extract(F("handled_at") - F("created_at"), "epoch") / 3600.0, output_field=FloatField() ) ) .values_list("response_hours", flat=True) diff --git a/backend/apps/moderation/serializers.py b/backend/apps/moderation/serializers.py index 8f97a425..fb49a627 100644 --- a/backend/apps/moderation/serializers.py +++ b/backend/apps/moderation/serializers.py @@ -68,9 +68,7 @@ class EditSubmissionSerializer(serializers.ModelSerializer): submitted_by = UserBasicSerializer(source="user", read_only=True) claimed_by = UserBasicSerializer(read_only=True) - content_type_name = serializers.CharField( - source="content_type.model", read_only=True - ) + content_type_name = serializers.CharField(source="content_type.model", read_only=True) # UI Metadata fields for Nuxt rendering status_color = serializers.SerializerMethodField() @@ -117,10 +115,10 @@ class EditSubmissionSerializer(serializers.ModelSerializer): def get_status_color(self, obj) -> str: """Return hex color based on status for UI badges.""" colors = { - "PENDING": "#f59e0b", # Amber - "CLAIMED": "#3b82f6", # Blue - "APPROVED": "#10b981", # Emerald - "REJECTED": "#ef4444", # Red + "PENDING": "#f59e0b", # Amber + "CLAIMED": "#3b82f6", # Blue + "APPROVED": "#10b981", # Emerald + "REJECTED": "#ef4444", # Red "ESCALATED": "#8b5cf6", # Violet } return colors.get(obj.status, "#6b7280") # Default gray @@ -154,15 +152,9 @@ class EditSubmissionSerializer(serializers.ModelSerializer): class EditSubmissionListSerializer(serializers.ModelSerializer): """Optimized serializer for EditSubmission lists.""" - submitted_by_username = serializers.CharField( - source="user.username", read_only=True - ) - claimed_by_username = serializers.CharField( - source="claimed_by.username", read_only=True, allow_null=True - ) - content_type_name = serializers.CharField( - source="content_type.model", read_only=True - ) + submitted_by_username = serializers.CharField(source="user.username", read_only=True) + claimed_by_username = serializers.CharField(source="claimed_by.username", read_only=True, allow_null=True) + content_type_name = serializers.CharField(source="content_type.model", read_only=True) status_color = serializers.SerializerMethodField() status_icon = serializers.SerializerMethodField() @@ -218,13 +210,9 @@ class ModerationReportSerializer(serializers.ModelSerializer): # Computed fields is_overdue = serializers.SerializerMethodField() time_since_created = serializers.SerializerMethodField() - priority_display = serializers.CharField( - source="get_priority_display", read_only=True - ) + priority_display = serializers.CharField(source="get_priority_display", read_only=True) status_display = serializers.CharField(source="get_status_display", read_only=True) - report_type_display = serializers.CharField( - source="get_report_type_display", read_only=True - ) + report_type_display = serializers.CharField(source="get_report_type_display", read_only=True) class Meta: model = ModerationReport @@ -318,17 +306,13 @@ class CreateModerationReportSerializer(serializers.ModelSerializer): valid_entity_types = ["park", "ride", "review", "photo", "user", "comment"] if attrs["reported_entity_type"] not in valid_entity_types: raise serializers.ValidationError( - { - "reported_entity_type": f'Must be one of: {", ".join(valid_entity_types)}' - } + {"reported_entity_type": f'Must be one of: {", ".join(valid_entity_types)}'} ) # Validate evidence URLs evidence_urls = attrs.get("evidence_urls", []) if not isinstance(evidence_urls, list): - raise serializers.ValidationError( - {"evidence_urls": "Must be a list of URLs"} - ) + raise serializers.ValidationError({"evidence_urls": "Must be a list of URLs"}) return attrs @@ -351,9 +335,7 @@ class CreateModerationReportSerializer(serializers.ModelSerializer): if entity_type in app_label_map: try: - content_type = ContentType.objects.get( - app_label=app_label_map[entity_type], model=entity_type - ) + content_type = ContentType.objects.get(app_label=app_label_map[entity_type], model=entity_type) validated_data["content_type"] = content_type except ContentType.DoesNotExist: pass @@ -377,9 +359,7 @@ class UpdateModerationReportSerializer(serializers.ModelSerializer): def validate_status(self, value): """Validate status transitions.""" if self.instance and self.instance.status == "RESOLVED" and value != "RESOLVED": - raise serializers.ValidationError( - "Cannot change status of resolved report" - ) + raise serializers.ValidationError("Cannot change status of resolved report") return value def update(self, instance, validated_data): @@ -462,13 +442,9 @@ class ModerationQueueSerializer(serializers.ModelSerializer): def get_estimated_completion(self, obj) -> str: """Estimated completion time.""" if obj.assigned_at: - completion_time = obj.assigned_at + timedelta( - minutes=obj.estimated_review_time - ) + completion_time = obj.assigned_at + timedelta(minutes=obj.estimated_review_time) else: - completion_time = timezone.now() + timedelta( - minutes=obj.estimated_review_time - ) + completion_time = timezone.now() + timedelta(minutes=obj.estimated_review_time) return completion_time.isoformat() @@ -484,12 +460,10 @@ class AssignQueueItemSerializer(serializers.Serializer): user = User.objects.get(id=value) user_role = getattr(user, "role", "USER") if user_role not in ["MODERATOR", "ADMIN", "SUPERUSER"]: - raise serializers.ValidationError( - "User must be a moderator, admin, or superuser" - ) + raise serializers.ValidationError("User must be a moderator, admin, or superuser") return value except User.DoesNotExist: - raise serializers.ValidationError("Moderator not found") + raise serializers.ValidationError("Moderator not found") from None class CompleteQueueItemSerializer(serializers.Serializer): @@ -514,9 +488,7 @@ class CompleteQueueItemSerializer(serializers.Serializer): # Require notes for certain actions if action in ["USER_WARNING", "USER_SUSPENDED", "USER_BANNED"] and not notes: - raise serializers.ValidationError( - {"notes": f"Notes are required for action: {action}"} - ) + raise serializers.ValidationError({"notes": f"Notes are required for action: {action}"}) return attrs @@ -536,9 +508,7 @@ class ModerationActionSerializer(serializers.ModelSerializer): # Computed fields is_expired = serializers.SerializerMethodField() time_remaining = serializers.SerializerMethodField() - action_type_display = serializers.CharField( - source="get_action_type_display", read_only=True - ) + action_type_display = serializers.CharField(source="get_action_type_display", read_only=True) class Meta: model = ModerationAction @@ -620,7 +590,7 @@ class CreateModerationActionSerializer(serializers.ModelSerializer): User.objects.get(id=value) return value except User.DoesNotExist: - raise serializers.ValidationError("Target user not found") + raise serializers.ValidationError("Target user not found") from None def validate_related_report_id(self, value): """Validate related report exists.""" @@ -629,7 +599,7 @@ class CreateModerationActionSerializer(serializers.ModelSerializer): ModerationReport.objects.get(id=value) return value except ModerationReport.DoesNotExist: - raise serializers.ValidationError("Related report not found") + raise serializers.ValidationError("Related report not found") from None return value def validate(self, attrs): @@ -640,17 +610,11 @@ class CreateModerationActionSerializer(serializers.ModelSerializer): # Validate duration for temporary actions temporary_actions = ["USER_SUSPENSION", "CONTENT_RESTRICTION"] if action_type in temporary_actions and not duration_hours: - raise serializers.ValidationError( - {"duration_hours": f"Duration is required for {action_type}"} - ) + raise serializers.ValidationError({"duration_hours": f"Duration is required for {action_type}"}) # Validate duration range - if duration_hours and ( - duration_hours < 1 or duration_hours > 8760 - ): # 1 hour to 1 year - raise serializers.ValidationError( - {"duration_hours": "Duration must be between 1 and 8760 hours (1 year)"} - ) + if duration_hours and (duration_hours < 1 or duration_hours > 8760): # 1 hour to 1 year + raise serializers.ValidationError({"duration_hours": "Duration must be between 1 and 8760 hours (1 year)"}) return attrs @@ -668,9 +632,7 @@ class CreateModerationActionSerializer(serializers.ModelSerializer): # Set expiration time for temporary actions if validated_data.get("duration_hours"): - validated_data["expires_at"] = timezone.now() + timedelta( - hours=validated_data["duration_hours"] - ) + validated_data["expires_at"] = timezone.now() + timedelta(hours=validated_data["duration_hours"]) return super().create(validated_data) @@ -688,9 +650,7 @@ class BulkOperationSerializer(serializers.ModelSerializer): # Computed fields progress_percentage = serializers.SerializerMethodField() estimated_completion = serializers.SerializerMethodField() - operation_type_display = serializers.CharField( - source="get_operation_type_display", read_only=True - ) + operation_type_display = serializers.CharField(source="get_operation_type_display", read_only=True) status_display = serializers.CharField(source="get_status_display", read_only=True) class Meta: @@ -741,17 +701,13 @@ class BulkOperationSerializer(serializers.ModelSerializer): if obj.status == "COMPLETED": return obj.completed_at.isoformat() if obj.completed_at else None - if obj.status == "RUNNING" and obj.started_at: + if obj.status == "RUNNING" and obj.started_at: # noqa: SIM102 # Calculate based on current progress if obj.processed_items > 0: elapsed_minutes = (timezone.now() - obj.started_at).total_seconds() / 60 rate = obj.processed_items / elapsed_minutes remaining_items = obj.total_items - obj.processed_items - remaining_minutes = ( - remaining_items / rate - if rate > 0 - else obj.estimated_duration_minutes - ) + remaining_minutes = remaining_items / rate if rate > 0 else obj.estimated_duration_minutes completion_time = timezone.now() + timedelta(minutes=remaining_minutes) return completion_time.isoformat() @@ -759,9 +715,7 @@ class BulkOperationSerializer(serializers.ModelSerializer): if obj.schedule_for: return obj.schedule_for.isoformat() elif obj.estimated_duration_minutes: - completion_time = timezone.now() + timedelta( - minutes=obj.estimated_duration_minutes - ) + completion_time = timezone.now() + timedelta(minutes=obj.estimated_duration_minutes) return completion_time.isoformat() return None @@ -801,9 +755,7 @@ class CreateBulkOperationSerializer(serializers.ModelSerializer): if operation_type in required_params: for param in required_params[operation_type]: if param not in value: - raise serializers.ValidationError( - f'Parameter "{param}" is required for {operation_type}' - ) + raise serializers.ValidationError(f'Parameter "{param}" is required for {operation_type}') return value @@ -902,27 +854,28 @@ class UserModerationProfileSerializer(serializers.Serializer): class StateLogSerializer(serializers.ModelSerializer): """Serializer for FSM transition history.""" - user = serializers.CharField(source='by.username', read_only=True) - model = serializers.CharField(source='content_type.model', read_only=True) - from_state = serializers.CharField(source='source_state', read_only=True) - to_state = serializers.CharField(source='state', read_only=True) - reason = serializers.CharField(source='description', read_only=True) + user = serializers.CharField(source="by.username", read_only=True) + model = serializers.CharField(source="content_type.model", read_only=True) + from_state = serializers.CharField(source="source_state", read_only=True) + to_state = serializers.CharField(source="state", read_only=True) + reason = serializers.CharField(source="description", read_only=True) class Meta: from django_fsm_log.models import StateLog + model = StateLog fields = [ - 'id', - 'timestamp', - 'model', - 'object_id', - 'state', - 'from_state', - 'to_state', - 'transition', - 'user', - 'description', - 'reason', + "id", + "timestamp", + "model", + "object_id", + "state", + "from_state", + "to_state", + "transition", + "user", + "description", + "reason", ] read_only_fields = fields @@ -931,9 +884,7 @@ class PhotoSubmissionSerializer(serializers.ModelSerializer): """Serializer for PhotoSubmission.""" submitted_by = UserBasicSerializer(source="user", read_only=True) - content_type_name = serializers.CharField( - source="content_type.model", read_only=True - ) + content_type_name = serializers.CharField(source="content_type.model", read_only=True) photo_url = serializers.SerializerMethodField() # UI Metadata @@ -1012,4 +963,3 @@ class PhotoSubmissionSerializer(serializers.ModelSerializer): else: minutes = diff.seconds // 60 return f"{minutes} minutes ago" - diff --git a/backend/apps/moderation/services.py b/backend/apps/moderation/services.py index 46778b0f..bead54d6 100644 --- a/backend/apps/moderation/services.py +++ b/backend/apps/moderation/services.py @@ -19,9 +19,7 @@ class ModerationService: """Service for handling content moderation workflows.""" @staticmethod - def approve_submission( - *, submission_id: int, moderator: User, notes: str | None = None - ) -> object | None: + def approve_submission(*, submission_id: int, moderator: User, notes: str | None = None) -> object | None: """ Approve a content submission and apply changes. @@ -39,9 +37,7 @@ class ModerationService: ValueError: If submission cannot be processed """ with transaction.atomic(): - submission = EditSubmission.objects.select_for_update().get( - id=submission_id - ) + submission = EditSubmission.objects.select_for_update().get(id=submission_id) if submission.status != "PENDING": raise ValueError(f"Submission {submission_id} is not pending approval") @@ -75,9 +71,7 @@ class ModerationService: raise @staticmethod - def reject_submission( - *, submission_id: int, moderator: User, reason: str - ) -> EditSubmission: + def reject_submission(*, submission_id: int, moderator: User, reason: str) -> EditSubmission: """ Reject a content submission. @@ -94,9 +88,7 @@ class ModerationService: ValueError: If submission cannot be rejected """ with transaction.atomic(): - submission = EditSubmission.objects.select_for_update().get( - id=submission_id - ) + submission = EditSubmission.objects.select_for_update().get(id=submission_id) if submission.status != "PENDING": raise ValueError(f"Submission {submission_id} is not pending review") @@ -175,9 +167,7 @@ class ModerationService: ValueError: If submission cannot be modified """ with transaction.atomic(): - submission = EditSubmission.objects.select_for_update().get( - id=submission_id - ) + submission = EditSubmission.objects.select_for_update().get(id=submission_id) if submission.status != "PENDING": raise ValueError(f"Submission {submission_id} is not pending review") @@ -220,9 +210,7 @@ class ModerationService: return pending_submissions_for_review(content_type=content_type, limit=limit) @staticmethod - def get_submission_statistics( - *, days: int = 30, moderator: User | None = None - ) -> dict[str, Any]: + def get_submission_statistics(*, days: int = 30, moderator: User | None = None) -> dict[str, Any]: """ Get moderation statistics for a time period. @@ -248,7 +236,7 @@ class ModerationService: Returns: True if user is MODERATOR, ADMIN, or SUPERUSER """ - return user.role in ['MODERATOR', 'ADMIN', 'SUPERUSER'] + return user.role in ["MODERATOR", "ADMIN", "SUPERUSER"] @staticmethod def create_edit_submission_with_queue( @@ -297,33 +285,32 @@ class ModerationService: try: created_object = submission.approve(submitter) return { - 'submission': submission, - 'status': 'auto_approved', - 'created_object': created_object, - 'queue_item': None, - 'message': 'Submission auto-approved for moderator' + "submission": submission, + "status": "auto_approved", + "created_object": created_object, + "queue_item": None, + "message": "Submission auto-approved for moderator", } except Exception as e: return { - 'submission': submission, - 'status': 'failed', - 'created_object': None, - 'queue_item': None, - 'message': f'Auto-approval failed: {str(e)}' + "submission": submission, + "status": "failed", + "created_object": None, + "queue_item": None, + "message": f"Auto-approval failed: {str(e)}", } else: # Create queue item for regular users queue_item = ModerationService._create_queue_item_for_submission( - submission=submission, - submitter=submitter + submission=submission, submitter=submitter ) return { - 'submission': submission, - 'status': 'queued', - 'created_object': None, - 'queue_item': queue_item, - 'message': 'Submission added to moderation queue' + "submission": submission, + "status": "queued", + "created_object": None, + "queue_item": queue_item, + "message": "Submission added to moderation queue", } @staticmethod @@ -370,36 +357,33 @@ class ModerationService: try: submission.auto_approve() return { - 'submission': submission, - 'status': 'auto_approved', - 'queue_item': None, - 'message': 'Photo submission auto-approved for moderator' + "submission": submission, + "status": "auto_approved", + "queue_item": None, + "message": "Photo submission auto-approved for moderator", } except Exception as e: return { - 'submission': submission, - 'status': 'failed', - 'queue_item': None, - 'message': f'Auto-approval failed: {str(e)}' + "submission": submission, + "status": "failed", + "queue_item": None, + "message": f"Auto-approval failed: {str(e)}", } else: # Create queue item for regular users queue_item = ModerationService._create_queue_item_for_photo_submission( - submission=submission, - submitter=submitter + submission=submission, submitter=submitter ) return { - 'submission': submission, - 'status': 'queued', - 'queue_item': queue_item, - 'message': 'Photo submission added to moderation queue' + "submission": submission, + "status": "queued", + "queue_item": queue_item, + "message": "Photo submission added to moderation queue", } @staticmethod - def _create_queue_item_for_submission( - *, submission: EditSubmission, submitter: User - ) -> ModerationQueue: + def _create_queue_item_for_submission(*, submission: EditSubmission, submitter: User) -> ModerationQueue: """ Create a moderation queue item for an edit submission. @@ -417,13 +401,13 @@ class ModerationService: # Create preview data entity_preview = { - 'submission_type': submission.submission_type, - 'changes_count': len(submission.changes) if submission.changes else 0, - 'reason': submission.reason[:100] if submission.reason else "", + "submission_type": submission.submission_type, + "changes_count": len(submission.changes) if submission.changes else 0, + "reason": submission.reason[:100] if submission.reason else "", } if submission.content_object: - entity_preview['object_name'] = str(submission.content_object) + entity_preview["object_name"] = str(submission.content_object) # Determine title and description action = "creation" if submission.submission_type == "CREATE" else "edit" @@ -435,7 +419,7 @@ class ModerationService: # Create queue item queue_item = ModerationQueue( - item_type='CONTENT_REVIEW', + item_type="CONTENT_REVIEW", title=title, description=description, entity_type=entity_type, @@ -443,9 +427,9 @@ class ModerationService: entity_preview=entity_preview, content_type=content_type, flagged_by=submitter, - priority='MEDIUM', + priority="MEDIUM", estimated_review_time=15, # 15 minutes default - tags=['edit_submission', submission.submission_type.lower()], + tags=["edit_submission", submission.submission_type.lower()], ) queue_item.full_clean() @@ -454,9 +438,7 @@ class ModerationService: return queue_item @staticmethod - def _create_queue_item_for_photo_submission( - *, submission: PhotoSubmission, submitter: User - ) -> ModerationQueue: + def _create_queue_item_for_photo_submission(*, submission: PhotoSubmission, submitter: User) -> ModerationQueue: """ Create a moderation queue item for a photo submission. @@ -474,13 +456,13 @@ class ModerationService: # Create preview data entity_preview = { - 'caption': submission.caption, - 'date_taken': submission.date_taken.isoformat() if submission.date_taken else None, - 'photo_url': submission.photo.url if submission.photo else None, + "caption": submission.caption, + "date_taken": submission.date_taken.isoformat() if submission.date_taken else None, + "photo_url": submission.photo.url if submission.photo else None, } if submission.content_object: - entity_preview['object_name'] = str(submission.content_object) + entity_preview["object_name"] = str(submission.content_object) # Create title and description title = f"Photo submission for {entity_type} by {submitter.username}" @@ -490,7 +472,7 @@ class ModerationService: # Create queue item queue_item = ModerationQueue( - item_type='CONTENT_REVIEW', + item_type="CONTENT_REVIEW", title=title, description=description, entity_type=entity_type, @@ -498,9 +480,9 @@ class ModerationService: entity_preview=entity_preview, content_type=content_type, flagged_by=submitter, - priority='LOW', # Photos typically lower priority + priority="LOW", # Photos typically lower priority estimated_review_time=5, # 5 minutes default for photos - tags=['photo_submission'], + tags=["photo_submission"], ) queue_item.full_clean() @@ -525,11 +507,9 @@ class ModerationService: Dictionary with processing results """ with transaction.atomic(): - queue_item = ModerationQueue.objects.select_for_update().get( - id=queue_item_id - ) + queue_item = ModerationQueue.objects.select_for_update().get(id=queue_item_id) - if queue_item.status != 'PENDING': + if queue_item.status != "PENDING": raise ValueError(f"Queue item {queue_item_id} is not pending") # Transition queue item into an active state before processing @@ -542,7 +522,7 @@ class ModerationService: pass except AttributeError: # Fallback for environments without the generated transition method - queue_item.status = 'IN_PROGRESS' + queue_item.status = "IN_PROGRESS" moved_to_in_progress = True if moved_to_in_progress: @@ -554,116 +534,94 @@ class ModerationService: try: queue_item.transition_to_completed(user=moderator) except TransitionNotAllowed: - queue_item.status = 'COMPLETED' + queue_item.status = "COMPLETED" except AttributeError: - queue_item.status = 'COMPLETED' + queue_item.status = "COMPLETED" # Find related submission - if 'edit_submission' in queue_item.tags: + if "edit_submission" in queue_item.tags: # Find EditSubmission submissions = EditSubmission.objects.filter( user=queue_item.flagged_by, content_type=queue_item.content_type, object_id=queue_item.entity_id, - status='PENDING' - ).order_by('-created_at') + status="PENDING", + ).order_by("-created_at") if not submissions.exists(): - raise ValueError( - "No pending edit submission found for this queue item") + raise ValueError("No pending edit submission found for this queue item") submission = submissions.first() - if action == 'approve': + if action == "approve": try: created_object = submission.approve(moderator) # Use FSM transition for queue status _complete_queue_item() result = { - 'status': 'approved', - 'created_object': created_object, - 'message': 'Submission approved successfully' + "status": "approved", + "created_object": created_object, + "message": "Submission approved successfully", } except Exception as e: # Use FSM transition for queue status _complete_queue_item() - result = { - 'status': 'failed', - 'created_object': None, - 'message': f'Approval failed: {str(e)}' - } - elif action == 'reject': + result = {"status": "failed", "created_object": None, "message": f"Approval failed: {str(e)}"} + elif action == "reject": submission.reject(moderator, notes or "Rejected by moderator") # Use FSM transition for queue status _complete_queue_item() - result = { - 'status': 'rejected', - 'created_object': None, - 'message': 'Submission rejected' - } - elif action == 'escalate': + result = {"status": "rejected", "created_object": None, "message": "Submission rejected"} + elif action == "escalate": submission.escalate(moderator, notes or "Escalated for review") - queue_item.priority = 'HIGH' + queue_item.priority = "HIGH" # Keep status as PENDING for escalation - result = { - 'status': 'escalated', - 'created_object': None, - 'message': 'Submission escalated' - } + result = {"status": "escalated", "created_object": None, "message": "Submission escalated"} else: raise ValueError(f"Unknown action: {action}") - elif 'photo_submission' in queue_item.tags: + elif "photo_submission" in queue_item.tags: # Find PhotoSubmission submissions = PhotoSubmission.objects.filter( user=queue_item.flagged_by, content_type=queue_item.content_type, object_id=queue_item.entity_id, - status='PENDING' - ).order_by('-created_at') + status="PENDING", + ).order_by("-created_at") if not submissions.exists(): - raise ValueError( - "No pending photo submission found for this queue item") + raise ValueError("No pending photo submission found for this queue item") submission = submissions.first() - if action == 'approve': + if action == "approve": try: submission.approve(moderator, notes or "") # Use FSM transition for queue status _complete_queue_item() result = { - 'status': 'approved', - 'created_object': None, - 'message': 'Photo submission approved successfully' + "status": "approved", + "created_object": None, + "message": "Photo submission approved successfully", } except Exception as e: # Use FSM transition for queue status _complete_queue_item() result = { - 'status': 'failed', - 'created_object': None, - 'message': f'Photo approval failed: {str(e)}' + "status": "failed", + "created_object": None, + "message": f"Photo approval failed: {str(e)}", } - elif action == 'reject': + elif action == "reject": submission.reject(moderator, notes or "Rejected by moderator") # Use FSM transition for queue status _complete_queue_item() - result = { - 'status': 'rejected', - 'created_object': None, - 'message': 'Photo submission rejected' - } - elif action == 'escalate': + result = {"status": "rejected", "created_object": None, "message": "Photo submission rejected"} + elif action == "escalate": submission.escalate(moderator, notes or "Escalated for review") - queue_item.priority = 'HIGH' + queue_item.priority = "HIGH" # Keep status as PENDING for escalation - result = { - 'status': 'escalated', - 'created_object': None, - 'message': 'Photo submission escalated' - } + result = {"status": "escalated", "created_object": None, "message": "Photo submission escalated"} else: raise ValueError(f"Unknown action: {action}") else: @@ -678,5 +636,5 @@ class ModerationService: queue_item.full_clean() queue_item.save() - result['queue_item'] = queue_item + result["queue_item"] = queue_item return result diff --git a/backend/apps/moderation/signals.py b/backend/apps/moderation/signals.py index 91dddb22..5e8bb60b 100644 --- a/backend/apps/moderation/signals.py +++ b/backend/apps/moderation/signals.py @@ -48,12 +48,10 @@ def handle_submission_claimed(instance, source, target, user, context=None, **kw user: The user who claimed. context: Optional TransitionContext. """ - if target != 'CLAIMED': + if target != "CLAIMED": return - logger.info( - f"Submission {instance.pk} claimed by {user.username if user else 'system'}" - ) + logger.info(f"Submission {instance.pk} claimed by {user.username if user else 'system'}") # Broadcast for real-time dashboard updates _broadcast_submission_status_change(instance, source, target, user) @@ -72,12 +70,10 @@ def handle_submission_unclaimed(instance, source, target, user, context=None, ** user: The user who unclaimed. context: Optional TransitionContext. """ - if source != 'CLAIMED' or target != 'PENDING': + if source != "CLAIMED" or target != "PENDING": return - logger.info( - f"Submission {instance.pk} unclaimed by {user.username if user else 'system'}" - ) + logger.info(f"Submission {instance.pk} unclaimed by {user.username if user else 'system'}") # Broadcast for real-time dashboard updates _broadcast_submission_status_change(instance, source, target, user) @@ -96,25 +92,21 @@ def handle_submission_approved(instance, source, target, user, context=None, **k user: The user who approved. context: Optional TransitionContext. """ - if target != 'APPROVED': + if target != "APPROVED": return - logger.info( - f"Submission {instance.pk} approved by {user if user else 'system'}" - ) + logger.info(f"Submission {instance.pk} approved by {user if user else 'system'}") # Trigger notification (handled by NotificationCallback) # Invalidate cache (handled by CacheInvalidationCallback) # Apply the submission changes if applicable - if hasattr(instance, 'apply_changes'): + if hasattr(instance, "apply_changes"): try: instance.apply_changes() logger.info(f"Applied changes for submission {instance.pk}") except Exception as e: - logger.exception( - f"Failed to apply changes for submission {instance.pk}: {e}" - ) + logger.exception(f"Failed to apply changes for submission {instance.pk}: {e}") def handle_submission_rejected(instance, source, target, user, context=None, **kwargs): @@ -130,13 +122,12 @@ def handle_submission_rejected(instance, source, target, user, context=None, **k user: The user who rejected. context: Optional TransitionContext. """ - if target != 'REJECTED': + if target != "REJECTED": return - reason = context.extra_data.get('reason', '') if context else '' + reason = context.extra_data.get("reason", "") if context else "" logger.info( - f"Submission {instance.pk} rejected by {user if user else 'system'}" - f"{f': {reason}' if reason else ''}" + f"Submission {instance.pk} rejected by {user if user else 'system'}" f"{f': {reason}' if reason else ''}" ) @@ -153,13 +144,12 @@ def handle_submission_escalated(instance, source, target, user, context=None, ** user: The user who escalated. context: Optional TransitionContext. """ - if target != 'ESCALATED': + if target != "ESCALATED": return - reason = context.extra_data.get('reason', '') if context else '' + reason = context.extra_data.get("reason", "") if context else "" logger.info( - f"Submission {instance.pk} escalated by {user if user else 'system'}" - f"{f': {reason}' if reason else ''}" + f"Submission {instance.pk} escalated by {user if user else 'system'}" f"{f': {reason}' if reason else ''}" ) # Create escalation task if task system is available @@ -179,15 +169,13 @@ def handle_report_resolved(instance, source, target, user, context=None, **kwarg user: The user who resolved. context: Optional TransitionContext. """ - if target != 'RESOLVED': + if target != "RESOLVED": return - logger.info( - f"ModerationReport {instance.pk} resolved by {user if user else 'system'}" - ) + logger.info(f"ModerationReport {instance.pk} resolved by {user if user else 'system'}") # Update related queue items - _update_related_queue_items(instance, 'COMPLETED') + _update_related_queue_items(instance, "COMPLETED") def handle_queue_completed(instance, source, target, user, context=None, **kwargs): @@ -203,12 +191,10 @@ def handle_queue_completed(instance, source, target, user, context=None, **kwarg user: The user who completed. context: Optional TransitionContext. """ - if target != 'COMPLETED': + if target != "COMPLETED": return - logger.info( - f"ModerationQueue {instance.pk} completed by {user if user else 'system'}" - ) + logger.info(f"ModerationQueue {instance.pk} completed by {user if user else 'system'}") # Update moderation statistics _update_moderation_stats(instance, user) @@ -227,18 +213,17 @@ def handle_bulk_operation_status(instance, source, target, user, context=None, * user: The user who initiated the change. context: Optional TransitionContext. """ - logger.info( - f"BulkOperation {instance.pk} transitioned: {source} → {target}" - ) + logger.info(f"BulkOperation {instance.pk} transitioned: {source} → {target}") - if target == 'COMPLETED': + if target == "COMPLETED": _finalize_bulk_operation(instance, success=True) - elif target == 'FAILED': + elif target == "FAILED": _finalize_bulk_operation(instance, success=False) # Helper functions + def _create_escalation_task(instance, user, reason): """Create an escalation task for admin review.""" try: @@ -247,7 +232,7 @@ def _create_escalation_task(instance, user, reason): # Create a queue item for the escalated submission ModerationQueue.objects.create( content_object=instance, - priority='HIGH', + priority="HIGH", reason=f"Escalated: {reason}" if reason else "Escalated for review", created_by=user, ) @@ -287,10 +272,10 @@ def _update_moderation_stats(instance, user): try: # Update user's moderation count if they have a profile - profile = getattr(user, 'profile', None) - if profile and hasattr(profile, 'moderation_count'): + profile = getattr(user, "profile", None) + if profile and hasattr(profile, "moderation_count"): profile.moderation_count += 1 - profile.save(update_fields=['moderation_count']) + profile.save(update_fields=["moderation_count"]) logger.debug(f"Updated moderation count for {user}") except Exception as e: logger.warning(f"Failed to update moderation stats: {e}") @@ -302,7 +287,7 @@ def _finalize_bulk_operation(instance, success): from django.utils import timezone instance.completed_at = timezone.now() - instance.save(update_fields=['completed_at']) + instance.save(update_fields=["completed_at"]) if success: logger.info( @@ -312,8 +297,7 @@ def _finalize_bulk_operation(instance, success): ) else: logger.warning( - f"BulkOperation {instance.pk} failed: " - f"{getattr(instance, 'error_message', 'Unknown error')}" + f"BulkOperation {instance.pk} failed: " f"{getattr(instance, 'error_message', 'Unknown error')}" ) except Exception as e: logger.warning(f"Failed to finalize bulk operation: {e}") @@ -355,9 +339,9 @@ def _broadcast_submission_status_change(instance, source, target, user): } # Add claim information if available - if hasattr(instance, 'claimed_by') and instance.claimed_by: + if hasattr(instance, "claimed_by") and instance.claimed_by: payload["locked_by"] = instance.claimed_by.username - if hasattr(instance, 'claimed_at') and instance.claimed_at: + if hasattr(instance, "claimed_at") and instance.claimed_at: payload["locked_at"] = instance.claimed_at.isoformat() # Emit the signal for downstream notification handlers @@ -371,16 +355,14 @@ def _broadcast_submission_status_change(instance, source, target, user): payload=payload, ) - logger.debug( - f"Broadcast status change: {submission_type}#{instance.pk} " - f"{source} -> {target}" - ) + logger.debug(f"Broadcast status change: {submission_type}#{instance.pk} " f"{source} -> {target}") except Exception as e: logger.warning(f"Failed to broadcast submission status change: {e}") # Signal handler registration + def register_moderation_signal_handlers(): """ Register all moderation signal handlers. @@ -399,70 +381,31 @@ def register_moderation_signal_handlers(): ) # EditSubmission handlers - register_transition_handler( - EditSubmission, '*', 'APPROVED', - handle_submission_approved, stage='post' - ) - register_transition_handler( - EditSubmission, '*', 'REJECTED', - handle_submission_rejected, stage='post' - ) - register_transition_handler( - EditSubmission, '*', 'ESCALATED', - handle_submission_escalated, stage='post' - ) + register_transition_handler(EditSubmission, "*", "APPROVED", handle_submission_approved, stage="post") + register_transition_handler(EditSubmission, "*", "REJECTED", handle_submission_rejected, stage="post") + register_transition_handler(EditSubmission, "*", "ESCALATED", handle_submission_escalated, stage="post") # PhotoSubmission handlers - register_transition_handler( - PhotoSubmission, '*', 'APPROVED', - handle_submission_approved, stage='post' - ) - register_transition_handler( - PhotoSubmission, '*', 'REJECTED', - handle_submission_rejected, stage='post' - ) - register_transition_handler( - PhotoSubmission, '*', 'ESCALATED', - handle_submission_escalated, stage='post' - ) + register_transition_handler(PhotoSubmission, "*", "APPROVED", handle_submission_approved, stage="post") + register_transition_handler(PhotoSubmission, "*", "REJECTED", handle_submission_rejected, stage="post") + register_transition_handler(PhotoSubmission, "*", "ESCALATED", handle_submission_escalated, stage="post") # ModerationReport handlers - register_transition_handler( - ModerationReport, '*', 'RESOLVED', - handle_report_resolved, stage='post' - ) + register_transition_handler(ModerationReport, "*", "RESOLVED", handle_report_resolved, stage="post") # ModerationQueue handlers - register_transition_handler( - ModerationQueue, '*', 'COMPLETED', - handle_queue_completed, stage='post' - ) + register_transition_handler(ModerationQueue, "*", "COMPLETED", handle_queue_completed, stage="post") # BulkOperation handlers - register_transition_handler( - BulkOperation, '*', '*', - handle_bulk_operation_status, stage='post' - ) + register_transition_handler(BulkOperation, "*", "*", handle_bulk_operation_status, stage="post") # Claim/Unclaim handlers for EditSubmission - register_transition_handler( - EditSubmission, 'PENDING', 'CLAIMED', - handle_submission_claimed, stage='post' - ) - register_transition_handler( - EditSubmission, 'CLAIMED', 'PENDING', - handle_submission_unclaimed, stage='post' - ) + register_transition_handler(EditSubmission, "PENDING", "CLAIMED", handle_submission_claimed, stage="post") + register_transition_handler(EditSubmission, "CLAIMED", "PENDING", handle_submission_unclaimed, stage="post") # Claim/Unclaim handlers for PhotoSubmission - register_transition_handler( - PhotoSubmission, 'PENDING', 'CLAIMED', - handle_submission_claimed, stage='post' - ) - register_transition_handler( - PhotoSubmission, 'CLAIMED', 'PENDING', - handle_submission_unclaimed, stage='post' - ) + register_transition_handler(PhotoSubmission, "PENDING", "CLAIMED", handle_submission_claimed, stage="post") + register_transition_handler(PhotoSubmission, "CLAIMED", "PENDING", handle_submission_unclaimed, stage="post") logger.info("Registered moderation signal handlers") @@ -471,14 +414,14 @@ def register_moderation_signal_handlers(): __all__ = [ - 'submission_status_changed', - 'register_moderation_signal_handlers', - 'handle_submission_approved', - 'handle_submission_rejected', - 'handle_submission_escalated', - 'handle_submission_claimed', - 'handle_submission_unclaimed', - 'handle_report_resolved', - 'handle_queue_completed', - 'handle_bulk_operation_status', + "submission_status_changed", + "register_moderation_signal_handlers", + "handle_submission_approved", + "handle_submission_rejected", + "handle_submission_escalated", + "handle_submission_claimed", + "handle_submission_unclaimed", + "handle_report_resolved", + "handle_queue_completed", + "handle_bulk_operation_status", ] diff --git a/backend/apps/moderation/sse.py b/backend/apps/moderation/sse.py index 17567a8a..63ea15f7 100644 --- a/backend/apps/moderation/sse.py +++ b/backend/apps/moderation/sse.py @@ -4,6 +4,7 @@ Server-Sent Events (SSE) endpoint for real-time moderation dashboard updates. This module provides a streaming HTTP response that broadcasts submission status changes to connected moderators in real-time. """ + import json import logging import queue @@ -103,6 +104,7 @@ class ModerationSSEView(APIView): Sends a heartbeat every 30 seconds to keep the connection alive. """ + def event_stream() -> Generator[str]: client_queue = sse_broadcaster.subscribe() @@ -124,13 +126,10 @@ class ModerationSSEView(APIView): finally: sse_broadcaster.unsubscribe(client_queue) - response = StreamingHttpResponse( - event_stream(), - content_type='text/event-stream' - ) - response['Cache-Control'] = 'no-cache' - response['X-Accel-Buffering'] = 'no' # Disable nginx buffering - response['Connection'] = 'keep-alive' + response = StreamingHttpResponse(event_stream(), content_type="text/event-stream") + response["Cache-Control"] = "no-cache" + response["X-Accel-Buffering"] = "no" # Disable nginx buffering + response["Connection"] = "keep-alive" return response @@ -168,15 +167,17 @@ class ModerationSSETestView(APIView): sse_broadcaster.broadcast(test_payload) - return JsonResponse({ - "status": "ok", - "message": f"Test event broadcast to {len(sse_broadcaster._subscribers)} clients", - "payload": test_payload, - }) + return JsonResponse( + { + "status": "ok", + "message": f"Test event broadcast to {len(sse_broadcaster._subscribers)} clients", + "payload": test_payload, + } + ) __all__ = [ - 'ModerationSSEView', - 'ModerationSSETestView', - 'sse_broadcaster', + "ModerationSSEView", + "ModerationSSETestView", + "sse_broadcaster", ] diff --git a/backend/apps/moderation/templatetags/moderation_tags.py b/backend/apps/moderation/templatetags/moderation_tags.py index f7f39017..7f0a927a 100644 --- a/backend/apps/moderation/templatetags/moderation_tags.py +++ b/backend/apps/moderation/templatetags/moderation_tags.py @@ -14,9 +14,7 @@ def get_object_name(value: int | None, model_path: str) -> str | None: app_label, model = model_path.split(".") try: - content_type = ContentType.objects.get( - app_label=app_label.lower(), model=model.lower() - ) + content_type = ContentType.objects.get(app_label=app_label.lower(), model=model.lower()) model_class = content_type.model_class() if not model_class: return None @@ -60,9 +58,7 @@ def get_park_area_name(value: int | None, park_id: int | None) -> str | None: @register.filter -def get_item( - dictionary: dict[str, Any] | None, key: str | int | None -) -> list[Any]: +def get_item(dictionary: dict[str, Any] | None, key: str | int | None) -> list[Any]: """Get item from dictionary by key.""" if not dictionary or not isinstance(dictionary, dict) or not key: return [] diff --git a/backend/apps/moderation/tests.py b/backend/apps/moderation/tests.py index d9870b03..f0b6cf1d 100644 --- a/backend/apps/moderation/tests.py +++ b/backend/apps/moderation/tests.py @@ -147,9 +147,7 @@ class ModerationMixinsTests(TestCase): view.setup(request, pk=self.operator.pk) view.kwargs = {"pk": self.operator.pk} changes = {"name": "New Name"} - response = view.handle_edit_submission( - request, changes, "Test reason", "Test source" - ) + response = view.handle_edit_submission(request, changes, "Test reason", "Test source") self.assertIsInstance(response, JsonResponse) self.assertEqual(response.status_code, 200) data = json.loads(response.content.decode()) @@ -163,9 +161,7 @@ class ModerationMixinsTests(TestCase): view.setup(request, pk=self.operator.pk) view.kwargs = {"pk": self.operator.pk} changes = {"name": "New Name"} - response = view.handle_edit_submission( - request, changes, "Test reason", "Test source" - ) + response = view.handle_edit_submission(request, changes, "Test reason", "Test source") self.assertIsInstance(response, JsonResponse) self.assertEqual(response.status_code, 200) data = json.loads(response.content.decode()) @@ -177,9 +173,7 @@ class ModerationMixinsTests(TestCase): view.kwargs = {"pk": self.operator.pk} view.object = self.operator - request = self.factory.post( - f"/test/{self.operator.pk}/", data={}, format="multipart" - ) + request = self.factory.post(f"/test/{self.operator.pk}/", data={}, format="multipart") request.user = AnonymousUser() view.setup(request, pk=self.operator.pk) response = view.handle_photo_submission(request) @@ -192,9 +186,7 @@ class ModerationMixinsTests(TestCase): view.kwargs = {"pk": self.operator.pk} view.object = self.operator - request = self.factory.post( - f"/test/{self.operator.pk}/", data={}, format="multipart" - ) + request = self.factory.post(f"/test/{self.operator.pk}/", data={}, format="multipart") request.user = self.user view.setup(request, pk=self.operator.pk) response = view.handle_photo_submission(request) @@ -384,45 +376,33 @@ class EditSubmissionTransitionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='testuser', - email='test@example.com', - password='testpass123', - role='USER' + username="testuser", email="test@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' - ) - self.operator = Operator.objects.create( - name='Test Operator', - description='Test Description' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) + self.operator = Operator.objects.create(name="Test Operator", description="Test Description") self.content_type = ContentType.objects.get_for_model(Operator) - def _create_submission(self, status='PENDING'): + def _create_submission(self, status="PENDING"): """Helper to create an EditSubmission.""" return EditSubmission.objects.create( user=self.user, content_type=self.content_type, object_id=self.operator.id, - submission_type='EDIT', - changes={'name': 'Updated Name'}, + submission_type="EDIT", + changes={"name": "Updated Name"}, status=status, - reason='Test reason' + reason="Test reason", ) def test_pending_to_approved_transition(self): """Test transition from PENDING to APPROVED.""" submission = self._create_submission() - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") submission.transition_to_approved(user=self.moderator) submission.handled_by = self.moderator @@ -430,43 +410,43 @@ class EditSubmissionTransitionTests(TestCase): submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") self.assertEqual(submission.handled_by, self.moderator) self.assertIsNotNone(submission.handled_at) def test_pending_to_rejected_transition(self): """Test transition from PENDING to REJECTED.""" submission = self._create_submission() - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") submission.transition_to_rejected(user=self.moderator) submission.handled_by = self.moderator submission.handled_at = timezone.now() - submission.notes = 'Rejected: Insufficient evidence' + submission.notes = "Rejected: Insufficient evidence" submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'REJECTED') + self.assertEqual(submission.status, "REJECTED") self.assertEqual(submission.handled_by, self.moderator) - self.assertIn('Rejected', submission.notes) + self.assertIn("Rejected", submission.notes) def test_pending_to_escalated_transition(self): """Test transition from PENDING to ESCALATED.""" submission = self._create_submission() - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") submission.transition_to_escalated(user=self.moderator) submission.handled_by = self.moderator submission.handled_at = timezone.now() - submission.notes = 'Escalated: Needs admin review' + submission.notes = "Escalated: Needs admin review" submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'ESCALATED') + self.assertEqual(submission.status, "ESCALATED") def test_escalated_to_approved_transition(self): """Test transition from ESCALATED to APPROVED.""" - submission = self._create_submission(status='ESCALATED') + submission = self._create_submission(status="ESCALATED") submission.transition_to_approved(user=self.admin) submission.handled_by = self.admin @@ -474,25 +454,25 @@ class EditSubmissionTransitionTests(TestCase): submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") self.assertEqual(submission.handled_by, self.admin) def test_escalated_to_rejected_transition(self): """Test transition from ESCALATED to REJECTED.""" - submission = self._create_submission(status='ESCALATED') + submission = self._create_submission(status="ESCALATED") submission.transition_to_rejected(user=self.admin) submission.handled_by = self.admin submission.handled_at = timezone.now() - submission.notes = 'Rejected by admin' + submission.notes = "Rejected by admin" submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'REJECTED') + self.assertEqual(submission.status, "REJECTED") def test_invalid_transition_from_approved(self): """Test that transitions from APPROVED state fail.""" - submission = self._create_submission(status='APPROVED') + submission = self._create_submission(status="APPROVED") # Attempting to transition from APPROVED should raise TransitionNotAllowed with self.assertRaises(TransitionNotAllowed): @@ -500,7 +480,7 @@ class EditSubmissionTransitionTests(TestCase): def test_invalid_transition_from_rejected(self): """Test that transitions from REJECTED state fail.""" - submission = self._create_submission(status='REJECTED') + submission = self._create_submission(status="REJECTED") # Attempting to transition from REJECTED should raise TransitionNotAllowed with self.assertRaises(TransitionNotAllowed): @@ -513,7 +493,7 @@ class EditSubmissionTransitionTests(TestCase): submission.approve(self.moderator) submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") self.assertEqual(submission.handled_by, self.moderator) self.assertIsNotNone(submission.handled_at) @@ -521,21 +501,21 @@ class EditSubmissionTransitionTests(TestCase): """Test the reject() wrapper method.""" submission = self._create_submission() - submission.reject(self.moderator, reason='Not enough evidence') + submission.reject(self.moderator, reason="Not enough evidence") submission.refresh_from_db() - self.assertEqual(submission.status, 'REJECTED') - self.assertIn('Not enough evidence', submission.notes) + self.assertEqual(submission.status, "REJECTED") + self.assertIn("Not enough evidence", submission.notes) def test_escalate_wrapper_method(self): """Test the escalate() wrapper method.""" submission = self._create_submission() - submission.escalate(self.moderator, reason='Needs admin approval') + submission.escalate(self.moderator, reason="Needs admin approval") submission.refresh_from_db() - self.assertEqual(submission.status, 'ESCALATED') - self.assertIn('Needs admin approval', submission.notes) + self.assertEqual(submission.status, "ESCALATED") + self.assertIn("Needs admin approval", submission.notes) # ============================================================================ @@ -549,90 +529,81 @@ class ModerationReportTransitionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='reporter', - email='reporter@example.com', - password='testpass123', - role='USER' + username="reporter", email="reporter@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' - ) - self.operator = Operator.objects.create( - name='Test Operator', - description='Test Description' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) + self.operator = Operator.objects.create(name="Test Operator", description="Test Description") self.content_type = ContentType.objects.get_for_model(Operator) - def _create_report(self, status='PENDING'): + def _create_report(self, status="PENDING"): """Helper to create a ModerationReport.""" return ModerationReport.objects.create( - report_type='CONTENT', + report_type="CONTENT", status=status, - priority='MEDIUM', - reported_entity_type='company', + priority="MEDIUM", + reported_entity_type="company", reported_entity_id=self.operator.id, content_type=self.content_type, - reason='Inaccurate information', - description='The company information is incorrect', - reported_by=self.user + reason="Inaccurate information", + description="The company information is incorrect", + reported_by=self.user, ) def test_pending_to_under_review_transition(self): """Test transition from PENDING to UNDER_REVIEW.""" report = self._create_report() - self.assertEqual(report.status, 'PENDING') + self.assertEqual(report.status, "PENDING") report.transition_to_under_review(user=self.moderator) report.assigned_moderator = self.moderator report.save() report.refresh_from_db() - self.assertEqual(report.status, 'UNDER_REVIEW') + self.assertEqual(report.status, "UNDER_REVIEW") self.assertEqual(report.assigned_moderator, self.moderator) def test_under_review_to_resolved_transition(self): """Test transition from UNDER_REVIEW to RESOLVED.""" - report = self._create_report(status='UNDER_REVIEW') + report = self._create_report(status="UNDER_REVIEW") report.assigned_moderator = self.moderator report.save() report.transition_to_resolved(user=self.moderator) - report.resolution_action = 'Content updated' - report.resolution_notes = 'Fixed the incorrect information' + report.resolution_action = "Content updated" + report.resolution_notes = "Fixed the incorrect information" report.resolved_at = timezone.now() report.save() report.refresh_from_db() - self.assertEqual(report.status, 'RESOLVED') + self.assertEqual(report.status, "RESOLVED") self.assertIsNotNone(report.resolved_at) def test_under_review_to_dismissed_transition(self): """Test transition from UNDER_REVIEW to DISMISSED.""" - report = self._create_report(status='UNDER_REVIEW') + report = self._create_report(status="UNDER_REVIEW") report.assigned_moderator = self.moderator report.save() report.transition_to_dismissed(user=self.moderator) - report.resolution_notes = 'Report is not valid' + report.resolution_notes = "Report is not valid" report.resolved_at = timezone.now() report.save() report.refresh_from_db() - self.assertEqual(report.status, 'DISMISSED') + self.assertEqual(report.status, "DISMISSED") def test_invalid_transition_from_resolved(self): """Test that transitions from RESOLVED state fail.""" - report = self._create_report(status='RESOLVED') + report = self._create_report(status="RESOLVED") with self.assertRaises(TransitionNotAllowed): report.transition_to_dismissed(user=self.moderator) def test_invalid_transition_from_dismissed(self): """Test that transitions from DISMISSED state fail.""" - report = self._create_report(status='DISMISSED') + report = self._create_report(status="DISMISSED") with self.assertRaises(TransitionNotAllowed): report.transition_to_resolved(user=self.moderator) @@ -649,27 +620,24 @@ class ModerationQueueTransitionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) - def _create_queue_item(self, status='PENDING'): + def _create_queue_item(self, status="PENDING"): """Helper to create a ModerationQueue item.""" return ModerationQueue.objects.create( - item_type='EDIT_SUBMISSION', + item_type="EDIT_SUBMISSION", status=status, - priority='MEDIUM', - title='Review edit submission', - description='User submitted an edit that needs review', - flagged_by=self.moderator + priority="MEDIUM", + title="Review edit submission", + description="User submitted an edit that needs review", + flagged_by=self.moderator, ) def test_pending_to_in_progress_transition(self): """Test transition from PENDING to IN_PROGRESS.""" item = self._create_queue_item() - self.assertEqual(item.status, 'PENDING') + self.assertEqual(item.status, "PENDING") item.transition_to_in_progress(user=self.moderator) item.assigned_to = self.moderator @@ -677,12 +645,12 @@ class ModerationQueueTransitionTests(TestCase): item.save() item.refresh_from_db() - self.assertEqual(item.status, 'IN_PROGRESS') + self.assertEqual(item.status, "IN_PROGRESS") self.assertEqual(item.assigned_to, self.moderator) def test_in_progress_to_completed_transition(self): """Test transition from IN_PROGRESS to COMPLETED.""" - item = self._create_queue_item(status='IN_PROGRESS') + item = self._create_queue_item(status="IN_PROGRESS") item.assigned_to = self.moderator item.save() @@ -690,11 +658,11 @@ class ModerationQueueTransitionTests(TestCase): item.save() item.refresh_from_db() - self.assertEqual(item.status, 'COMPLETED') + self.assertEqual(item.status, "COMPLETED") def test_in_progress_to_cancelled_transition(self): """Test transition from IN_PROGRESS to CANCELLED.""" - item = self._create_queue_item(status='IN_PROGRESS') + item = self._create_queue_item(status="IN_PROGRESS") item.assigned_to = self.moderator item.save() @@ -702,7 +670,7 @@ class ModerationQueueTransitionTests(TestCase): item.save() item.refresh_from_db() - self.assertEqual(item.status, 'CANCELLED') + self.assertEqual(item.status, "CANCELLED") def test_pending_to_cancelled_transition(self): """Test transition from PENDING to CANCELLED.""" @@ -712,11 +680,11 @@ class ModerationQueueTransitionTests(TestCase): item.save() item.refresh_from_db() - self.assertEqual(item.status, 'CANCELLED') + self.assertEqual(item.status, "CANCELLED") def test_invalid_transition_from_completed(self): """Test that transitions from COMPLETED state fail.""" - item = self._create_queue_item(status='COMPLETED') + item = self._create_queue_item(status="COMPLETED") with self.assertRaises(TransitionNotAllowed): item.transition_to_in_progress(user=self.moderator) @@ -733,40 +701,37 @@ class BulkOperationTransitionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) - def _create_bulk_operation(self, status='PENDING'): + def _create_bulk_operation(self, status="PENDING"): """Helper to create a BulkOperation.""" return BulkOperation.objects.create( - operation_type='BULK_UPDATE', + operation_type="BULK_UPDATE", status=status, - priority='MEDIUM', - description='Bulk update park statuses', - parameters={'target': 'parks', 'action': 'update_status'}, + priority="MEDIUM", + description="Bulk update park statuses", + parameters={"target": "parks", "action": "update_status"}, created_by=self.admin, - total_items=100 + total_items=100, ) def test_pending_to_running_transition(self): """Test transition from PENDING to RUNNING.""" operation = self._create_bulk_operation() - self.assertEqual(operation.status, 'PENDING') + self.assertEqual(operation.status, "PENDING") operation.transition_to_running(user=self.admin) operation.started_at = timezone.now() operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'RUNNING') + self.assertEqual(operation.status, "RUNNING") self.assertIsNotNone(operation.started_at) def test_running_to_completed_transition(self): """Test transition from RUNNING to COMPLETED.""" - operation = self._create_bulk_operation(status='RUNNING') + operation = self._create_bulk_operation(status="RUNNING") operation.started_at = timezone.now() operation.save() @@ -776,24 +741,24 @@ class BulkOperationTransitionTests(TestCase): operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'COMPLETED') + self.assertEqual(operation.status, "COMPLETED") self.assertIsNotNone(operation.completed_at) self.assertEqual(operation.processed_items, 100) def test_running_to_failed_transition(self): """Test transition from RUNNING to FAILED.""" - operation = self._create_bulk_operation(status='RUNNING') + operation = self._create_bulk_operation(status="RUNNING") operation.started_at = timezone.now() operation.save() operation.transition_to_failed(user=self.admin) operation.completed_at = timezone.now() - operation.results = {'error': 'Database connection failed'} + operation.results = {"error": "Database connection failed"} operation.failed_items = 50 operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'FAILED') + self.assertEqual(operation.status, "FAILED") self.assertEqual(operation.failed_items, 50) def test_pending_to_cancelled_transition(self): @@ -804,11 +769,11 @@ class BulkOperationTransitionTests(TestCase): operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'CANCELLED') + self.assertEqual(operation.status, "CANCELLED") def test_running_to_cancelled_transition(self): """Test transition from RUNNING to CANCELLED when cancellable.""" - operation = self._create_bulk_operation(status='RUNNING') + operation = self._create_bulk_operation(status="RUNNING") operation.can_cancel = True operation.save() @@ -816,18 +781,18 @@ class BulkOperationTransitionTests(TestCase): operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'CANCELLED') + self.assertEqual(operation.status, "CANCELLED") def test_invalid_transition_from_completed(self): """Test that transitions from COMPLETED state fail.""" - operation = self._create_bulk_operation(status='COMPLETED') + operation = self._create_bulk_operation(status="COMPLETED") with self.assertRaises(TransitionNotAllowed): operation.transition_to_running(user=self.admin) def test_invalid_transition_from_failed(self): """Test that transitions from FAILED state fail.""" - operation = self._create_bulk_operation(status='FAILED') + operation = self._create_bulk_operation(status="FAILED") with self.assertRaises(TransitionNotAllowed): operation.transition_to_completed(user=self.admin) @@ -858,21 +823,12 @@ class TransitionLoggingTestCase(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='testuser', - email='test@example.com', - password='testpass123', - role='USER' + username="testuser", email="test@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' - ) - self.operator = Operator.objects.create( - name='Test Operator', - description='Test Description' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) + self.operator = Operator.objects.create(name="Test Operator", description="Test Description") self.content_type = ContentType.objects.get_for_model(Operator) def test_transition_creates_log(self): @@ -884,10 +840,10 @@ class TransitionLoggingTestCase(TestCase): user=self.user, content_type=self.content_type, object_id=self.operator.id, - submission_type='EDIT', - changes={'name': 'Updated Name'}, - status='PENDING', - reason='Test reason' + submission_type="EDIT", + changes={"name": "Updated Name"}, + status="PENDING", + reason="Test reason", ) # Perform transition @@ -896,15 +852,12 @@ class TransitionLoggingTestCase(TestCase): # Check log was created submission_ct = ContentType.objects.get_for_model(submission) - log = StateLog.objects.filter( - content_type=submission_ct, - object_id=submission.id - ).first() + log = StateLog.objects.filter(content_type=submission_ct, object_id=submission.id).first() self.assertIsNotNone(log, "StateLog entry should be created") - self.assertEqual(log.state, 'APPROVED') + self.assertEqual(log.state, "APPROVED") self.assertEqual(log.by, self.moderator) - self.assertIn('approved', log.transition.lower()) + self.assertIn("approved", log.transition.lower()) def test_multiple_transitions_logged(self): """Test that multiple transitions are all logged.""" @@ -914,10 +867,10 @@ class TransitionLoggingTestCase(TestCase): user=self.user, content_type=self.content_type, object_id=self.operator.id, - submission_type='EDIT', - changes={'name': 'Updated Name'}, - status='PENDING', - reason='Test reason' + submission_type="EDIT", + changes={"name": "Updated Name"}, + status="PENDING", + reason="Test reason", ) submission_ct = ContentType.objects.get_for_model(submission) @@ -931,14 +884,11 @@ class TransitionLoggingTestCase(TestCase): submission.save() # Check multiple logs created - logs = StateLog.objects.filter( - content_type=submission_ct, - object_id=submission.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=submission_ct, object_id=submission.id).order_by("timestamp") self.assertEqual(logs.count(), 2, "Should have 2 log entries") - self.assertEqual(logs[0].state, 'ESCALATED') - self.assertEqual(logs[1].state, 'APPROVED') + self.assertEqual(logs[0].state, "ESCALATED") + self.assertEqual(logs[1].state, "APPROVED") def test_history_endpoint_returns_logs(self): """Test history API endpoint returns transition logs.""" @@ -951,10 +901,10 @@ class TransitionLoggingTestCase(TestCase): user=self.user, content_type=self.content_type, object_id=self.operator.id, - submission_type='EDIT', - changes={'name': 'Updated Name'}, - status='PENDING', - reason='Test reason' + submission_type="EDIT", + changes={"name": "Updated Name"}, + status="PENDING", + reason="Test reason", ) # Perform transition to create log @@ -963,7 +913,7 @@ class TransitionLoggingTestCase(TestCase): # Note: This assumes EditSubmission has a history endpoint # Adjust URL pattern based on actual implementation - response = api_client.get('/api/moderation/reports/all_history/') + response = api_client.get("/api/moderation/reports/all_history/") self.assertEqual(response.status_code, 200) @@ -975,10 +925,10 @@ class TransitionLoggingTestCase(TestCase): user=self.user, content_type=self.content_type, object_id=self.operator.id, - submission_type='EDIT', - changes={'name': 'Updated Name'}, - status='PENDING', - reason='Test reason' + submission_type="EDIT", + changes={"name": "Updated Name"}, + status="PENDING", + reason="Test reason", ) # Perform transition without user @@ -987,13 +937,10 @@ class TransitionLoggingTestCase(TestCase): # Check log was created even without user submission_ct = ContentType.objects.get_for_model(submission) - log = StateLog.objects.filter( - content_type=submission_ct, - object_id=submission.id - ).first() + log = StateLog.objects.filter(content_type=submission_ct, object_id=submission.id).first() self.assertIsNotNone(log) - self.assertEqual(log.state, 'REJECTED') + self.assertEqual(log.state, "REJECTED") self.assertIsNone(log.by, "System transitions should have no user") def test_transition_log_includes_description(self): @@ -1004,10 +951,10 @@ class TransitionLoggingTestCase(TestCase): user=self.user, content_type=self.content_type, object_id=self.operator.id, - submission_type='EDIT', - changes={'name': 'Updated Name'}, - status='PENDING', - reason='Test reason' + submission_type="EDIT", + changes={"name": "Updated Name"}, + status="PENDING", + reason="Test reason", ) # Perform transition @@ -1016,14 +963,11 @@ class TransitionLoggingTestCase(TestCase): # Check log submission_ct = ContentType.objects.get_for_model(submission) - log = StateLog.objects.filter( - content_type=submission_ct, - object_id=submission.id - ).first() + log = StateLog.objects.filter(content_type=submission_ct, object_id=submission.id).first() self.assertIsNotNone(log) # Description field exists and can be used for audit trails - self.assertTrue(hasattr(log, 'description')) + self.assertTrue(hasattr(log, "description")) def test_log_ordering_by_timestamp(self): """Test that logs are properly ordered by timestamp.""" @@ -1034,10 +978,10 @@ class TransitionLoggingTestCase(TestCase): user=self.user, content_type=self.content_type, object_id=self.operator.id, - submission_type='EDIT', - changes={'name': 'Updated Name'}, - status='PENDING', - reason='Test reason' + submission_type="EDIT", + changes={"name": "Updated Name"}, + status="PENDING", + reason="Test reason", ) submission_ct = ContentType.objects.get_for_model(submission) @@ -1050,10 +994,7 @@ class TransitionLoggingTestCase(TestCase): submission.save() # Get logs ordered by timestamp - logs = list(StateLog.objects.filter( - content_type=submission_ct, - object_id=submission.id - ).order_by('timestamp')) + logs = list(StateLog.objects.filter(content_type=submission_ct, object_id=submission.id).order_by("timestamp")) # Verify ordering self.assertEqual(len(logs), 2) @@ -1071,27 +1012,21 @@ class ModerationActionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.target_user = User.objects.create_user( - username='target', - email='target@example.com', - password='testpass123', - role='USER' + username="target", email="target@example.com", password="testpass123", role="USER" ) def test_create_action_with_duration(self): """Test creating an action with duration sets expires_at.""" action = ModerationAction.objects.create( - action_type='TEMPORARY_BAN', - reason='Spam', - details='User was spamming the forums', + action_type="TEMPORARY_BAN", + reason="Spam", + details="User was spamming the forums", duration_hours=24, moderator=self.moderator, - target_user=self.target_user + target_user=self.target_user, ) self.assertIsNotNone(action.expires_at) @@ -1102,11 +1037,11 @@ class ModerationActionTests(TestCase): def test_create_action_without_duration(self): """Test creating an action without duration has no expires_at.""" action = ModerationAction.objects.create( - action_type='WARNING', - reason='First offense', - details='Warning issued for minor violation', + action_type="WARNING", + reason="First offense", + details="Warning issued for minor violation", moderator=self.moderator, - target_user=self.target_user + target_user=self.target_user, ) self.assertIsNone(action.expires_at) @@ -1114,11 +1049,11 @@ class ModerationActionTests(TestCase): def test_action_is_active_by_default(self): """Test that new actions are active by default.""" action = ModerationAction.objects.create( - action_type='WARNING', - reason='Test', - details='Test warning', + action_type="WARNING", + reason="Test", + details="Test warning", moderator=self.moderator, - target_user=self.target_user + target_user=self.target_user, ) self.assertTrue(action.is_active) @@ -1135,56 +1070,46 @@ class PhotoSubmissionTransitionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='testuser', - email='test@example.com', - password='testpass123', - role='USER' + username="testuser", email="test@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) self.operator = Operator.objects.create( - name='Test Operator', - description='Test Description', - roles=['OPERATOR'] + name="Test Operator", description="Test Description", roles=["OPERATOR"] ) self.content_type = ContentType.objects.get_for_model(Operator) def _create_mock_photo(self): """Create a mock CloudflareImage for testing.""" from unittest.mock import Mock + mock_photo = Mock() mock_photo.pk = 1 mock_photo.id = 1 return mock_photo - def _create_submission(self, status='PENDING'): + def _create_submission(self, status="PENDING"): """Helper to create a PhotoSubmission.""" # Create using direct database creation to bypass FK validation from unittest.mock import Mock, patch - with patch.object(PhotoSubmission, 'photo', Mock()): + with patch.object(PhotoSubmission, "photo", Mock()): submission = PhotoSubmission( user=self.user, content_type=self.content_type, object_id=self.operator.id, - caption='Test Photo', + caption="Test Photo", status=status, ) # Bypass model save to avoid FK constraint on photo submission.photo_id = 1 submission.save(update_fields=None) # Force status after creation for non-PENDING states - if status != 'PENDING': + if status != "PENDING": PhotoSubmission.objects.filter(pk=submission.pk).update(status=status) submission.refresh_from_db() return submission @@ -1192,7 +1117,7 @@ class PhotoSubmissionTransitionTests(TestCase): def test_pending_to_approved_transition(self): """Test transition from PENDING to APPROVED.""" submission = self._create_submission() - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") submission.transition_to_approved(user=self.moderator) submission.handled_by = self.moderator @@ -1200,43 +1125,43 @@ class PhotoSubmissionTransitionTests(TestCase): submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") self.assertEqual(submission.handled_by, self.moderator) self.assertIsNotNone(submission.handled_at) def test_pending_to_rejected_transition(self): """Test transition from PENDING to REJECTED.""" submission = self._create_submission() - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") submission.transition_to_rejected(user=self.moderator) submission.handled_by = self.moderator submission.handled_at = timezone.now() - submission.notes = 'Rejected: Image quality too low' + submission.notes = "Rejected: Image quality too low" submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'REJECTED') + self.assertEqual(submission.status, "REJECTED") self.assertEqual(submission.handled_by, self.moderator) - self.assertIn('Rejected', submission.notes) + self.assertIn("Rejected", submission.notes) def test_pending_to_escalated_transition(self): """Test transition from PENDING to ESCALATED.""" submission = self._create_submission() - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") submission.transition_to_escalated(user=self.moderator) submission.handled_by = self.moderator submission.handled_at = timezone.now() - submission.notes = 'Escalated: Copyright concerns' + submission.notes = "Escalated: Copyright concerns" submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'ESCALATED') + self.assertEqual(submission.status, "ESCALATED") def test_escalated_to_approved_transition(self): """Test transition from ESCALATED to APPROVED.""" - submission = self._create_submission(status='ESCALATED') + submission = self._create_submission(status="ESCALATED") submission.transition_to_approved(user=self.admin) submission.handled_by = self.admin @@ -1244,32 +1169,32 @@ class PhotoSubmissionTransitionTests(TestCase): submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") self.assertEqual(submission.handled_by, self.admin) def test_escalated_to_rejected_transition(self): """Test transition from ESCALATED to REJECTED.""" - submission = self._create_submission(status='ESCALATED') + submission = self._create_submission(status="ESCALATED") submission.transition_to_rejected(user=self.admin) submission.handled_by = self.admin submission.handled_at = timezone.now() - submission.notes = 'Rejected by admin after review' + submission.notes = "Rejected by admin after review" submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'REJECTED') + self.assertEqual(submission.status, "REJECTED") def test_invalid_transition_from_approved(self): """Test that transitions from APPROVED state fail.""" - submission = self._create_submission(status='APPROVED') + submission = self._create_submission(status="APPROVED") with self.assertRaises(TransitionNotAllowed): submission.transition_to_rejected(user=self.moderator) def test_invalid_transition_from_rejected(self): """Test that transitions from REJECTED state fail.""" - submission = self._create_submission(status='REJECTED') + submission = self._create_submission(status="REJECTED") with self.assertRaises(TransitionNotAllowed): submission.transition_to_approved(user=self.moderator) @@ -1281,12 +1206,12 @@ class PhotoSubmissionTransitionTests(TestCase): submission = self._create_submission() # Mock the photo creation part since we don't have actual photos - with patch.object(submission, 'transition_to_rejected'): - submission.reject(self.moderator, notes='Not suitable') + with patch.object(submission, "transition_to_rejected"): + submission.reject(self.moderator, notes="Not suitable") submission.refresh_from_db() - self.assertEqual(submission.status, 'REJECTED') - self.assertIn('Not suitable', submission.notes) + self.assertEqual(submission.status, "REJECTED") + self.assertIn("Not suitable", submission.notes) def test_escalate_wrapper_method(self): """Test the escalate() wrapper method.""" @@ -1294,12 +1219,12 @@ class PhotoSubmissionTransitionTests(TestCase): submission = self._create_submission() - with patch.object(submission, 'transition_to_escalated'): - submission.escalate(self.moderator, notes='Needs admin review') + with patch.object(submission, "transition_to_escalated"): + submission.escalate(self.moderator, notes="Needs admin review") submission.refresh_from_db() - self.assertEqual(submission.status, 'ESCALATED') - self.assertIn('Needs admin review', submission.notes) + self.assertEqual(submission.status, "ESCALATED") + self.assertIn("Needs admin review", submission.notes) def test_transition_creates_state_log(self): """Test that transitions create StateLog entries.""" @@ -1313,13 +1238,10 @@ class PhotoSubmissionTransitionTests(TestCase): # Check log was created submission_ct = ContentType.objects.get_for_model(submission) - log = StateLog.objects.filter( - content_type=submission_ct, - object_id=submission.id - ).first() + log = StateLog.objects.filter(content_type=submission_ct, object_id=submission.id).first() self.assertIsNotNone(log, "StateLog entry should be created") - self.assertEqual(log.state, 'APPROVED') + self.assertEqual(log.state, "APPROVED") self.assertEqual(log.by, self.moderator) def test_multiple_transitions_logged(self): @@ -1338,14 +1260,11 @@ class PhotoSubmissionTransitionTests(TestCase): submission.save() # Check multiple logs created - logs = StateLog.objects.filter( - content_type=submission_ct, - object_id=submission.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=submission_ct, object_id=submission.id).order_by("timestamp") self.assertEqual(logs.count(), 2, "Should have 2 log entries") - self.assertEqual(logs[0].state, 'ESCALATED') - self.assertEqual(logs[1].state, 'APPROVED') + self.assertEqual(logs[0].state, "ESCALATED") + self.assertEqual(logs[1].state, "APPROVED") def test_handled_by_and_handled_at_updated(self): """Test that handled_by and handled_at are properly updated.""" @@ -1369,7 +1288,7 @@ class PhotoSubmissionTransitionTests(TestCase): def test_notes_field_updated_on_rejection(self): """Test that notes field is updated with rejection reason.""" submission = self._create_submission() - rejection_reason = 'Image contains watermarks' + rejection_reason = "Image contains watermarks" submission.transition_to_rejected(user=self.moderator) submission.notes = rejection_reason @@ -1381,7 +1300,7 @@ class PhotoSubmissionTransitionTests(TestCase): def test_notes_field_updated_on_escalation(self): """Test that notes field is updated with escalation reason.""" submission = self._create_submission() - escalation_reason = 'Potentially copyrighted content' + escalation_reason = "Potentially copyrighted content" submission.transition_to_escalated(user=self.moderator) submission.notes = escalation_reason diff --git a/backend/apps/moderation/tests/test_admin.py b/backend/apps/moderation/tests/test_admin.py index 9d9cf1e4..944aba6b 100644 --- a/backend/apps/moderation/tests/test_admin.py +++ b/backend/apps/moderation/tests/test_admin.py @@ -43,24 +43,15 @@ class TestModerationAdminSite(TestCase): assert moderation_site.has_permission(request) is False # Regular user - request.user = type("obj", (object,), { - "is_authenticated": True, - "role": "USER" - })() + request.user = type("obj", (object,), {"is_authenticated": True, "role": "USER"})() assert moderation_site.has_permission(request) is False # Moderator - request.user = type("obj", (object,), { - "is_authenticated": True, - "role": "MODERATOR" - })() + request.user = type("obj", (object,), {"is_authenticated": True, "role": "MODERATOR"})() assert moderation_site.has_permission(request) is True # Admin - request.user = type("obj", (object,), { - "is_authenticated": True, - "role": "ADMIN" - })() + request.user = type("obj", (object,), {"is_authenticated": True, "role": "ADMIN"})() assert moderation_site.has_permission(request) is True @@ -146,6 +137,7 @@ class TestStateLogAdmin(TestCase): self.site = AdminSite() # Note: StateLog is from django_fsm_log from django_fsm_log.models import StateLog + self.admin = StateLogAdmin(model=StateLog, admin_site=self.site) def test_readonly_permissions(self): @@ -215,4 +207,5 @@ class TestRegisteredModels(TestCase): def test_state_log_registered(self): """Verify StateLog is registered with moderation site.""" from django_fsm_log.models import StateLog + assert StateLog in moderation_site._registry diff --git a/backend/apps/moderation/tests/test_workflows.py b/backend/apps/moderation/tests/test_workflows.py index 7ff807ce..5c98da11 100644 --- a/backend/apps/moderation/tests/test_workflows.py +++ b/backend/apps/moderation/tests/test_workflows.py @@ -9,7 +9,6 @@ This module tests end-to-end moderation workflows including: - Bulk operation workflow """ - from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.test import TestCase @@ -25,22 +24,13 @@ class SubmissionApprovalWorkflowTests(TestCase): def setUpTestData(cls): """Set up test data for all tests.""" cls.regular_user = User.objects.create_user( - username='regular_user', - email='user@example.com', - password='testpass123', - role='USER' + username="regular_user", email="user@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='moderator', - email='mod@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="mod@example.com", password="testpass123", role="MODERATOR" ) cls.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) def test_edit_submission_approval_workflow(self): @@ -53,10 +43,7 @@ class SubmissionApprovalWorkflowTests(TestCase): from apps.parks.models import Company # Create target object - company = Company.objects.create( - name='Test Company', - description='Original description' - ) + company = Company.objects.create(name="Test Company", description="Original description") # User submits an edit content_type = ContentType.objects.get_for_model(company) @@ -64,13 +51,13 @@ class SubmissionApprovalWorkflowTests(TestCase): user=self.regular_user, content_type=content_type, object_id=company.id, - submission_type='EDIT', - changes={'description': 'Updated description'}, - status='PENDING', - reason='Fixing typo' + submission_type="EDIT", + changes={"description": "Updated description"}, + status="PENDING", + reason="Fixing typo", ) - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") self.assertIsNone(submission.handled_by) self.assertIsNone(submission.handled_at) @@ -81,7 +68,7 @@ class SubmissionApprovalWorkflowTests(TestCase): submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") self.assertEqual(submission.handled_by, self.moderator) self.assertIsNotNone(submission.handled_at) @@ -95,16 +82,9 @@ class SubmissionApprovalWorkflowTests(TestCase): from apps.parks.models import Company, Park # Create target park - operator = Company.objects.create( - name='Test Operator', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name="Test Operator", roles=["OPERATOR"]) park = Park.objects.create( - name='Test Park', - slug='test-park', - operator=operator, - status='OPERATING', - timezone='America/New_York' + name="Test Park", slug="test-park", operator=operator, status="OPERATING", timezone="America/New_York" ) # User submits a photo @@ -113,12 +93,12 @@ class SubmissionApprovalWorkflowTests(TestCase): user=self.regular_user, content_type=content_type, object_id=park.id, - status='PENDING', - photo_type='GENERAL', - description='Beautiful park entrance' + status="PENDING", + photo_type="GENERAL", + description="Beautiful park entrance", ) - self.assertEqual(submission.status, 'PENDING') + self.assertEqual(submission.status, "PENDING") # Moderator approves submission.transition_to_approved(user=self.moderator) @@ -127,7 +107,7 @@ class SubmissionApprovalWorkflowTests(TestCase): submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") class SubmissionRejectionWorkflowTests(TestCase): @@ -136,16 +116,10 @@ class SubmissionRejectionWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.regular_user = User.objects.create_user( - username='user_rej', - email='user_rej@example.com', - password='testpass123', - role='USER' + username="user_rej", email="user_rej@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='mod_rej', - email='mod_rej@example.com', - password='testpass123', - role='MODERATOR' + username="mod_rej", email="mod_rej@example.com", password="testpass123", role="MODERATOR" ) def test_edit_submission_rejection_with_reason(self): @@ -157,32 +131,29 @@ class SubmissionRejectionWorkflowTests(TestCase): from apps.moderation.models import EditSubmission from apps.parks.models import Company - company = Company.objects.create( - name='Test Company', - description='Original' - ) + company = Company.objects.create(name="Test Company", description="Original") content_type = ContentType.objects.get_for_model(company) submission = EditSubmission.objects.create( user=self.regular_user, content_type=content_type, object_id=company.id, - submission_type='EDIT', - changes={'name': 'Spam Content'}, - status='PENDING', - reason='Name change request' + submission_type="EDIT", + changes={"name": "Spam Content"}, + status="PENDING", + reason="Name change request", ) # Moderator rejects submission.transition_to_rejected(user=self.moderator) submission.handled_by = self.moderator submission.handled_at = timezone.now() - submission.notes = 'Rejected: Content appears to be spam' + submission.notes = "Rejected: Content appears to be spam" submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'REJECTED') - self.assertIn('spam', submission.notes.lower()) + self.assertEqual(submission.status, "REJECTED") + self.assertIn("spam", submission.notes.lower()) class SubmissionEscalationWorkflowTests(TestCase): @@ -191,22 +162,13 @@ class SubmissionEscalationWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.regular_user = User.objects.create_user( - username='user_esc', - email='user_esc@example.com', - password='testpass123', - role='USER' + username="user_esc", email="user_esc@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='mod_esc', - email='mod_esc@example.com', - password='testpass123', - role='MODERATOR' + username="mod_esc", email="mod_esc@example.com", password="testpass123", role="MODERATOR" ) cls.admin = User.objects.create_user( - username='admin_esc', - email='admin_esc@example.com', - password='testpass123', - role='ADMIN' + username="admin_esc", email="admin_esc@example.com", password="testpass123", role="ADMIN" ) def test_escalation_workflow(self): @@ -218,28 +180,25 @@ class SubmissionEscalationWorkflowTests(TestCase): from apps.moderation.models import EditSubmission from apps.parks.models import Company - company = Company.objects.create( - name='Sensitive Company', - description='Original' - ) + company = Company.objects.create(name="Sensitive Company", description="Original") content_type = ContentType.objects.get_for_model(company) submission = EditSubmission.objects.create( user=self.regular_user, content_type=content_type, object_id=company.id, - submission_type='EDIT', - changes={'name': 'New Sensitive Name'}, - status='PENDING', - reason='Major name change' + submission_type="EDIT", + changes={"name": "New Sensitive Name"}, + status="PENDING", + reason="Major name change", ) # Moderator escalates submission.transition_to_escalated(user=self.moderator) - submission.notes = 'Escalated: Major change needs admin review' + submission.notes = "Escalated: Major change needs admin review" submission.save() - self.assertEqual(submission.status, 'ESCALATED') + self.assertEqual(submission.status, "ESCALATED") # Admin approves submission.transition_to_approved(user=self.admin) @@ -248,7 +207,7 @@ class SubmissionEscalationWorkflowTests(TestCase): submission.save() submission.refresh_from_db() - self.assertEqual(submission.status, 'APPROVED') + self.assertEqual(submission.status, "APPROVED") self.assertEqual(submission.handled_by, self.admin) @@ -258,16 +217,10 @@ class ReportHandlingWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.reporter = User.objects.create_user( - username='reporter', - email='reporter@example.com', - password='testpass123', - role='USER' + username="reporter", email="reporter@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='mod_report', - email='mod_report@example.com', - password='testpass123', - role='MODERATOR' + username="mod_report", email="mod_report@example.com", password="testpass123", role="MODERATOR" ) def test_report_resolution_workflow(self): @@ -279,45 +232,42 @@ class ReportHandlingWorkflowTests(TestCase): from apps.moderation.models import ModerationReport from apps.parks.models import Company - reported_company = Company.objects.create( - name='Problematic Company', - description='Some inappropriate content' - ) + reported_company = Company.objects.create(name="Problematic Company", description="Some inappropriate content") content_type = ContentType.objects.get_for_model(reported_company) # User reports content report = ModerationReport.objects.create( - report_type='CONTENT', - status='PENDING', - priority='HIGH', - reported_entity_type='company', + report_type="CONTENT", + status="PENDING", + priority="HIGH", + reported_entity_type="company", reported_entity_id=reported_company.id, content_type=content_type, - reason='INAPPROPRIATE', - description='This content is inappropriate', - reported_by=self.reporter + reason="INAPPROPRIATE", + description="This content is inappropriate", + reported_by=self.reporter, ) - self.assertEqual(report.status, 'PENDING') + self.assertEqual(report.status, "PENDING") # Moderator claims and starts review report.transition_to_under_review(user=self.moderator) report.assigned_moderator = self.moderator report.save() - self.assertEqual(report.status, 'UNDER_REVIEW') + self.assertEqual(report.status, "UNDER_REVIEW") self.assertEqual(report.assigned_moderator, self.moderator) # Moderator resolves report.transition_to_resolved(user=self.moderator) - report.resolution_action = 'CONTENT_REMOVED' - report.resolution_notes = 'Content was removed' + report.resolution_action = "CONTENT_REMOVED" + report.resolution_notes = "Content was removed" report.resolved_at = timezone.now() report.save() report.refresh_from_db() - self.assertEqual(report.status, 'RESOLVED') + self.assertEqual(report.status, "RESOLVED") self.assertIsNotNone(report.resolved_at) def test_report_dismissal_workflow(self): @@ -329,23 +279,20 @@ class ReportHandlingWorkflowTests(TestCase): from apps.moderation.models import ModerationReport from apps.parks.models import Company - company = Company.objects.create( - name='Valid Company', - description='Normal content' - ) + company = Company.objects.create(name="Valid Company", description="Normal content") content_type = ContentType.objects.get_for_model(company) report = ModerationReport.objects.create( - report_type='CONTENT', - status='PENDING', - priority='LOW', - reported_entity_type='company', + report_type="CONTENT", + status="PENDING", + priority="LOW", + reported_entity_type="company", reported_entity_id=company.id, content_type=content_type, - reason='OTHER', - description='I just do not like this', - reported_by=self.reporter + reason="OTHER", + description="I just do not like this", + reported_by=self.reporter, ) # Moderator claims @@ -355,12 +302,12 @@ class ReportHandlingWorkflowTests(TestCase): # Moderator dismisses as invalid report.transition_to_dismissed(user=self.moderator) - report.resolution_notes = 'Report does not violate any guidelines' + report.resolution_notes = "Report does not violate any guidelines" report.resolved_at = timezone.now() report.save() report.refresh_from_db() - self.assertEqual(report.status, 'DISMISSED') + self.assertEqual(report.status, "DISMISSED") class BulkOperationWorkflowTests(TestCase): @@ -369,10 +316,7 @@ class BulkOperationWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.admin = User.objects.create_user( - username='admin_bulk', - email='admin_bulk@example.com', - password='testpass123', - role='ADMIN' + username="admin_bulk", email="admin_bulk@example.com", password="testpass123", role="ADMIN" ) def test_bulk_operation_success_workflow(self): @@ -384,22 +328,22 @@ class BulkOperationWorkflowTests(TestCase): from apps.moderation.models import BulkOperation operation = BulkOperation.objects.create( - operation_type='APPROVE_SUBMISSIONS', - status='PENDING', + operation_type="APPROVE_SUBMISSIONS", + status="PENDING", total_items=10, processed_items=0, created_by=self.admin, - parameters={'submission_ids': list(range(1, 11))} + parameters={"submission_ids": list(range(1, 11))}, ) - self.assertEqual(operation.status, 'PENDING') + self.assertEqual(operation.status, "PENDING") # Start operation operation.transition_to_running(user=self.admin) operation.started_at = timezone.now() operation.save() - self.assertEqual(operation.status, 'RUNNING') + self.assertEqual(operation.status, "RUNNING") # Simulate progress for i in range(1, 11): @@ -409,11 +353,11 @@ class BulkOperationWorkflowTests(TestCase): # Complete operation operation.transition_to_completed(user=self.admin) operation.completed_at = timezone.now() - operation.results = {'approved': 10, 'failed': 0} + operation.results = {"approved": 10, "failed": 0} operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'COMPLETED') + self.assertEqual(operation.status, "COMPLETED") self.assertEqual(operation.processed_items, 10) def test_bulk_operation_failure_workflow(self): @@ -425,12 +369,12 @@ class BulkOperationWorkflowTests(TestCase): from apps.moderation.models import BulkOperation operation = BulkOperation.objects.create( - operation_type='DELETE_CONTENT', - status='PENDING', + operation_type="DELETE_CONTENT", + status="PENDING", total_items=5, processed_items=0, created_by=self.admin, - parameters={'content_ids': list(range(1, 6))} + parameters={"content_ids": list(range(1, 6))}, ) operation.transition_to_running(user=self.admin) @@ -442,11 +386,11 @@ class BulkOperationWorkflowTests(TestCase): operation.failed_items = 3 operation.transition_to_failed(user=self.admin) operation.completed_at = timezone.now() - operation.results = {'error': 'Database connection lost', 'processed': 2} + operation.results = {"error": "Database connection lost", "processed": 2} operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'FAILED') + self.assertEqual(operation.status, "FAILED") self.assertEqual(operation.failed_items, 3) def test_bulk_operation_cancellation_workflow(self): @@ -458,13 +402,13 @@ class BulkOperationWorkflowTests(TestCase): from apps.moderation.models import BulkOperation operation = BulkOperation.objects.create( - operation_type='BATCH_UPDATE', - status='PENDING', + operation_type="BATCH_UPDATE", + status="PENDING", total_items=100, processed_items=0, created_by=self.admin, - parameters={'update_field': 'status'}, - can_cancel=True + parameters={"update_field": "status"}, + can_cancel=True, ) operation.transition_to_running(user=self.admin) @@ -477,11 +421,11 @@ class BulkOperationWorkflowTests(TestCase): # Admin cancels operation.transition_to_cancelled(user=self.admin) operation.completed_at = timezone.now() - operation.results = {'cancelled_at': 30, 'reason': 'User requested cancellation'} + operation.results = {"cancelled_at": 30, "reason": "User requested cancellation"} operation.save() operation.refresh_from_db() - self.assertEqual(operation.status, 'CANCELLED') + self.assertEqual(operation.status, "CANCELLED") self.assertEqual(operation.processed_items, 30) @@ -491,10 +435,7 @@ class ModerationQueueWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='mod_queue', - email='mod_queue@example.com', - password='testpass123', - role='MODERATOR' + username="mod_queue", email="mod_queue@example.com", password="testpass123", role="MODERATOR" ) def test_queue_completion_workflow(self): @@ -506,14 +447,14 @@ class ModerationQueueWorkflowTests(TestCase): from apps.moderation.models import ModerationQueue queue_item = ModerationQueue.objects.create( - queue_type='SUBMISSION_REVIEW', - status='PENDING', - priority='MEDIUM', - item_type='edit_submission', - item_id=123 + queue_type="SUBMISSION_REVIEW", + status="PENDING", + priority="MEDIUM", + item_type="edit_submission", + item_id=123, ) - self.assertEqual(queue_item.status, 'PENDING') + self.assertEqual(queue_item.status, "PENDING") # Moderator claims queue_item.transition_to_in_progress(user=self.moderator) @@ -521,7 +462,7 @@ class ModerationQueueWorkflowTests(TestCase): queue_item.assigned_at = timezone.now() queue_item.save() - self.assertEqual(queue_item.status, 'IN_PROGRESS') + self.assertEqual(queue_item.status, "IN_PROGRESS") # Work completed queue_item.transition_to_completed(user=self.moderator) @@ -529,4 +470,4 @@ class ModerationQueueWorkflowTests(TestCase): queue_item.save() queue_item.refresh_from_db() - self.assertEqual(queue_item.status, 'COMPLETED') + self.assertEqual(queue_item.status, "COMPLETED") diff --git a/backend/apps/moderation/urls.py b/backend/apps/moderation/urls.py index 548fa6b8..22038517 100644 --- a/backend/apps/moderation/urls.py +++ b/backend/apps/moderation/urls.py @@ -26,6 +26,7 @@ from .views import ( class ModerationDashboardView(TemplateView): """Moderation dashboard view with HTMX integration.""" + template_name = "moderation/dashboard.html" def get_context_data(self, **kwargs): @@ -38,6 +39,7 @@ class ModerationDashboardView(TemplateView): class SubmissionListView(TemplateView): """Submission list view with filtering.""" + template_name = "moderation/partials/dashboard_content.html" def get_context_data(self, **kwargs): @@ -63,8 +65,10 @@ class SubmissionListView(TemplateView): class HistoryPageView(TemplateView): """Main history page view.""" + template_name = "moderation/history.html" + # Create router and register viewsets router = DefaultRouter() router.register(r"reports", ModerationReportViewSet, basename="moderation-reports") diff --git a/backend/apps/moderation/views.py b/backend/apps/moderation/views.py index 8a7b6c87..b15c5493 100644 --- a/backend/apps/moderation/views.py +++ b/backend/apps/moderation/views.py @@ -86,9 +86,7 @@ class ModerationReportViewSet(viewsets.ModelViewSet): filtering, search, and permission controls. """ - queryset = ModerationReport.objects.select_related( - "reported_by", "assigned_moderator", "content_type" - ).all() + queryset = ModerationReport.objects.select_related("reported_by", "assigned_moderator", "content_type").all() filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] filterset_class = ModerationReportFilter @@ -207,9 +205,7 @@ class ModerationReportViewSet(viewsets.ModelViewSet): return Response(serializer.data) except User.DoesNotExist: - return Response( - {"error": "Moderator not found"}, status=status.HTTP_404_NOT_FOUND - ) + return Response({"error": "Moderator not found"}, status=status.HTTP_404_NOT_FOUND) @action(detail=True, methods=["post"], permission_classes=[IsModeratorOrAdmin]) def resolve(self, request, pk=None): @@ -313,17 +309,11 @@ class ModerationReportViewSet(viewsets.ModelViewSet): overdue_reports += 1 # Reports by priority and type - reports_by_priority = dict( - queryset.values_list("priority").annotate(count=Count("id")) - ) - reports_by_type = dict( - queryset.values_list("report_type").annotate(count=Count("id")) - ) + reports_by_priority = dict(queryset.values_list("priority").annotate(count=Count("id"))) + reports_by_type = dict(queryset.values_list("report_type").annotate(count=Count("id"))) # Average resolution time - resolved_queryset = queryset.filter( - status="RESOLVED", resolved_at__isnull=False - ) + resolved_queryset = queryset.filter(status="RESOLVED", resolved_at__isnull=False) avg_resolution_time = 0 if resolved_queryset.exists(): @@ -430,9 +420,7 @@ class ModerationReportViewSet(viewsets.ModelViewSet): "log": None, }, ) - return Response( - {"error": "Log not found"}, status=status.HTTP_404_NOT_FOUND - ) + return Response({"error": "Log not found"}, status=status.HTTP_404_NOT_FOUND) # Filter by model type with app_label support for correct ContentType resolution model_type = request.query_params.get("model_type") @@ -441,9 +429,7 @@ class ModerationReportViewSet(viewsets.ModelViewSet): try: if app_label: # Use both app_label and model for precise matching - content_type = ContentType.objects.get_by_natural_key( - app_label, model_type - ) + content_type = ContentType.objects.get_by_natural_key(app_label, model_type) else: # Map common model names to their app_labels for correct resolution model_app_mapping = { @@ -457,9 +443,7 @@ class ModerationReportViewSet(viewsets.ModelViewSet): } mapped_app_label = model_app_mapping.get(model_type.lower()) if mapped_app_label: - content_type = ContentType.objects.get_by_natural_key( - mapped_app_label, model_type.lower() - ) + content_type = ContentType.objects.get_by_natural_key(mapped_app_label, model_type.lower()) else: # Fallback to model-only lookup content_type = ContentType.objects.get(model=model_type) @@ -576,9 +560,7 @@ class ModerationQueueViewSet(viewsets.ModelViewSet): completion, and progress tracking. """ - queryset = ModerationQueue.objects.select_related( - "assigned_to", "related_report", "content_type" - ).all() + queryset = ModerationQueue.objects.select_related("assigned_to", "related_report", "content_type").all() serializer_class = ModerationQueueSerializer permission_classes = [CanViewModerationData] @@ -871,9 +853,7 @@ class ModerationActionViewSet(viewsets.ModelViewSet): and status management. """ - queryset = ModerationAction.objects.select_related( - "moderator", "target_user", "related_report" - ).all() + queryset = ModerationAction.objects.select_related("moderator", "target_user", "related_report").all() filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] filterset_class = ModerationActionFilter @@ -907,9 +887,7 @@ class ModerationActionViewSet(viewsets.ModelViewSet): @action(detail=False, methods=["get"], permission_classes=[CanViewModerationData]) def active(self, request): """Get all active moderation actions.""" - queryset = self.get_queryset().filter( - is_active=True, expires_at__gt=timezone.now() - ) + queryset = self.get_queryset().filter(is_active=True, expires_at__gt=timezone.now()) page = self.paginate_queryset(queryset) if page is not None: @@ -922,9 +900,7 @@ class ModerationActionViewSet(viewsets.ModelViewSet): @action(detail=False, methods=["get"], permission_classes=[CanViewModerationData]) def expired(self, request): """Get all expired moderation actions.""" - queryset = self.get_queryset().filter( - expires_at__lte=timezone.now(), is_active=True - ) + queryset = self.get_queryset().filter(expires_at__lte=timezone.now(), is_active=True) page = self.paginate_queryset(queryset) if page is not None: @@ -1173,9 +1149,7 @@ class UserModerationViewSet(viewsets.ViewSet): if not query: return Response([]) - queryset = User.objects.filter( - Q(username__icontains=query) | Q(email__icontains=query) - )[:20] + queryset = User.objects.filter(Q(username__icontains=query) | Q(email__icontains=query))[:20] users_data = [ { @@ -1194,9 +1168,7 @@ class UserModerationViewSet(viewsets.ViewSet): try: user = User.objects.get(pk=pk) except User.DoesNotExist: - return Response( - {"error": "User not found"}, status=status.HTTP_404_NOT_FOUND - ) + return Response({"error": "User not found"}, status=status.HTTP_404_NOT_FOUND) # Gather user moderation data reports_made = ModerationReport.objects.filter(reported_by=user).count() @@ -1206,12 +1178,8 @@ class UserModerationViewSet(viewsets.ViewSet): actions_against = ModerationAction.objects.filter(target_user=user) warnings_received = actions_against.filter(action_type="WARNING").count() - suspensions_received = actions_against.filter( - action_type="USER_SUSPENSION" - ).count() - active_restrictions = actions_against.filter( - is_active=True, expires_at__gt=timezone.now() - ).count() + suspensions_received = actions_against.filter(action_type="USER_SUSPENSION").count() + active_restrictions = actions_against.filter(is_active=True, expires_at__gt=timezone.now()).count() # Risk assessment (simplified) risk_factors = [] @@ -1230,9 +1198,7 @@ class UserModerationViewSet(viewsets.ViewSet): risk_level = "HIGH" # Recent activity - recent_reports = ModerationReport.objects.filter(reported_by=user).order_by( - "-created_at" - )[:5] + recent_reports = ModerationReport.objects.filter(reported_by=user).order_by("-created_at")[:5] recent_actions = actions_against.order_by("-created_at")[:5] @@ -1244,9 +1210,7 @@ class UserModerationViewSet(viewsets.ViewSet): account_status = "RESTRICTED" last_violation = ( - actions_against.filter( - action_type__in=["WARNING", "USER_SUSPENSION", "USER_BAN"] - ) + actions_against.filter(action_type__in=["WARNING", "USER_SUSPENSION", "USER_BAN"]) .order_by("-created_at") .first() ) @@ -1266,16 +1230,10 @@ class UserModerationViewSet(viewsets.ViewSet): "active_restrictions": active_restrictions, "risk_level": risk_level, "risk_factors": risk_factors, - "recent_reports": ModerationReportSerializer( - recent_reports, many=True - ).data, - "recent_actions": ModerationActionSerializer( - recent_actions, many=True - ).data, + "recent_reports": ModerationReportSerializer(recent_reports, many=True).data, + "recent_actions": ModerationActionSerializer(recent_actions, many=True).data, "account_status": account_status, - "last_violation_date": ( - last_violation.created_at if last_violation else None - ), + "last_violation_date": (last_violation.created_at if last_violation else None), "next_review_date": None, # Would be calculated based on business rules } @@ -1287,13 +1245,9 @@ class UserModerationViewSet(viewsets.ViewSet): try: user = User.objects.get(pk=pk) except User.DoesNotExist: - return Response( - {"error": "User not found"}, status=status.HTTP_404_NOT_FOUND - ) + return Response({"error": "User not found"}, status=status.HTTP_404_NOT_FOUND) - serializer = CreateModerationActionSerializer( - data=request.data, context={"request": request} - ) + serializer = CreateModerationActionSerializer(data=request.data, context={"request": request}) if serializer.is_valid(): # Override target_user_id with the user from URL @@ -1331,9 +1285,7 @@ class UserModerationViewSet(viewsets.ViewSet): queryset = User.objects.all() if query: - queryset = queryset.filter( - Q(username__icontains=query) | Q(email__icontains=query) - ) + queryset = queryset.filter(Q(username__icontains=query) | Q(email__icontains=query)) if role: queryset = queryset.filter(role=role) @@ -1376,12 +1328,8 @@ class UserModerationViewSet(viewsets.ViewSet): def stats(self, request): """Get overall user moderation statistics.""" total_actions = ModerationAction.objects.count() - active_actions = ModerationAction.objects.filter( - is_active=True, expires_at__gt=timezone.now() - ).count() - expired_actions = ModerationAction.objects.filter( - expires_at__lte=timezone.now() - ).count() + active_actions = ModerationAction.objects.filter(is_active=True, expires_at__gt=timezone.now()).count() + expired_actions = ModerationAction.objects.filter(expires_at__lte=timezone.now()).count() stats_data = { "total_actions": total_actions, @@ -1404,6 +1352,7 @@ class EditSubmissionViewSet(viewsets.ModelViewSet): Includes claim/unclaim endpoints with concurrency protection using database row locking (select_for_update) to prevent race conditions. """ + queryset = EditSubmission.objects.all() filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] search_fields = ["reason", "changes"] @@ -1425,7 +1374,7 @@ class EditSubmissionViewSet(viewsets.ModelViewSet): # User filter user_id = self.request.query_params.get("user") if user_id: - queryset = queryset.filter(user_id=user_id) + queryset = queryset.filter(user_id=user_id) return queryset @@ -1452,15 +1401,12 @@ class EditSubmissionViewSet(viewsets.ModelViewSet): # Lock the row for update - other transactions will fail immediately submission = EditSubmission.objects.select_for_update(nowait=True).get(pk=pk) except EditSubmission.DoesNotExist: - return Response( - {"error": "Submission not found"}, - status=status.HTTP_404_NOT_FOUND - ) + return Response({"error": "Submission not found"}, status=status.HTTP_404_NOT_FOUND) except DatabaseError: # Row is already locked by another transaction return Response( {"error": "Submission is being claimed by another moderator. Please try again."}, - status=status.HTTP_409_CONFLICT + status=status.HTTP_409_CONFLICT, ) # Check if already claimed @@ -1471,14 +1417,14 @@ class EditSubmissionViewSet(viewsets.ModelViewSet): "claimed_by": submission.claimed_by.username if submission.claimed_by else None, "claimed_at": submission.claimed_at.isoformat() if submission.claimed_at else None, }, - status=status.HTTP_409_CONFLICT + status=status.HTTP_409_CONFLICT, ) # Check if in valid state for claiming if submission.status != "PENDING": return Response( {"error": f"Cannot claim submission in {submission.status} state"}, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) try: @@ -1512,15 +1458,11 @@ class EditSubmissionViewSet(viewsets.ModelViewSet): # Only the claiming user or an admin can unclaim if submission.claimed_by != request.user and not request.user.is_staff: return Response( - {"error": "Only the claiming moderator or an admin can unclaim"}, - status=status.HTTP_403_FORBIDDEN + {"error": "Only the claiming moderator or an admin can unclaim"}, status=status.HTTP_403_FORBIDDEN ) if submission.status != "CLAIMED": - return Response( - {"error": "Submission is not claimed"}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"error": "Submission is not claimed"}, status=status.HTTP_400_BAD_REQUEST) try: submission.unclaim(user=request.user) @@ -1557,8 +1499,8 @@ class EditSubmissionViewSet(viewsets.ModelViewSet): reason = request.data.get("reason", "") try: - submission.reject(moderator=user, reason=reason) - return Response(self.get_serializer(submission).data) + submission.reject(moderator=user, reason=reason) + return Response(self.get_serializer(submission).data) except Exception as e: return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) @@ -1569,8 +1511,8 @@ class EditSubmissionViewSet(viewsets.ModelViewSet): reason = request.data.get("reason", "") try: - submission.escalate(moderator=user, reason=reason) - return Response(self.get_serializer(submission).data) + submission.escalate(moderator=user, reason=reason) + return Response(self.get_serializer(submission).data) except Exception as e: return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) @@ -1582,6 +1524,7 @@ class PhotoSubmissionViewSet(viewsets.ModelViewSet): Includes claim/unclaim endpoints with concurrency protection using database row locking (select_for_update) to prevent race conditions. """ + queryset = PhotoSubmission.objects.all() serializer_class = PhotoSubmissionSerializer filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] @@ -1599,7 +1542,7 @@ class PhotoSubmissionViewSet(viewsets.ModelViewSet): # User filter user_id = self.request.query_params.get("user") if user_id: - queryset = queryset.filter(user_id=user_id) + queryset = queryset.filter(user_id=user_id) return queryset @@ -1617,14 +1560,11 @@ class PhotoSubmissionViewSet(viewsets.ModelViewSet): try: submission = PhotoSubmission.objects.select_for_update(nowait=True).get(pk=pk) except PhotoSubmission.DoesNotExist: - return Response( - {"error": "Submission not found"}, - status=status.HTTP_404_NOT_FOUND - ) + return Response({"error": "Submission not found"}, status=status.HTTP_404_NOT_FOUND) except DatabaseError: return Response( {"error": "Submission is being claimed by another moderator. Please try again."}, - status=status.HTTP_409_CONFLICT + status=status.HTTP_409_CONFLICT, ) if submission.status == "CLAIMED": @@ -1634,13 +1574,13 @@ class PhotoSubmissionViewSet(viewsets.ModelViewSet): "claimed_by": submission.claimed_by.username if submission.claimed_by else None, "claimed_at": submission.claimed_at.isoformat() if submission.claimed_at else None, }, - status=status.HTTP_409_CONFLICT + status=status.HTTP_409_CONFLICT, ) if submission.status != "PENDING": return Response( {"error": f"Cannot claim submission in {submission.status} state"}, - status=status.HTTP_400_BAD_REQUEST + status=status.HTTP_400_BAD_REQUEST, ) try: @@ -1669,15 +1609,11 @@ class PhotoSubmissionViewSet(viewsets.ModelViewSet): if submission.claimed_by != request.user and not request.user.is_staff: return Response( - {"error": "Only the claiming moderator or an admin can unclaim"}, - status=status.HTTP_403_FORBIDDEN + {"error": "Only the claiming moderator or an admin can unclaim"}, status=status.HTTP_403_FORBIDDEN ) if submission.status != "CLAIMED": - return Response( - {"error": "Submission is not claimed"}, - status=status.HTTP_400_BAD_REQUEST - ) + return Response({"error": "Submission is not claimed"}, status=status.HTTP_400_BAD_REQUEST) try: submission.unclaim(user=request.user) @@ -1731,4 +1667,3 @@ class PhotoSubmissionViewSet(viewsets.ModelViewSet): return Response(self.get_serializer(submission).data) except Exception as e: return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) - diff --git a/backend/apps/parks/admin.py b/backend/apps/parks/admin.py index 9921c85a..021b1c95 100644 --- a/backend/apps/parks/admin.py +++ b/backend/apps/parks/admin.py @@ -156,7 +156,6 @@ class ParkLocationAdmin(QueryOptimizationMixin, GISModelAdmin): "description": "OpenStreetMap identifiers for data synchronization.", }, ), - ) @admin.display(description="Park") @@ -358,9 +357,7 @@ class ParkAdmin( for park in queryset: # Statistics are auto-calculated, so just touch the record park.save(update_fields=["updated_at"]) - self.message_user( - request, f"Successfully recalculated statistics for {queryset.count()} parks." - ) + self.message_user(request, f"Successfully recalculated statistics for {queryset.count()} parks.") def get_actions(self, request): """Add custom actions to the admin.""" @@ -482,9 +479,7 @@ class CompanyHeadquartersInline(admin.StackedInline): ) -class CompanyHeadquartersAdmin( - QueryOptimizationMixin, TimestampFieldsMixin, BaseModelAdmin -): +class CompanyHeadquartersAdmin(QueryOptimizationMixin, TimestampFieldsMixin, BaseModelAdmin): """ Admin interface for standalone CompanyHeadquarters management. @@ -661,7 +656,7 @@ class CompanyAdmin( color = colors.get(role, "#6c757d") badges.append( f'{role}' ) return format_html("".join(badges)) @@ -702,9 +697,7 @@ class CompanyAdmin( """Refresh park count statistics for selected companies.""" for company in queryset: company.save(update_fields=["updated_at"]) - self.message_user( - request, f"Successfully updated counts for {queryset.count()} companies." - ) + self.message_user(request, f"Successfully updated counts for {queryset.count()} companies.") def get_actions(self, request): """Add custom actions to the admin.""" @@ -840,12 +833,8 @@ class ParkReviewAdmin(QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin) """Display moderation status with color coding.""" if obj.moderated_by: if obj.is_published: - return format_html( - 'Approved' - ) - return format_html( - 'Rejected' - ) + return format_html('Approved') + return format_html('Rejected') return format_html('Pending') def save_model(self, request, obj, form, change): diff --git a/backend/apps/parks/apps.py b/backend/apps/parks/apps.py index a88f63c7..7ccbe6a4 100644 --- a/backend/apps/parks/apps.py +++ b/backend/apps/parks/apps.py @@ -22,9 +22,7 @@ class ParksConfig(AppConfig): from apps.parks.models import Park # Register FSM transitions for Park - apply_state_machine( - Park, field_name="status", choice_group="statuses", domain="parks" - ) + apply_state_machine(Park, field_name="status", choice_group="statuses", domain="parks") def _register_callbacks(self): """Register FSM transition callbacks for park models.""" @@ -42,31 +40,16 @@ class ParksConfig(AppConfig): from apps.parks.models import Park # Cache invalidation for all park status changes - register_callback( - Park, 'status', '*', '*', - ParkCacheInvalidation() - ) + register_callback(Park, "status", "*", "*", ParkCacheInvalidation()) # API cache invalidation - register_callback( - Park, 'status', '*', '*', - APICacheInvalidation(include_geo_cache=True) - ) + register_callback(Park, "status", "*", "*", APICacheInvalidation(include_geo_cache=True)) # Search text update - register_callback( - Park, 'status', '*', '*', - SearchTextUpdateCallback() - ) + register_callback(Park, "status", "*", "*", SearchTextUpdateCallback()) # Notification for significant status changes - register_callback( - Park, 'status', '*', 'CLOSED_PERM', - StatusChangeNotification(notify_admins=True) - ) - register_callback( - Park, 'status', '*', 'DEMOLISHED', - StatusChangeNotification(notify_admins=True) - ) + register_callback(Park, "status", "*", "CLOSED_PERM", StatusChangeNotification(notify_admins=True)) + register_callback(Park, "status", "*", "DEMOLISHED", StatusChangeNotification(notify_admins=True)) logger.debug("Registered park transition callbacks") diff --git a/backend/apps/parks/choices.py b/backend/apps/parks/choices.py index 37e83da8..78930077 100644 --- a/backend/apps/parks/choices.py +++ b/backend/apps/parks/choices.py @@ -15,101 +15,101 @@ PARK_STATUSES = [ label="Operating", description="Park is currently open and operating normally", metadata={ - 'color': 'green', - 'icon': 'check-circle', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 1, - 'can_transition_to': [ - 'CLOSED_TEMP', - 'CLOSED_PERM', + "color": "green", + "icon": "check-circle", + "css_class": "bg-green-100 text-green-800", + "sort_order": 1, + "can_transition_to": [ + "CLOSED_TEMP", + "CLOSED_PERM", ], - 'requires_moderator': False, - 'is_final': False, - 'is_initial': True, + "requires_moderator": False, + "is_final": False, + "is_initial": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CLOSED_TEMP", label="Temporarily Closed", description="Park is temporarily closed for maintenance, weather, or seasonal reasons", metadata={ - 'color': 'yellow', - 'icon': 'pause-circle', - 'css_class': 'bg-yellow-100 text-yellow-800', - 'sort_order': 2, - 'can_transition_to': [ - 'CLOSED_PERM', + "color": "yellow", + "icon": "pause-circle", + "css_class": "bg-yellow-100 text-yellow-800", + "sort_order": 2, + "can_transition_to": [ + "CLOSED_PERM", ], - 'requires_moderator': False, - 'is_final': False, + "requires_moderator": False, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CLOSED_PERM", label="Permanently Closed", description="Park has been permanently closed and will not reopen", metadata={ - 'color': 'red', - 'icon': 'x-circle', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 3, - 'can_transition_to': [ - 'DEMOLISHED', - 'RELOCATED', + "color": "red", + "icon": "x-circle", + "css_class": "bg-red-100 text-red-800", + "sort_order": 3, + "can_transition_to": [ + "DEMOLISHED", + "RELOCATED", ], - 'requires_moderator': True, - 'is_final': False, + "requires_moderator": True, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="UNDER_CONSTRUCTION", label="Under Construction", description="Park is currently being built or undergoing major renovation", metadata={ - 'color': 'blue', - 'icon': 'tool', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 4, - 'can_transition_to': [ - 'OPERATING', + "color": "blue", + "icon": "tool", + "css_class": "bg-blue-100 text-blue-800", + "sort_order": 4, + "can_transition_to": [ + "OPERATING", ], - 'requires_moderator': False, - 'is_final': False, + "requires_moderator": False, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="DEMOLISHED", label="Demolished", description="Park has been completely demolished and removed", metadata={ - 'color': 'gray', - 'icon': 'trash', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 5, - 'can_transition_to': [], - 'requires_moderator': True, - 'is_final': True, + "color": "gray", + "icon": "trash", + "css_class": "bg-gray-100 text-gray-800", + "sort_order": 5, + "can_transition_to": [], + "requires_moderator": True, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="RELOCATED", label="Relocated", description="Park has been moved to a different location", metadata={ - 'color': 'purple', - 'icon': 'arrow-right', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 6, - 'can_transition_to': [], - 'requires_moderator': True, - 'is_final': True, + "color": "purple", + "icon": "arrow-right", + "css_class": "bg-purple-100 text-purple-800", + "sort_order": 6, + "can_transition_to": [], + "requires_moderator": True, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -119,133 +119,88 @@ PARK_TYPES = [ value="THEME_PARK", label="Theme Park", description="Large-scale amusement park with themed areas and attractions", - metadata={ - 'color': 'red', - 'icon': 'castle', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 1 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "red", "icon": "castle", "css_class": "bg-red-100 text-red-800", "sort_order": 1}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="AMUSEMENT_PARK", label="Amusement Park", description="Traditional amusement park with rides and games", - metadata={ - 'color': 'blue', - 'icon': 'ferris-wheel', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 2 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "blue", "icon": "ferris-wheel", "css_class": "bg-blue-100 text-blue-800", "sort_order": 2}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="WATER_PARK", label="Water Park", description="Park featuring water-based attractions and activities", - metadata={ - 'color': 'cyan', - 'icon': 'water', - 'css_class': 'bg-cyan-100 text-cyan-800', - 'sort_order': 3 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "cyan", "icon": "water", "css_class": "bg-cyan-100 text-cyan-800", "sort_order": 3}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="FAMILY_ENTERTAINMENT_CENTER", label="Family Entertainment Center", description="Indoor entertainment facility with games and family attractions", - metadata={ - 'color': 'green', - 'icon': 'family', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 4 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "green", "icon": "family", "css_class": "bg-green-100 text-green-800", "sort_order": 4}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="CARNIVAL", label="Carnival", description="Traveling amusement show with rides, games, and entertainment", - metadata={ - 'color': 'yellow', - 'icon': 'carnival', - 'css_class': 'bg-yellow-100 text-yellow-800', - 'sort_order': 5 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "yellow", "icon": "carnival", "css_class": "bg-yellow-100 text-yellow-800", "sort_order": 5}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="FAIR", label="Fair", description="Temporary event featuring rides, games, and agricultural exhibits", - metadata={ - 'color': 'orange', - 'icon': 'fair', - 'css_class': 'bg-orange-100 text-orange-800', - 'sort_order': 6 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "orange", "icon": "fair", "css_class": "bg-orange-100 text-orange-800", "sort_order": 6}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="PIER", label="Pier", description="Seaside entertainment pier with rides and attractions", - metadata={ - 'color': 'teal', - 'icon': 'pier', - 'css_class': 'bg-teal-100 text-teal-800', - 'sort_order': 7 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "teal", "icon": "pier", "css_class": "bg-teal-100 text-teal-800", "sort_order": 7}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="BOARDWALK", label="Boardwalk", description="Waterfront entertainment area with rides and attractions", metadata={ - 'color': 'indigo', - 'icon': 'boardwalk', - 'css_class': 'bg-indigo-100 text-indigo-800', - 'sort_order': 8 + "color": "indigo", + "icon": "boardwalk", + "css_class": "bg-indigo-100 text-indigo-800", + "sort_order": 8, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="SAFARI_PARK", label="Safari Park", description="Wildlife park with drive-through animal experiences", metadata={ - 'color': 'emerald', - 'icon': 'safari', - 'css_class': 'bg-emerald-100 text-emerald-800', - 'sort_order': 9 + "color": "emerald", + "icon": "safari", + "css_class": "bg-emerald-100 text-emerald-800", + "sort_order": 9, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="ZOO", label="Zoo", description="Zoological park with animal exhibits and educational programs", - metadata={ - 'color': 'lime', - 'icon': 'zoo', - 'css_class': 'bg-lime-100 text-lime-800', - 'sort_order': 10 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "lime", "icon": "zoo", "css_class": "bg-lime-100 text-lime-800", "sort_order": 10}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="OTHER", label="Other", description="Park type that doesn't fit into standard categories", - metadata={ - 'color': 'gray', - 'icon': 'other', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 11 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 11}, + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -256,30 +211,30 @@ PARKS_COMPANY_ROLES = [ label="Park Operator", description="Company that operates and manages theme parks and amusement facilities", metadata={ - 'color': 'blue', - 'icon': 'building-office', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 1, - 'domain': 'parks', - 'permissions': ['manage_parks', 'view_operations'], - 'url_pattern': '/parks/operators/{slug}/' + "color": "blue", + "icon": "building-office", + "css_class": "bg-blue-100 text-blue-800", + "sort_order": 1, + "domain": "parks", + "permissions": ["manage_parks", "view_operations"], + "url_pattern": "/parks/operators/{slug}/", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="PROPERTY_OWNER", label="Property Owner", description="Company that owns the land and property where parks are located", metadata={ - 'color': 'green', - 'icon': 'home', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 2, - 'domain': 'parks', - 'permissions': ['manage_property', 'view_ownership'], - 'url_pattern': '/parks/owners/{slug}/' + "color": "green", + "icon": "home", + "css_class": "bg-green-100 text-green-800", + "sort_order": 2, + "domain": "parks", + "permissions": ["manage_property", "view_ownership"], + "url_pattern": "/parks/owners/{slug}/", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -292,7 +247,7 @@ def register_parks_choices(): choices=PARK_STATUSES, domain="parks", description="Park operational status options", - metadata={'domain': 'parks', 'type': 'status'} + metadata={"domain": "parks", "type": "status"}, ) register_choices( @@ -300,7 +255,7 @@ def register_parks_choices(): choices=PARK_TYPES, domain="parks", description="Park type and category classifications", - metadata={'domain': 'parks', 'type': 'park_type'} + metadata={"domain": "parks", "type": "park_type"}, ) register_choices( @@ -308,7 +263,7 @@ def register_parks_choices(): choices=PARKS_COMPANY_ROLES, domain="parks", description="Company role classifications for parks domain (OPERATOR and PROPERTY_OWNER only)", - metadata={'domain': 'parks', 'type': 'company_role'} + metadata={"domain": "parks", "type": "company_role"}, ) diff --git a/backend/apps/parks/filters.py b/backend/apps/parks/filters.py index 25148dd9..80bcb450 100644 --- a/backend/apps/parks/filters.py +++ b/backend/apps/parks/filters.py @@ -29,7 +29,7 @@ def validate_positive_integer(value): raise ValidationError(_("Value must be a positive integer")) return int(value) except (TypeError, ValueError): - raise ValidationError(_("Invalid number format")) + raise ValidationError(_("Invalid number format")) from None class ParkFilter(FilterSet): @@ -341,9 +341,7 @@ class ParkFilter(FilterSet): if value: return queryset.filter(coaster_count__gt=0) else: - return queryset.filter( - models.Q(coaster_count__isnull=True) | models.Q(coaster_count=0) - ) + return queryset.filter(models.Q(coaster_count__isnull=True) | models.Q(coaster_count=0)) def filter_min_rating(self, queryset, name, value): """Filter parks by minimum rating""" diff --git a/backend/apps/parks/forms.py b/backend/apps/parks/forms.py index 1be322eb..be2fcbb0 100644 --- a/backend/apps/parks/forms.py +++ b/backend/apps/parks/forms.py @@ -256,9 +256,7 @@ class ParkForm(forms.ModelForm): # Validate range if latitude < -90 or latitude > 90: - raise forms.ValidationError( - "Latitude must be between -90 and 90 degrees." - ) + raise forms.ValidationError("Latitude must be between -90 and 90 degrees.") # Convert to string to preserve exact decimal places return str(latitude) @@ -277,9 +275,7 @@ class ParkForm(forms.ModelForm): # Validate range if longitude < -180 or longitude > 180: - raise forms.ValidationError( - "Longitude must be between -180 and 180 degrees." - ) + raise forms.ValidationError("Longitude must be between -180 and 180 degrees.") # Convert to string to preserve exact decimal places return str(longitude) @@ -314,7 +310,7 @@ class ParkForm(forms.ModelForm): setattr(park_location, key, value) # Handle coordinates if provided - if "latitude" in location_data and "longitude" in location_data: + if "latitude" in location_data and "longitude" in location_data: # noqa: SIM102 if location_data["latitude"] and location_data["longitude"]: park_location.set_coordinates( float(location_data["latitude"]), @@ -324,7 +320,7 @@ class ParkForm(forms.ModelForm): except ParkLocation.DoesNotExist: # Create new ParkLocation coordinates_data = {} - if "latitude" in location_data and "longitude" in location_data: + if "latitude" in location_data and "longitude" in location_data: # noqa: SIM102 if location_data["latitude"] and location_data["longitude"]: coordinates_data = { "latitude": float(location_data["latitude"]), @@ -332,19 +328,13 @@ class ParkForm(forms.ModelForm): } # Remove coordinate fields from location_data for creation - creation_data = { - k: v - for k, v in location_data.items() - if k not in ["latitude", "longitude"] - } + creation_data = {k: v for k, v in location_data.items() if k not in ["latitude", "longitude"]} creation_data.setdefault("country", "USA") park_location = ParkLocation.objects.create(park=park, **creation_data) if coordinates_data: - park_location.set_coordinates( - coordinates_data["latitude"], coordinates_data["longitude"] - ) + park_location.set_coordinates(coordinates_data["latitude"], coordinates_data["longitude"]) park_location.save() if commit: diff --git a/backend/apps/parks/management/commands/create_sample_data.py b/backend/apps/parks/management/commands/create_sample_data.py index d89aa716..a0e6e21c 100644 --- a/backend/apps/parks/management/commands/create_sample_data.py +++ b/backend/apps/parks/management/commands/create_sample_data.py @@ -27,9 +27,7 @@ class Command(BaseCommand): self.create_park_areas() self.create_reviews() - self.stdout.write( - self.style.SUCCESS("Successfully created comprehensive sample data!") - ) + self.stdout.write(self.style.SUCCESS("Successfully created comprehensive sample data!")) self.print_summary() except Exception as e: @@ -101,13 +99,9 @@ class Command(BaseCommand): ] for company_data in park_operators_data: - company, created = ParkCompany.objects.get_or_create( - slug=company_data["slug"], defaults=company_data - ) + company, created = ParkCompany.objects.get_or_create(slug=company_data["slug"], defaults=company_data) self.created_companies[company.slug] = company - self.stdout.write( - f" {'Created' if created else 'Found'} park company: {company.name}" - ) + self.stdout.write(f" {'Created' if created else 'Found'} park company: {company.name}") # Ride manufacturers and designers (using rides.models.Company) ride_companies_data = [ @@ -194,13 +188,9 @@ class Command(BaseCommand): ] for company_data in ride_companies_data: - company, created = RideCompany.objects.get_or_create( - slug=company_data["slug"], defaults=company_data - ) + company, created = RideCompany.objects.get_or_create(slug=company_data["slug"], defaults=company_data) self.created_companies[company.slug] = company - self.stdout.write( - f" {'Created' if created else 'Found'} ride company: {company.name}" - ) + self.stdout.write(f" {'Created' if created else 'Found'} ride company: {company.name}") def create_parks(self): """Create parks with proper operator relationships.""" diff --git a/backend/apps/parks/management/commands/fix_migrations.py b/backend/apps/parks/management/commands/fix_migrations.py index 6bec9c84..c8d45b91 100644 --- a/backend/apps/parks/management/commands/fix_migrations.py +++ b/backend/apps/parks/management/commands/fix_migrations.py @@ -31,6 +31,4 @@ class Command(BaseCommand): """ ) - self.stdout.write( - self.style.SUCCESS("Successfully fixed migration history") - ) + self.stdout.write(self.style.SUCCESS("Successfully fixed migration history")) diff --git a/backend/apps/parks/management/commands/seed_initial_data.py b/backend/apps/parks/management/commands/seed_initial_data.py index 3f386436..0f164597 100644 --- a/backend/apps/parks/management/commands/seed_initial_data.py +++ b/backend/apps/parks/management/commands/seed_initial_data.py @@ -50,13 +50,9 @@ class Command(BaseCommand): companies = {} for company_data in companies_data: - operator, created = Operator.objects.get_or_create( - name=company_data["name"], defaults=company_data - ) + operator, created = Operator.objects.get_or_create(name=company_data["name"], defaults=company_data) companies[operator.name] = operator - self.stdout.write( - f"{'Created' if created else 'Found'} company: {operator.name}" - ) + self.stdout.write(f"{'Created' if created else 'Found'} company: {operator.name}") # Create parks with their locations parks_data = [ @@ -317,9 +313,7 @@ class Command(BaseCommand): postal_code=loc_data["postal_code"], ) # Set coordinates using the helper method - park_location.set_coordinates( - loc_data["latitude"], loc_data["longitude"] - ) + park_location.set_coordinates(loc_data["latitude"], loc_data["longitude"]) park_location.save() # Create areas for park @@ -329,8 +323,6 @@ class Command(BaseCommand): park=park, defaults={"description": area_data["description"]}, ) - self.stdout.write( - f"{'Created' if created else 'Found'} area: {area.name} in {park.name}" - ) + self.stdout.write(f"{'Created' if created else 'Found'} area: {area.name} in {park.name}") self.stdout.write(self.style.SUCCESS("Successfully seeded initial park data")) diff --git a/backend/apps/parks/management/commands/seed_sample_data.py b/backend/apps/parks/management/commands/seed_sample_data.py index 623ab271..52454d59 100644 --- a/backend/apps/parks/management/commands/seed_sample_data.py +++ b/backend/apps/parks/management/commands/seed_sample_data.py @@ -43,19 +43,13 @@ class Command(BaseCommand): # Log what will be deleted self.stdout.write(f" Found {park_review_count} park reviews to delete") self.stdout.write(f" Found {ride_review_count} ride reviews to delete") - self.stdout.write( - f" Found {rollercoaster_stats_count} roller coaster stats to delete" - ) + self.stdout.write(f" Found {rollercoaster_stats_count} roller coaster stats to delete") self.stdout.write(f" Found {ride_count} rides to delete") self.stdout.write(f" Found {ride_model_count} ride models to delete") self.stdout.write(f" Found {park_area_count} park areas to delete") - self.stdout.write( - f" Found {park_location_count} park locations to delete" - ) + self.stdout.write(f" Found {park_location_count} park locations to delete") self.stdout.write(f" Found {park_count} parks to delete") - self.stdout.write( - f" Found {ride_company_count} ride companies to delete" - ) + self.stdout.write(f" Found {ride_company_count} ride companies to delete") self.stdout.write(f" Found {company_count} park companies to delete") self.stdout.write(f" Found {test_user_count} test users to delete") @@ -72,9 +66,7 @@ class Command(BaseCommand): # Roller coaster stats (references Ride) if rollercoaster_stats_count > 0: RollerCoasterStats.objects.all().delete() - self.stdout.write( - f" Deleted {rollercoaster_stats_count} roller coaster stats" - ) + self.stdout.write(f" Deleted {rollercoaster_stats_count} roller coaster stats") # Rides (references Park, RideCompany, RideModel) if ride_count > 0: @@ -116,18 +108,14 @@ class Command(BaseCommand): User.objects.filter(username="testuser").delete() self.stdout.write(f" Deleted {test_user_count} test users") - self.stdout.write( - self.style.SUCCESS("Successfully cleaned up existing sample data!") - ) + self.stdout.write(self.style.SUCCESS("Successfully cleaned up existing sample data!")) except Exception as e: self.logger.error( f"Error during data cleanup: {str(e)}", exc_info=True, ) - self.stdout.write( - self.style.ERROR(f"Failed to clean up existing data: {str(e)}") - ) + self.stdout.write(self.style.ERROR(f"Failed to clean up existing data: {str(e)}")) raise def handle(self, *args, **options): @@ -137,9 +125,7 @@ class Command(BaseCommand): # Check if required tables exist if not self.check_required_tables(): self.stdout.write( - self.style.ERROR( - "Required database tables are missing. Please run migrations first." - ) + self.style.ERROR("Required database tables are missing. Please run migrations first.") ) return @@ -163,17 +149,11 @@ class Command(BaseCommand): # Add sample reviews for testing self.create_reviews() - self.stdout.write( - self.style.SUCCESS("Successfully created comprehensive sample data!") - ) + self.stdout.write(self.style.SUCCESS("Successfully created comprehensive sample data!")) except Exception as e: - self.logger.error( - f"Error during sample data creation: {str(e)}", exc_info=True - ) - self.stdout.write( - self.style.ERROR(f"Failed to create sample data: {str(e)}") - ) + self.logger.error(f"Error during sample data creation: {str(e)}", exc_info=True) + self.stdout.write(self.style.ERROR(f"Failed to create sample data: {str(e)}")) raise def check_required_tables(self): @@ -202,11 +182,7 @@ class Command(BaseCommand): missing_tables.append(model._meta.label) if missing_tables: - self.stdout.write( - self.style.WARNING( - f"Missing tables for models: {', '.join(missing_tables)}" - ) - ) + self.stdout.write(self.style.WARNING(f"Missing tables for models: {', '.join(missing_tables)}")) return False self.stdout.write(self.style.SUCCESS("All required tables exist.")) @@ -357,9 +333,7 @@ class Command(BaseCommand): }" ) except Exception as e: - self.logger.error( - f"Error creating park company {data['name']}: {str(e)}" - ) + self.logger.error(f"Error creating park company {data['name']}: {str(e)}") raise # Create companies in rides app (for manufacturers and designers) @@ -382,9 +356,7 @@ class Command(BaseCommand): }" ) except Exception as e: - self.logger.error( - f"Error creating ride company {data['name']}: {str(e)}" - ) + self.logger.error(f"Error creating ride company {data['name']}: {str(e)}") raise except Exception as e: @@ -512,9 +484,7 @@ class Command(BaseCommand): try: operator = self.park_companies[park_data["operator"]] property_owner = ( - self.park_companies.get(park_data["property_owner"]) - if park_data["property_owner"] - else None + self.park_companies.get(park_data["property_owner"]) if park_data["property_owner"] else None ) park, created = Park.objects.get_or_create( @@ -530,9 +500,7 @@ class Command(BaseCommand): }, ) self.parks[park_data["name"]] = park - self.stdout.write( - f" {'Created' if created else 'Found'} park: {park.name}" - ) + self.stdout.write(f" {'Created' if created else 'Found'} park: {park.name}") # Create location for park if created: @@ -547,9 +515,7 @@ class Command(BaseCommand): postal_code=loc_data["postal_code"], ) # Set coordinates using the helper method - park_location.set_coordinates( - loc_data["latitude"], loc_data["longitude"] - ) + park_location.set_coordinates(loc_data["latitude"], loc_data["longitude"]) park_location.save() except Exception as e: self.logger.error( @@ -560,9 +526,7 @@ class Command(BaseCommand): raise except Exception as e: - self.logger.error( - f"Error creating park {park_data['name']}: {str(e)}" - ) + self.logger.error(f"Error creating park {park_data['name']}: {str(e)}") raise except Exception as e: @@ -633,9 +597,7 @@ class Command(BaseCommand): }" ) except Exception as e: - self.logger.error( - f"Error creating ride model {model_data['name']}: {str(e)}" - ) + self.logger.error(f"Error creating ride model {model_data['name']}: {str(e)}") raise # Create rides @@ -834,9 +796,7 @@ class Command(BaseCommand): for ride_data in rides_data: try: park = self.parks[ride_data["park"]] - manufacturer = self.ride_companies.get( - ride_data.get("manufacturer") - ) + manufacturer = self.ride_companies.get(ride_data.get("manufacturer")) designer = self.ride_companies.get(ride_data.get("designer")) ride_model = self.ride_models.get(ride_data.get("ride_model")) @@ -854,9 +814,7 @@ class Command(BaseCommand): }, ) self.rides[ride_data["name"]] = ride - self.stdout.write( - f" {'Created' if created else 'Found'} ride: {ride.name}" - ) + self.stdout.write(f" {'Created' if created else 'Found'} ride: {ride.name}") # Create roller coaster stats if provided if created and "coaster_stats" in ride_data: @@ -872,9 +830,7 @@ class Command(BaseCommand): raise except Exception as e: - self.logger.error( - f"Error creating ride {ride_data['name']}: {str(e)}" - ) + self.logger.error(f"Error creating ride {ride_data['name']}: {str(e)}") raise except Exception as e: @@ -1011,9 +967,7 @@ class Command(BaseCommand): } in {park.name}" ) except Exception as e: - self.logger.error( - f"Error creating areas for park {area_group['park']}: {str(e)}" - ) + self.logger.error(f"Error creating areas for park {area_group['park']}: {str(e)}") raise except Exception as e: diff --git a/backend/apps/parks/management/commands/test_location.py b/backend/apps/parks/management/commands/test_location.py index 9a945086..c840325c 100644 --- a/backend/apps/parks/management/commands/test_location.py +++ b/backend/apps/parks/management/commands/test_location.py @@ -85,9 +85,7 @@ class Command(BaseCommand): "country": "USA", }, ) - location2.set_coordinates( - 34.4244, -118.5971 - ) # Six Flags Magic Mountain coordinates + location2.set_coordinates(34.4244, -118.5971) # Six Flags Magic Mountain coordinates location2.save() # Test distance calculation @@ -107,9 +105,7 @@ class Command(BaseCommand): # Find parks within 100km of a point # Same as Disneyland search_point = Point(-117.9190, 33.8121, srid=4326) - nearby_locations = ParkLocation.objects.filter( - point__distance_lte=(search_point, D(km=100)) - ) + nearby_locations = ParkLocation.objects.filter(point__distance_lte=(search_point, D(km=100))) self.stdout.write(f" Found {nearby_locations.count()} parks within 100km") for loc in nearby_locations: self.stdout.write(f" - {loc.park.name} in {loc.city}, {loc.state}") diff --git a/backend/apps/parks/management/commands/update_park_counts.py b/backend/apps/parks/management/commands/update_park_counts.py index 82bc925c..cca0d52f 100644 --- a/backend/apps/parks/management/commands/update_park_counts.py +++ b/backend/apps/parks/management/commands/update_park_counts.py @@ -20,11 +20,7 @@ class Command(BaseCommand): total_coasters = park.rides.filter(operating_rides, category="RC").count() # Update park counts - Park.objects.filter(id=park.id).update( - total_rides=total_rides, total_roller_coasters=total_coasters - ) + Park.objects.filter(id=park.id).update(total_rides=total_rides, total_roller_coasters=total_coasters) updated += 1 - self.stdout.write( - self.style.SUCCESS(f"Successfully updated counts for {updated} parks") - ) + self.stdout.write(self.style.SUCCESS(f"Successfully updated counts for {updated} parks")) diff --git a/backend/apps/parks/managers.py b/backend/apps/parks/managers.py index 68dbb084..0f0530bb 100644 --- a/backend/apps/parks/managers.py +++ b/backend/apps/parks/managers.py @@ -30,23 +30,15 @@ class ParkQuerySet(StatusQuerySet, ReviewableQuerySet, LocationQuerySet): distinct=True, ), area_count=Count("areas", distinct=True), - review_count=Count( - "reviews", filter=Q(reviews__is_published=True), distinct=True - ), - average_rating_calculated=Avg( - "reviews__rating", filter=Q(reviews__is_published=True) - ), + review_count=Count("reviews", filter=Q(reviews__is_published=True), distinct=True), + average_rating_calculated=Avg("reviews__rating", filter=Q(reviews__is_published=True)), latest_ride_opening=Max("rides__opening_date"), oldest_ride_opening=Min("rides__opening_date"), ) def optimized_for_list(self): """Optimize for park list display.""" - return ( - self.select_related("operator", "property_owner") - .prefetch_related("location") - .with_complete_stats() - ) + return self.select_related("operator", "property_owner").prefetch_related("location").with_complete_stats() def optimized_for_detail(self): """Optimize for park detail display.""" @@ -59,9 +51,9 @@ class ParkQuerySet(StatusQuerySet, ReviewableQuerySet, LocationQuerySet): "areas", Prefetch( "rides", - queryset=Ride.objects.select_related( - "manufacturer", "designer", "ride_model", "park_area" - ).order_by("name"), + queryset=Ride.objects.select_related("manufacturer", "designer", "ride_model", "park_area").order_by( + "name" + ), ), Prefetch( "reviews", @@ -82,9 +74,7 @@ class ParkQuerySet(StatusQuerySet, ReviewableQuerySet, LocationQuerySet): def with_minimum_coasters(self, *, min_coasters: int = 5): """Filter parks with minimum number of coasters.""" - return self.with_complete_stats().filter( - coaster_count_calculated__gte=min_coasters - ) + return self.with_complete_stats().filter(coaster_count_calculated__gte=min_coasters) def large_parks(self, *, min_acres: float = 100.0): """Filter for large parks.""" @@ -123,16 +113,10 @@ class ParkQuerySet(StatusQuerySet, ReviewableQuerySet, LocationQuerySet): """Optimized search for autocomplete.""" return ( self.filter( - Q(name__icontains=query) - | Q(location__city__icontains=query) - | Q(location__state__icontains=query) + Q(name__icontains=query) | Q(location__city__icontains=query) | Q(location__state__icontains=query) ) .select_related("operator", "location") - .only( - "id", "name", "slug", - "location__city", "location__state", - "operator__name" - )[:limit] + .only("id", "name", "slug", "location__city", "location__state", "operator__name")[:limit] ) def with_location(self): @@ -247,9 +231,7 @@ class ParkReviewManager(BaseManager): return self.get_queryset().for_park(park_id=park_id) def by_rating_range(self, *, min_rating: int = 1, max_rating: int = 10): - return self.get_queryset().by_rating_range( - min_rating=min_rating, max_rating=max_rating - ) + return self.get_queryset().by_rating_range(min_rating=min_rating, max_rating=max_rating) def moderation_required(self): return self.get_queryset().moderation_required() @@ -275,17 +257,12 @@ class CompanyQuerySet(BaseQuerySet): return self.annotate( operated_parks_count=Count("operated_parks", distinct=True), owned_parks_count=Count("owned_parks", distinct=True), - total_parks_involvement=Count("operated_parks", distinct=True) - + Count("owned_parks", distinct=True), + total_parks_involvement=Count("operated_parks", distinct=True) + Count("owned_parks", distinct=True), ) def major_operators(self, *, min_parks: int = 5): """Filter for major park operators.""" - return ( - self.operators() - .with_park_counts() - .filter(operated_parks_count__gte=min_parks) - ) + return self.operators().with_park_counts().filter(operated_parks_count__gte=min_parks) def optimized_for_list(self): """Optimize for company list display.""" @@ -313,7 +290,7 @@ class CompanyManager(BaseManager): self.get_queryset() .manufacturers() .annotate(ride_count=Count("manufactured_rides", distinct=True)) - .only('id', 'name', 'slug', 'roles', 'description') + .only("id", "name", "slug", "roles", "description") .order_by("name") ) @@ -323,7 +300,7 @@ class CompanyManager(BaseManager): self.get_queryset() .filter(roles__contains=["DESIGNER"]) .annotate(ride_count=Count("designed_rides", distinct=True)) - .only('id', 'name', 'slug', 'roles', 'description') + .only("id", "name", "slug", "roles", "description") .order_by("name") ) @@ -333,6 +310,6 @@ class CompanyManager(BaseManager): self.get_queryset() .operators() .with_park_counts() - .only('id', 'name', 'slug', 'roles', 'description') + .only("id", "name", "slug", "roles", "description") .order_by("name") ) diff --git a/backend/apps/parks/migrations/0001_initial.py b/backend/apps/parks/migrations/0001_initial.py index 9b9e34c6..412369ba 100644 --- a/backend/apps/parks/migrations/0001_initial.py +++ b/backend/apps/parks/migrations/0001_initial.py @@ -102,16 +102,12 @@ class Migration(migrations.Migration): ), ( "size_acres", - models.DecimalField( - blank=True, decimal_places=2, max_digits=10, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=10, null=True), ), ("website", models.URLField(blank=True)), ( "average_rating", - models.DecimalField( - blank=True, decimal_places=2, max_digits=3, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=3, null=True), ), ("ride_count", models.IntegerField(blank=True, null=True)), ("coaster_count", models.IntegerField(blank=True, null=True)), @@ -266,16 +262,12 @@ class Migration(migrations.Migration): ), ( "size_acres", - models.DecimalField( - blank=True, decimal_places=2, max_digits=10, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=10, null=True), ), ("website", models.URLField(blank=True)), ( "average_rating", - models.DecimalField( - blank=True, decimal_places=2, max_digits=3, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=3, null=True), ), ("ride_count", models.IntegerField(blank=True, null=True)), ("coaster_count", models.IntegerField(blank=True, null=True)), @@ -678,9 +670,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="parklocation", - index=models.Index( - fields=["city", "state"], name="parks_parkl_city_7cc873_idx" - ), + index=models.Index(fields=["city", "state"], name="parks_parkl_city_7cc873_idx"), ), migrations.AlterUniqueTogether( name="parkreview", diff --git a/backend/apps/parks/migrations/0007_companyheadquartersevent_parklocationevent_and_more.py b/backend/apps/parks/migrations/0007_companyheadquartersevent_parklocationevent_and_more.py index 839c05b6..82d30c5a 100644 --- a/backend/apps/parks/migrations/0007_companyheadquartersevent_parklocationevent_and_more.py +++ b/backend/apps/parks/migrations/0007_companyheadquartersevent_parklocationevent_and_more.py @@ -35,9 +35,7 @@ class Migration(migrations.Migration): ), ( "state_province", - models.CharField( - blank=True, help_text="State/Province/Region", max_length=100 - ), + models.CharField(blank=True, help_text="State/Province/Region", max_length=100), ), ( "country", @@ -49,9 +47,7 @@ class Migration(migrations.Migration): ), ( "postal_code", - models.CharField( - blank=True, help_text="ZIP or postal code", max_length=20 - ), + models.CharField(blank=True, help_text="ZIP or postal code", max_length=20), ), ( "mailing_address", diff --git a/backend/apps/parks/migrations/0008_parkphoto_parkphotoevent_and_more.py b/backend/apps/parks/migrations/0008_parkphoto_parkphotoevent_and_more.py index 0b1d2fd0..1eef37f5 100644 --- a/backend/apps/parks/migrations/0008_parkphoto_parkphotoevent_and_more.py +++ b/backend/apps/parks/migrations/0008_parkphoto_parkphotoevent_and_more.py @@ -133,21 +133,15 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="parkphoto", - index=models.Index( - fields=["park", "is_primary"], name="parks_parkp_park_id_eda26e_idx" - ), + index=models.Index(fields=["park", "is_primary"], name="parks_parkp_park_id_eda26e_idx"), ), migrations.AddIndex( model_name="parkphoto", - index=models.Index( - fields=["park", "is_approved"], name="parks_parkp_park_id_5fe576_idx" - ), + index=models.Index(fields=["park", "is_approved"], name="parks_parkp_park_id_5fe576_idx"), ), migrations.AddIndex( model_name="parkphoto", - index=models.Index( - fields=["created_at"], name="parks_parkp_created_033dc3_idx" - ), + index=models.Index(fields=["created_at"], name="parks_parkp_created_033dc3_idx"), ), migrations.AddConstraint( model_name="parkphoto", diff --git a/backend/apps/parks/migrations/0015_populate_hybrid_filtering_fields.py b/backend/apps/parks/migrations/0015_populate_hybrid_filtering_fields.py index 8c6ea11b..2b686d3d 100644 --- a/backend/apps/parks/migrations/0015_populate_hybrid_filtering_fields.py +++ b/backend/apps/parks/migrations/0015_populate_hybrid_filtering_fields.py @@ -11,24 +11,29 @@ def populate_computed_fields(apps, schema_editor): try: # Use raw SQL to update opening_year from opening_date - schema_editor.execute(""" + schema_editor.execute( + """ UPDATE parks_park SET opening_year = EXTRACT(YEAR FROM opening_date) WHERE opening_date IS NOT NULL; - """) + """ + ) # Use raw SQL to populate search_text # This is a simplified version - we'll populate it with just name and description - schema_editor.execute(""" + schema_editor.execute( + """ UPDATE parks_park SET search_text = LOWER( COALESCE(name, '') || ' ' || COALESCE(description, '') ); - """) + """ + ) # Update search_text to include operator names using a join - schema_editor.execute(""" + schema_editor.execute( + """ UPDATE parks_park SET search_text = LOWER( COALESCE(parks_park.name, '') || ' ' || @@ -37,7 +42,8 @@ def populate_computed_fields(apps, schema_editor): ) FROM parks_company WHERE parks_park.operator_id = parks_company.id; - """) + """ + ) finally: # Re-enable pghistory triggers @@ -46,8 +52,8 @@ def populate_computed_fields(apps, schema_editor): def reverse_populate_computed_fields(apps, schema_editor): """Clear computed fields (reverse operation)""" - Park = apps.get_model('parks', 'Park') - Park.objects.update(opening_year=None, search_text='') + Park = apps.get_model("parks", "Park") + Park.objects.update(opening_year=None, search_text="") class Migration(migrations.Migration): diff --git a/backend/apps/parks/migrations/0016_add_hybrid_filtering_indexes.py b/backend/apps/parks/migrations/0016_add_hybrid_filtering_indexes.py index 7b1678b2..4b3859bb 100644 --- a/backend/apps/parks/migrations/0016_add_hybrid_filtering_indexes.py +++ b/backend/apps/parks/migrations/0016_add_hybrid_filtering_indexes.py @@ -13,37 +13,34 @@ class Migration(migrations.Migration): # Composite indexes for common filter combinations migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_status_park_type_idx ON parks_park (status, park_type);", - reverse_sql="DROP INDEX IF EXISTS parks_park_status_park_type_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_status_park_type_idx;", ), migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_opening_year_status_idx ON parks_park (opening_year, status) WHERE opening_year IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS parks_park_opening_year_status_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_opening_year_status_idx;", ), migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_size_rating_idx ON parks_park (size_acres, average_rating) WHERE size_acres IS NOT NULL AND average_rating IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS parks_park_size_rating_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_size_rating_idx;", ), migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_ride_coaster_count_idx ON parks_park (ride_count, coaster_count) WHERE ride_count IS NOT NULL AND coaster_count IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS parks_park_ride_coaster_count_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_ride_coaster_count_idx;", ), - # Full-text search index for search_text field migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_search_text_gin_idx ON parks_park USING gin(to_tsvector('english', search_text));", - reverse_sql="DROP INDEX IF EXISTS parks_park_search_text_gin_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_search_text_gin_idx;", ), - # Trigram index for fuzzy search on search_text migrations.RunSQL( "CREATE EXTENSION IF NOT EXISTS pg_trgm;", - reverse_sql="-- Cannot drop extension as it might be used elsewhere" + reverse_sql="-- Cannot drop extension as it might be used elsewhere", ), migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_search_text_trgm_idx ON parks_park USING gin(search_text gin_trgm_ops);", - reverse_sql="DROP INDEX IF EXISTS parks_park_search_text_trgm_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_search_text_trgm_idx;", ), - # Indexes for location-based filtering (assuming location relationship exists) migrations.RunSQL( """ @@ -51,27 +48,23 @@ class Migration(migrations.Migration): ON parks_parklocation (country, state) WHERE country IS NOT NULL AND state IS NOT NULL; """, - reverse_sql="DROP INDEX IF EXISTS parks_parklocation_country_state_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_parklocation_country_state_idx;", ), - # Index for operator-based filtering migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_operator_status_idx ON parks_park (operator_id, status);", - reverse_sql="DROP INDEX IF EXISTS parks_park_operator_status_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_operator_status_idx;", ), - # Partial indexes for common status filters migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_operating_parks_idx ON parks_park (name, opening_year) WHERE status IN ('OPERATING', 'CLOSED_TEMP');", - reverse_sql="DROP INDEX IF EXISTS parks_park_operating_parks_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_operating_parks_idx;", ), - # Index for ordering by name (already exists but ensuring it's optimized) migrations.RunSQL( "CREATE INDEX IF NOT EXISTS parks_park_name_lower_idx ON parks_park (LOWER(name));", - reverse_sql="DROP INDEX IF EXISTS parks_park_name_lower_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_name_lower_idx;", ), - # Covering index for common query patterns migrations.RunSQL( """ @@ -80,6 +73,6 @@ class Migration(migrations.Migration): INCLUDE (name, slug, size_acres, average_rating, ride_count, coaster_count, operator_id) WHERE status IN ('OPERATING', 'CLOSED_TEMP'); """, - reverse_sql="DROP INDEX IF EXISTS parks_park_hybrid_covering_idx;" + reverse_sql="DROP INDEX IF EXISTS parks_park_hybrid_covering_idx;", ), ] diff --git a/backend/apps/parks/migrations/0019_fix_pghistory_timezone.py b/backend/apps/parks/migrations/0019_fix_pghistory_timezone.py index d0eb0fa8..1e0e4615 100644 --- a/backend/apps/parks/migrations/0019_fix_pghistory_timezone.py +++ b/backend/apps/parks/migrations/0019_fix_pghistory_timezone.py @@ -47,6 +47,6 @@ class Migration(migrations.Migration): reverse_sql=""" -- This is irreversible, but we can drop and recreate without timezone DROP FUNCTION IF EXISTS pgtrigger_insert_insert_66883() CASCADE; - """ + """, ), ] diff --git a/backend/apps/parks/migrations/0020_fix_pghistory_update_timezone.py b/backend/apps/parks/migrations/0020_fix_pghistory_update_timezone.py index de8057bc..b042274f 100644 --- a/backend/apps/parks/migrations/0020_fix_pghistory_update_timezone.py +++ b/backend/apps/parks/migrations/0020_fix_pghistory_update_timezone.py @@ -47,6 +47,6 @@ class Migration(migrations.Migration): reverse_sql=""" -- This is irreversible, but we can drop and recreate without timezone DROP FUNCTION IF EXISTS pgtrigger_update_update_19f56() CASCADE; - """ + """, ), ] diff --git a/backend/apps/parks/migrations/0023_add_company_roles_gin_index.py b/backend/apps/parks/migrations/0023_add_company_roles_gin_index.py index 8f87a01b..08a9de82 100644 --- a/backend/apps/parks/migrations/0023_add_company_roles_gin_index.py +++ b/backend/apps/parks/migrations/0023_add_company_roles_gin_index.py @@ -14,7 +14,7 @@ from django.db import migrations class Migration(migrations.Migration): dependencies = [ - ('parks', '0022_alter_company_roles_alter_companyevent_roles'), + ("parks", "0022_alter_company_roles_alter_companyevent_roles"), ] operations = [ diff --git a/backend/apps/parks/migrations/0024_add_timezone_default.py b/backend/apps/parks/migrations/0024_add_timezone_default.py index 9c535dfe..a2372f0e 100644 --- a/backend/apps/parks/migrations/0024_add_timezone_default.py +++ b/backend/apps/parks/migrations/0024_add_timezone_default.py @@ -11,16 +11,16 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('parks', '0023_add_company_roles_gin_index'), + ("parks", "0023_add_company_roles_gin_index"), ] operations = [ migrations.AlterField( - model_name='park', - name='timezone', + model_name="park", + name="timezone", field=models.CharField( blank=True, - default='UTC', + default="UTC", help_text="Timezone identifier for park operations (e.g., 'America/New_York')", max_length=50, ), diff --git a/backend/apps/parks/migrations/0025_alter_company_options_alter_park_options_and_more.py b/backend/apps/parks/migrations/0025_alter_company_options_alter_park_options_and_more.py index 737c0988..657f7bf9 100644 --- a/backend/apps/parks/migrations/0025_alter_company_options_alter_park_options_and_more.py +++ b/backend/apps/parks/migrations/0025_alter_company_options_alter_park_options_and_more.py @@ -61,16 +61,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="company", name="description", - field=models.TextField( - blank=True, help_text="Detailed company description" - ), + field=models.TextField(blank=True, help_text="Detailed company description"), ), migrations.AlterField( model_name="company", name="founded_year", - field=models.PositiveIntegerField( - blank=True, help_text="Year the company was founded", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Year the company was founded", null=True), ), migrations.AlterField( model_name="company", @@ -80,16 +76,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="company", name="parks_count", - field=models.IntegerField( - default=0, help_text="Number of parks operated (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of parks operated (auto-calculated)"), ), migrations.AlterField( model_name="company", name="rides_count", - field=models.IntegerField( - default=0, help_text="Number of rides manufactured (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of rides manufactured (auto-calculated)"), ), migrations.AlterField( model_name="company", @@ -114,9 +106,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="company", name="slug", - field=models.SlugField( - help_text="URL-friendly identifier", max_length=255, unique=True - ), + field=models.SlugField(help_text="URL-friendly identifier", max_length=255, unique=True), ), migrations.AlterField( model_name="company", @@ -126,16 +116,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="companyevent", name="description", - field=models.TextField( - blank=True, help_text="Detailed company description" - ), + field=models.TextField(blank=True, help_text="Detailed company description"), ), migrations.AlterField( model_name="companyevent", name="founded_year", - field=models.PositiveIntegerField( - blank=True, help_text="Year the company was founded", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Year the company was founded", null=True), ), migrations.AlterField( model_name="companyevent", @@ -145,16 +131,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="companyevent", name="parks_count", - field=models.IntegerField( - default=0, help_text="Number of parks operated (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of parks operated (auto-calculated)"), ), migrations.AlterField( model_name="companyevent", name="rides_count", - field=models.IntegerField( - default=0, help_text="Number of rides manufactured (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of rides manufactured (auto-calculated)"), ), migrations.AlterField( model_name="companyevent", @@ -179,9 +161,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="companyevent", name="slug", - field=models.SlugField( - db_index=False, help_text="URL-friendly identifier", max_length=255 - ), + field=models.SlugField(db_index=False, help_text="URL-friendly identifier", max_length=255), ), migrations.AlterField( model_name="companyevent", @@ -229,9 +209,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="park", name="coaster_count", - field=models.IntegerField( - blank=True, help_text="Total coaster count", null=True - ), + field=models.IntegerField(blank=True, help_text="Total coaster count", null=True), ), migrations.AlterField( model_name="park", @@ -251,16 +229,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="park", name="operating_season", - field=models.CharField( - blank=True, help_text="Operating season", max_length=255 - ), + field=models.CharField(blank=True, help_text="Operating season", max_length=255), ), migrations.AlterField( model_name="park", name="ride_count", - field=models.IntegerField( - blank=True, help_text="Total ride count", null=True - ), + field=models.IntegerField(blank=True, help_text="Total ride count", null=True), ), migrations.AlterField( model_name="park", @@ -276,9 +250,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="park", name="slug", - field=models.SlugField( - help_text="URL-friendly identifier", max_length=255, unique=True - ), + field=models.SlugField(help_text="URL-friendly identifier", max_length=255, unique=True), ), migrations.AlterField( model_name="park", @@ -300,16 +272,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkarea", name="closing_date", - field=models.DateField( - blank=True, help_text="Date this area closed (if applicable)", null=True - ), + field=models.DateField(blank=True, help_text="Date this area closed (if applicable)", null=True), ), migrations.AlterField( model_name="parkarea", name="description", - field=models.TextField( - blank=True, help_text="Detailed description of the area" - ), + field=models.TextField(blank=True, help_text="Detailed description of the area"), ), migrations.AlterField( model_name="parkarea", @@ -319,9 +287,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkarea", name="opening_date", - field=models.DateField( - blank=True, help_text="Date this area opened", null=True - ), + field=models.DateField(blank=True, help_text="Date this area opened", null=True), ), migrations.AlterField( model_name="parkarea", @@ -336,23 +302,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkarea", name="slug", - field=models.SlugField( - help_text="URL-friendly identifier (unique within park)", max_length=255 - ), + field=models.SlugField(help_text="URL-friendly identifier (unique within park)", max_length=255), ), migrations.AlterField( model_name="parkareaevent", name="closing_date", - field=models.DateField( - blank=True, help_text="Date this area closed (if applicable)", null=True - ), + field=models.DateField(blank=True, help_text="Date this area closed (if applicable)", null=True), ), migrations.AlterField( model_name="parkareaevent", name="description", - field=models.TextField( - blank=True, help_text="Detailed description of the area" - ), + field=models.TextField(blank=True, help_text="Detailed description of the area"), ), migrations.AlterField( model_name="parkareaevent", @@ -362,9 +322,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkareaevent", name="opening_date", - field=models.DateField( - blank=True, help_text="Date this area opened", null=True - ), + field=models.DateField(blank=True, help_text="Date this area opened", null=True), ), migrations.AlterField( model_name="parkareaevent", @@ -406,9 +364,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkevent", name="coaster_count", - field=models.IntegerField( - blank=True, help_text="Total coaster count", null=True - ), + field=models.IntegerField(blank=True, help_text="Total coaster count", null=True), ), migrations.AlterField( model_name="parkevent", @@ -428,16 +384,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkevent", name="operating_season", - field=models.CharField( - blank=True, help_text="Operating season", max_length=255 - ), + field=models.CharField(blank=True, help_text="Operating season", max_length=255), ), migrations.AlterField( model_name="parkevent", name="ride_count", - field=models.IntegerField( - blank=True, help_text="Total ride count", null=True - ), + field=models.IntegerField(blank=True, help_text="Total ride count", null=True), ), migrations.AlterField( model_name="parkevent", @@ -453,9 +405,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkevent", name="slug", - field=models.SlugField( - db_index=False, help_text="URL-friendly identifier", max_length=255 - ), + field=models.SlugField(db_index=False, help_text="URL-friendly identifier", max_length=255), ), migrations.AlterField( model_name="parkevent", @@ -496,9 +446,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkphoto", name="caption", - field=models.CharField( - blank=True, help_text="Photo caption or description", max_length=255 - ), + field=models.CharField(blank=True, help_text="Photo caption or description", max_length=255), ), migrations.AlterField( model_name="parkphoto", @@ -549,9 +497,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkphotoevent", name="caption", - field=models.CharField( - blank=True, help_text="Photo caption or description", max_length=255 - ), + field=models.CharField(blank=True, help_text="Photo caption or description", max_length=255), ), migrations.AlterField( model_name="parkphotoevent", @@ -602,16 +548,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkreview", name="is_published", - field=models.BooleanField( - default=True, help_text="Whether this review is publicly visible" - ), + field=models.BooleanField(default=True, help_text="Whether this review is publicly visible"), ), migrations.AlterField( model_name="parkreview", name="moderated_at", - field=models.DateTimeField( - blank=True, help_text="When this review was moderated", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this review was moderated", null=True), ), migrations.AlterField( model_name="parkreview", @@ -628,9 +570,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkreview", name="moderation_notes", - field=models.TextField( - blank=True, help_text="Internal notes from moderators" - ), + field=models.TextField(blank=True, help_text="Internal notes from moderators"), ), migrations.AlterField( model_name="parkreview", @@ -681,16 +621,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkreviewevent", name="is_published", - field=models.BooleanField( - default=True, help_text="Whether this review is publicly visible" - ), + field=models.BooleanField(default=True, help_text="Whether this review is publicly visible"), ), migrations.AlterField( model_name="parkreviewevent", name="moderated_at", - field=models.DateTimeField( - blank=True, help_text="When this review was moderated", null=True - ), + field=models.DateTimeField(blank=True, help_text="When this review was moderated", null=True), ), migrations.AlterField( model_name="parkreviewevent", @@ -709,9 +645,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="parkreviewevent", name="moderation_notes", - field=models.TextField( - blank=True, help_text="Internal notes from moderators" - ), + field=models.TextField(blank=True, help_text="Internal notes from moderators"), ), migrations.AlterField( model_name="parkreviewevent", diff --git a/backend/apps/parks/models/__init__.py b/backend/apps/parks/models/__init__.py index a9410647..4567d53c 100644 --- a/backend/apps/parks/models/__init__.py +++ b/backend/apps/parks/models/__init__.py @@ -9,7 +9,7 @@ while maintaining backward compatibility through the Company alias. """ # Import choices to trigger registration -from ..choices import * +from ..choices import * # noqa: F403 from .areas import ParkArea from .companies import Company, CompanyHeadquarters from .location import ParkLocation diff --git a/backend/apps/parks/models/areas.py b/backend/apps/parks/models/areas.py index 57df59ab..ab635071 100644 --- a/backend/apps/parks/models/areas.py +++ b/backend/apps/parks/models/areas.py @@ -21,16 +21,10 @@ class ParkArea(TrackedModel): help_text="Park this area belongs to", ) name = models.CharField(max_length=255, help_text="Name of the park area") - slug = models.SlugField( - max_length=255, help_text="URL-friendly identifier (unique within park)" - ) + slug = models.SlugField(max_length=255, help_text="URL-friendly identifier (unique within park)") description = models.TextField(blank=True, help_text="Detailed description of the area") - opening_date = models.DateField( - null=True, blank=True, help_text="Date this area opened" - ) - closing_date = models.DateField( - null=True, blank=True, help_text="Date this area closed (if applicable)" - ) + opening_date = models.DateField(null=True, blank=True, help_text="Date this area opened") + closing_date = models.DateField(null=True, blank=True, help_text="Date this area closed (if applicable)") def save(self, *args, **kwargs): if not self.slug: diff --git a/backend/apps/parks/models/companies.py b/backend/apps/parks/models/companies.py index bc981a51..53616f42 100644 --- a/backend/apps/parks/models/companies.py +++ b/backend/apps/parks/models/companies.py @@ -26,15 +26,9 @@ class Company(TrackedModel): website = models.URLField(blank=True, help_text="Company website URL") # Operator-specific fields - founded_year = models.PositiveIntegerField( - blank=True, null=True, help_text="Year the company was founded" - ) - parks_count = models.IntegerField( - default=0, help_text="Number of parks operated (auto-calculated)" - ) - rides_count = models.IntegerField( - default=0, help_text="Number of rides manufactured (auto-calculated)" - ) + founded_year = models.PositiveIntegerField(blank=True, null=True, help_text="Year the company was founded") + parks_count = models.IntegerField(default=0, help_text="Number of parks operated (auto-calculated)") + rides_count = models.IntegerField(default=0, help_text="Number of rides manufactured (auto-calculated)") def save(self, *args, **kwargs): if not self.slug: @@ -72,9 +66,7 @@ class CompanyHeadquarters(models.Model): blank=True, help_text="Mailing address if publicly available", ) - city = models.CharField( - max_length=100, db_index=True, help_text="Headquarters city" - ) + city = models.CharField(max_length=100, db_index=True, help_text="Headquarters city") state_province = models.CharField( max_length=100, blank=True, @@ -87,9 +79,7 @@ class CompanyHeadquarters(models.Model): db_index=True, help_text="Country where headquarters is located", ) - postal_code = models.CharField( - max_length=20, blank=True, help_text="ZIP or postal code" - ) + postal_code = models.CharField(max_length=20, blank=True, help_text="ZIP or postal code") # Optional mailing address if different or more complete mailing_address = models.TextField( diff --git a/backend/apps/parks/models/location.py b/backend/apps/parks/models/location.py index bae65b10..c684ca6c 100644 --- a/backend/apps/parks/models/location.py +++ b/backend/apps/parks/models/location.py @@ -9,9 +9,7 @@ class ParkLocation(models.Model): Represents the geographic location and address of a park, with PostGIS support. """ - park = models.OneToOneField( - "parks.Park", on_delete=models.CASCADE, related_name="location" - ) + park = models.OneToOneField("parks.Park", on_delete=models.CASCADE, related_name="location") # Spatial Data point = models.PointField( @@ -27,10 +25,7 @@ class ParkLocation(models.Model): state = models.CharField(max_length=100, db_index=True) country = models.CharField(max_length=100, default="USA") continent = models.CharField( - max_length=50, - blank=True, - db_index=True, - help_text="Continent where the park is located" + max_length=50, blank=True, db_index=True, help_text="Continent where the park is located" ) postal_code = models.CharField(max_length=20, blank=True) diff --git a/backend/apps/parks/models/media.py b/backend/apps/parks/models/media.py index 48ae2bfd..562985fe 100644 --- a/backend/apps/parks/models/media.py +++ b/backend/apps/parks/models/media.py @@ -22,9 +22,7 @@ def park_photo_upload_path(instance: models.Model, filename: str) -> str: if park is None: raise ValueError("Park cannot be None") - return MediaService.generate_upload_path( - domain="park", identifier=park.slug, filename=filename - ) + return MediaService.generate_upload_path(domain="park", identifier=park.slug, filename=filename) @pghistory.track() @@ -39,23 +37,15 @@ class ParkPhoto(TrackedModel): ) image = models.ForeignKey( - 'django_cloudflareimages_toolkit.CloudflareImage', + "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.CASCADE, - help_text="Park photo stored on Cloudflare Images" + help_text="Park photo stored on Cloudflare Images", ) - caption = models.CharField( - max_length=255, blank=True, help_text="Photo caption or description" - ) - alt_text = models.CharField( - max_length=255, blank=True, help_text="Alternative text for accessibility" - ) - is_primary = models.BooleanField( - default=False, help_text="Whether this is the primary photo for the park" - ) - is_approved = models.BooleanField( - default=False, help_text="Whether this photo has been approved by moderators" - ) + caption = models.CharField(max_length=255, blank=True, help_text="Photo caption or description") + alt_text = models.CharField(max_length=255, blank=True, help_text="Alternative text for accessibility") + is_primary = models.BooleanField(default=False, help_text="Whether this is the primary photo for the park") + is_approved = models.BooleanField(default=False, help_text="Whether this photo has been approved by moderators") # Metadata created_at = models.DateTimeField(auto_now_add=True) @@ -100,9 +90,7 @@ class ParkPhoto(TrackedModel): # Set default caption if not provided if not self.caption and self.uploaded_by: - self.caption = MediaService.generate_default_caption( - self.uploaded_by.username - ) + self.caption = MediaService.generate_default_caption(self.uploaded_by.username) # If this is marked as primary, unmark other primary photos for this park if self.is_primary: diff --git a/backend/apps/parks/models/parks.py b/backend/apps/parks/models/parks.py index 1cc9b199..d154dba5 100644 --- a/backend/apps/parks/models/parks.py +++ b/backend/apps/parks/models/parks.py @@ -45,7 +45,7 @@ class Park(StateMachineMixin, TrackedModel): max_length=30, default="THEME_PARK", db_index=True, - help_text="Type/category of the park" + help_text="Type/category of the park", ) # Location relationship - reverse relation from ParkLocation @@ -118,23 +118,18 @@ class Park(StateMachineMixin, TrackedModel): # Computed fields for hybrid filtering opening_year = models.IntegerField( - null=True, - blank=True, - db_index=True, - help_text="Year the park opened (computed from opening_date)" + null=True, blank=True, db_index=True, help_text="Year the park opened (computed from opening_date)" ) search_text = models.TextField( - blank=True, - db_index=True, - help_text="Searchable text combining name, description, location, and operator" + blank=True, db_index=True, help_text="Searchable text combining name, description, location, and operator" ) # Timezone for park operations timezone = models.CharField( max_length=50, - default='UTC', + default="UTC", blank=True, - help_text="Timezone identifier for park operations (e.g., 'America/New_York')" + help_text="Timezone identifier for park operations (e.g., 'America/New_York')", ) class Meta: @@ -171,8 +166,7 @@ class Park(StateMachineMixin, TrackedModel): ), models.CheckConstraint( name="park_coaster_count_non_negative", - check=models.Q(coaster_count__isnull=True) - | models.Q(coaster_count__gte=0), + check=models.Q(coaster_count__isnull=True) | models.Q(coaster_count__gte=0), violation_error_message="Coaster count must be non-negative", ), # Business rule: Coaster count cannot exceed ride count @@ -204,9 +198,7 @@ class Park(StateMachineMixin, TrackedModel): self.transition_to_under_construction(user=user) self.save() - def close_permanently( - self, *, closing_date=None, user: Optional["AbstractBaseUser"] = None - ) -> None: + def close_permanently(self, *, closing_date=None, user: Optional["AbstractBaseUser"] = None) -> None: """Transition park to CLOSED_PERM status.""" self.transition_to_closed_perm(user=user) if closing_date: @@ -279,7 +271,7 @@ class Park(StateMachineMixin, TrackedModel): # Add location information if available try: - if hasattr(self, 'location') and self.location: + if hasattr(self, "location") and self.location: if self.location.city: search_parts.append(self.location.city) if self.location.state: @@ -299,16 +291,14 @@ class Park(StateMachineMixin, TrackedModel): search_parts.append(self.property_owner.name) # Combine all parts into searchable text - self.search_text = ' '.join(filter(None, search_parts)).lower() + self.search_text = " ".join(filter(None, search_parts)).lower() def clean(self): super().clean() if self.operator and "OPERATOR" not in self.operator.roles: raise ValidationError({"operator": "Company must have the OPERATOR role."}) if self.property_owner and "PROPERTY_OWNER" not in self.property_owner.roles: - raise ValidationError( - {"property_owner": "Company must have the PROPERTY_OWNER role."} - ) + raise ValidationError({"property_owner": "Company must have the PROPERTY_OWNER role."}) def get_absolute_url(self) -> str: return reverse("parks:park_detail", kwargs={"slug": self.slug}) @@ -325,7 +315,7 @@ class Park(StateMachineMixin, TrackedModel): """Returns coordinates as a list [latitude, longitude]""" if hasattr(self, "location") and self.location: coords = self.location.coordinates - if coords and isinstance(coords, (tuple, list)): + if coords and isinstance(coords, tuple | list): return list(coords) return None @@ -349,9 +339,7 @@ class Park(StateMachineMixin, TrackedModel): content_type = ContentType.objects.get_for_model(cls) print(f"Searching HistoricalSlug with content_type: {content_type}") historical = ( - HistoricalSlug.objects.filter(content_type=content_type, slug=slug) - .order_by("-created_at") - .first() + HistoricalSlug.objects.filter(content_type=content_type, slug=slug).order_by("-created_at").first() ) if historical: @@ -373,11 +361,7 @@ class Park(StateMachineMixin, TrackedModel): print("Searching pghistory events") event_model = getattr(cls, "event_model", None) if event_model: - historical_event = ( - event_model.objects.filter(slug=slug) - .order_by("-pgh_created_at") - .first() - ) + historical_event = event_model.objects.filter(slug=slug).order_by("-pgh_created_at").first() if historical_event: print( @@ -394,4 +378,4 @@ class Park(StateMachineMixin, TrackedModel): else: print("No pghistory event found") - raise cls.DoesNotExist("No park found with this slug") + raise cls.DoesNotExist("No park found with this slug") from None diff --git a/backend/apps/parks/models/reviews.py b/backend/apps/parks/models/reviews.py index 6450c7e7..50bd7dff 100644 --- a/backend/apps/parks/models/reviews.py +++ b/backend/apps/parks/models/reviews.py @@ -40,12 +40,8 @@ class ParkReview(TrackedModel): updated_at = models.DateTimeField(auto_now=True) # Moderation - is_published = models.BooleanField( - default=True, help_text="Whether this review is publicly visible" - ) - moderation_notes = models.TextField( - blank=True, help_text="Internal notes from moderators" - ) + is_published = models.BooleanField(default=True, help_text="Whether this review is publicly visible") + moderation_notes = models.TextField(blank=True, help_text="Internal notes from moderators") moderated_by = models.ForeignKey( "accounts.User", on_delete=models.SET_NULL, @@ -54,9 +50,7 @@ class ParkReview(TrackedModel): related_name="moderated_park_reviews", help_text="Moderator who reviewed this", ) - moderated_at = models.DateTimeField( - null=True, blank=True, help_text="When this review was moderated" - ) + moderated_at = models.DateTimeField(null=True, blank=True, help_text="When this review was moderated") class Meta(TrackedModel.Meta): verbose_name = "Park Review" @@ -82,10 +76,7 @@ class ParkReview(TrackedModel): name="park_review_moderation_consistency", check=models.Q(moderated_by__isnull=True, moderated_at__isnull=True) | models.Q(moderated_by__isnull=False, moderated_at__isnull=False), - violation_error_message=( - "Moderated reviews must have both moderator and moderation " - "timestamp" - ), + violation_error_message=("Moderated reviews must have both moderator and moderation " "timestamp"), ), ] diff --git a/backend/apps/parks/querysets.py b/backend/apps/parks/querysets.py index d5347c31..2dde7e71 100644 --- a/backend/apps/parks/querysets.py +++ b/backend/apps/parks/querysets.py @@ -10,9 +10,7 @@ def get_base_park_queryset() -> QuerySet[Park]: .prefetch_related("photos", "rides") .annotate( current_ride_count=Count("rides", distinct=True), - current_coaster_count=Count( - "rides", filter=Q(rides__category="RC"), distinct=True - ), + current_coaster_count=Count("rides", filter=Q(rides__category="RC"), distinct=True), ) .order_by("name") ) diff --git a/backend/apps/parks/selectors.py b/backend/apps/parks/selectors.py index 1f95a78a..cd2c0941 100644 --- a/backend/apps/parks/selectors.py +++ b/backend/apps/parks/selectors.py @@ -47,9 +47,7 @@ def park_list_with_stats(*, filters: dict[str, Any] | None = None) -> QuerySet[P queryset = queryset.filter(location__country=filters["country"]) if "search" in filters: search_term = filters["search"] - queryset = queryset.filter( - Q(name__icontains=search_term) | Q(description__icontains=search_term) - ) + queryset = queryset.filter(Q(name__icontains=search_term) | Q(description__icontains=search_term)) return queryset.order_by("name") @@ -74,15 +72,11 @@ def park_detail_optimized(*, slug: str) -> Park: "areas", Prefetch( "rides", - queryset=Ride.objects.select_related( - "manufacturer", "designer", "ride_model" - ), + queryset=Ride.objects.select_related("manufacturer", "designer", "ride_model"), ), Prefetch( "reviews", - queryset=ParkReview.objects.select_related("user").filter( - is_published=True - ), + queryset=ParkReview.objects.select_related("user").filter(is_published=True), ), "photos", ) @@ -90,9 +84,7 @@ def park_detail_optimized(*, slug: str) -> Park: ) -def parks_near_location( - *, point: Point, distance_km: float = 50, limit: int = 10 -) -> QuerySet[Park]: +def parks_near_location(*, point: Point, distance_km: float = 50, limit: int = 10) -> QuerySet[Park]: """ Get parks near a specific geographic location. @@ -176,16 +168,10 @@ def parks_with_recent_reviews(*, days: int = 30) -> QuerySet[Park]: cutoff_date = timezone.now() - timedelta(days=days) return ( - Park.objects.filter( - reviews__created_at__gte=cutoff_date, reviews__is_published=True - ) + Park.objects.filter(reviews__created_at__gte=cutoff_date, reviews__is_published=True) .select_related("operator") .prefetch_related("location") - .annotate( - recent_review_count=Count( - "reviews", filter=Q(reviews__created_at__gte=cutoff_date) - ) - ) + .annotate(recent_review_count=Count("reviews", filter=Q(reviews__created_at__gte=cutoff_date))) .order_by("-recent_review_count") .distinct() ) @@ -204,9 +190,7 @@ def park_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet[Park]: """ return ( Park.objects.filter( - Q(name__icontains=query) - | Q(location__city__icontains=query) - | Q(location__region__icontains=query) + Q(name__icontains=query) | Q(location__city__icontains=query) | Q(location__region__icontains=query) ) .select_related("operator") .prefetch_related("location") diff --git a/backend/apps/parks/services.py b/backend/apps/parks/services.py index 26156824..47d67c51 100644 --- a/backend/apps/parks/services.py +++ b/backend/apps/parks/services.py @@ -212,9 +212,9 @@ class ParkService: ) # Calculate average rating - avg_rating = ParkReview.objects.filter( - park=park, is_published=True - ).aggregate(avg_rating=Avg("rating"))["avg_rating"] + avg_rating = ParkReview.objects.filter(park=park, is_published=True).aggregate(avg_rating=Avg("rating"))[ + "avg_rating" + ] # Update park fields park.ride_count = ride_stats["total_rides"] or 0 diff --git a/backend/apps/parks/services/filter_service.py b/backend/apps/parks/services/filter_service.py index 2e6f6dfa..7f9922a0 100644 --- a/backend/apps/parks/services/filter_service.py +++ b/backend/apps/parks/services/filter_service.py @@ -26,9 +26,7 @@ class ParkFilterService: def __init__(self): self.cache_prefix = "park_filter" - def get_filter_counts( - self, base_queryset: QuerySet | None = None - ) -> dict[str, Any]: + def get_filter_counts(self, base_queryset: QuerySet | None = None) -> dict[str, Any]: """ Get counts for various filter options to show users what's available. @@ -76,9 +74,7 @@ class ParkFilterService: ).count(), } - def _get_top_operators( - self, queryset: QuerySet, limit: int = 10 - ) -> list[dict[str, Any]]: + def _get_top_operators(self, queryset: QuerySet, limit: int = 10) -> list[dict[str, Any]]: """Get the top operators by number of parks.""" return list( queryset.values("operator__name", "operator__id") @@ -87,9 +83,7 @@ class ParkFilterService: .order_by("-park_count")[:limit] ) - def _get_country_counts( - self, queryset: QuerySet, limit: int = 10 - ) -> list[dict[str, Any]]: + def _get_country_counts(self, queryset: QuerySet, limit: int = 10) -> list[dict[str, Any]]: """Get countries with the most parks.""" return list( queryset.filter(location__country__isnull=False) @@ -123,21 +117,18 @@ class ParkFilterService: if len(query) >= 2: # Only search for queries of 2+ characters # Park name suggestions - park_names = Park.objects.filter(name__icontains=query).values_list( - "name", flat=True - )[:5] + park_names = Park.objects.filter(name__icontains=query).values_list("name", flat=True)[:5] suggestions["parks"] = list(park_names) # Operator suggestions - operator_names = Company.objects.filter( - roles__contains=["OPERATOR"], name__icontains=query - ).values_list("name", flat=True)[:5] + operator_names = Company.objects.filter(roles__contains=["OPERATOR"], name__icontains=query).values_list( + "name", flat=True + )[:5] suggestions["operators"] = list(operator_names) # Location suggestions (cities and countries) locations = Park.objects.filter( - Q(location__city__icontains=query) - | Q(location__country__icontains=query) + Q(location__city__icontains=query) | Q(location__country__icontains=query) ).values_list("location__city", "location__country")[:5] location_suggestions = [] @@ -264,14 +255,10 @@ class ParkFilterService: # Apply location filters if filters.get("country_filter"): - queryset = queryset.filter( - location__country__icontains=filters["country_filter"] - ) + queryset = queryset.filter(location__country__icontains=filters["country_filter"]) if filters.get("state_filter"): - queryset = queryset.filter( - location__state__icontains=filters["state_filter"] - ) + queryset = queryset.filter(location__state__icontains=filters["state_filter"]) # Apply ordering if filters.get("ordering"): diff --git a/backend/apps/parks/services/hybrid_loader.py b/backend/apps/parks/services/hybrid_loader.py index 10dc5a37..32ebe406 100644 --- a/backend/apps/parks/services/hybrid_loader.py +++ b/backend/apps/parks/services/hybrid_loader.py @@ -21,8 +21,8 @@ class SmartParkLoader: """ # Cache configuration - CACHE_TIMEOUT = getattr(settings, 'HYBRID_FILTER_CACHE_TIMEOUT', 300) # 5 minutes - CACHE_KEY_PREFIX = 'hybrid_parks' + CACHE_TIMEOUT = getattr(settings, "HYBRID_FILTER_CACHE_TIMEOUT", 300) # 5 minutes + CACHE_KEY_PREFIX = "hybrid_parks" # Progressive loading thresholds INITIAL_LOAD_SIZE = 50 @@ -34,17 +34,22 @@ class SmartParkLoader: def _get_optimized_queryset(self) -> models.QuerySet: """Get optimized base queryset with all necessary prefetches.""" - return Park.objects.select_related( - 'operator', - 'property_owner', - 'banner_image', - 'card_image', - ).prefetch_related( - 'location', # ParkLocation relationship - ).filter( - # Only include operating and temporarily closed parks by default - status__in=['OPERATING', 'CLOSED_TEMP'] - ).order_by('name') + return ( + Park.objects.select_related( + "operator", + "property_owner", + "banner_image", + "card_image", + ) + .prefetch_related( + "location", # ParkLocation relationship + ) + .filter( + # Only include operating and temporarily closed parks by default + status__in=["OPERATING", "CLOSED_TEMP"] + ) + .order_by("name") + ) def get_initial_load(self, filters: dict[str, Any] | None = None) -> dict[str, Any]: """ @@ -56,7 +61,7 @@ class SmartParkLoader: Returns: Dictionary containing parks data and metadata """ - cache_key = self._generate_cache_key('initial', filters) + cache_key = self._generate_cache_key("initial", filters) cached_result = cache.get(cache_key) if cached_result: @@ -74,21 +79,21 @@ class SmartParkLoader: if total_count <= self.MAX_CLIENT_SIDE_RECORDS: # Load all data for client-side filtering parks = list(queryset.all()) - strategy = 'client_side' + strategy = "client_side" has_more = False else: # Load initial batch for server-side pagination - parks = list(queryset[:self.INITIAL_LOAD_SIZE]) - strategy = 'server_side' + parks = list(queryset[: self.INITIAL_LOAD_SIZE]) + strategy = "server_side" has_more = total_count > self.INITIAL_LOAD_SIZE result = { - 'parks': parks, - 'total_count': total_count, - 'strategy': strategy, - 'has_more': has_more, - 'next_offset': len(parks) if has_more else None, - 'filter_metadata': self._get_filter_metadata(queryset), + "parks": parks, + "total_count": total_count, + "strategy": strategy, + "has_more": has_more, + "next_offset": len(parks) if has_more else None, + "filter_metadata": self._get_filter_metadata(queryset), } # Cache the result @@ -96,11 +101,7 @@ class SmartParkLoader: return result - def get_progressive_load( - self, - offset: int, - filters: dict[str, Any] | None = None - ) -> dict[str, Any]: + def get_progressive_load(self, offset: int, filters: dict[str, Any] | None = None) -> dict[str, Any]: """ Get next batch of parks for progressive loading. @@ -111,7 +112,7 @@ class SmartParkLoader: Returns: Dictionary containing parks data and metadata """ - cache_key = self._generate_cache_key(f'progressive_{offset}', filters) + cache_key = self._generate_cache_key(f"progressive_{offset}", filters) cached_result = cache.get(cache_key) if cached_result: @@ -131,10 +132,10 @@ class SmartParkLoader: has_more = end_offset < total_count result = { - 'parks': parks, - 'total_count': total_count, - 'has_more': has_more, - 'next_offset': end_offset if has_more else None, + "parks": parks, + "total_count": total_count, + "has_more": has_more, + "next_offset": end_offset if has_more else None, } # Cache the result @@ -152,7 +153,7 @@ class SmartParkLoader: Returns: Dictionary containing filter metadata """ - cache_key = self._generate_cache_key('metadata', filters) + cache_key = self._generate_cache_key("metadata", filters) cached_result = cache.get(cache_key) if cached_result: @@ -174,72 +175,72 @@ class SmartParkLoader: """Apply filters to the queryset.""" # Status filter - if 'status' in filters and filters['status']: - if isinstance(filters['status'], list): - queryset = queryset.filter(status__in=filters['status']) + if "status" in filters and filters["status"]: + if isinstance(filters["status"], list): + queryset = queryset.filter(status__in=filters["status"]) else: - queryset = queryset.filter(status=filters['status']) + queryset = queryset.filter(status=filters["status"]) # Park type filter - if 'park_type' in filters and filters['park_type']: - if isinstance(filters['park_type'], list): - queryset = queryset.filter(park_type__in=filters['park_type']) + if "park_type" in filters and filters["park_type"]: + if isinstance(filters["park_type"], list): + queryset = queryset.filter(park_type__in=filters["park_type"]) else: - queryset = queryset.filter(park_type=filters['park_type']) + queryset = queryset.filter(park_type=filters["park_type"]) # Country filter - if 'country' in filters and filters['country']: - queryset = queryset.filter(location__country__in=filters['country']) + if "country" in filters and filters["country"]: + queryset = queryset.filter(location__country__in=filters["country"]) # State filter - if 'state' in filters and filters['state']: - queryset = queryset.filter(location__state__in=filters['state']) + if "state" in filters and filters["state"]: + queryset = queryset.filter(location__state__in=filters["state"]) # Opening year range - if 'opening_year_min' in filters and filters['opening_year_min']: - queryset = queryset.filter(opening_year__gte=filters['opening_year_min']) + if "opening_year_min" in filters and filters["opening_year_min"]: + queryset = queryset.filter(opening_year__gte=filters["opening_year_min"]) - if 'opening_year_max' in filters and filters['opening_year_max']: - queryset = queryset.filter(opening_year__lte=filters['opening_year_max']) + if "opening_year_max" in filters and filters["opening_year_max"]: + queryset = queryset.filter(opening_year__lte=filters["opening_year_max"]) # Size range - if 'size_min' in filters and filters['size_min']: - queryset = queryset.filter(size_acres__gte=filters['size_min']) + if "size_min" in filters and filters["size_min"]: + queryset = queryset.filter(size_acres__gte=filters["size_min"]) - if 'size_max' in filters and filters['size_max']: - queryset = queryset.filter(size_acres__lte=filters['size_max']) + if "size_max" in filters and filters["size_max"]: + queryset = queryset.filter(size_acres__lte=filters["size_max"]) # Rating range - if 'rating_min' in filters and filters['rating_min']: - queryset = queryset.filter(average_rating__gte=filters['rating_min']) + if "rating_min" in filters and filters["rating_min"]: + queryset = queryset.filter(average_rating__gte=filters["rating_min"]) - if 'rating_max' in filters and filters['rating_max']: - queryset = queryset.filter(average_rating__lte=filters['rating_max']) + if "rating_max" in filters and filters["rating_max"]: + queryset = queryset.filter(average_rating__lte=filters["rating_max"]) # Ride count range - if 'ride_count_min' in filters and filters['ride_count_min']: - queryset = queryset.filter(ride_count__gte=filters['ride_count_min']) + if "ride_count_min" in filters and filters["ride_count_min"]: + queryset = queryset.filter(ride_count__gte=filters["ride_count_min"]) - if 'ride_count_max' in filters and filters['ride_count_max']: - queryset = queryset.filter(ride_count__lte=filters['ride_count_max']) + if "ride_count_max" in filters and filters["ride_count_max"]: + queryset = queryset.filter(ride_count__lte=filters["ride_count_max"]) # Coaster count range - if 'coaster_count_min' in filters and filters['coaster_count_min']: - queryset = queryset.filter(coaster_count__gte=filters['coaster_count_min']) + if "coaster_count_min" in filters and filters["coaster_count_min"]: + queryset = queryset.filter(coaster_count__gte=filters["coaster_count_min"]) - if 'coaster_count_max' in filters and filters['coaster_count_max']: - queryset = queryset.filter(coaster_count__lte=filters['coaster_count_max']) + if "coaster_count_max" in filters and filters["coaster_count_max"]: + queryset = queryset.filter(coaster_count__lte=filters["coaster_count_max"]) # Operator filter - if 'operator' in filters and filters['operator']: - if isinstance(filters['operator'], list): - queryset = queryset.filter(operator__slug__in=filters['operator']) + if "operator" in filters and filters["operator"]: + if isinstance(filters["operator"], list): + queryset = queryset.filter(operator__slug__in=filters["operator"]) else: - queryset = queryset.filter(operator__slug=filters['operator']) + queryset = queryset.filter(operator__slug=filters["operator"]) # Search query - if 'search' in filters and filters['search']: - search_term = filters['search'].lower() + if "search" in filters and filters["search"]: + search_term = filters["search"].lower() queryset = queryset.filter(search_text__icontains=search_term) return queryset @@ -249,150 +250,125 @@ class SmartParkLoader: # Get distinct values for categorical filters with counts countries_data = list( - queryset.values('location__country') + queryset.values("location__country") .exclude(location__country__isnull=True) - .annotate(count=models.Count('id')) - .order_by('location__country') + .annotate(count=models.Count("id")) + .order_by("location__country") ) states_data = list( - queryset.values('location__state') + queryset.values("location__state") .exclude(location__state__isnull=True) - .annotate(count=models.Count('id')) - .order_by('location__state') + .annotate(count=models.Count("id")) + .order_by("location__state") ) park_types_data = list( - queryset.values('park_type') + queryset.values("park_type") .exclude(park_type__isnull=True) - .annotate(count=models.Count('id')) - .order_by('park_type') + .annotate(count=models.Count("id")) + .order_by("park_type") ) - statuses_data = list( - queryset.values('status') - .annotate(count=models.Count('id')) - .order_by('status') - ) + statuses_data = list(queryset.values("status").annotate(count=models.Count("id")).order_by("status")) operators_data = list( - queryset.select_related('operator') - .values('operator__id', 'operator__name', 'operator__slug') + queryset.select_related("operator") + .values("operator__id", "operator__name", "operator__slug") .exclude(operator__isnull=True) - .annotate(count=models.Count('id')) - .order_by('operator__name') + .annotate(count=models.Count("id")) + .order_by("operator__name") ) # Convert to frontend-expected format with value/label/count countries = [ - { - 'value': item['location__country'], - 'label': item['location__country'], - 'count': item['count'] - } + {"value": item["location__country"], "label": item["location__country"], "count": item["count"]} for item in countries_data ] states = [ - { - 'value': item['location__state'], - 'label': item['location__state'], - 'count': item['count'] - } + {"value": item["location__state"], "label": item["location__state"], "count": item["count"]} for item in states_data ] park_types = [ - { - 'value': item['park_type'], - 'label': item['park_type'], - 'count': item['count'] - } - for item in park_types_data + {"value": item["park_type"], "label": item["park_type"], "count": item["count"]} for item in park_types_data ] statuses = [ - { - 'value': item['status'], - 'label': self._get_status_label(item['status']), - 'count': item['count'] - } + {"value": item["status"], "label": self._get_status_label(item["status"]), "count": item["count"]} for item in statuses_data ] operators = [ - { - 'value': item['operator__slug'], - 'label': item['operator__name'], - 'count': item['count'] - } + {"value": item["operator__slug"], "label": item["operator__name"], "count": item["count"]} for item in operators_data ] # Get ranges for numerical filters aggregates = queryset.aggregate( - opening_year_min=models.Min('opening_year'), - opening_year_max=models.Max('opening_year'), - size_min=models.Min('size_acres'), - size_max=models.Max('size_acres'), - rating_min=models.Min('average_rating'), - rating_max=models.Max('average_rating'), - ride_count_min=models.Min('ride_count'), - ride_count_max=models.Max('ride_count'), - coaster_count_min=models.Min('coaster_count'), - coaster_count_max=models.Max('coaster_count'), + opening_year_min=models.Min("opening_year"), + opening_year_max=models.Max("opening_year"), + size_min=models.Min("size_acres"), + size_max=models.Max("size_acres"), + rating_min=models.Min("average_rating"), + rating_max=models.Max("average_rating"), + ride_count_min=models.Min("ride_count"), + ride_count_max=models.Max("ride_count"), + coaster_count_min=models.Min("coaster_count"), + coaster_count_max=models.Max("coaster_count"), ) return { - 'categorical': { - 'countries': countries, - 'states': states, - 'park_types': park_types, - 'statuses': statuses, - 'operators': operators, + "categorical": { + "countries": countries, + "states": states, + "park_types": park_types, + "statuses": statuses, + "operators": operators, }, - 'ranges': { - 'opening_year': { - 'min': aggregates['opening_year_min'], - 'max': aggregates['opening_year_max'], - 'step': 1, - 'unit': 'year' + "ranges": { + "opening_year": { + "min": aggregates["opening_year_min"], + "max": aggregates["opening_year_max"], + "step": 1, + "unit": "year", }, - 'size_acres': { - 'min': float(aggregates['size_min']) if aggregates['size_min'] else None, - 'max': float(aggregates['size_max']) if aggregates['size_max'] else None, - 'step': 1.0, - 'unit': 'acres' + "size_acres": { + "min": float(aggregates["size_min"]) if aggregates["size_min"] else None, + "max": float(aggregates["size_max"]) if aggregates["size_max"] else None, + "step": 1.0, + "unit": "acres", }, - 'average_rating': { - 'min': float(aggregates['rating_min']) if aggregates['rating_min'] else None, - 'max': float(aggregates['rating_max']) if aggregates['rating_max'] else None, - 'step': 0.1, - 'unit': 'stars' + "average_rating": { + "min": float(aggregates["rating_min"]) if aggregates["rating_min"] else None, + "max": float(aggregates["rating_max"]) if aggregates["rating_max"] else None, + "step": 0.1, + "unit": "stars", }, - 'ride_count': { - 'min': aggregates['ride_count_min'], - 'max': aggregates['ride_count_max'], - 'step': 1, - 'unit': 'rides' + "ride_count": { + "min": aggregates["ride_count_min"], + "max": aggregates["ride_count_max"], + "step": 1, + "unit": "rides", }, - 'coaster_count': { - 'min': aggregates['coaster_count_min'], - 'max': aggregates['coaster_count_max'], - 'step': 1, - 'unit': 'coasters' + "coaster_count": { + "min": aggregates["coaster_count_min"], + "max": aggregates["coaster_count_max"], + "step": 1, + "unit": "coasters", }, }, - 'total_count': queryset.count(), + "total_count": queryset.count(), } def _get_status_label(self, status: str) -> str: """Convert status code to human-readable label.""" status_labels = { - 'OPERATING': 'Operating', - 'CLOSED_TEMP': 'Temporarily Closed', - 'CLOSED_PERM': 'Permanently Closed', - 'UNDER_CONSTRUCTION': 'Under Construction', + "OPERATING": "Operating", + "CLOSED_TEMP": "Temporarily Closed", + "CLOSED_PERM": "Permanently Closed", + "UNDER_CONSTRUCTION": "Under Construction", } if status in status_labels: return status_labels[status] @@ -405,23 +381,23 @@ class SmartParkLoader: if filters: # Create a consistent string representation of filters - filter_str = '_'.join(f"{k}:{v}" for k, v in sorted(filters.items()) if v) + filter_str = "_".join(f"{k}:{v}" for k, v in sorted(filters.items()) if v) key_parts.append(filter_str) - return '_'.join(key_parts) + return "_".join(key_parts) def invalidate_cache(self, filters: dict[str, Any] | None = None) -> None: """Invalidate cached data for the given filters.""" # This is a simplified implementation # In production, you might want to use cache versioning or tags cache_keys = [ - self._generate_cache_key('initial', filters), - self._generate_cache_key('metadata', filters), + self._generate_cache_key("initial", filters), + self._generate_cache_key("metadata", filters), ] # Also invalidate progressive load caches for offset in range(0, 1000, self.PROGRESSIVE_LOAD_SIZE): - cache_keys.append(self._generate_cache_key(f'progressive_{offset}', filters)) + cache_keys.append(self._generate_cache_key(f"progressive_{offset}", filters)) cache.delete_many(cache_keys) diff --git a/backend/apps/parks/services/location_service.py b/backend/apps/parks/services/location_service.py index d7eae540..1f944e50 100644 --- a/backend/apps/parks/services/location_service.py +++ b/backend/apps/parks/services/location_service.py @@ -245,9 +245,7 @@ class ParkLocationService: return park_location @classmethod - def update_park_location( - cls, park_location: ParkLocation, **updates - ) -> ParkLocation: + def update_park_location(cls, park_location: ParkLocation, **updates) -> ParkLocation: """ Update park location with validation. @@ -278,9 +276,7 @@ class ParkLocationService: return park_location @classmethod - def find_nearby_parks( - cls, latitude: float, longitude: float, radius_km: float = 50 - ) -> list[ParkLocation]: + def find_nearby_parks(cls, latitude: float, longitude: float, radius_km: float = 50) -> list[ParkLocation]: """ Find parks near given coordinates using PostGIS. @@ -298,9 +294,7 @@ class ParkLocationService: center_point = Point(longitude, latitude, srid=4326) return list( - ParkLocation.objects.filter( - point__distance_lte=(center_point, Distance(km=radius_km)) - ) + ParkLocation.objects.filter(point__distance_lte=(center_point, Distance(km=radius_km))) .select_related("park", "park__operator") .order_by("point__distance") ) @@ -349,9 +343,7 @@ class ParkLocationService: return park_location @classmethod - def _transform_osm_result( - cls, osm_item: dict[str, Any] - ) -> dict[str, Any] | None: + def _transform_osm_result(cls, osm_item: dict[str, Any]) -> dict[str, Any] | None: """Transform OSM search result to our standard format.""" try: address = osm_item.get("address", {}) @@ -369,12 +361,7 @@ class ParkLocationService: or "" ) - state = ( - address.get("state") - or address.get("province") - or address.get("region") - or "" - ) + state = address.get("state") or address.get("province") or address.get("region") or "" country = address.get("country", "") postal_code = address.get("postcode", "") @@ -432,9 +419,7 @@ class ParkLocationService: return None @classmethod - def _transform_osm_reverse_result( - cls, osm_result: dict[str, Any] - ) -> dict[str, Any]: + def _transform_osm_reverse_result(cls, osm_result: dict[str, Any]) -> dict[str, Any]: """Transform OSM reverse geocoding result to our standard format.""" address = osm_result.get("address", {}) @@ -443,20 +428,9 @@ class ParkLocationService: street_name = address.get("road", "") street_address = f"{street_number} {street_name}".strip() - city = ( - address.get("city") - or address.get("town") - or address.get("village") - or address.get("municipality") - or "" - ) + city = address.get("city") or address.get("town") or address.get("village") or address.get("municipality") or "" - state = ( - address.get("state") - or address.get("province") - or address.get("region") - or "" - ) + state = address.get("state") or address.get("province") or address.get("region") or "" country = address.get("country", "") postal_code = address.get("postcode", "") diff --git a/backend/apps/parks/services/media_service.py b/backend/apps/parks/services/media_service.py index 3582b9fb..10493e86 100644 --- a/backend/apps/parks/services/media_service.py +++ b/backend/apps/parks/services/media_service.py @@ -79,9 +79,7 @@ class ParkMediaService: return photo @staticmethod - def get_park_photos( - park: Park, approved_only: bool = True, primary_first: bool = True - ) -> list[ParkPhoto]: + def get_park_photos(park: Park, approved_only: bool = True, primary_first: bool = True) -> list[ParkPhoto]: """ Get photos for a park. @@ -190,9 +188,7 @@ class ParkMediaService: photo.image.delete(save=False) photo.delete() - logger.info( - f"Photo {photo_id} deleted from park {park_slug} by user {deleted_by.username}" - ) + logger.info(f"Photo {photo_id} deleted from park {park_slug} by user {deleted_by.username}") return True except Exception as e: logger.error(f"Failed to delete photo {photo.pk}: {str(e)}") @@ -238,7 +234,5 @@ class ParkMediaService: if ParkMediaService.approve_photo(photo, approved_by): approved_count += 1 - logger.info( - f"Bulk approved {approved_count} photos by user {approved_by.username}" - ) + logger.info(f"Bulk approved {approved_count} photos by user {approved_by.username}") return approved_count diff --git a/backend/apps/parks/services/park_management.py b/backend/apps/parks/services/park_management.py index e957296f..a09597b9 100644 --- a/backend/apps/parks/services/park_management.py +++ b/backend/apps/parks/services/park_management.py @@ -133,9 +133,7 @@ class ParkService: return park @staticmethod - def delete_park( - *, park_id: int, deleted_by: Optional["AbstractUser"] = None - ) -> bool: + def delete_park(*, park_id: int, deleted_by: Optional["AbstractUser"] = None) -> bool: """ Soft delete a park by setting status to DEMOLISHED. @@ -219,9 +217,9 @@ class ParkService: ) # Calculate average rating - avg_rating = ParkReview.objects.filter( - park=park, is_published=True - ).aggregate(avg_rating=Avg("rating"))["avg_rating"] + avg_rating = ParkReview.objects.filter(park=park, is_published=True).aggregate(avg_rating=Avg("rating"))[ + "avg_rating" + ] # Update park fields park.ride_count = ride_stats["total_rides"] or 0 diff --git a/backend/apps/parks/services/roadtrip.py b/backend/apps/parks/services/roadtrip.py index cefaa599..b6742c59 100644 --- a/backend/apps/parks/services/roadtrip.py +++ b/backend/apps/parks/services/roadtrip.py @@ -148,12 +148,8 @@ class RoadTripService: # Configuration from Django settings self.cache_timeout = getattr(settings, "ROADTRIP_CACHE_TIMEOUT", 3600 * 24) - self.route_cache_timeout = getattr( - settings, "ROADTRIP_ROUTE_CACHE_TIMEOUT", 3600 * 6 - ) - self.user_agent = getattr( - settings, "ROADTRIP_USER_AGENT", "ThrillWiki Road Trip Planner" - ) + self.route_cache_timeout = getattr(settings, "ROADTRIP_ROUTE_CACHE_TIMEOUT", 3600 * 6) + self.user_agent = getattr(settings, "ROADTRIP_USER_AGENT", "ThrillWiki Road Trip Planner") self.request_timeout = getattr(settings, "ROADTRIP_REQUEST_TIMEOUT", 10) self.max_retries = getattr(settings, "ROADTRIP_MAX_RETRIES", 3) self.backoff_factor = getattr(settings, "ROADTRIP_BACKOFF_FACTOR", 2) @@ -179,9 +175,7 @@ class RoadTripService: for attempt in range(self.max_retries): try: - response = self.session.get( - url, params=params, timeout=self.request_timeout - ) + response = self.session.get(url, params=params, timeout=self.request_timeout) response.raise_for_status() return response.json() @@ -192,9 +186,7 @@ class RoadTripService: wait_time = self.backoff_factor**attempt time.sleep(wait_time) else: - raise OSMAPIException( - f"Failed to make request after {self.max_retries} attempts: {e}" - ) + raise OSMAPIException(f"Failed to make request after {self.max_retries} attempts: {e}") from e def geocode_address(self, address: str) -> Coordinates | None: """ @@ -243,9 +235,7 @@ class RoadTripService: self.cache_timeout, ) - logger.info( - f"Geocoded '{address}' to {coords.latitude}, {coords.longitude}" - ) + logger.info(f"Geocoded '{address}' to {coords.latitude}, {coords.longitude}") return coords else: logger.warning(f"No geocoding results for address: {address}") @@ -255,9 +245,7 @@ class RoadTripService: logger.error(f"Geocoding failed for '{address}': {e}") return None - def calculate_route( - self, start_coords: Coordinates, end_coords: Coordinates - ) -> RouteInfo | None: + def calculate_route(self, start_coords: Coordinates, end_coords: Coordinates) -> RouteInfo | None: """ Calculate route between two coordinate points using OSRM. @@ -327,9 +315,7 @@ class RoadTripService: return route_info else: # Fallback to straight-line distance calculation - logger.warning( - "OSRM routing failed, falling back to straight-line distance" - ) + logger.warning("OSRM routing failed, falling back to straight-line distance") return self._calculate_straight_line_route(start_coords, end_coords) except Exception as e: @@ -337,9 +323,7 @@ class RoadTripService: # Fallback to straight-line distance return self._calculate_straight_line_route(start_coords, end_coords) - def _calculate_straight_line_route( - self, start_coords: Coordinates, end_coords: Coordinates - ) -> RouteInfo: + def _calculate_straight_line_route(self, start_coords: Coordinates, end_coords: Coordinates) -> RouteInfo: """ Calculate straight-line distance as fallback when routing fails. """ @@ -356,10 +340,7 @@ class RoadTripService: dlat = lat2 - lat1 dlon = lon2 - lon1 - a = ( - math.sin(dlat / 2) ** 2 - + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2 - ) + a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2 c = 2 * math.asin(math.sqrt(a)) # Earth's radius in kilometers @@ -376,9 +357,7 @@ class RoadTripService: geometry=None, ) - def find_parks_along_route( - self, start_park: "Park", end_park: "Park", max_detour_km: float = 50 - ) -> list["Park"]: + def find_parks_along_route(self, start_park: "Park", end_park: "Park", max_detour_km: float = 50) -> list["Park"]: """ Find parks along a route within specified detour distance. @@ -443,9 +422,7 @@ class RoadTripService: return parks_along_route - def _calculate_detour_distance( - self, start: Coordinates, end: Coordinates, waypoint: Coordinates - ) -> float | None: + def _calculate_detour_distance(self, start: Coordinates, end: Coordinates, waypoint: Coordinates) -> float | None: """ Calculate the detour distance when visiting a waypoint. """ @@ -508,9 +485,7 @@ class RoadTripService: return best_trip - def _optimize_trip_nearest_neighbor( - self, park_list: list["Park"] - ) -> RoadTrip | None: + def _optimize_trip_nearest_neighbor(self, park_list: list["Park"]) -> RoadTrip | None: """ Optimize trip using nearest neighbor heuristic (for larger lists). """ @@ -536,9 +511,7 @@ class RoadTripService: if not park_coords: continue - route = self.calculate_route( - Coordinates(*current_coords), Coordinates(*park_coords) - ) + route = self.calculate_route(Coordinates(*current_coords), Coordinates(*park_coords)) if route and route.distance_km < min_distance: min_distance = route.distance_km @@ -553,9 +526,7 @@ class RoadTripService: return self._create_trip_from_order(ordered_parks) - def _create_trip_from_order( - self, ordered_parks: list["Park"] - ) -> RoadTrip | None: + def _create_trip_from_order(self, ordered_parks: list["Park"]) -> RoadTrip | None: """ Create a RoadTrip object from an ordered list of parks. """ @@ -576,9 +547,7 @@ class RoadTripService: if not from_coords or not to_coords: continue - route = self.calculate_route( - Coordinates(*from_coords), Coordinates(*to_coords) - ) + route = self.calculate_route(Coordinates(*from_coords), Coordinates(*to_coords)) if route: legs.append(TripLeg(from_park=from_park, to_park=to_park, route=route)) @@ -595,9 +564,7 @@ class RoadTripService: total_duration_minutes=total_duration, ) - def get_park_distances( - self, center_park: "Park", radius_km: float = 100 - ) -> list[dict[str, Any]]: + def get_park_distances(self, center_park: "Park", radius_km: float = 100) -> list[dict[str, Any]]: """ Get all parks within radius of a center park with distances. @@ -621,9 +588,7 @@ class RoadTripService: search_distance = Distance(km=radius_km) nearby_parks = ( - Park.objects.filter( - location__point__distance_lte=(center_point, search_distance) - ) + Park.objects.filter(location__point__distance_lte=(center_point, search_distance)) .exclude(id=center_park.id) .select_related("location") ) @@ -635,9 +600,7 @@ class RoadTripService: if not park_coords: continue - route = self.calculate_route( - Coordinates(*center_coords), Coordinates(*park_coords) - ) + route = self.calculate_route(Coordinates(*center_coords), Coordinates(*park_coords)) if route: results.append( @@ -691,9 +654,7 @@ class RoadTripService: if coords: location.set_coordinates(coords.latitude, coords.longitude) location.save() - logger.info( - f"Geocoded park '{park.name}' to {coords.latitude}, {coords.longitude}" - ) + logger.info(f"Geocoded park '{park.name}' to {coords.latitude}, {coords.longitude}") return True return False diff --git a/backend/apps/parks/signals.py b/backend/apps/parks/signals.py index a67d6886..215d09c2 100644 --- a/backend/apps/parks/signals.py +++ b/backend/apps/parks/signals.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) # Computed Field Maintenance Signals # ============================================================================= + def update_park_search_text(park): """ Update park's search_text computed field. @@ -27,17 +28,17 @@ def update_park_search_text(park): try: park._populate_computed_fields() - park.save(update_fields=['search_text']) + park.save(update_fields=["search_text"]) logger.debug(f"Updated search_text for park {park.pk}") except Exception as e: logger.exception(f"Failed to update search_text for park {park.pk}: {e}") # Status values that count as "active" rides for counting purposes -ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'} +ACTIVE_STATUSES = {"OPERATING", "SEASONAL", "UNDER_CONSTRUCTION"} # Status values that should decrement ride counts -INACTIVE_STATUSES = {'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED', 'REMOVED'} +INACTIVE_STATUSES = {"CLOSED_PERM", "DEMOLISHED", "RELOCATED", "REMOVED"} def update_park_ride_counts(park, old_status=None, new_status=None): @@ -54,11 +55,11 @@ def update_park_ride_counts(park, old_status=None, new_status=None): return # Get park ID - park_id = park.pk if hasattr(park, 'pk') else park + park_id = park.pk if hasattr(park, "pk") else park try: # Fetch the park if we only have an ID - if not hasattr(park, 'rides'): + if not hasattr(park, "rides"): park = Park.objects.get(id=park_id) # Build the query for active rides @@ -72,14 +73,9 @@ def update_park_ride_counts(park, old_status=None, new_status=None): coaster_count = park.rides.filter(operating_rides, category__in=["RC", "WC"]).count() # Update park counts - Park.objects.filter(id=park_id).update( - ride_count=ride_count, coaster_count=coaster_count - ) + Park.objects.filter(id=park_id).update(ride_count=ride_count, coaster_count=coaster_count) - logger.debug( - f"Updated park {park_id} counts: " - f"ride_count={ride_count}, coaster_count={coaster_count}" - ) + logger.debug(f"Updated park {park_id} counts: " f"ride_count={ride_count}, coaster_count={coaster_count}") except Park.DoesNotExist: logger.warning(f"Park {park_id} does not exist, cannot update counts") @@ -124,14 +120,12 @@ def ride_saved(sender, instance, created, **kwargs): return # Check if status changed using model's tracker if available - if hasattr(instance, 'tracker') and hasattr(instance.tracker, 'has_changed'): - if instance.tracker.has_changed('status'): - old_status = instance.tracker.previous('status') + if hasattr(instance, "tracker") and hasattr(instance.tracker, "has_changed"): + if instance.tracker.has_changed("status"): + old_status = instance.tracker.previous("status") new_status = instance.status if should_update_counts(old_status, new_status): - logger.info( - f"Ride {instance.pk} status changed: {old_status} → {new_status}" - ) + logger.info(f"Ride {instance.pk} status changed: {old_status} → {new_status}") update_park_ride_counts(instance.park, old_status, new_status) else: # Fallback: always update counts on save @@ -151,6 +145,7 @@ def ride_deleted(sender, instance, **kwargs): # FSM transition signal handlers + def handle_ride_status_transition(instance, source, target, user, **kwargs): """ Handle ride status FSM transitions. @@ -165,10 +160,7 @@ def handle_ride_status_transition(instance, source, target, user, **kwargs): user: The user who initiated the transition. """ if should_update_counts(source, target): - logger.info( - f"FSM transition: Ride {instance.pk} {source} → {target} " - f"by {user if user else 'system'}" - ) + logger.info(f"FSM transition: Ride {instance.pk} {source} → {target} " f"by {user if user else 'system'}") update_park_ride_counts(instance.park, source, target) @@ -176,7 +168,8 @@ def handle_ride_status_transition(instance, source, target, user, **kwargs): # Computed Field Maintenance Signal Handlers # ============================================================================= -@receiver(post_save, sender='parks.ParkLocation') + +@receiver(post_save, sender="parks.ParkLocation") def update_park_search_text_on_location_change(sender, instance, **kwargs): """ Update park search_text when location changes. @@ -186,13 +179,13 @@ def update_park_search_text_on_location_change(sender, instance, **kwargs): location information. """ try: - if hasattr(instance, 'park') and instance.park: + if hasattr(instance, "park") and instance.park: update_park_search_text(instance.park) except Exception as e: logger.exception(f"Failed to update park search_text on location change: {e}") -@receiver(post_save, sender='parks.Company') +@receiver(post_save, sender="parks.Company") def update_park_search_text_on_company_change(sender, instance, **kwargs): """ Update park search_text when operator/owner name changes. diff --git a/backend/apps/parks/templatetags/park_tags.py b/backend/apps/parks/templatetags/park_tags.py index 31370d15..fad5d7d1 100644 --- a/backend/apps/parks/templatetags/park_tags.py +++ b/backend/apps/parks/templatetags/park_tags.py @@ -5,48 +5,48 @@ register = template.Library() # Status configuration mapping for parks and rides STATUS_CONFIG = { - 'OPERATING': { - 'label': 'Operating', - 'classes': 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200', - 'icon': True, + "OPERATING": { + "label": "Operating", + "classes": "bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200", + "icon": True, }, - 'CLOSED_TEMP': { - 'label': 'Temporarily Closed', - 'classes': 'bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200', - 'icon': True, + "CLOSED_TEMP": { + "label": "Temporarily Closed", + "classes": "bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200", + "icon": True, }, - 'CLOSED_PERM': { - 'label': 'Permanently Closed', - 'classes': 'bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200', - 'icon': True, + "CLOSED_PERM": { + "label": "Permanently Closed", + "classes": "bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200", + "icon": True, }, - 'CONSTRUCTION': { - 'label': 'Under Construction', - 'classes': 'bg-orange-100 text-orange-800 dark:bg-orange-900 dark:text-orange-200', - 'icon': True, + "CONSTRUCTION": { + "label": "Under Construction", + "classes": "bg-orange-100 text-orange-800 dark:bg-orange-900 dark:text-orange-200", + "icon": True, }, - 'DEMOLISHED': { - 'label': 'Demolished', - 'classes': 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300', - 'icon': True, + "DEMOLISHED": { + "label": "Demolished", + "classes": "bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300", + "icon": True, }, - 'RELOCATED': { - 'label': 'Relocated', - 'classes': 'bg-purple-100 text-purple-800 dark:bg-purple-900 dark:text-purple-200', - 'icon': True, + "RELOCATED": { + "label": "Relocated", + "classes": "bg-purple-100 text-purple-800 dark:bg-purple-900 dark:text-purple-200", + "icon": True, }, - 'SBNO': { - 'label': 'Standing But Not Operating', - 'classes': 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200', - 'icon': True, + "SBNO": { + "label": "Standing But Not Operating", + "classes": "bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200", + "icon": True, }, } # Default config for unknown statuses DEFAULT_STATUS_CONFIG = { - 'label': 'Unknown', - 'classes': 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300', - 'icon': False, + "label": "Unknown", + "classes": "bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300", + "icon": False, } diff --git a/backend/apps/parks/tests.py b/backend/apps/parks/tests.py index 216b75b0..516f7f73 100644 --- a/backend/apps/parks/tests.py +++ b/backend/apps/parks/tests.py @@ -31,39 +31,28 @@ class ParkTransitionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='testuser', - email='test@example.com', - password='testpass123', - role='USER' + username="testuser", email="test@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) # Create operator company self.operator = Company.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) - def _create_park(self, status='OPERATING', **kwargs): + def _create_park(self, status="OPERATING", **kwargs): """Helper to create a Park with specified status.""" defaults = { - 'name': 'Test Park', - 'slug': 'test-park', - 'description': 'A test park', - 'operator': self.operator, - 'timezone': 'America/New_York' + "name": "Test Park", + "slug": "test-park", + "description": "A test park", + "operator": self.operator, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(status=status, **defaults) @@ -74,25 +63,25 @@ class ParkTransitionTests(TestCase): def test_operating_to_closed_temp_transition(self): """Test transition from OPERATING to CLOSED_TEMP.""" - park = self._create_park(status='OPERATING') - self.assertEqual(park.status, 'OPERATING') + park = self._create_park(status="OPERATING") + self.assertEqual(park.status, "OPERATING") park.transition_to_closed_temp(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_TEMP') + self.assertEqual(park.status, "CLOSED_TEMP") def test_operating_to_closed_perm_transition(self): """Test transition from OPERATING to CLOSED_PERM.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park.transition_to_closed_perm(user=self.moderator) park.closing_date = date.today() park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_PERM') + self.assertEqual(park.status, "CLOSED_PERM") self.assertIsNotNone(park.closing_date) # ------------------------------------------------------------------------- @@ -101,14 +90,14 @@ class ParkTransitionTests(TestCase): def test_under_construction_to_operating_transition(self): """Test transition from UNDER_CONSTRUCTION to OPERATING.""" - park = self._create_park(status='UNDER_CONSTRUCTION') - self.assertEqual(park.status, 'UNDER_CONSTRUCTION') + park = self._create_park(status="UNDER_CONSTRUCTION") + self.assertEqual(park.status, "UNDER_CONSTRUCTION") park.transition_to_operating(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'OPERATING') + self.assertEqual(park.status, "OPERATING") # ------------------------------------------------------------------------- # Closed temp transitions @@ -116,24 +105,24 @@ class ParkTransitionTests(TestCase): def test_closed_temp_to_operating_transition(self): """Test transition from CLOSED_TEMP to OPERATING (reopen).""" - park = self._create_park(status='CLOSED_TEMP') + park = self._create_park(status="CLOSED_TEMP") park.transition_to_operating(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'OPERATING') + self.assertEqual(park.status, "OPERATING") def test_closed_temp_to_closed_perm_transition(self): """Test transition from CLOSED_TEMP to CLOSED_PERM.""" - park = self._create_park(status='CLOSED_TEMP') + park = self._create_park(status="CLOSED_TEMP") park.transition_to_closed_perm(user=self.moderator) park.closing_date = date.today() park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_PERM') + self.assertEqual(park.status, "CLOSED_PERM") # ------------------------------------------------------------------------- # Closed perm transitions (to final states) @@ -141,23 +130,23 @@ class ParkTransitionTests(TestCase): def test_closed_perm_to_demolished_transition(self): """Test transition from CLOSED_PERM to DEMOLISHED.""" - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") park.transition_to_demolished(user=self.moderator) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'DEMOLISHED') + self.assertEqual(park.status, "DEMOLISHED") def test_closed_perm_to_relocated_transition(self): """Test transition from CLOSED_PERM to RELOCATED.""" - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") park.transition_to_relocated(user=self.moderator) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'RELOCATED') + self.assertEqual(park.status, "RELOCATED") # ------------------------------------------------------------------------- # Invalid transitions (final states) @@ -165,28 +154,28 @@ class ParkTransitionTests(TestCase): def test_demolished_cannot_transition(self): """Test that DEMOLISHED state cannot transition further.""" - park = self._create_park(status='DEMOLISHED') + park = self._create_park(status="DEMOLISHED") with self.assertRaises(TransitionNotAllowed): park.transition_to_operating(user=self.moderator) def test_relocated_cannot_transition(self): """Test that RELOCATED state cannot transition further.""" - park = self._create_park(status='RELOCATED') + park = self._create_park(status="RELOCATED") with self.assertRaises(TransitionNotAllowed): park.transition_to_operating(user=self.moderator) def test_operating_cannot_directly_demolish(self): """Test that OPERATING cannot directly transition to DEMOLISHED.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") with self.assertRaises(TransitionNotAllowed): park.transition_to_demolished(user=self.moderator) def test_operating_cannot_directly_relocate(self): """Test that OPERATING cannot directly transition to RELOCATED.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") with self.assertRaises(TransitionNotAllowed): park.transition_to_relocated(user=self.moderator) @@ -197,69 +186,69 @@ class ParkTransitionTests(TestCase): def test_reopen_wrapper_method(self): """Test the reopen() wrapper method.""" - park = self._create_park(status='CLOSED_TEMP') + park = self._create_park(status="CLOSED_TEMP") park.reopen(user=self.user) park.refresh_from_db() - self.assertEqual(park.status, 'OPERATING') + self.assertEqual(park.status, "OPERATING") def test_close_temporarily_wrapper_method(self): """Test the close_temporarily() wrapper method.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park.close_temporarily(user=self.user) park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_TEMP') + self.assertEqual(park.status, "CLOSED_TEMP") def test_close_permanently_wrapper_method(self): """Test the close_permanently() wrapper method.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") closing = date(2025, 12, 31) park.close_permanently(closing_date=closing, user=self.moderator) park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_PERM') + self.assertEqual(park.status, "CLOSED_PERM") self.assertEqual(park.closing_date, closing) def test_close_permanently_without_date(self): """Test close_permanently() without closing_date.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park.close_permanently(user=self.moderator) park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_PERM') + self.assertEqual(park.status, "CLOSED_PERM") self.assertIsNone(park.closing_date) def test_demolish_wrapper_method(self): """Test the demolish() wrapper method.""" - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") park.demolish(user=self.moderator) park.refresh_from_db() - self.assertEqual(park.status, 'DEMOLISHED') + self.assertEqual(park.status, "DEMOLISHED") def test_relocate_wrapper_method(self): """Test the relocate() wrapper method.""" - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") park.relocate(user=self.moderator) park.refresh_from_db() - self.assertEqual(park.status, 'RELOCATED') + self.assertEqual(park.status, "RELOCATED") def test_start_construction_wrapper_method(self): """Test the start_construction() wrapper method if applicable.""" # This depends on allowed transitions - skip if not allowed try: - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park.start_construction(user=self.moderator) park.refresh_from_db() - self.assertEqual(park.status, 'UNDER_CONSTRUCTION') + self.assertEqual(park.status, "UNDER_CONSTRUCTION") except TransitionNotAllowed: # If transition from OPERATING to UNDER_CONSTRUCTION is not allowed pass @@ -276,52 +265,44 @@ class ParkTransitionHistoryTests(TestCase): def setUp(self): """Set up test fixtures.""" self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.operator = Company.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) - def _create_park(self, status='OPERATING'): + def _create_park(self, status="OPERATING"): """Helper to create a Park.""" return Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, status=status, - timezone='America/New_York' + timezone="America/New_York", ) def test_transition_creates_state_log(self): """Test that transitions create StateLog entries.""" from django_fsm_log.models import StateLog - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park.transition_to_closed_temp(user=self.moderator) park.save() park_ct = ContentType.objects.get_for_model(park) - log = StateLog.objects.filter( - content_type=park_ct, - object_id=park.id - ).first() + log = StateLog.objects.filter(content_type=park_ct, object_id=park.id).first() self.assertIsNotNone(log) - self.assertEqual(log.state, 'CLOSED_TEMP') + self.assertEqual(log.state, "CLOSED_TEMP") self.assertEqual(log.by, self.moderator) def test_multiple_transitions_create_multiple_logs(self): """Test that multiple transitions create multiple log entries.""" from django_fsm_log.models import StateLog - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park_ct = ContentType.objects.get_for_model(park) # First transition @@ -332,29 +313,23 @@ class ParkTransitionHistoryTests(TestCase): park.transition_to_operating(user=self.moderator) park.save() - logs = StateLog.objects.filter( - content_type=park_ct, - object_id=park.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=park_ct, object_id=park.id).order_by("timestamp") self.assertEqual(logs.count(), 2) - self.assertEqual(logs[0].state, 'CLOSED_TEMP') - self.assertEqual(logs[1].state, 'OPERATING') + self.assertEqual(logs[0].state, "CLOSED_TEMP") + self.assertEqual(logs[1].state, "OPERATING") def test_transition_log_includes_user(self): """Test that transition logs include the user who made the change.""" from django_fsm_log.models import StateLog - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park.transition_to_closed_perm(user=self.moderator) park.save() park_ct = ContentType.objects.get_for_model(park) - log = StateLog.objects.filter( - content_type=park_ct, - object_id=park.id - ).first() + log = StateLog.objects.filter(content_type=park_ct, object_id=park.id).first() self.assertEqual(log.by, self.moderator) @@ -370,24 +345,20 @@ class ParkBusinessLogicTests(TestCase): def setUp(self): """Set up test fixtures.""" self.operator = Company.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) self.property_owner = Company.objects.create( - name='Property Owner', - description='Property owner company', - roles=['PROPERTY_OWNER'] + name="Property Owner", description="Property owner company", roles=["PROPERTY_OWNER"] ) def test_park_creates_with_valid_operator(self): """Test park can be created with valid operator.""" park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) self.assertEqual(park.operator, self.operator) @@ -395,35 +366,32 @@ class ParkBusinessLogicTests(TestCase): def test_park_slug_auto_generated(self): """Test that park slug is auto-generated from name.""" park = Park.objects.create( - name='My Amazing Theme Park', - description='A test park', - operator=self.operator, - timezone='America/New_York' + name="My Amazing Theme Park", description="A test park", operator=self.operator, timezone="America/New_York" ) - self.assertEqual(park.slug, 'my-amazing-theme-park') + self.assertEqual(park.slug, "my-amazing-theme-park") def test_park_url_generated(self): """Test that frontend URL is generated on save.""" park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) - self.assertIn('test-park', park.url) + self.assertIn("test-park", park.url) def test_opening_year_computed_from_opening_date(self): """Test that opening_year is computed from opening_date.""" park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, opening_date=date(2020, 6, 15), - timezone='America/New_York' + timezone="America/New_York", ) self.assertEqual(park.opening_year, 2020) @@ -431,26 +399,26 @@ class ParkBusinessLogicTests(TestCase): def test_search_text_populated(self): """Test that search_text is populated on save.""" park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A wonderful theme park', + name="Test Park", + slug="test-park", + description="A wonderful theme park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) - self.assertIn('test park', park.search_text) - self.assertIn('wonderful theme park', park.search_text) - self.assertIn('test operator', park.search_text) + self.assertIn("test park", park.search_text) + self.assertIn("wonderful theme park", park.search_text) + self.assertIn("test operator", park.search_text) def test_park_with_property_owner(self): """Test park with separate property owner.""" park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, property_owner=self.property_owner, - timezone='America/New_York' + timezone="America/New_York", ) self.assertEqual(park.operator, self.operator) @@ -468,9 +436,7 @@ class ParkSlugHistoryTests(TestCase): def setUp(self): """Set up test fixtures.""" self.operator = Company.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) def test_historical_slug_created_on_name_change(self): @@ -480,25 +446,18 @@ class ParkSlugHistoryTests(TestCase): from apps.core.history import HistoricalSlug park = Park.objects.create( - name='Original Name', - description='A test park', - operator=self.operator, - timezone='America/New_York' + name="Original Name", description="A test park", operator=self.operator, timezone="America/New_York" ) original_slug = park.slug # Change name - park.name = 'New Name' + park.name = "New Name" park.save() # Check historical slug was created park_ct = ContentType.objects.get_for_model(park) - historical = HistoricalSlug.objects.filter( - content_type=park_ct, - object_id=park.id, - slug=original_slug - ).first() + historical = HistoricalSlug.objects.filter(content_type=park_ct, object_id=park.id, slug=original_slug).first() self.assertIsNotNone(historical) self.assertEqual(historical.slug, original_slug) @@ -506,14 +465,14 @@ class ParkSlugHistoryTests(TestCase): def test_get_by_slug_finds_current_slug(self): """Test get_by_slug finds park by current slug.""" park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) - found_park, is_historical = Park.get_by_slug('test-park') + found_park, is_historical = Park.get_by_slug("test-park") self.assertEqual(found_park, park) self.assertFalse(is_historical) @@ -522,16 +481,13 @@ class ParkSlugHistoryTests(TestCase): """Test get_by_slug finds park by historical slug.""" park = Park.objects.create( - name='Original Name', - description='A test park', - operator=self.operator, - timezone='America/New_York' + name="Original Name", description="A test park", operator=self.operator, timezone="America/New_York" ) original_slug = park.slug # Change name to create historical slug - park.name = 'New Name' + park.name = "New Name" park.save() # Find by historical slug diff --git a/backend/apps/parks/tests/test_park_workflows.py b/backend/apps/parks/tests/test_park_workflows.py index 6b953fad..06031e45 100644 --- a/backend/apps/parks/tests/test_park_workflows.py +++ b/backend/apps/parks/tests/test_park_workflows.py @@ -22,33 +22,24 @@ class ParkOpeningWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='park_user', - email='park_user@example.com', - password='testpass123', - role='USER' + username="park_user", email="park_user@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='park_mod', - email='park_mod@example.com', - password='testpass123', - role='MODERATOR' + username="park_mod", email="park_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_park(self, status='OPERATING', **kwargs): + def _create_park(self, status="OPERATING", **kwargs): """Helper to create a park.""" from apps.parks.models import Company, Park - operator = Company.objects.create( - name=f'Operator {status}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator {status}", roles=["OPERATOR"]) defaults = { - 'name': f'Test Park {status}', - 'slug': f'test-park-{status.lower()}-{timezone.now().timestamp()}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": f"Test Park {status}", + "slug": f"test-park-{status.lower()}-{timezone.now().timestamp()}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) @@ -59,16 +50,16 @@ class ParkOpeningWorkflowTests(TestCase): Flow: UNDER_CONSTRUCTION → OPERATING """ - park = self._create_park(status='UNDER_CONSTRUCTION') + park = self._create_park(status="UNDER_CONSTRUCTION") - self.assertEqual(park.status, 'UNDER_CONSTRUCTION') + self.assertEqual(park.status, "UNDER_CONSTRUCTION") # Park opens park.transition_to_operating(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'OPERATING') + self.assertEqual(park.status, "OPERATING") class ParkTemporaryClosureWorkflowTests(TestCase): @@ -77,26 +68,20 @@ class ParkTemporaryClosureWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='temp_closure_user', - email='temp_closure@example.com', - password='testpass123', - role='USER' + username="temp_closure_user", email="temp_closure@example.com", password="testpass123", role="USER" ) - def _create_park(self, status='OPERATING', **kwargs): + def _create_park(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park - operator = Company.objects.create( - name=f'Operator Temp {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator Temp {timezone.now().timestamp()}", roles=["OPERATOR"]) defaults = { - 'name': f'Test Park Temp {timezone.now().timestamp()}', - 'slug': f'test-park-temp-{timezone.now().timestamp()}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": f"Test Park Temp {timezone.now().timestamp()}", + "slug": f"test-park-temp-{timezone.now().timestamp()}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) @@ -107,23 +92,23 @@ class ParkTemporaryClosureWorkflowTests(TestCase): Flow: OPERATING → CLOSED_TEMP → OPERATING """ - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") - self.assertEqual(park.status, 'OPERATING') + self.assertEqual(park.status, "OPERATING") # Close temporarily (e.g., off-season) park.transition_to_closed_temp(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_TEMP') + self.assertEqual(park.status, "CLOSED_TEMP") # Reopen park.transition_to_operating(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'OPERATING') + self.assertEqual(park.status, "OPERATING") class ParkPermanentClosureWorkflowTests(TestCase): @@ -132,26 +117,20 @@ class ParkPermanentClosureWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='perm_mod', - email='perm_mod@example.com', - password='testpass123', - role='MODERATOR' + username="perm_mod", email="perm_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_park(self, status='OPERATING', **kwargs): + def _create_park(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park - operator = Company.objects.create( - name=f'Operator Perm {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator Perm {timezone.now().timestamp()}", roles=["OPERATOR"]) defaults = { - 'name': f'Test Park Perm {timezone.now().timestamp()}', - 'slug': f'test-park-perm-{timezone.now().timestamp()}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": f"Test Park Perm {timezone.now().timestamp()}", + "slug": f"test-park-perm-{timezone.now().timestamp()}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) @@ -162,7 +141,7 @@ class ParkPermanentClosureWorkflowTests(TestCase): Flow: OPERATING → CLOSED_PERM """ - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") # Close permanently park.transition_to_closed_perm(user=self.moderator) @@ -170,7 +149,7 @@ class ParkPermanentClosureWorkflowTests(TestCase): park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_PERM') + self.assertEqual(park.status, "CLOSED_PERM") self.assertIsNotNone(park.closing_date) def test_park_permanent_closure_from_temp(self): @@ -179,7 +158,7 @@ class ParkPermanentClosureWorkflowTests(TestCase): Flow: OPERATING → CLOSED_TEMP → CLOSED_PERM """ - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") # Temporary closure park.transition_to_closed_temp(user=self.moderator) @@ -191,7 +170,7 @@ class ParkPermanentClosureWorkflowTests(TestCase): park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_PERM') + self.assertEqual(park.status, "CLOSED_PERM") class ParkDemolitionWorkflowTests(TestCase): @@ -200,26 +179,20 @@ class ParkDemolitionWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='demo_mod', - email='demo_mod@example.com', - password='testpass123', - role='MODERATOR' + username="demo_mod", email="demo_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_park(self, status='CLOSED_PERM', **kwargs): + def _create_park(self, status="CLOSED_PERM", **kwargs): from apps.parks.models import Company, Park - operator = Company.objects.create( - name=f'Operator Demo {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator Demo {timezone.now().timestamp()}", roles=["OPERATOR"]) defaults = { - 'name': f'Test Park Demo {timezone.now().timestamp()}', - 'slug': f'test-park-demo-{timezone.now().timestamp()}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": f"Test Park Demo {timezone.now().timestamp()}", + "slug": f"test-park-demo-{timezone.now().timestamp()}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) @@ -230,20 +203,20 @@ class ParkDemolitionWorkflowTests(TestCase): Flow: OPERATING → CLOSED_PERM → DEMOLISHED """ - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") # Demolish park.transition_to_demolished(user=self.moderator) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'DEMOLISHED') + self.assertEqual(park.status, "DEMOLISHED") def test_demolished_is_final_state(self): """Test that demolished parks cannot transition further.""" from django_fsm import TransitionNotAllowed - park = self._create_park(status='DEMOLISHED') + park = self._create_park(status="DEMOLISHED") # Cannot transition from demolished with self.assertRaises(TransitionNotAllowed): @@ -256,26 +229,20 @@ class ParkRelocationWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='reloc_mod', - email='reloc_mod@example.com', - password='testpass123', - role='MODERATOR' + username="reloc_mod", email="reloc_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_park(self, status='CLOSED_PERM', **kwargs): + def _create_park(self, status="CLOSED_PERM", **kwargs): from apps.parks.models import Company, Park - operator = Company.objects.create( - name=f'Operator Reloc {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator Reloc {timezone.now().timestamp()}", roles=["OPERATOR"]) defaults = { - 'name': f'Test Park Reloc {timezone.now().timestamp()}', - 'slug': f'test-park-reloc-{timezone.now().timestamp()}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": f"Test Park Reloc {timezone.now().timestamp()}", + "slug": f"test-park-reloc-{timezone.now().timestamp()}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) @@ -286,20 +253,20 @@ class ParkRelocationWorkflowTests(TestCase): Flow: OPERATING → CLOSED_PERM → RELOCATED """ - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") # Relocate park.transition_to_relocated(user=self.moderator) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'RELOCATED') + self.assertEqual(park.status, "RELOCATED") def test_relocated_is_final_state(self): """Test that relocated parks cannot transition further.""" from django_fsm import TransitionNotAllowed - park = self._create_park(status='RELOCATED') + park = self._create_park(status="RELOCATED") # Cannot transition from relocated with self.assertRaises(TransitionNotAllowed): @@ -312,71 +279,62 @@ class ParkWrapperMethodTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='wrapper_user', - email='wrapper@example.com', - password='testpass123', - role='USER' + username="wrapper_user", email="wrapper@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='wrapper_mod', - email='wrapper_mod@example.com', - password='testpass123', - role='MODERATOR' + username="wrapper_mod", email="wrapper_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_park(self, status='OPERATING', **kwargs): + def _create_park(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park - operator = Company.objects.create( - name=f'Operator Wrapper {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator Wrapper {timezone.now().timestamp()}", roles=["OPERATOR"]) defaults = { - 'name': f'Test Park Wrapper {timezone.now().timestamp()}', - 'slug': f'test-park-wrapper-{timezone.now().timestamp()}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": f"Test Park Wrapper {timezone.now().timestamp()}", + "slug": f"test-park-wrapper-{timezone.now().timestamp()}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) def test_close_temporarily_wrapper(self): """Test close_temporarily wrapper method.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") # Use wrapper method if it exists - if hasattr(park, 'close_temporarily'): + if hasattr(park, "close_temporarily"): park.close_temporarily(user=self.user) else: park.transition_to_closed_temp(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_TEMP') + self.assertEqual(park.status, "CLOSED_TEMP") def test_reopen_wrapper(self): """Test reopen wrapper method.""" - park = self._create_park(status='CLOSED_TEMP') + park = self._create_park(status="CLOSED_TEMP") # Use wrapper method if it exists - if hasattr(park, 'reopen'): + if hasattr(park, "reopen"): park.reopen(user=self.user) else: park.transition_to_operating(user=self.user) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'OPERATING') + self.assertEqual(park.status, "OPERATING") def test_close_permanently_wrapper(self): """Test close_permanently wrapper method.""" - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") closing_date = timezone.now().date() # Use wrapper method if it exists - if hasattr(park, 'close_permanently'): + if hasattr(park, "close_permanently"): park.close_permanently(closing_date=closing_date, user=self.moderator) else: park.transition_to_closed_perm(user=self.moderator) @@ -384,35 +342,35 @@ class ParkWrapperMethodTests(TestCase): park.save() park.refresh_from_db() - self.assertEqual(park.status, 'CLOSED_PERM') + self.assertEqual(park.status, "CLOSED_PERM") def test_demolish_wrapper(self): """Test demolish wrapper method.""" - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") # Use wrapper method if it exists - if hasattr(park, 'demolish'): + if hasattr(park, "demolish"): park.demolish(user=self.moderator) else: park.transition_to_demolished(user=self.moderator) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'DEMOLISHED') + self.assertEqual(park.status, "DEMOLISHED") def test_relocate_wrapper(self): """Test relocate wrapper method.""" - park = self._create_park(status='CLOSED_PERM') + park = self._create_park(status="CLOSED_PERM") # Use wrapper method if it exists - if hasattr(park, 'relocate'): + if hasattr(park, "relocate"): park.relocate(user=self.moderator) else: park.transition_to_relocated(user=self.moderator) park.save() park.refresh_from_db() - self.assertEqual(park.status, 'RELOCATED') + self.assertEqual(park.status, "RELOCATED") class ParkStateLogTests(TestCase): @@ -421,32 +379,23 @@ class ParkStateLogTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='log_user', - email='log_user@example.com', - password='testpass123', - role='USER' + username="log_user", email="log_user@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='log_mod', - email='log_mod@example.com', - password='testpass123', - role='MODERATOR' + username="log_mod", email="log_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_park(self, status='OPERATING', **kwargs): + def _create_park(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park - operator = Company.objects.create( - name=f'Operator Log {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator Log {timezone.now().timestamp()}", roles=["OPERATOR"]) defaults = { - 'name': f'Test Park Log {timezone.now().timestamp()}', - 'slug': f'test-park-log-{timezone.now().timestamp()}', - 'operator': operator, - 'status': status, - 'timezone': 'America/New_York' + "name": f"Test Park Log {timezone.now().timestamp()}", + "slug": f"test-park-log-{timezone.now().timestamp()}", + "operator": operator, + "status": status, + "timezone": "America/New_York", } defaults.update(kwargs) return Park.objects.create(**defaults) @@ -456,7 +405,7 @@ class ParkStateLogTests(TestCase): from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park_ct = ContentType.objects.get_for_model(park) # Perform transition @@ -464,13 +413,10 @@ class ParkStateLogTests(TestCase): park.save() # Check log was created - log = StateLog.objects.filter( - content_type=park_ct, - object_id=park.id - ).first() + log = StateLog.objects.filter(content_type=park_ct, object_id=park.id).first() self.assertIsNotNone(log, "StateLog entry should be created") - self.assertEqual(log.state, 'CLOSED_TEMP') + self.assertEqual(log.state, "CLOSED_TEMP") self.assertEqual(log.by, self.user) def test_multiple_transitions_logged(self): @@ -478,7 +424,7 @@ class ParkStateLogTests(TestCase): from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park_ct = ContentType.objects.get_for_model(park) # First transition: OPERATING -> CLOSED_TEMP @@ -490,15 +436,12 @@ class ParkStateLogTests(TestCase): park.save() # Check multiple logs created - logs = StateLog.objects.filter( - content_type=park_ct, - object_id=park.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=park_ct, object_id=park.id).order_by("timestamp") self.assertEqual(logs.count(), 2, "Should have 2 log entries") - self.assertEqual(logs[0].state, 'CLOSED_TEMP') + self.assertEqual(logs[0].state, "CLOSED_TEMP") self.assertEqual(logs[0].by, self.user) - self.assertEqual(logs[1].state, 'CLOSED_PERM') + self.assertEqual(logs[1].state, "CLOSED_PERM") self.assertEqual(logs[1].by, self.moderator) def test_full_lifecycle_logged(self): @@ -506,7 +449,7 @@ class ParkStateLogTests(TestCase): from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - park = self._create_park(status='OPERATING') + park = self._create_park(status="OPERATING") park_ct = ContentType.objects.get_for_model(park) # Full lifecycle: OPERATING -> CLOSED_TEMP -> OPERATING -> CLOSED_PERM -> DEMOLISHED @@ -523,11 +466,8 @@ class ParkStateLogTests(TestCase): park.save() # Check all logs created - logs = StateLog.objects.filter( - content_type=park_ct, - object_id=park.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=park_ct, object_id=park.id).order_by("timestamp") self.assertEqual(logs.count(), 4, "Should have 4 log entries") states = [log.state for log in logs] - self.assertEqual(states, ['CLOSED_TEMP', 'OPERATING', 'CLOSED_PERM', 'DEMOLISHED']) + self.assertEqual(states, ["CLOSED_TEMP", "OPERATING", "CLOSED_PERM", "DEMOLISHED"]) diff --git a/backend/apps/parks/tests/test_query_optimization.py b/backend/apps/parks/tests/test_query_optimization.py index 28c7e3ae..a895efee 100644 --- a/backend/apps/parks/tests/test_query_optimization.py +++ b/backend/apps/parks/tests/test_query_optimization.py @@ -55,9 +55,7 @@ class ParkQueryOptimizationTests(TestCase): # Should be a small number of queries (main query + prefetch) # The exact count depends on prefetch_related configuration self.assertLessEqual( - len(context.captured_queries), - 5, - f"Expected <= 5 queries, got {len(context.captured_queries)}" + len(context.captured_queries), 5, f"Expected <= 5 queries, got {len(context.captured_queries)}" ) def test_optimized_for_detail_query_count(self): @@ -72,9 +70,7 @@ class ParkQueryOptimizationTests(TestCase): # Should be a reasonable number of queries self.assertLessEqual( - len(context.captured_queries), - 10, - f"Expected <= 10 queries, got {len(context.captured_queries)}" + len(context.captured_queries), 10, f"Expected <= 10 queries, got {len(context.captured_queries)}" ) def test_with_location_includes_location(self): @@ -94,10 +90,10 @@ class ParkQueryOptimizationTests(TestCase): if result.exists(): first = result.first() # Should include these fields - self.assertIn('id', first) - self.assertIn('name', first) - self.assertIn('slug', first) - self.assertIn('status', first) + self.assertIn("id", first) + self.assertIn("name", first) + self.assertIn("slug", first) + self.assertIn("status", first) def test_search_autocomplete_limits_results(self): """Verify search_autocomplete respects limit parameter.""" @@ -148,7 +144,7 @@ class CompanyQueryOptimizationTests(TestCase): if result.exists(): first = result.first() # Should have ride_count attribute - self.assertTrue(hasattr(first, 'ride_count')) + self.assertTrue(hasattr(first, "ride_count")) def test_operators_with_park_count_includes_annotation(self): """Verify operators_with_park_count adds park count annotations.""" @@ -156,7 +152,7 @@ class CompanyQueryOptimizationTests(TestCase): if result.exists(): first = result.first() # Should have operated_parks_count attribute - self.assertTrue(hasattr(first, 'operated_parks_count')) + self.assertTrue(hasattr(first, "operated_parks_count")) class ComputedFieldMaintenanceTests(TestCase): diff --git a/backend/apps/parks/tests_disabled/test_filters.py b/backend/apps/parks/tests_disabled/test_filters.py index 02f53be3..672437b3 100644 --- a/backend/apps/parks/tests_disabled/test_filters.py +++ b/backend/apps/parks/tests_disabled/test_filters.py @@ -19,12 +19,8 @@ class ParkFilterTests(TestCase): def setUpTestData(cls): """Set up test data for all filter tests""" # Create operators - cls.operator1 = Company.objects.create( - name="Thrilling Adventures Inc", slug="thrilling-adventures" - ) - cls.operator2 = Company.objects.create( - name="Family Fun Corp", slug="family-fun" - ) + cls.operator1 = Company.objects.create(name="Thrilling Adventures Inc", slug="thrilling-adventures") + cls.operator2 = Company.objects.create(name="Family Fun Corp", slug="family-fun") # Create parks with various attributes for testing all filters cls.park1 = Park.objects.create( diff --git a/backend/apps/parks/tests_disabled/test_models.py b/backend/apps/parks/tests_disabled/test_models.py index cffc0740..fb16c4e3 100644 --- a/backend/apps/parks/tests_disabled/test_models.py +++ b/backend/apps/parks/tests_disabled/test_models.py @@ -89,9 +89,7 @@ class ParkModelTests(TestCase): # Check pghistory records event_model = getattr(Park, "event_model", None) if event_model: - historical_records = event_model.objects.filter( - pgh_obj_id=park.id - ).order_by("-pgh_created_at") + historical_records = event_model.objects.filter(pgh_obj_id=park.id).order_by("-pgh_created_at") print("\nPG History records:") for record in historical_records: print(f"- Event ID: {record.pgh_id}") @@ -104,17 +102,13 @@ class ParkModelTests(TestCase): # Try to find by old slug found_park, is_historical = Park.get_by_slug(original_slug) self.assertEqual(found_park.id, park.id) - print( - f"Found park by old slug: {found_park.slug}, is_historical: {is_historical}" - ) + print(f"Found park by old slug: {found_park.slug}, is_historical: {is_historical}") self.assertTrue(is_historical) # Try current slug found_park, is_historical = Park.get_by_slug(new_slug) self.assertEqual(found_park.id, park.id) - print( - f"Found park by new slug: {found_park.slug}, is_historical: {is_historical}" - ) + print(f"Found park by new slug: {found_park.slug}, is_historical: {is_historical}") self.assertFalse(is_historical) def test_status_color_mapping(self): @@ -141,15 +135,9 @@ class ParkModelTests(TestCase): class ParkAreaModelTests(TestCase): def setUp(self): """Set up test data""" - self.operator = Company.objects.create( - name="Test Company 2", slug="test-company-2" - ) - self.park = Park.objects.create( - name="Test Park", status="OPERATING", operator=self.operator - ) - self.area = ParkArea.objects.create( - park=self.park, name="Test Area", description="A test area" - ) + self.operator = Company.objects.create(name="Test Company 2", slug="test-company-2") + self.park = Park.objects.create(name="Test Park", status="OPERATING", operator=self.operator) + self.area = ParkArea.objects.create(park=self.park, name="Test Area", description="A test area") def test_area_creation(self): """Test basic area creation and fields""" diff --git a/backend/apps/parks/views.py b/backend/apps/parks/views.py index b92cc612..a36820f0 100644 --- a/backend/apps/parks/views.py +++ b/backend/apps/parks/views.py @@ -42,9 +42,7 @@ logger = logging.getLogger(__name__) # Constants PARK_DETAIL_URL = "parks:park_detail" PARK_LIST_ITEM_TEMPLATE = "parks/partials/park_list_item.html" -REQUIRED_FIELDS_ERROR = ( - "Please correct the errors below. Required fields are marked with an asterisk (*)." -) +REQUIRED_FIELDS_ERROR = "Please correct the errors below. Required fields are marked with an asterisk (*)." TRIP_PARKS_TEMPLATE = "parks/partials/trip_parks_list.html" TRIP_SUMMARY_TEMPLATE = "parks/partials/trip_summary.html" SAVED_TRIPS_TEMPLATE = "parks/partials/saved_trips.html" @@ -87,18 +85,10 @@ def normalize_osm_result(result: dict) -> dict: neighborhood = address.get("neighbourhood", "") # Build city from available components - city = ( - address.get("city") - or address.get("town") - or address.get("village") - or address.get("municipality") - or "" - ) + city = address.get("city") or address.get("town") or address.get("village") or address.get("municipality") or "" # Get detailed state/region information - state = ( - address.get("state") or address.get("province") or address.get("region") or "" - ) + state = address.get("state") or address.get("province") or address.get("region") or "" # Get postal code with fallbacks postal_code = address.get("postcode") or address.get("postal_code") or "" @@ -170,9 +160,7 @@ def get_park_areas(request: HttpRequest) -> HttpResponse: park = Park.objects.get(id=park_id) areas = park.areas.all() options = [''] - options.extend( - [f'' for area in areas] - ) + options.extend([f'' for area in areas]) return HttpResponse("\n".join(options)) except Park.DoesNotExist: return HttpResponse('') @@ -201,11 +189,7 @@ def location_search(request: HttpRequest) -> JsonResponse: if response.status_code == 200: results = response.json() normalized_results = [normalize_osm_result(result) for result in results] - valid_results = [ - r - for r in normalized_results - if r["lat"] is not None and r["lon"] is not None - ] + valid_results = [r for r in normalized_results if r["lat"] is not None and r["lon"] is not None] return JsonResponse({"results": valid_results}) return JsonResponse({"results": []}) @@ -226,13 +210,9 @@ def reverse_geocode(request: HttpRequest) -> JsonResponse: lon = lon.quantize(Decimal("0.000001"), rounding=ROUND_DOWN) if lat < -90 or lat > 90: - return JsonResponse( - {"error": "Latitude must be between -90 and 90"}, status=400 - ) + return JsonResponse({"error": "Latitude must be between -90 and 90"}, status=400) if lon < -180 or lon > 180: - return JsonResponse( - {"error": "Longitude must be between -180 and 180"}, status=400 - ) + return JsonResponse({"error": "Longitude must be between -180 and 180"}, status=400) response = requests.get( "https://nominatim.openstreetmap.org/reverse", @@ -306,9 +286,7 @@ class ParkListView(HTMXFilterableMixin, ListView): try: # Initialize filterset if not exists if not hasattr(self, "filterset"): - self.filterset = self.filter_class( - self.request.GET, queryset=self.model.objects.none() - ) + self.filterset = self.filter_class(self.request.GET, queryset=self.model.objects.none()) context = super().get_context_data(**kwargs) @@ -323,20 +301,14 @@ class ParkListView(HTMXFilterableMixin, ListView): "search_query": self.request.GET.get("search", ""), "filter_counts": filter_counts, "popular_filters": popular_filters, - "total_results": ( - context.get("paginator").count - if context.get("paginator") - else 0 - ), + "total_results": (context.get("paginator").count if context.get("paginator") else 0), } ) # Add filter suggestions for search queries search_query = self.request.GET.get("search", "") if search_query: - context["filter_suggestions"] = ( - self.filter_service.get_filter_suggestions(search_query) - ) + context["filter_suggestions"] = self.filter_service.get_filter_suggestions(search_query) return context @@ -353,9 +325,7 @@ class ParkListView(HTMXFilterableMixin, ListView): messages.error(self.request, f"Error applying filters: {str(e)}") # Ensure filterset exists in error case if not hasattr(self, "filterset"): - self.filterset = self.filter_class( - self.request.GET, queryset=self.model.objects.none() - ) + self.filterset = self.filter_class(self.request.GET, queryset=self.model.objects.none()) return { "filter": self.filterset, "error": "Unable to apply filters. Please try adjusting your criteria.", @@ -427,9 +397,7 @@ class ParkListView(HTMXFilterableMixin, ListView): return urlencode(url_params) - def _get_pagination_urls( - self, page_obj, filter_params: dict[str, Any] - ) -> dict[str, str]: + def _get_pagination_urls(self, page_obj, filter_params: dict[str, Any]) -> dict[str, str]: """Generate pagination URLs that preserve filter state.""" base_query = self._build_filter_query_string(filter_params) @@ -476,9 +444,7 @@ def search_parks(request: HttpRequest) -> HttpResponse: # Get current view mode from request current_view_mode = request.GET.get("view_mode", "grid") - park_filter = ParkFilter( - {"search": search_query}, queryset=get_base_park_queryset() - ) + park_filter = ParkFilter({"search": search_query}, queryset=get_base_park_queryset()) parks = park_filter.qs if request.GET.get("quick_search"): @@ -747,10 +713,7 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse: rlat1, rlon1, rlat2, rlon2 = map(math.radians, [lat1, lon1, lat2, lon2]) dlat = rlat2 - rlat1 dlon = rlon2 - rlon1 - a = ( - math.sin(dlat / 2) ** 2 - + math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2 - ) + a = math.sin(dlat / 2) ** 2 + math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2 c = 2 * math.asin(min(1, math.sqrt(a))) miles = 3958.8 * c return miles @@ -762,18 +725,14 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse: lat = getattr(loc, "latitude", None) if loc else None lon = getattr(loc, "longitude", None) if loc else None if lat is not None and lon is not None: - waypoints.append( - {"id": p.id, "name": p.name, "latitude": lat, "longitude": lon} - ) + waypoints.append({"id": p.id, "name": p.name, "latitude": lat, "longitude": lon}) # sum straight-line distances between consecutive waypoints for i in range(len(waypoints) - 1): a = waypoints[i] b = waypoints[i + 1] try: - total_miles += haversine_miles( - a["latitude"], a["longitude"], b["latitude"], b["longitude"] - ) + total_miles += haversine_miles(a["latitude"], a["longitude"], b["latitude"], b["longitude"]) except Exception as e: log_exception( logger, @@ -807,9 +766,7 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse: "total_rides": sum(getattr(p, "ride_count", 0) or 0 for p in parks), } - html = render_to_string( - TRIP_SUMMARY_TEMPLATE, {"summary": summary}, request=request - ) + html = render_to_string(TRIP_SUMMARY_TEMPLATE, {"summary": summary}, request=request) resp = HttpResponse(html) # Include waypoints payload in HX-Trigger so client can render route on the map resp["HX-Trigger"] = json.dumps({"tripOptimized": {"parks": waypoints}}) @@ -843,9 +800,7 @@ def htmx_save_trip(request: HttpRequest) -> HttpResponse: # attempt to associate parks if the Trip model supports it with contextlib.suppress(Exception): trip.parks.set([p.id for p in parks]) - trips = list( - Trip.objects.filter(owner=request.user).order_by("-created_at")[:10] - ) + trips = list(Trip.objects.filter(owner=request.user).order_by("-created_at")[:10]) except Exception: trips = [] @@ -892,14 +847,10 @@ class ParkCreateView(LoginRequiredMixin, CreateView): def normalize_coordinates(self, form: ParkForm) -> None: if form.cleaned_data.get("latitude"): lat = Decimal(str(form.cleaned_data["latitude"])) - form.cleaned_data["latitude"] = lat.quantize( - Decimal("0.000001"), rounding=ROUND_DOWN - ) + form.cleaned_data["latitude"] = lat.quantize(Decimal("0.000001"), rounding=ROUND_DOWN) if form.cleaned_data.get("longitude"): lon = Decimal(str(form.cleaned_data["longitude"])) - form.cleaned_data["longitude"] = lon.quantize( - Decimal("0.000001"), rounding=ROUND_DOWN - ) + form.cleaned_data["longitude"] = lon.quantize(Decimal("0.000001"), rounding=ROUND_DOWN) def form_valid(self, form: ParkForm) -> HttpResponse: self.normalize_coordinates(form) @@ -942,8 +893,7 @@ class ParkCreateView(LoginRequiredMixin, CreateView): ) messages.success( self.request, - f"Successfully created {self.object.name}. " - f"Added {service_result['uploaded_count']} photo(s).", + f"Successfully created {self.object.name}. " f"Added {service_result['uploaded_count']} photo(s).", ) return HttpResponseRedirect(self.get_success_url()) @@ -960,8 +910,7 @@ class ParkCreateView(LoginRequiredMixin, CreateView): ) messages.success( self.request, - "Your park submission has been sent for review. " - "You will be notified when it is approved.", + "Your park submission has been sent for review. " "You will be notified when it is approved.", ) return HttpResponseRedirect(reverse("parks:park_list")) @@ -1016,14 +965,10 @@ class ParkUpdateView(LoginRequiredMixin, UpdateView): def normalize_coordinates(self, form: ParkForm) -> None: if form.cleaned_data.get("latitude"): lat = Decimal(str(form.cleaned_data["latitude"])) - form.cleaned_data["latitude"] = lat.quantize( - Decimal("0.000001"), rounding=ROUND_DOWN - ) + form.cleaned_data["latitude"] = lat.quantize(Decimal("0.000001"), rounding=ROUND_DOWN) if form.cleaned_data.get("longitude"): lon = Decimal(str(form.cleaned_data["longitude"])) - form.cleaned_data["longitude"] = lon.quantize( - Decimal("0.000001"), rounding=ROUND_DOWN - ) + form.cleaned_data["longitude"] = lon.quantize(Decimal("0.000001"), rounding=ROUND_DOWN) def form_valid(self, form: ParkForm) -> HttpResponse: self.normalize_coordinates(form) @@ -1068,8 +1013,7 @@ class ParkUpdateView(LoginRequiredMixin, UpdateView): ) messages.success( self.request, - f"Successfully updated {self.object.name}. " - f"Added {service_result['uploaded_count']} new photo(s).", + f"Successfully updated {self.object.name}. " f"Added {service_result['uploaded_count']} new photo(s).", ) return HttpResponseRedirect(self.get_success_url()) @@ -1090,9 +1034,7 @@ class ParkUpdateView(LoginRequiredMixin, UpdateView): f"Your changes to {self.object.name} have been sent for review. " "You will be notified when they are approved.", ) - return HttpResponseRedirect( - reverse(PARK_DETAIL_URL, kwargs={"slug": self.object.slug}) - ) + return HttpResponseRedirect(reverse(PARK_DETAIL_URL, kwargs={"slug": self.object.slug})) elif service_result["status"] == "failed": messages.error( @@ -1143,11 +1085,7 @@ class ParkDetailView( def get_queryset(self) -> QuerySet[Park]: return cast( QuerySet[Park], - super() - .get_queryset() - .prefetch_related( - "rides", "rides__manufacturer", "photos", "areas", "location" - ), + super().get_queryset().prefetch_related("rides", "rides__manufacturer", "photos", "areas", "location"), ) def get_context_data(self, **kwargs: Any) -> dict[str, Any]: diff --git a/backend/apps/parks/views_roadtrip.py b/backend/apps/parks/views_roadtrip.py index ff977ae1..b7696b51 100644 --- a/backend/apps/parks/views_roadtrip.py +++ b/backend/apps/parks/views_roadtrip.py @@ -119,9 +119,7 @@ class CreateTripView(RoadTripViewMixin, View): # Get parks parks = list( - Park.objects.filter( - id__in=park_ids, location__isnull=False - ).select_related("location", "operator") + Park.objects.filter(id__in=park_ids, location__isnull=False).select_related("location", "operator") ) if len(parks) != len(park_ids): @@ -159,9 +157,7 @@ class CreateTripView(RoadTripViewMixin, View): { "status": "success", "data": trip_data, - "trip_url": reverse( - "parks:roadtrip_detail", kwargs={"trip_id": "temp"} - ), + "trip_url": reverse("parks:roadtrip_detail", kwargs={"trip_id": "temp"}), } ) @@ -258,12 +254,8 @@ class FindParksAlongRouteView(RoadTripViewMixin, View): # Get start and end parks try: - start_park = Park.objects.select_related("location").get( - id=start_park_id, location__isnull=False - ) - end_park = Park.objects.select_related("location").get( - id=end_park_id, location__isnull=False - ) + start_park = Park.objects.select_related("location").get(id=start_park_id, location__isnull=False) + end_park = Park.objects.select_related("location").get(id=end_park_id, location__isnull=False) except Park.DoesNotExist: return render( request, @@ -272,21 +264,21 @@ class FindParksAlongRouteView(RoadTripViewMixin, View): ) # Find parks along route - parks_along_route = self.roadtrip_service.find_parks_along_route( - start_park, end_park, max_detour_km - ) + parks_along_route = self.roadtrip_service.find_parks_along_route(start_park, end_park, max_detour_km) # Return JSON if requested if request.headers.get("Accept") == "application/json" or request.content_type == "application/json": - return JsonResponse({ - "status": "success", - "data": { - "parks": [self._park_to_dict(p) for p in parks_along_route], - "start_park": self._park_to_dict(start_park), - "end_park": self._park_to_dict(end_park), - "count": len(parks_along_route) + return JsonResponse( + { + "status": "success", + "data": { + "parks": [self._park_to_dict(p) for p in parks_along_route], + "start_park": self._park_to_dict(start_park), + "end_park": self._park_to_dict(end_park), + "count": len(parks_along_route), + }, } - }) + ) return render( request, @@ -375,9 +367,7 @@ class GeocodeAddressView(RoadTripViewMixin, View): "longitude": coordinates.longitude, }, "address": address, - "nearby_parks": [ - loc.to_dict() for loc in map_response.locations[:20] - ], + "nearby_parks": [loc.to_dict() for loc in map_response.locations[:20]], "radius_km": radius_km, }, } @@ -418,12 +408,8 @@ class ParkDistanceCalculatorView(RoadTripViewMixin, View): # Get parks try: - park1 = Park.objects.select_related("location").get( - id=park1_id, location__isnull=False - ) - park2 = Park.objects.select_related("location").get( - id=park2_id, location__isnull=False - ) + park1 = Park.objects.select_related("location").get(id=park1_id, location__isnull=False) + park2 = Park.objects.select_related("location").get(id=park2_id, location__isnull=False) except Park.DoesNotExist: return JsonResponse( { @@ -448,9 +434,7 @@ class ParkDistanceCalculatorView(RoadTripViewMixin, View): from services.roadtrip import Coordinates - route = self.roadtrip_service.calculate_route( - Coordinates(*coords1), Coordinates(*coords2) - ) + route = self.roadtrip_service.calculate_route(Coordinates(*coords1), Coordinates(*coords2)) if not route: return JsonResponse( @@ -471,15 +455,11 @@ class ParkDistanceCalculatorView(RoadTripViewMixin, View): "formatted_duration": route.formatted_duration, "park1": { "name": park1.name, - "formatted_location": getattr( - park1, "formatted_location", "" - ), + "formatted_location": getattr(park1, "formatted_location", ""), }, "park2": { "name": park2.name, - "formatted_location": getattr( - park2, "formatted_location", "" - ), + "formatted_location": getattr(park2, "formatted_location", ""), }, }, } diff --git a/backend/apps/reviews/models.py b/backend/apps/reviews/models.py index 5ef49477..65fcdb79 100644 --- a/backend/apps/reviews/models.py +++ b/backend/apps/reviews/models.py @@ -34,14 +34,8 @@ class Review(TrackedModel): text = models.TextField(blank=True, help_text="Review text (optional)") # Metadata - is_public = models.BooleanField( - default=True, - help_text="Whether this review is visible to others" - ) - helpful_votes = models.PositiveIntegerField( - default=0, - help_text="Number of users who found this helpful" - ) + is_public = models.BooleanField(default=True, help_text="Whether this review is visible to others") + helpful_votes = models.PositiveIntegerField(default=0, help_text="Number of users who found this helpful") class Meta(TrackedModel.Meta): verbose_name = "Review" diff --git a/backend/apps/reviews/signals.py b/backend/apps/reviews/signals.py index dfa7f510..5ba630d8 100644 --- a/backend/apps/reviews/signals.py +++ b/backend/apps/reviews/signals.py @@ -17,16 +17,15 @@ def update_average_rating(sender, instance, **kwargs): return # Check if the content object has an 'average_rating' field - if not hasattr(content_object, 'average_rating'): + if not hasattr(content_object, "average_rating"): return # Calculate new average # We query the Review model filtering by content_type and object_id - avg_rating = Review.objects.filter( - content_type=instance.content_type, - object_id=instance.object_id - ).aggregate(Avg('rating'))['rating__avg'] + avg_rating = Review.objects.filter(content_type=instance.content_type, object_id=instance.object_id).aggregate( + Avg("rating") + )["rating__avg"] # Update field content_object.average_rating = avg_rating or 0 # Default to 0 if no reviews - content_object.save(update_fields=['average_rating']) + content_object.save(update_fields=["average_rating"]) diff --git a/backend/apps/rides/__init__.py b/backend/apps/rides/__init__.py index 00ff67d8..de369588 100644 --- a/backend/apps/rides/__init__.py +++ b/backend/apps/rides/__init__.py @@ -9,4 +9,4 @@ companies, rankings, and search functionality. from . import choices # Ensure choices are registered on app startup -__all__ = ['choices'] +__all__ = ["choices"] diff --git a/backend/apps/rides/admin.py b/backend/apps/rides/admin.py index 367c6db6..e7211344 100644 --- a/backend/apps/rides/admin.py +++ b/backend/apps/rides/admin.py @@ -875,12 +875,8 @@ class RideReviewAdmin(QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin) """Display moderation status with color coding.""" if obj.moderated_by: if obj.is_published: - return format_html( - 'Approved' - ) - return format_html( - 'Rejected' - ) + return format_html('Approved') + return format_html('Rejected') return format_html('Pending') def save_model(self, request, obj, form, change): @@ -987,9 +983,7 @@ class CompanyAdmin( ( "Company Details", { - "fields": ( - "founded_date", - ), + "fields": ("founded_date",), "classes": ("collapse",), "description": "Historical information about the company.", }, @@ -1024,7 +1018,7 @@ class CompanyAdmin( color = colors.get(role, "#6c757d") badges.append( f'{role}' ) return format_html("".join(badges)) diff --git a/backend/apps/rides/apps.py b/backend/apps/rides/apps.py index 0888a8b8..53c81a87 100644 --- a/backend/apps/rides/apps.py +++ b/backend/apps/rides/apps.py @@ -23,9 +23,7 @@ class RidesConfig(AppConfig): from apps.rides.models import Ride # Register FSM transitions for Ride - apply_state_machine( - Ride, field_name="status", choice_group="statuses", domain="rides" - ) + apply_state_machine(Ride, field_name="status", choice_group="statuses", domain="rides") def _register_callbacks(self): """Register FSM transition callbacks for ride models.""" @@ -41,43 +39,19 @@ class RidesConfig(AppConfig): from apps.rides.models import Ride # Cache invalidation for all ride status changes - register_callback( - Ride, 'status', '*', '*', - RideCacheInvalidation() - ) + register_callback(Ride, "status", "*", "*", RideCacheInvalidation()) # API cache invalidation - register_callback( - Ride, 'status', '*', '*', - APICacheInvalidation(include_geo_cache=True) - ) + register_callback(Ride, "status", "*", "*", APICacheInvalidation(include_geo_cache=True)) # Park count updates for status changes that affect active rides - register_callback( - Ride, 'status', '*', 'OPERATING', - ParkCountUpdateCallback() - ) - register_callback( - Ride, 'status', 'OPERATING', '*', - ParkCountUpdateCallback() - ) - register_callback( - Ride, 'status', '*', 'CLOSED_PERM', - ParkCountUpdateCallback() - ) - register_callback( - Ride, 'status', '*', 'DEMOLISHED', - ParkCountUpdateCallback() - ) - register_callback( - Ride, 'status', '*', 'RELOCATED', - ParkCountUpdateCallback() - ) + register_callback(Ride, "status", "*", "OPERATING", ParkCountUpdateCallback()) + register_callback(Ride, "status", "OPERATING", "*", ParkCountUpdateCallback()) + register_callback(Ride, "status", "*", "CLOSED_PERM", ParkCountUpdateCallback()) + register_callback(Ride, "status", "*", "DEMOLISHED", ParkCountUpdateCallback()) + register_callback(Ride, "status", "*", "RELOCATED", ParkCountUpdateCallback()) # Search text update - register_callback( - Ride, 'status', '*', '*', - SearchTextUpdateCallback() - ) + register_callback(Ride, "status", "*", "*", SearchTextUpdateCallback()) logger.debug("Registered ride transition callbacks") diff --git a/backend/apps/rides/choices.py b/backend/apps/rides/choices.py index 69a80308..62b4d15b 100644 --- a/backend/apps/rides/choices.py +++ b/backend/apps/rides/choices.py @@ -14,73 +14,48 @@ RIDE_CATEGORIES = [ value="RC", label="Roller Coaster", description="Thrill rides with tracks featuring hills, loops, and high speeds", - metadata={ - 'color': 'red', - 'icon': 'roller-coaster', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 1 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "red", "icon": "roller-coaster", "css_class": "bg-red-100 text-red-800", "sort_order": 1}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="DR", label="Dark Ride", description="Indoor rides with themed environments and storytelling", metadata={ - 'color': 'purple', - 'icon': 'dark-ride', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 2 + "color": "purple", + "icon": "dark-ride", + "css_class": "bg-purple-100 text-purple-800", + "sort_order": 2, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="FR", label="Flat Ride", description="Rides that move along a generally flat plane with spinning, swinging, or bouncing motions", - metadata={ - 'color': 'blue', - 'icon': 'flat-ride', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 3 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "blue", "icon": "flat-ride", "css_class": "bg-blue-100 text-blue-800", "sort_order": 3}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="WR", label="Water Ride", description="Rides that incorporate water elements like splashing, floating, or getting wet", - metadata={ - 'color': 'cyan', - 'icon': 'water-ride', - 'css_class': 'bg-cyan-100 text-cyan-800', - 'sort_order': 4 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "cyan", "icon": "water-ride", "css_class": "bg-cyan-100 text-cyan-800", "sort_order": 4}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="TR", label="Transport Ride", description="Rides primarily designed for transportation around the park", - metadata={ - 'color': 'green', - 'icon': 'transport', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 5 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "green", "icon": "transport", "css_class": "bg-green-100 text-green-800", "sort_order": 5}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="OT", label="Other", description="Rides that don't fit into standard categories", - metadata={ - 'color': 'gray', - 'icon': 'other', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 6 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 6}, + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -91,140 +66,140 @@ RIDE_STATUSES = [ label="Operating", description="Ride is currently open and operating normally", metadata={ - 'color': 'green', - 'icon': 'check-circle', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 1, - 'can_transition_to': [ - 'CLOSED_TEMP', - 'SBNO', - 'CLOSING', + "color": "green", + "icon": "check-circle", + "css_class": "bg-green-100 text-green-800", + "sort_order": 1, + "can_transition_to": [ + "CLOSED_TEMP", + "SBNO", + "CLOSING", ], - 'requires_moderator': False, - 'is_final': False, - 'is_initial': True, + "requires_moderator": False, + "is_final": False, + "is_initial": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CLOSED_TEMP", label="Temporarily Closed", description="Ride is temporarily closed for maintenance, weather, or other short-term reasons", metadata={ - 'color': 'yellow', - 'icon': 'pause-circle', - 'css_class': 'bg-yellow-100 text-yellow-800', - 'sort_order': 2, - 'can_transition_to': [ - 'SBNO', - 'CLOSING', + "color": "yellow", + "icon": "pause-circle", + "css_class": "bg-yellow-100 text-yellow-800", + "sort_order": 2, + "can_transition_to": [ + "SBNO", + "CLOSING", ], - 'requires_moderator': False, - 'is_final': False, + "requires_moderator": False, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="SBNO", label="Standing But Not Operating", description="Ride structure remains but is not currently operating", metadata={ - 'color': 'orange', - 'icon': 'stop-circle', - 'css_class': 'bg-orange-100 text-orange-800', - 'sort_order': 3, - 'can_transition_to': [ - 'CLOSED_PERM', - 'DEMOLISHED', - 'RELOCATED', + "color": "orange", + "icon": "stop-circle", + "css_class": "bg-orange-100 text-orange-800", + "sort_order": 3, + "can_transition_to": [ + "CLOSED_PERM", + "DEMOLISHED", + "RELOCATED", ], - 'requires_moderator': True, - 'is_final': False, + "requires_moderator": True, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CLOSING", label="Closing", description="Ride is scheduled to close permanently", metadata={ - 'color': 'red', - 'icon': 'x-circle', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 4, - 'can_transition_to': [ - 'CLOSED_PERM', - 'SBNO', + "color": "red", + "icon": "x-circle", + "css_class": "bg-red-100 text-red-800", + "sort_order": 4, + "can_transition_to": [ + "CLOSED_PERM", + "SBNO", ], - 'requires_moderator': True, - 'is_final': False, + "requires_moderator": True, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CLOSED_PERM", label="Permanently Closed", description="Ride has been permanently closed and will not reopen", metadata={ - 'color': 'red', - 'icon': 'x-circle', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 5, - 'can_transition_to': [ - 'DEMOLISHED', - 'RELOCATED', + "color": "red", + "icon": "x-circle", + "css_class": "bg-red-100 text-red-800", + "sort_order": 5, + "can_transition_to": [ + "DEMOLISHED", + "RELOCATED", ], - 'requires_moderator': True, - 'is_final': False, + "requires_moderator": True, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="UNDER_CONSTRUCTION", label="Under Construction", description="Ride is currently being built or undergoing major renovation", metadata={ - 'color': 'blue', - 'icon': 'tool', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 6, - 'can_transition_to': [ - 'OPERATING', + "color": "blue", + "icon": "tool", + "css_class": "bg-blue-100 text-blue-800", + "sort_order": 6, + "can_transition_to": [ + "OPERATING", ], - 'requires_moderator': False, - 'is_final': False, + "requires_moderator": False, + "is_final": False, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="DEMOLISHED", label="Demolished", description="Ride has been completely removed and demolished", metadata={ - 'color': 'gray', - 'icon': 'trash', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 7, - 'can_transition_to': [], - 'requires_moderator': True, - 'is_final': True, + "color": "gray", + "icon": "trash", + "css_class": "bg-gray-100 text-gray-800", + "sort_order": 7, + "can_transition_to": [], + "requires_moderator": True, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="RELOCATED", label="Relocated", description="Ride has been moved to a different location", metadata={ - 'color': 'purple', - 'icon': 'arrow-right', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 8, - 'can_transition_to': [], - 'requires_moderator': True, - 'is_final': True, + "color": "purple", + "icon": "arrow-right", + "css_class": "bg-purple-100 text-purple-800", + "sort_order": 8, + "can_transition_to": [], + "requires_moderator": True, + "is_final": True, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), ] @@ -235,24 +210,19 @@ POST_CLOSING_STATUSES = [ label="Standing But Not Operating", description="Ride structure remains but is not operating after closure", metadata={ - 'color': 'orange', - 'icon': 'stop-circle', - 'css_class': 'bg-orange-100 text-orange-800', - 'sort_order': 1 + "color": "orange", + "icon": "stop-circle", + "css_class": "bg-orange-100 text-orange-800", + "sort_order": 1, }, - category=ChoiceCategory.STATUS + category=ChoiceCategory.STATUS, ), RichChoice( value="CLOSED_PERM", label="Permanently Closed", description="Ride has been permanently closed after the closing date", - metadata={ - 'color': 'red', - 'icon': 'x-circle', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 2 - }, - category=ChoiceCategory.STATUS + metadata={"color": "red", "icon": "x-circle", "css_class": "bg-red-100 text-red-800", "sort_order": 2}, + category=ChoiceCategory.STATUS, ), ] @@ -262,37 +232,22 @@ TRACK_MATERIALS = [ value="STEEL", label="Steel", description="Modern steel track construction providing smooth rides and complex layouts", - metadata={ - 'color': 'gray', - 'icon': 'steel', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 1 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "gray", "icon": "steel", "css_class": "bg-gray-100 text-gray-800", "sort_order": 1}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="WOOD", label="Wood", description="Traditional wooden track construction providing classic coaster experience", - metadata={ - 'color': 'amber', - 'icon': 'wood', - 'css_class': 'bg-amber-100 text-amber-800', - 'sort_order': 2 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "amber", "icon": "wood", "css_class": "bg-amber-100 text-amber-800", "sort_order": 2}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="HYBRID", label="Hybrid", description="Combination of steel and wooden construction elements", - metadata={ - 'color': 'orange', - 'icon': 'hybrid', - 'css_class': 'bg-orange-100 text-orange-800', - 'sort_order': 3 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "orange", "icon": "hybrid", "css_class": "bg-orange-100 text-orange-800", "sort_order": 3}, + category=ChoiceCategory.TECHNICAL, ), ] @@ -302,133 +257,83 @@ COASTER_TYPES = [ value="SITDOWN", label="Sit Down", description="Traditional seated roller coaster with riders sitting upright", - metadata={ - 'color': 'blue', - 'icon': 'sitdown', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 1 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "blue", "icon": "sitdown", "css_class": "bg-blue-100 text-blue-800", "sort_order": 1}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="INVERTED", label="Inverted", description="Coaster where riders' feet dangle freely below the track", - metadata={ - 'color': 'purple', - 'icon': 'inverted', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 2 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "purple", "icon": "inverted", "css_class": "bg-purple-100 text-purple-800", "sort_order": 2}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="FLYING", label="Flying", description="Riders lie face-down in a flying position", - metadata={ - 'color': 'sky', - 'icon': 'flying', - 'css_class': 'bg-sky-100 text-sky-800', - 'sort_order': 3 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "sky", "icon": "flying", "css_class": "bg-sky-100 text-sky-800", "sort_order": 3}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="STANDUP", label="Stand Up", description="Riders stand upright during the ride", - metadata={ - 'color': 'green', - 'icon': 'standup', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 4 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "green", "icon": "standup", "css_class": "bg-green-100 text-green-800", "sort_order": 4}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="WING", label="Wing", description="Riders sit on either side of the track with nothing above or below", - metadata={ - 'color': 'indigo', - 'icon': 'wing', - 'css_class': 'bg-indigo-100 text-indigo-800', - 'sort_order': 5 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "indigo", "icon": "wing", "css_class": "bg-indigo-100 text-indigo-800", "sort_order": 5}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="DIVE", label="Dive", description="Features a vertical or near-vertical drop as the main element", - metadata={ - 'color': 'red', - 'icon': 'dive', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 6 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "red", "icon": "dive", "css_class": "bg-red-100 text-red-800", "sort_order": 6}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="FAMILY", label="Family", description="Designed for riders of all ages with moderate thrills", metadata={ - 'color': 'emerald', - 'icon': 'family', - 'css_class': 'bg-emerald-100 text-emerald-800', - 'sort_order': 7 + "color": "emerald", + "icon": "family", + "css_class": "bg-emerald-100 text-emerald-800", + "sort_order": 7, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="WILD_MOUSE", label="Wild Mouse", description="Compact coaster with sharp turns and sudden drops", - metadata={ - 'color': 'yellow', - 'icon': 'mouse', - 'css_class': 'bg-yellow-100 text-yellow-800', - 'sort_order': 8 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "yellow", "icon": "mouse", "css_class": "bg-yellow-100 text-yellow-800", "sort_order": 8}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="SPINNING", label="Spinning", description="Cars rotate freely during the ride", - metadata={ - 'color': 'pink', - 'icon': 'spinning', - 'css_class': 'bg-pink-100 text-pink-800', - 'sort_order': 9 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "pink", "icon": "spinning", "css_class": "bg-pink-100 text-pink-800", "sort_order": 9}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="FOURTH_DIMENSION", label="4th Dimension", description="Seats rotate independently of the track direction", - metadata={ - 'color': 'violet', - 'icon': '4d', - 'css_class': 'bg-violet-100 text-violet-800', - 'sort_order': 10 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "violet", "icon": "4d", "css_class": "bg-violet-100 text-violet-800", "sort_order": 10}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="OTHER", label="Other", description="Coaster type that doesn't fit standard classifications", - metadata={ - 'color': 'gray', - 'icon': 'other', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 11 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 11}, + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -438,61 +343,36 @@ PROPULSION_SYSTEMS = [ value="CHAIN", label="Chain Lift", description="Traditional chain lift system to pull trains up the lift hill", - metadata={ - 'color': 'gray', - 'icon': 'chain', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 1 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "gray", "icon": "chain", "css_class": "bg-gray-100 text-gray-800", "sort_order": 1}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="LSM", label="LSM Launch", description="Linear Synchronous Motor launch system using magnetic propulsion", - metadata={ - 'color': 'blue', - 'icon': 'lightning', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 2 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "blue", "icon": "lightning", "css_class": "bg-blue-100 text-blue-800", "sort_order": 2}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="HYDRAULIC", label="Hydraulic Launch", description="High-pressure hydraulic launch system for rapid acceleration", - metadata={ - 'color': 'red', - 'icon': 'hydraulic', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 3 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "red", "icon": "hydraulic", "css_class": "bg-red-100 text-red-800", "sort_order": 3}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="GRAVITY", label="Gravity", description="Uses gravity and momentum without mechanical lift systems", - metadata={ - 'color': 'green', - 'icon': 'gravity', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 4 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "green", "icon": "gravity", "css_class": "bg-green-100 text-green-800", "sort_order": 4}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="OTHER", label="Other", description="Propulsion system that doesn't fit standard categories", - metadata={ - 'color': 'gray', - 'icon': 'other', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 5 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 5}, + category=ChoiceCategory.TECHNICAL, ), ] @@ -502,61 +382,36 @@ TARGET_MARKETS = [ value="FAMILY", label="Family", description="Designed for families with children, moderate thrills", - metadata={ - 'color': 'green', - 'icon': 'family', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 1 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "green", "icon": "family", "css_class": "bg-green-100 text-green-800", "sort_order": 1}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="THRILL", label="Thrill", description="High-intensity rides for thrill seekers", - metadata={ - 'color': 'red', - 'icon': 'thrill', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 2 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "red", "icon": "thrill", "css_class": "bg-red-100 text-red-800", "sort_order": 2}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="EXTREME", label="Extreme", description="Maximum intensity rides for extreme thrill seekers", - metadata={ - 'color': 'purple', - 'icon': 'extreme', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 3 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "purple", "icon": "extreme", "css_class": "bg-purple-100 text-purple-800", "sort_order": 3}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="KIDDIE", label="Kiddie", description="Gentle rides designed specifically for young children", - metadata={ - 'color': 'yellow', - 'icon': 'kiddie', - 'css_class': 'bg-yellow-100 text-yellow-800', - 'sort_order': 4 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "yellow", "icon": "kiddie", "css_class": "bg-yellow-100 text-yellow-800", "sort_order": 4}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="ALL_AGES", label="All Ages", description="Suitable for riders of all ages and thrill preferences", - metadata={ - 'color': 'blue', - 'icon': 'all-ages', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 5 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "blue", "icon": "all-ages", "css_class": "bg-blue-100 text-blue-800", "sort_order": 5}, + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -566,61 +421,41 @@ PHOTO_TYPES = [ value="PROMOTIONAL", label="Promotional", description="Marketing and promotional photos of the ride model", - metadata={ - 'color': 'blue', - 'icon': 'camera', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 1 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "blue", "icon": "camera", "css_class": "bg-blue-100 text-blue-800", "sort_order": 1}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="TECHNICAL", label="Technical Drawing", description="Technical drawings and engineering diagrams", - metadata={ - 'color': 'gray', - 'icon': 'blueprint', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 2 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "gray", "icon": "blueprint", "css_class": "bg-gray-100 text-gray-800", "sort_order": 2}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="INSTALLATION", label="Installation Example", description="Photos of actual installations of this ride model", metadata={ - 'color': 'green', - 'icon': 'installation', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 3 + "color": "green", + "icon": "installation", + "css_class": "bg-green-100 text-green-800", + "sort_order": 3, }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="RENDERING", label="3D Rendering", description="Computer-generated 3D renderings of the ride model", - metadata={ - 'color': 'purple', - 'icon': 'cube', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 4 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "purple", "icon": "cube", "css_class": "bg-purple-100 text-purple-800", "sort_order": 4}, + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="CATALOG", label="Catalog Image", description="Official catalog and brochure images", - metadata={ - 'color': 'orange', - 'icon': 'catalog', - 'css_class': 'bg-orange-100 text-orange-800', - 'sort_order': 5 - }, - category=ChoiceCategory.CLASSIFICATION + metadata={"color": "orange", "icon": "catalog", "css_class": "bg-orange-100 text-orange-800", "sort_order": 5}, + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -630,97 +465,62 @@ SPEC_CATEGORIES = [ value="DIMENSIONS", label="Dimensions", description="Physical dimensions and measurements", - metadata={ - 'color': 'blue', - 'icon': 'ruler', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 1 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "blue", "icon": "ruler", "css_class": "bg-blue-100 text-blue-800", "sort_order": 1}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="PERFORMANCE", label="Performance", description="Performance specifications and capabilities", - metadata={ - 'color': 'red', - 'icon': 'speedometer', - 'css_class': 'bg-red-100 text-red-800', - 'sort_order': 2 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "red", "icon": "speedometer", "css_class": "bg-red-100 text-red-800", "sort_order": 2}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="CAPACITY", label="Capacity", description="Rider capacity and throughput specifications", - metadata={ - 'color': 'green', - 'icon': 'users', - 'css_class': 'bg-green-100 text-green-800', - 'sort_order': 3 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "green", "icon": "users", "css_class": "bg-green-100 text-green-800", "sort_order": 3}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="SAFETY", label="Safety Features", description="Safety systems and features", - metadata={ - 'color': 'yellow', - 'icon': 'shield', - 'css_class': 'bg-yellow-100 text-yellow-800', - 'sort_order': 4 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "yellow", "icon": "shield", "css_class": "bg-yellow-100 text-yellow-800", "sort_order": 4}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="ELECTRICAL", label="Electrical Requirements", description="Power and electrical system requirements", metadata={ - 'color': 'purple', - 'icon': 'lightning', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 5 + "color": "purple", + "icon": "lightning", + "css_class": "bg-purple-100 text-purple-800", + "sort_order": 5, }, - category=ChoiceCategory.TECHNICAL + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="FOUNDATION", label="Foundation Requirements", description="Foundation and structural requirements", - metadata={ - 'color': 'gray', - 'icon': 'foundation', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 6 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "gray", "icon": "foundation", "css_class": "bg-gray-100 text-gray-800", "sort_order": 6}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="MAINTENANCE", label="Maintenance", description="Maintenance requirements and procedures", - metadata={ - 'color': 'orange', - 'icon': 'wrench', - 'css_class': 'bg-orange-100 text-orange-800', - 'sort_order': 7 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "orange", "icon": "wrench", "css_class": "bg-orange-100 text-orange-800", "sort_order": 7}, + category=ChoiceCategory.TECHNICAL, ), RichChoice( value="OTHER", label="Other", description="Other technical specifications", - metadata={ - 'color': 'gray', - 'icon': 'other', - 'css_class': 'bg-gray-100 text-gray-800', - 'sort_order': 8 - }, - category=ChoiceCategory.TECHNICAL + metadata={"color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 8}, + category=ChoiceCategory.TECHNICAL, ), ] @@ -731,30 +531,30 @@ RIDES_COMPANY_ROLES = [ label="Ride Manufacturer", description="Company that designs and builds ride hardware and systems", metadata={ - 'color': 'blue', - 'icon': 'factory', - 'css_class': 'bg-blue-100 text-blue-800', - 'sort_order': 1, - 'domain': 'rides', - 'permissions': ['manage_ride_models', 'view_manufacturing'], - 'url_pattern': '/rides/manufacturers/{slug}/' + "color": "blue", + "icon": "factory", + "css_class": "bg-blue-100 text-blue-800", + "sort_order": 1, + "domain": "rides", + "permissions": ["manage_ride_models", "view_manufacturing"], + "url_pattern": "/rides/manufacturers/{slug}/", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), RichChoice( value="DESIGNER", label="Ride Designer", description="Company that specializes in ride design, layout, and engineering", metadata={ - 'color': 'purple', - 'icon': 'design', - 'css_class': 'bg-purple-100 text-purple-800', - 'sort_order': 2, - 'domain': 'rides', - 'permissions': ['manage_ride_designs', 'view_design_specs'], - 'url_pattern': '/rides/designers/{slug}/' + "color": "purple", + "icon": "design", + "css_class": "bg-purple-100 text-purple-800", + "sort_order": 2, + "domain": "rides", + "permissions": ["manage_ride_designs", "view_design_specs"], + "url_pattern": "/rides/designers/{slug}/", }, - category=ChoiceCategory.CLASSIFICATION + category=ChoiceCategory.CLASSIFICATION, ), ] @@ -767,7 +567,7 @@ def register_rides_choices(): choices=RIDE_CATEGORIES, domain="rides", description="Ride category classifications", - metadata={'domain': 'rides', 'type': 'category'} + metadata={"domain": "rides", "type": "category"}, ) register_choices( @@ -775,7 +575,7 @@ def register_rides_choices(): choices=RIDE_STATUSES, domain="rides", description="Ride operational status options", - metadata={'domain': 'rides', 'type': 'status'} + metadata={"domain": "rides", "type": "status"}, ) register_choices( @@ -783,7 +583,7 @@ def register_rides_choices(): choices=POST_CLOSING_STATUSES, domain="rides", description="Status options after ride closure", - metadata={'domain': 'rides', 'type': 'post_closing_status'} + metadata={"domain": "rides", "type": "post_closing_status"}, ) register_choices( @@ -791,7 +591,7 @@ def register_rides_choices(): choices=TRACK_MATERIALS, domain="rides", description="Roller coaster track material types", - metadata={'domain': 'rides', 'type': 'track_material', 'applies_to': 'roller_coasters'} + metadata={"domain": "rides", "type": "track_material", "applies_to": "roller_coasters"}, ) register_choices( @@ -799,7 +599,7 @@ def register_rides_choices(): choices=COASTER_TYPES, domain="rides", description="Roller coaster type classifications", - metadata={'domain': 'rides', 'type': 'coaster_type', 'applies_to': 'roller_coasters'} + metadata={"domain": "rides", "type": "coaster_type", "applies_to": "roller_coasters"}, ) register_choices( @@ -807,7 +607,7 @@ def register_rides_choices(): choices=PROPULSION_SYSTEMS, domain="rides", description="Roller coaster propulsion and lift systems", - metadata={'domain': 'rides', 'type': 'propulsion_system', 'applies_to': 'roller_coasters'} + metadata={"domain": "rides", "type": "propulsion_system", "applies_to": "roller_coasters"}, ) register_choices( @@ -815,7 +615,7 @@ def register_rides_choices(): choices=TARGET_MARKETS, domain="rides", description="Target market classifications for ride models", - metadata={'domain': 'rides', 'type': 'target_market', 'applies_to': 'ride_models'} + metadata={"domain": "rides", "type": "target_market", "applies_to": "ride_models"}, ) register_choices( @@ -823,7 +623,7 @@ def register_rides_choices(): choices=PHOTO_TYPES, domain="rides", description="Photo type classifications for ride model images", - metadata={'domain': 'rides', 'type': 'photo_type', 'applies_to': 'ride_model_photos'} + metadata={"domain": "rides", "type": "photo_type", "applies_to": "ride_model_photos"}, ) register_choices( @@ -831,7 +631,7 @@ def register_rides_choices(): choices=SPEC_CATEGORIES, domain="rides", description="Technical specification category classifications", - metadata={'domain': 'rides', 'type': 'spec_category', 'applies_to': 'ride_model_specs'} + metadata={"domain": "rides", "type": "spec_category", "applies_to": "ride_model_specs"}, ) register_choices( @@ -839,7 +639,7 @@ def register_rides_choices(): choices=RIDES_COMPANY_ROLES, domain="rides", description="Company role classifications for rides domain (MANUFACTURER and DESIGNER only)", - metadata={'domain': 'rides', 'type': 'company_role'} + metadata={"domain": "rides", "type": "company_role"}, ) diff --git a/backend/apps/rides/events.py b/backend/apps/rides/events.py index b0bcc805..db9120b5 100644 --- a/backend/apps/rides/events.py +++ b/backend/apps/rides/events.py @@ -1,5 +1,3 @@ - - def get_ride_display_changes(changes: dict) -> dict: """Returns a human-readable version of the ride changes""" field_names = { diff --git a/backend/apps/rides/forms.py b/backend/apps/rides/forms.py index b225a3bb..febeb46a 100644 --- a/backend/apps/rides/forms.py +++ b/backend/apps/rides/forms.py @@ -348,9 +348,7 @@ class RideForm(forms.ModelForm): # editing if self.instance and self.instance.pk: if self.instance.manufacturer: - self.fields["manufacturer_search"].initial = ( - self.instance.manufacturer.name - ) + self.fields["manufacturer_search"].initial = self.instance.manufacturer.name self.fields["manufacturer"].initial = self.instance.manufacturer if self.instance.designer: self.fields["designer_search"].initial = self.instance.designer.name diff --git a/backend/apps/rides/forms/base.py b/backend/apps/rides/forms/base.py index bb64a378..f09fbdf2 100644 --- a/backend/apps/rides/forms/base.py +++ b/backend/apps/rides/forms/base.py @@ -348,9 +348,7 @@ class RideForm(forms.ModelForm): # editing if self.instance and self.instance.pk: if self.instance.manufacturer: - self.fields["manufacturer_search"].initial = ( - self.instance.manufacturer.name - ) + self.fields["manufacturer_search"].initial = self.instance.manufacturer.name self.fields["manufacturer"].initial = self.instance.manufacturer if self.instance.designer: self.fields["designer_search"].initial = self.instance.designer.name diff --git a/backend/apps/rides/forms/search.py b/backend/apps/rides/forms/search.py index 18ae69c2..37f10571 100644 --- a/backend/apps/rides/forms/search.py +++ b/backend/apps/rides/forms/search.py @@ -105,8 +105,8 @@ class BasicInfoForm(BaseFilterForm): status_choices = [(choice.value, choice.label) for choice in get_choices("statuses", "rides")] # Update field choices dynamically - self.fields['category'].choices = category_choices - self.fields['status'].choices = status_choices + self.fields["category"].choices = category_choices + self.fields["status"].choices = status_choices category = forms.MultipleChoiceField( choices=[], # Will be populated in __init__ @@ -123,17 +123,13 @@ class BasicInfoForm(BaseFilterForm): park = forms.ModelMultipleChoiceField( queryset=Park.objects.all(), required=False, - widget=forms.CheckboxSelectMultiple( - attrs={"class": "max-h-48 overflow-y-auto space-y-1"} - ), + widget=forms.CheckboxSelectMultiple(attrs={"class": "max-h-48 overflow-y-auto space-y-1"}), ) park_area = forms.ModelMultipleChoiceField( queryset=ParkArea.objects.all(), required=False, - widget=forms.CheckboxSelectMultiple( - attrs={"class": "max-h-48 overflow-y-auto space-y-1"} - ), + widget=forms.CheckboxSelectMultiple(attrs={"class": "max-h-48 overflow-y-auto space-y-1"}), ) @@ -266,7 +262,7 @@ class NumberRangeField(forms.MultiValueField): def validate(self, value): super().validate(value) - if value and value.get("min") is not None and value.get("max") is not None: + if value and value.get("min") is not None and value.get("max") is not None: # noqa: SIM102 if value["min"] > value["max"]: raise ValidationError("Minimum value must be less than maximum value.") @@ -282,17 +278,13 @@ class HeightSafetyForm(BaseFilterForm): label="Minimum Height (inches)", ) - max_height_range = NumberRangeField( - min_val=0, max_val=84, step=1, required=False, label="Maximum Height (inches)" - ) + max_height_range = NumberRangeField(min_val=0, max_val=84, step=1, required=False, label="Maximum Height (inches)") class PerformanceForm(BaseFilterForm): """Form for performance metric filters.""" - capacity_range = NumberRangeField( - min_val=0, max_val=5000, step=50, required=False, label="Capacity per Hour" - ) + capacity_range = NumberRangeField(min_val=0, max_val=5000, step=50, required=False, label="Capacity per Hour") duration_range = NumberRangeField( min_val=0, @@ -302,9 +294,7 @@ class PerformanceForm(BaseFilterForm): label="Duration (seconds)", ) - rating_range = NumberRangeField( - min_val=0.0, max_val=10.0, step=0.1, required=False, label="Average Rating" - ) + rating_range = NumberRangeField(min_val=0.0, max_val=10.0, step=0.1, required=False, label="Average Rating") class RelationshipsForm(BaseFilterForm): @@ -313,25 +303,19 @@ class RelationshipsForm(BaseFilterForm): manufacturer = forms.ModelMultipleChoiceField( queryset=Company.objects.filter(roles__contains=["MANUFACTURER"]), required=False, - widget=forms.CheckboxSelectMultiple( - attrs={"class": "max-h-48 overflow-y-auto space-y-1"} - ), + widget=forms.CheckboxSelectMultiple(attrs={"class": "max-h-48 overflow-y-auto space-y-1"}), ) designer = forms.ModelMultipleChoiceField( queryset=Company.objects.filter(roles__contains=["DESIGNER"]), required=False, - widget=forms.CheckboxSelectMultiple( - attrs={"class": "max-h-48 overflow-y-auto space-y-1"} - ), + widget=forms.CheckboxSelectMultiple(attrs={"class": "max-h-48 overflow-y-auto space-y-1"}), ) ride_model = forms.ModelMultipleChoiceField( queryset=RideModel.objects.all(), required=False, - widget=forms.CheckboxSelectMultiple( - attrs={"class": "max-h-48 overflow-y-auto space-y-1"} - ), + widget=forms.CheckboxSelectMultiple(attrs={"class": "max-h-48 overflow-y-auto space-y-1"}), ) @@ -347,28 +331,22 @@ class RollerCoasterForm(BaseFilterForm): # Get choices - let exceptions propagate if registry fails track_material_choices = [(choice.value, choice.label) for choice in get_choices("track_materials", "rides")] coaster_type_choices = [(choice.value, choice.label) for choice in get_choices("coaster_types", "rides")] - propulsion_system_choices = [(choice.value, choice.label) for choice in get_choices("propulsion_systems", "rides")] + propulsion_system_choices = [ + (choice.value, choice.label) for choice in get_choices("propulsion_systems", "rides") + ] # Update field choices dynamically - self.fields['track_material'].choices = track_material_choices - self.fields['coaster_type'].choices = coaster_type_choices - self.fields['propulsion_system'].choices = propulsion_system_choices + self.fields["track_material"].choices = track_material_choices + self.fields["coaster_type"].choices = coaster_type_choices + self.fields["propulsion_system"].choices = propulsion_system_choices - height_ft_range = NumberRangeField( - min_val=0, max_val=500, step=1, required=False, label="Height (feet)" - ) + height_ft_range = NumberRangeField(min_val=0, max_val=500, step=1, required=False, label="Height (feet)") - length_ft_range = NumberRangeField( - min_val=0, max_val=10000, step=10, required=False, label="Length (feet)" - ) + length_ft_range = NumberRangeField(min_val=0, max_val=10000, step=10, required=False, label="Length (feet)") - speed_mph_range = NumberRangeField( - min_val=0, max_val=150, step=1, required=False, label="Speed (mph)" - ) + speed_mph_range = NumberRangeField(min_val=0, max_val=150, step=1, required=False, label="Speed (mph)") - inversions_range = NumberRangeField( - min_val=0, max_val=20, step=1, required=False, label="Number of Inversions" - ) + inversions_range = NumberRangeField(min_val=0, max_val=20, step=1, required=False, label="Number of Inversions") track_material = forms.MultipleChoiceField( choices=[], # Will be populated in __init__ @@ -379,17 +357,13 @@ class RollerCoasterForm(BaseFilterForm): coaster_type = forms.MultipleChoiceField( choices=[], # Will be populated in __init__ required=False, - widget=forms.CheckboxSelectMultiple( - attrs={"class": "grid grid-cols-2 gap-2 max-h-48 overflow-y-auto"} - ), + widget=forms.CheckboxSelectMultiple(attrs={"class": "grid grid-cols-2 gap-2 max-h-48 overflow-y-auto"}), ) propulsion_system = forms.MultipleChoiceField( choices=[], # Will be populated in __init__ required=False, - widget=forms.CheckboxSelectMultiple( - attrs={"class": "space-y-2 max-h-48 overflow-y-auto"} - ), + widget=forms.CheckboxSelectMultiple(attrs={"class": "space-y-2 max-h-48 overflow-y-auto"}), ) @@ -408,8 +382,8 @@ class CompanyForm(BaseFilterForm): role_choices = rides_roles + parks_roles # Update field choices dynamically - self.fields['manufacturer_roles'].choices = role_choices - self.fields['designer_roles'].choices = role_choices + self.fields["manufacturer_roles"].choices = role_choices + self.fields["designer_roles"].choices = role_choices manufacturer_roles = forms.MultipleChoiceField( choices=[], # Will be populated in __init__ @@ -423,9 +397,7 @@ class CompanyForm(BaseFilterForm): widget=forms.CheckboxSelectMultiple(attrs={"class": "space-y-2"}), ) - founded_date_range = DateRangeField( - required=False, label="Company Founded Date Range" - ) + founded_date_range = DateRangeField(required=False, label="Company Founded Date Range") class SortingForm(BaseFilterForm): @@ -452,7 +424,7 @@ class SortingForm(BaseFilterForm): ("capacity_desc", "Capacity (Highest)"), ] - self.fields['sort_by'].choices = sort_choices + self.fields["sort_by"].choices = sort_choices sort_by = forms.ChoiceField( choices=[], # Will be populated in __init__ @@ -578,8 +550,7 @@ class MasterFilterForm(BaseFilterForm): { "field": field_name, "value": value, - "label": self.fields[field_name].label - or field_name.replace("_", " ").title(), + "label": self.fields[field_name].label or field_name.replace("_", " ").title(), } ) diff --git a/backend/apps/rides/management/commands/update_ride_rankings.py b/backend/apps/rides/management/commands/update_ride_rankings.py index fd1f587b..b25f472a 100644 --- a/backend/apps/rides/management/commands/update_ride_rankings.py +++ b/backend/apps/rides/management/commands/update_ride_rankings.py @@ -19,11 +19,7 @@ class Command(BaseCommand): category = options.get("category") service = RideRankingService() - self.stdout.write( - self.style.SUCCESS( - f"Starting ride ranking calculation at {timezone.now().isoformat()}" - ) - ) + self.stdout.write(self.style.SUCCESS(f"Starting ride ranking calculation at {timezone.now().isoformat()}")) result = service.update_all_rankings(category=category) diff --git a/backend/apps/rides/managers.py b/backend/apps/rides/managers.py index 0b42dee7..9f12d2ab 100644 --- a/backend/apps/rides/managers.py +++ b/backend/apps/rides/managers.py @@ -3,7 +3,6 @@ Custom managers and QuerySets for Rides models. Optimized queries following Django styleguide patterns. """ - from django.db.models import Count, F, Prefetch, Q from apps.core.managers import ( @@ -35,9 +34,7 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet): def family_friendly(self, *, max_height_requirement: int = 42): """Filter for family-friendly rides.""" - return self.filter( - Q(min_height_in__lte=max_height_requirement) | Q(min_height_in__isnull=True) - ) + return self.filter(Q(min_height_in__lte=max_height_requirement) | Q(min_height_in__isnull=True)) def by_park(self, *, park_id: int): """Filter rides by park.""" @@ -54,8 +51,7 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet): def with_capacity_info(self): """Add capacity-related annotations.""" return self.annotate( - estimated_daily_capacity=F("capacity_per_hour") - * 10, # Assuming 10 operating hours + estimated_daily_capacity=F("capacity_per_hour") * 10, # Assuming 10 operating hours duration_minutes=F("ride_duration_seconds") / 60.0, ) @@ -65,9 +61,7 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet): def optimized_for_list(self): """Optimize for ride list display.""" - return self.select_related( - "park", "park_area", "manufacturer", "designer", "ride_model" - ).with_review_stats() + return self.select_related("park", "park_area", "manufacturer", "designer", "ride_model").with_review_stats() def optimized_for_detail(self): """Optimize for ride detail display.""" @@ -94,9 +88,7 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet): def with_coaster_stats(self): """Always prefetch coaster_stats for roller coaster queries.""" - return self.select_related( - "park", "manufacturer", "ride_model" - ).prefetch_related("coaster_stats") + return self.select_related("park", "manufacturer", "ride_model").prefetch_related("coaster_stats") def for_map_display(self): """Optimize for map display.""" @@ -129,14 +121,12 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet): if min_height: queryset = queryset.filter( - Q(rollercoaster_stats__height_ft__gte=min_height) - | Q(min_height_in__gte=min_height) + Q(rollercoaster_stats__height_ft__gte=min_height) | Q(min_height_in__gte=min_height) ) if max_height: queryset = queryset.filter( - Q(rollercoaster_stats__height_ft__lte=max_height) - | Q(max_height_in__lte=max_height) + Q(rollercoaster_stats__height_ft__lte=max_height) | Q(max_height_in__lte=max_height) ) if min_speed: @@ -146,10 +136,7 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet): if inversions: queryset = queryset.filter(rollercoaster_stats__inversions__gt=0) else: - queryset = queryset.filter( - Q(rollercoaster_stats__inversions=0) - | Q(rollercoaster_stats__isnull=True) - ) + queryset = queryset.filter(Q(rollercoaster_stats__inversions=0) | Q(rollercoaster_stats__isnull=True)) return queryset @@ -167,9 +154,7 @@ class RideManager(StatusManager, ReviewableManager): return self.get_queryset().thrill_rides() def family_friendly(self, *, max_height_requirement: int = 42): - return self.get_queryset().family_friendly( - max_height_requirement=max_height_requirement - ) + return self.get_queryset().family_friendly(max_height_requirement=max_height_requirement) def by_park(self, *, park_id: int): return self.get_queryset().by_park(park_id=park_id) @@ -203,9 +188,7 @@ class RideModelQuerySet(BaseQuerySet): """Add count of rides using this model.""" return self.annotate( ride_count=Count("rides", distinct=True), - operating_rides_count=Count( - "rides", filter=Q(rides__status="OPERATING"), distinct=True - ), + operating_rides_count=Count("rides", filter=Q(rides__status="OPERATING"), distinct=True), ) def popular_models(self, *, min_installations: int = 5): @@ -260,9 +243,7 @@ class RideReviewManager(BaseManager): return self.get_queryset().for_ride(ride_id=ride_id) def by_rating_range(self, *, min_rating: int = 1, max_rating: int = 10): - return self.get_queryset().by_rating_range( - min_rating=min_rating, max_rating=max_rating - ) + return self.get_queryset().by_rating_range(min_rating=min_rating, max_rating=max_rating) class RollerCoasterStatsQuerySet(BaseQuerySet): diff --git a/backend/apps/rides/migrations/0001_initial.py b/backend/apps/rides/migrations/0001_initial.py index 95eb46ed..17c09df0 100644 --- a/backend/apps/rides/migrations/0001_initial.py +++ b/backend/apps/rides/migrations/0001_initial.py @@ -190,9 +190,7 @@ class Migration(migrations.Migration): ), ( "average_rating", - models.DecimalField( - blank=True, decimal_places=2, max_digits=3, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=3, null=True), ), ], options={ @@ -374,21 +372,15 @@ class Migration(migrations.Migration): ), ( "height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "length_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=7, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=7, null=True), ), ( "speed_mph", - models.DecimalField( - blank=True, decimal_places=2, max_digits=5, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True), ), ("inversions", models.PositiveIntegerField(default=0)), ( @@ -432,9 +424,7 @@ class Migration(migrations.Migration): ), ( "max_drop_height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "launch_type", @@ -692,9 +682,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="ridelocation", - index=models.Index( - fields=["park_area"], name="rides_ridel_park_ar_26c90c_idx" - ), + index=models.Index(fields=["park_area"], name="rides_ridel_park_ar_26c90c_idx"), ), migrations.AlterUniqueTogether( name="ridemodel", diff --git a/backend/apps/rides/migrations/0004_rideevent_ridemodelevent_rollercoasterstatsevent_and_more.py b/backend/apps/rides/migrations/0004_rideevent_ridemodelevent_rollercoasterstatsevent_and_more.py index 2a519cbe..ea955b33 100644 --- a/backend/apps/rides/migrations/0004_rideevent_ridemodelevent_rollercoasterstatsevent_and_more.py +++ b/backend/apps/rides/migrations/0004_rideevent_ridemodelevent_rollercoasterstatsevent_and_more.py @@ -89,9 +89,7 @@ class Migration(migrations.Migration): ), ( "average_rating", - models.DecimalField( - blank=True, decimal_places=2, max_digits=3, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=3, null=True), ), ], options={ @@ -140,21 +138,15 @@ class Migration(migrations.Migration): ("id", models.BigIntegerField()), ( "height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "length_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=7, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=7, null=True), ), ( "speed_mph", - models.DecimalField( - blank=True, decimal_places=2, max_digits=5, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True), ), ("inversions", models.PositiveIntegerField(default=0)), ( @@ -198,9 +190,7 @@ class Migration(migrations.Migration): ), ( "max_drop_height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "launch_type", diff --git a/backend/apps/rides/migrations/0006_add_ride_rankings.py b/backend/apps/rides/migrations/0006_add_ride_rankings.py index 37cdacbc..c724f495 100644 --- a/backend/apps/rides/migrations/0006_add_ride_rankings.py +++ b/backend/apps/rides/migrations/0006_add_ride_rankings.py @@ -220,9 +220,7 @@ class Migration(migrations.Migration): ), ( "rank", - models.PositiveIntegerField( - db_index=True, help_text="Overall rank position (1 = best)" - ), + models.PositiveIntegerField(db_index=True, help_text="Overall rank position (1 = best)"), ), ( "wins", @@ -323,9 +321,7 @@ class Migration(migrations.Migration): ("id", models.BigIntegerField()), ( "rank", - models.PositiveIntegerField( - help_text="Overall rank position (1 = best)" - ), + models.PositiveIntegerField(help_text="Overall rank position (1 = best)"), ), ( "wins", @@ -487,15 +483,11 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="ridepaircomparison", - index=models.Index( - fields=["ride_a", "ride_b"], name="rides_ridep_ride_a__eb0674_idx" - ), + index=models.Index(fields=["ride_a", "ride_b"], name="rides_ridep_ride_a__eb0674_idx"), ), migrations.AddIndex( model_name="ridepaircomparison", - index=models.Index( - fields=["last_calculated"], name="rides_ridep_last_ca_bd9f6c_idx" - ), + index=models.Index(fields=["last_calculated"], name="rides_ridep_last_ca_bd9f6c_idx"), ), migrations.AlterUniqueTogether( name="ridepaircomparison", @@ -551,9 +543,7 @@ class Migration(migrations.Migration): migrations.AddConstraint( model_name="rideranking", constraint=models.CheckConstraint( - condition=models.Q( - ("winning_percentage__gte", 0), ("winning_percentage__lte", 1) - ), + condition=models.Q(("winning_percentage__gte", 0), ("winning_percentage__lte", 1)), name="rideranking_winning_percentage_range", violation_error_message="Winning percentage must be between 0 and 1", ), diff --git a/backend/apps/rides/migrations/0007_ridephoto_ridephotoevent_and_more.py b/backend/apps/rides/migrations/0007_ridephoto_ridephotoevent_and_more.py index ba68d080..37f4c595 100644 --- a/backend/apps/rides/migrations/0007_ridephoto_ridephotoevent_and_more.py +++ b/backend/apps/rides/migrations/0007_ridephoto_ridephotoevent_and_more.py @@ -163,27 +163,19 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="ridephoto", - index=models.Index( - fields=["ride", "is_primary"], name="rides_ridep_ride_id_aa49f1_idx" - ), + index=models.Index(fields=["ride", "is_primary"], name="rides_ridep_ride_id_aa49f1_idx"), ), migrations.AddIndex( model_name="ridephoto", - index=models.Index( - fields=["ride", "is_approved"], name="rides_ridep_ride_id_f1eddc_idx" - ), + index=models.Index(fields=["ride", "is_approved"], name="rides_ridep_ride_id_f1eddc_idx"), ), migrations.AddIndex( model_name="ridephoto", - index=models.Index( - fields=["ride", "photo_type"], name="rides_ridep_ride_id_49e7ec_idx" - ), + index=models.Index(fields=["ride", "photo_type"], name="rides_ridep_ride_id_49e7ec_idx"), ), migrations.AddIndex( model_name="ridephoto", - index=models.Index( - fields=["created_at"], name="rides_ridep_created_106e02_idx" - ), + index=models.Index(fields=["created_at"], name="rides_ridep_created_106e02_idx"), ), migrations.AddConstraint( model_name="ridephoto", diff --git a/backend/apps/rides/migrations/0010_add_comprehensive_ride_model_system.py b/backend/apps/rides/migrations/0010_add_comprehensive_ride_model_system.py index 49b7fb7c..08e09c38 100644 --- a/backend/apps/rides/migrations/0010_add_comprehensive_ride_model_system.py +++ b/backend/apps/rides/migrations/0010_add_comprehensive_ride_model_system.py @@ -147,21 +147,15 @@ class Migration(migrations.Migration): ), ( "spec_name", - models.CharField( - help_text="Name of the specification", max_length=100 - ), + models.CharField(help_text="Name of the specification", max_length=100), ), ( "spec_value", - models.CharField( - help_text="Value of the specification", max_length=255 - ), + models.CharField(help_text="Value of the specification", max_length=255), ), ( "spec_unit", - models.CharField( - blank=True, help_text="Unit of measurement", max_length=20 - ), + models.CharField(blank=True, help_text="Unit of measurement", max_length=20), ), ( "notes", @@ -203,21 +197,15 @@ class Migration(migrations.Migration): ), ( "spec_name", - models.CharField( - help_text="Name of the specification", max_length=100 - ), + models.CharField(help_text="Name of the specification", max_length=100), ), ( "spec_value", - models.CharField( - help_text="Value of the specification", max_length=255 - ), + models.CharField(help_text="Value of the specification", max_length=255), ), ( "spec_unit", - models.CharField( - blank=True, help_text="Unit of measurement", max_length=20 - ), + models.CharField(blank=True, help_text="Unit of measurement", max_length=20), ), ( "notes", @@ -251,33 +239,23 @@ class Migration(migrations.Migration): ), ( "description", - models.TextField( - blank=True, help_text="Description of variant differences" - ), + models.TextField(blank=True, help_text="Description of variant differences"), ), ( "min_height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "max_height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "min_speed_mph", - models.DecimalField( - blank=True, decimal_places=2, max_digits=5, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True), ), ( "max_speed_mph", - models.DecimalField( - blank=True, decimal_places=2, max_digits=5, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True), ), ( "distinguishing_features", @@ -307,33 +285,23 @@ class Migration(migrations.Migration): ), ( "description", - models.TextField( - blank=True, help_text="Description of variant differences" - ), + models.TextField(blank=True, help_text="Description of variant differences"), ), ( "min_height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "max_height_ft", - models.DecimalField( - blank=True, decimal_places=2, max_digits=6, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=6, null=True), ), ( "min_speed_mph", - models.DecimalField( - blank=True, decimal_places=2, max_digits=5, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True), ), ( "max_speed_mph", - models.DecimalField( - blank=True, decimal_places=2, max_digits=5, null=True - ), + models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True), ), ( "distinguishing_features", @@ -750,9 +718,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodel", name="description", - field=models.TextField( - blank=True, help_text="Detailed description of the ride model" - ), + field=models.TextField(blank=True, help_text="Detailed description of the ride model"), ), migrations.AlterField( model_name="ridemodel", @@ -794,9 +760,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodelevent", name="description", - field=models.TextField( - blank=True, help_text="Detailed description of the ride model" - ), + field=models.TextField(blank=True, help_text="Detailed description of the ride model"), ), migrations.AlterField( model_name="ridemodelevent", diff --git a/backend/apps/rides/migrations/0012_make_ride_model_slug_unique.py b/backend/apps/rides/migrations/0012_make_ride_model_slug_unique.py index ec916c59..123dc3d1 100644 --- a/backend/apps/rides/migrations/0012_make_ride_model_slug_unique.py +++ b/backend/apps/rides/migrations/0012_make_ride_model_slug_unique.py @@ -13,8 +13,6 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodel", name="slug", - field=models.SlugField( - help_text="URL-friendly identifier", max_length=255, unique=True - ), + field=models.SlugField(help_text="URL-friendly identifier", max_length=255, unique=True), ), ] diff --git a/backend/apps/rides/migrations/0014_update_ride_model_slugs_data.py b/backend/apps/rides/migrations/0014_update_ride_model_slugs_data.py index 12d1d39e..49eef36d 100644 --- a/backend/apps/rides/migrations/0014_update_ride_model_slugs_data.py +++ b/backend/apps/rides/migrations/0014_update_ride_model_slugs_data.py @@ -16,9 +16,7 @@ def update_ride_model_slugs(apps, schema_editor): counter = 1 base_slug = new_slug while ( - RideModel.objects.filter( - manufacturer=ride_model.manufacturer, slug=new_slug - ) + RideModel.objects.filter(manufacturer=ride_model.manufacturer, slug=new_slug) .exclude(pk=ride_model.pk) .exists() ): @@ -37,16 +35,12 @@ def reverse_ride_model_slugs(apps, schema_editor): for ride_model in RideModel.objects.all(): # Generate old-style slug with manufacturer + name - old_slug = slugify( - f"{ride_model.manufacturer.name if ride_model.manufacturer else ''} {ride_model.name}" - ) + old_slug = slugify(f"{ride_model.manufacturer.name if ride_model.manufacturer else ''} {ride_model.name}") # Ensure uniqueness globally (old way) counter = 1 base_slug = old_slug - while ( - RideModel.objects.filter(slug=old_slug).exclude(pk=ride_model.pk).exists() - ): + while RideModel.objects.filter(slug=old_slug).exclude(pk=ride_model.pk).exists(): old_slug = f"{base_slug}-{counter}" counter += 1 diff --git a/backend/apps/rides/migrations/0015_remove_company_insert_insert_and_more.py b/backend/apps/rides/migrations/0015_remove_company_insert_insert_and_more.py index b947f8ed..3edea3d7 100644 --- a/backend/apps/rides/migrations/0015_remove_company_insert_insert_and_more.py +++ b/backend/apps/rides/migrations/0015_remove_company_insert_insert_and_more.py @@ -39,16 +39,12 @@ class Migration(migrations.Migration): migrations.AddField( model_name="company", name="url", - field=models.URLField( - blank=True, help_text="Frontend URL for this company" - ), + field=models.URLField(blank=True, help_text="Frontend URL for this company"), ), migrations.AddField( model_name="companyevent", name="url", - field=models.URLField( - blank=True, help_text="Frontend URL for this company" - ), + field=models.URLField(blank=True, help_text="Frontend URL for this company"), ), migrations.AddField( model_name="ride", @@ -63,16 +59,12 @@ class Migration(migrations.Migration): migrations.AddField( model_name="ridemodel", name="url", - field=models.URLField( - blank=True, help_text="Frontend URL for this ride model" - ), + field=models.URLField(blank=True, help_text="Frontend URL for this ride model"), ), migrations.AddField( model_name="ridemodelevent", name="url", - field=models.URLField( - blank=True, help_text="Frontend URL for this ride model" - ), + field=models.URLField(blank=True, help_text="Frontend URL for this ride model"), ), pgtrigger.migrations.AddTrigger( model_name="company", diff --git a/backend/apps/rides/migrations/0016_remove_ride_insert_insert_remove_ride_update_update_and_more.py b/backend/apps/rides/migrations/0016_remove_ride_insert_insert_remove_ride_update_update_and_more.py index 54aa5f2d..34a9a4cf 100644 --- a/backend/apps/rides/migrations/0016_remove_ride_insert_insert_remove_ride_update_update_and_more.py +++ b/backend/apps/rides/migrations/0016_remove_ride_insert_insert_remove_ride_update_update_and_more.py @@ -23,16 +23,12 @@ class Migration(migrations.Migration): migrations.AddField( model_name="ride", name="park_url", - field=models.URLField( - blank=True, help_text="Frontend URL for this ride's park" - ), + field=models.URLField(blank=True, help_text="Frontend URL for this ride's park"), ), migrations.AddField( model_name="rideevent", name="park_url", - field=models.URLField( - blank=True, help_text="Frontend URL for this ride's park" - ), + field=models.URLField(blank=True, help_text="Frontend URL for this ride's park"), ), pgtrigger.migrations.AddTrigger( model_name="ride", diff --git a/backend/apps/rides/migrations/0019_populate_hybrid_filtering_fields.py b/backend/apps/rides/migrations/0019_populate_hybrid_filtering_fields.py index afb3cf09..dfb2d311 100644 --- a/backend/apps/rides/migrations/0019_populate_hybrid_filtering_fields.py +++ b/backend/apps/rides/migrations/0019_populate_hybrid_filtering_fields.py @@ -12,13 +12,15 @@ from django.db import migrations def populate_computed_fields(apps, schema_editor): """Populate computed fields for all existing rides.""" - Ride = apps.get_model('rides', 'Ride') + Ride = apps.get_model("rides", "Ride") # Disable pghistory triggers during bulk operations to avoid performance issues with pghistory.context(disable=True): - rides = list(Ride.objects.all().select_related( - 'park', 'park__location', 'park_area', 'manufacturer', 'designer', 'ride_model' - )) + rides = list( + Ride.objects.all().select_related( + "park", "park__location", "park_area", "manufacturer", "designer", "ride_model" + ) + ) for ride in rides: # Extract opening year from opening_date @@ -39,7 +41,7 @@ def populate_computed_fields(apps, schema_editor): # Park info if ride.park: search_parts.append(ride.park.name) - if hasattr(ride.park, 'location') and ride.park.location: + if hasattr(ride.park, "location") and ride.park.location: if ride.park.location.city: search_parts.append(ride.park.location.city) if ride.park.location.state: @@ -62,7 +64,7 @@ def populate_computed_fields(apps, schema_editor): ("TR", "Transport"), ("OT", "Other"), ] - category_display = dict(category_choices).get(ride.category, '') + category_display = dict(category_choices).get(ride.category, "") if category_display: search_parts.append(category_display) @@ -79,7 +81,7 @@ def populate_computed_fields(apps, schema_editor): ("DEMOLISHED", "Demolished"), ("RELOCATED", "Relocated"), ] - status_display = dict(status_choices).get(ride.status, '') + status_display = dict(status_choices).get(ride.status, "") if status_display: search_parts.append(status_display) @@ -95,24 +97,24 @@ def populate_computed_fields(apps, schema_editor): if ride.ride_model.manufacturer: search_parts.append(ride.ride_model.manufacturer.name) - ride.search_text = ' '.join(filter(None, search_parts)).lower() + ride.search_text = " ".join(filter(None, search_parts)).lower() # Bulk update all rides - Ride.objects.bulk_update(rides, ['opening_year', 'search_text'], batch_size=1000) + Ride.objects.bulk_update(rides, ["opening_year", "search_text"], batch_size=1000) def reverse_populate_computed_fields(apps, schema_editor): """Clear computed fields (reverse operation).""" - Ride = apps.get_model('rides', 'Ride') + Ride = apps.get_model("rides", "Ride") # Disable pghistory triggers during bulk operations with pghistory.context(disable=True): - Ride.objects.all().update(opening_year=None, search_text='') + Ride.objects.all().update(opening_year=None, search_text="") class Migration(migrations.Migration): dependencies = [ - ('rides', '0018_add_hybrid_filtering_fields'), + ("rides", "0018_add_hybrid_filtering_fields"), ] operations = [ diff --git a/backend/apps/rides/migrations/0020_add_hybrid_filtering_indexes.py b/backend/apps/rides/migrations/0020_add_hybrid_filtering_indexes.py index d64ef750..d60d397b 100644 --- a/backend/apps/rides/migrations/0020_add_hybrid_filtering_indexes.py +++ b/backend/apps/rides/migrations/0020_add_hybrid_filtering_indexes.py @@ -19,163 +19,136 @@ from django.db import migrations class Migration(migrations.Migration): dependencies = [ - ('rides', '0019_populate_hybrid_filtering_fields'), + ("rides", "0019_populate_hybrid_filtering_fields"), ] operations = [ # Composite index for park + category filtering (very common) migrations.RunSQL( "CREATE INDEX rides_ride_park_category_idx ON rides_ride (park_id, category) WHERE category != '';", - reverse_sql="DROP INDEX IF EXISTS rides_ride_park_category_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_park_category_idx;", ), - # Composite index for park + status filtering (common) migrations.RunSQL( "CREATE INDEX rides_ride_park_status_idx ON rides_ride (park_id, status);", - reverse_sql="DROP INDEX IF EXISTS rides_ride_park_status_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_park_status_idx;", ), - # Composite index for category + status filtering migrations.RunSQL( "CREATE INDEX rides_ride_category_status_idx ON rides_ride (category, status) WHERE category != '';", - reverse_sql="DROP INDEX IF EXISTS rides_ride_category_status_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_category_status_idx;", ), - # Composite index for manufacturer + category migrations.RunSQL( "CREATE INDEX rides_ride_manufacturer_category_idx ON rides_ride (manufacturer_id, category) WHERE manufacturer_id IS NOT NULL AND category != '';", - reverse_sql="DROP INDEX IF EXISTS rides_ride_manufacturer_category_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_manufacturer_category_idx;", ), - # Composite index for opening year + category (for timeline filtering) migrations.RunSQL( "CREATE INDEX rides_ride_opening_year_category_idx ON rides_ride (opening_year, category) WHERE opening_year IS NOT NULL AND category != '';", - reverse_sql="DROP INDEX IF EXISTS rides_ride_opening_year_category_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_opening_year_category_idx;", ), - # Partial index for operating rides only (most common filter) migrations.RunSQL( "CREATE INDEX rides_ride_operating_only_idx ON rides_ride (park_id, category, opening_year) WHERE status = 'OPERATING';", - reverse_sql="DROP INDEX IF EXISTS rides_ride_operating_only_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_operating_only_idx;", ), - # Partial index for roller coasters only (popular category) migrations.RunSQL( "CREATE INDEX rides_ride_roller_coasters_idx ON rides_ride (park_id, status, opening_year) WHERE category = 'RC';", - reverse_sql="DROP INDEX IF EXISTS rides_ride_roller_coasters_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_roller_coasters_idx;", ), - # Covering index for list views (includes commonly displayed fields) migrations.RunSQL( "CREATE INDEX rides_ride_list_covering_idx ON rides_ride (park_id, category, status) INCLUDE (name, opening_date, average_rating);", - reverse_sql="DROP INDEX IF EXISTS rides_ride_list_covering_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_list_covering_idx;", ), - # GIN index for full-text search on computed search_text field migrations.RunSQL( "CREATE INDEX rides_ride_search_text_gin_idx ON rides_ride USING gin(to_tsvector('english', search_text));", - reverse_sql="DROP INDEX IF EXISTS rides_ride_search_text_gin_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_search_text_gin_idx;", ), - # Trigram index for fuzzy text search migrations.RunSQL( "CREATE INDEX rides_ride_search_text_trgm_idx ON rides_ride USING gin(search_text gin_trgm_ops);", - reverse_sql="DROP INDEX IF EXISTS rides_ride_search_text_trgm_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_search_text_trgm_idx;", ), - # Index for rating-based filtering migrations.RunSQL( "CREATE INDEX rides_ride_rating_idx ON rides_ride (average_rating) WHERE average_rating IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_ride_rating_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_rating_idx;", ), - # Index for capacity-based filtering migrations.RunSQL( "CREATE INDEX rides_ride_capacity_idx ON rides_ride (capacity_per_hour) WHERE capacity_per_hour IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_ride_capacity_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_capacity_idx;", ), - # Index for height requirement filtering migrations.RunSQL( "CREATE INDEX rides_ride_height_req_idx ON rides_ride (min_height_in, max_height_in) WHERE min_height_in IS NOT NULL OR max_height_in IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_ride_height_req_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_height_req_idx;", ), - # Composite index for ride model filtering migrations.RunSQL( "CREATE INDEX rides_ride_model_manufacturer_idx ON rides_ride (ride_model_id, manufacturer_id) WHERE ride_model_id IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_ride_model_manufacturer_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_model_manufacturer_idx;", ), - # Index for designer filtering migrations.RunSQL( "CREATE INDEX rides_ride_designer_idx ON rides_ride (designer_id, category) WHERE designer_id IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_ride_designer_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_designer_idx;", ), - # Index for park area filtering migrations.RunSQL( "CREATE INDEX rides_ride_park_area_idx ON rides_ride (park_area_id, status) WHERE park_area_id IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_ride_park_area_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ride_park_area_idx;", ), - # Roller coaster stats indexes for performance migrations.RunSQL( "CREATE INDEX rides_rollercoasterstats_height_idx ON rides_rollercoasterstats (height_ft) WHERE height_ft IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_height_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_height_idx;", ), - migrations.RunSQL( "CREATE INDEX rides_rollercoasterstats_speed_idx ON rides_rollercoasterstats (speed_mph) WHERE speed_mph IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_speed_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_speed_idx;", ), - migrations.RunSQL( "CREATE INDEX rides_rollercoasterstats_inversions_idx ON rides_rollercoasterstats (inversions);", - reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_inversions_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_inversions_idx;", ), - migrations.RunSQL( "CREATE INDEX rides_rollercoasterstats_type_material_idx ON rides_rollercoasterstats (roller_coaster_type, track_material);", - reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_type_material_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_type_material_idx;", ), - migrations.RunSQL( "CREATE INDEX rides_rollercoasterstats_launch_type_idx ON rides_rollercoasterstats (launch_type);", - reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_launch_type_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_launch_type_idx;", ), - # Composite index for complex roller coaster filtering migrations.RunSQL( "CREATE INDEX rides_rollercoasterstats_complex_idx ON rides_rollercoasterstats (roller_coaster_type, track_material, launch_type) INCLUDE (height_ft, speed_mph, inversions);", - reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_complex_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_rollercoasterstats_complex_idx;", ), - # Index for ride model filtering and search migrations.RunSQL( "CREATE INDEX rides_ridemodel_manufacturer_category_idx ON rides_ridemodel (manufacturer_id, category) WHERE manufacturer_id IS NOT NULL;", - reverse_sql="DROP INDEX IF EXISTS rides_ridemodel_manufacturer_category_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ridemodel_manufacturer_category_idx;", ), - migrations.RunSQL( "CREATE INDEX rides_ridemodel_name_trgm_idx ON rides_ridemodel USING gin(name gin_trgm_ops);", - reverse_sql="DROP INDEX IF EXISTS rides_ridemodel_name_trgm_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_ridemodel_name_trgm_idx;", ), - # Index for company role-based filtering migrations.RunSQL( "CREATE INDEX rides_company_manufacturer_role_idx ON rides_company USING gin(roles) WHERE 'MANUFACTURER' = ANY(roles);", - reverse_sql="DROP INDEX IF EXISTS rides_company_manufacturer_role_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_company_manufacturer_role_idx;", ), - migrations.RunSQL( "CREATE INDEX rides_company_designer_role_idx ON rides_company USING gin(roles) WHERE 'DESIGNER' = ANY(roles);", - reverse_sql="DROP INDEX IF EXISTS rides_company_designer_role_idx;" + reverse_sql="DROP INDEX IF EXISTS rides_company_designer_role_idx;", ), - # Ensure trigram extension is available for fuzzy search migrations.RunSQL( - "CREATE EXTENSION IF NOT EXISTS pg_trgm;", - reverse_sql="-- Cannot safely drop pg_trgm extension" + "CREATE EXTENSION IF NOT EXISTS pg_trgm;", reverse_sql="-- Cannot safely drop pg_trgm extension" ), ] diff --git a/backend/apps/rides/migrations/0026_convert_unique_together_to_constraints.py b/backend/apps/rides/migrations/0026_convert_unique_together_to_constraints.py index 2e6d094e..a3a1eeef 100644 --- a/backend/apps/rides/migrations/0026_convert_unique_together_to_constraints.py +++ b/backend/apps/rides/migrations/0026_convert_unique_together_to_constraints.py @@ -11,30 +11,30 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('rides', '0025_convert_ride_status_to_fsm'), + ("rides", "0025_convert_ride_status_to_fsm"), ] operations = [ # Remove the old unique_together constraint migrations.AlterUniqueTogether( - name='ridemodel', + name="ridemodel", unique_together=set(), ), # Add new UniqueConstraints with better error messages migrations.AddConstraint( - model_name='ridemodel', + model_name="ridemodel", constraint=models.UniqueConstraint( - fields=['manufacturer', 'name'], - name='ridemodel_manufacturer_name_unique', - violation_error_message='A ride model with this name already exists for this manufacturer' + fields=["manufacturer", "name"], + name="ridemodel_manufacturer_name_unique", + violation_error_message="A ride model with this name already exists for this manufacturer", ), ), migrations.AddConstraint( - model_name='ridemodel', + model_name="ridemodel", constraint=models.UniqueConstraint( - fields=['manufacturer', 'slug'], - name='ridemodel_manufacturer_slug_unique', - violation_error_message='A ride model with this slug already exists for this manufacturer' + fields=["manufacturer", "slug"], + name="ridemodel_manufacturer_slug_unique", + violation_error_message="A ride model with this slug already exists for this manufacturer", ), ), ] diff --git a/backend/apps/rides/migrations/0027_alter_company_options_alter_rankingsnapshot_options_and_more.py b/backend/apps/rides/migrations/0027_alter_company_options_alter_rankingsnapshot_options_and_more.py index 759a2659..606564e0 100644 --- a/backend/apps/rides/migrations/0027_alter_company_options_alter_rankingsnapshot_options_and_more.py +++ b/backend/apps/rides/migrations/0027_alter_company_options_alter_rankingsnapshot_options_and_more.py @@ -98,23 +98,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="company", name="coasters_count", - field=models.IntegerField( - default=0, help_text="Number of coasters manufactured (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of coasters manufactured (auto-calculated)"), ), migrations.AlterField( model_name="company", name="description", - field=models.TextField( - blank=True, help_text="Detailed company description" - ), + field=models.TextField(blank=True, help_text="Detailed company description"), ), migrations.AlterField( model_name="company", name="founded_date", - field=models.DateField( - blank=True, help_text="Date the company was founded", null=True - ), + field=models.DateField(blank=True, help_text="Date the company was founded", null=True), ), migrations.AlterField( model_name="company", @@ -124,9 +118,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="company", name="rides_count", - field=models.IntegerField( - default=0, help_text="Number of rides manufactured (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of rides manufactured (auto-calculated)"), ), migrations.AlterField( model_name="company", @@ -151,9 +143,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="company", name="slug", - field=models.SlugField( - help_text="URL-friendly identifier", max_length=255, unique=True - ), + field=models.SlugField(help_text="URL-friendly identifier", max_length=255, unique=True), ), migrations.AlterField( model_name="company", @@ -163,23 +153,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="companyevent", name="coasters_count", - field=models.IntegerField( - default=0, help_text="Number of coasters manufactured (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of coasters manufactured (auto-calculated)"), ), migrations.AlterField( model_name="companyevent", name="description", - field=models.TextField( - blank=True, help_text="Detailed company description" - ), + field=models.TextField(blank=True, help_text="Detailed company description"), ), migrations.AlterField( model_name="companyevent", name="founded_date", - field=models.DateField( - blank=True, help_text="Date the company was founded", null=True - ), + field=models.DateField(blank=True, help_text="Date the company was founded", null=True), ), migrations.AlterField( model_name="companyevent", @@ -210,9 +194,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="companyevent", name="rides_count", - field=models.IntegerField( - default=0, help_text="Number of rides manufactured (auto-calculated)" - ), + field=models.IntegerField(default=0, help_text="Number of rides manufactured (auto-calculated)"), ), migrations.AlterField( model_name="companyevent", @@ -237,9 +219,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="companyevent", name="slug", - field=models.SlugField( - db_index=False, help_text="URL-friendly identifier", max_length=255 - ), + field=models.SlugField(db_index=False, help_text="URL-friendly identifier", max_length=255), ), migrations.AlterField( model_name="companyevent", @@ -321,23 +301,17 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodelphoto", name="caption", - field=models.CharField( - blank=True, help_text="Photo caption or description", max_length=500 - ), + field=models.CharField(blank=True, help_text="Photo caption or description", max_length=500), ), migrations.AlterField( model_name="ridemodelphoto", name="copyright_info", - field=models.CharField( - blank=True, help_text="Copyright information", max_length=255 - ), + field=models.CharField(blank=True, help_text="Copyright information", max_length=255), ), migrations.AlterField( model_name="ridemodelphoto", name="photographer", - field=models.CharField( - blank=True, help_text="Name of the photographer", max_length=255 - ), + field=models.CharField(blank=True, help_text="Name of the photographer", max_length=255), ), migrations.AlterField( model_name="ridemodelphoto", @@ -352,9 +326,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodelphoto", name="source", - field=models.CharField( - blank=True, help_text="Source of the photo", max_length=255 - ), + field=models.CharField(blank=True, help_text="Source of the photo", max_length=255), ), migrations.AlterField( model_name="ridemodelphotoevent", @@ -368,16 +340,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodelphotoevent", name="caption", - field=models.CharField( - blank=True, help_text="Photo caption or description", max_length=500 - ), + field=models.CharField(blank=True, help_text="Photo caption or description", max_length=500), ), migrations.AlterField( model_name="ridemodelphotoevent", name="copyright_info", - field=models.CharField( - blank=True, help_text="Copyright information", max_length=255 - ), + field=models.CharField(blank=True, help_text="Copyright information", max_length=255), ), migrations.AlterField( model_name="ridemodelphotoevent", @@ -403,9 +371,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodelphotoevent", name="photographer", - field=models.CharField( - blank=True, help_text="Name of the photographer", max_length=255 - ), + field=models.CharField(blank=True, help_text="Name of the photographer", max_length=255), ), migrations.AlterField( model_name="ridemodelphotoevent", @@ -422,9 +388,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="ridemodelphotoevent", name="source", - field=models.CharField( - blank=True, help_text="Source of the photo", max_length=255 - ), + field=models.CharField(blank=True, help_text="Source of the photo", max_length=255), ), migrations.AlterField( model_name="ridemodeltechnicalspec", @@ -709,9 +673,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="rollercoasterstats", name="cars_per_train", - field=models.PositiveIntegerField( - blank=True, help_text="Number of cars per train", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Number of cars per train", null=True), ), migrations.AlterField( model_name="rollercoasterstats", @@ -727,9 +689,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="rollercoasterstats", name="inversions", - field=models.PositiveIntegerField( - default=0, help_text="Number of inversions" - ), + field=models.PositiveIntegerField(default=0, help_text="Number of inversions"), ), migrations.AlterField( model_name="rollercoasterstats", @@ -766,16 +726,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="rollercoasterstats", name="ride_time_seconds", - field=models.PositiveIntegerField( - blank=True, help_text="Duration of the ride in seconds", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Duration of the ride in seconds", null=True), ), migrations.AlterField( model_name="rollercoasterstats", name="seats_per_car", - field=models.PositiveIntegerField( - blank=True, help_text="Number of seats per car", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Number of seats per car", null=True), ), migrations.AlterField( model_name="rollercoasterstats", @@ -809,16 +765,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="rollercoasterstats", name="trains_count", - field=models.PositiveIntegerField( - blank=True, help_text="Number of trains", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Number of trains", null=True), ), migrations.AlterField( model_name="rollercoasterstatsevent", name="cars_per_train", - field=models.PositiveIntegerField( - blank=True, help_text="Number of cars per train", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Number of cars per train", null=True), ), migrations.AlterField( model_name="rollercoasterstatsevent", @@ -834,9 +786,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="rollercoasterstatsevent", name="inversions", - field=models.PositiveIntegerField( - default=0, help_text="Number of inversions" - ), + field=models.PositiveIntegerField(default=0, help_text="Number of inversions"), ), migrations.AlterField( model_name="rollercoasterstatsevent", @@ -896,16 +846,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="rollercoasterstatsevent", name="ride_time_seconds", - field=models.PositiveIntegerField( - blank=True, help_text="Duration of the ride in seconds", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Duration of the ride in seconds", null=True), ), migrations.AlterField( model_name="rollercoasterstatsevent", name="seats_per_car", - field=models.PositiveIntegerField( - blank=True, help_text="Number of seats per car", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Number of seats per car", null=True), ), migrations.AlterField( model_name="rollercoasterstatsevent", @@ -939,8 +885,6 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="rollercoasterstatsevent", name="trains_count", - field=models.PositiveIntegerField( - blank=True, help_text="Number of trains", null=True - ), + field=models.PositiveIntegerField(blank=True, help_text="Number of trains", null=True), ), ] diff --git a/backend/apps/rides/mixins.py b/backend/apps/rides/mixins.py index cf0901b5..a13b08cf 100644 --- a/backend/apps/rides/mixins.py +++ b/backend/apps/rides/mixins.py @@ -34,16 +34,11 @@ class RideFormMixin: Returns: Dictionary with submission results from RideService """ - result = RideService.handle_new_entity_suggestions( - form_data=form.cleaned_data, - submitter=self.request.user - ) + result = RideService.handle_new_entity_suggestions(form_data=form.cleaned_data, submitter=self.request.user) - if result['total_submissions'] > 0: + if result["total_submissions"] > 0: messages.info( - self.request, - f"Created {result['total_submissions']} moderation submission(s) " - "for new entities" + self.request, f"Created {result['total_submissions']} moderation submission(s) " "for new entities" ) return result diff --git a/backend/apps/rides/models/company.py b/backend/apps/rides/models/company.py index 26f6b153..5d29094c 100644 --- a/backend/apps/rides/models/company.py +++ b/backend/apps/rides/models/company.py @@ -24,17 +24,11 @@ class Company(TrackedModel): website = models.URLField(blank=True, help_text="Company website URL") # General company info - founded_date = models.DateField( - null=True, blank=True, help_text="Date the company was founded" - ) + founded_date = models.DateField(null=True, blank=True, help_text="Date the company was founded") # Manufacturer-specific fields - rides_count = models.IntegerField( - default=0, help_text="Number of rides manufactured (auto-calculated)" - ) - coasters_count = models.IntegerField( - default=0, help_text="Number of coasters manufactured (auto-calculated)" - ) + rides_count = models.IntegerField(default=0, help_text="Number of rides manufactured (auto-calculated)") + coasters_count = models.IntegerField(default=0, help_text="Number of coasters manufactured (auto-calculated)") # Frontend URL url = models.URLField(blank=True, help_text="Frontend URL for this company") @@ -50,9 +44,7 @@ class Company(TrackedModel): # CRITICAL: Only MANUFACTURER and DESIGNER are for rides domain # OPERATOR and PROPERTY_OWNER are for parks domain and handled separately if self.roles: - frontend_domain = getattr( - settings, "FRONTEND_DOMAIN", "https://thrillwiki.com" - ) + frontend_domain = getattr(settings, "FRONTEND_DOMAIN", "https://thrillwiki.com") primary_role = self.roles[0] # Use first role as primary if primary_role == "MANUFACTURER": @@ -76,12 +68,9 @@ class Company(TrackedModel): # Check pghistory first try: from django.apps import apps - history_model = apps.get_model('rides', f'{cls.__name__}Event') - history_entry = ( - history_model.objects.filter(slug=slug) - .order_by("-pgh_created_at") - .first() - ) + + history_model = apps.get_model("rides", f"{cls.__name__}Event") + history_entry = history_model.objects.filter(slug=slug).order_by("-pgh_created_at").first() if history_entry: return cls.objects.get(id=history_entry.pgh_obj_id), True except LookupError: @@ -90,12 +79,10 @@ class Company(TrackedModel): # Check manual slug history as fallback try: - historical = HistoricalSlug.objects.get( - content_type__model="company", slug=slug - ) + historical = HistoricalSlug.objects.get(content_type__model="company", slug=slug) return cls.objects.get(pk=historical.object_id), True except (HistoricalSlug.DoesNotExist, cls.DoesNotExist): - raise cls.DoesNotExist("No company found with this slug") + raise cls.DoesNotExist("No company found with this slug") from None class Meta(TrackedModel.Meta): app_label = "rides" diff --git a/backend/apps/rides/models/credits.py b/backend/apps/rides/models/credits.py index b3feb634..74efb723 100644 --- a/backend/apps/rides/models/credits.py +++ b/backend/apps/rides/models/credits.py @@ -27,27 +27,17 @@ class RideCredit(TrackedModel): ) # Credit Details - count = models.PositiveIntegerField( - default=1, help_text="Number of times ridden" - ) + count = models.PositiveIntegerField(default=1, help_text="Number of times ridden") rating = models.IntegerField( null=True, blank=True, validators=[MinValueValidator(1), MaxValueValidator(5)], help_text="Personal rating (1-5)", ) - first_ridden_at = models.DateField( - null=True, blank=True, help_text="Date of first ride" - ) - last_ridden_at = models.DateField( - null=True, blank=True, help_text="Date of most recent ride" - ) - notes = models.TextField( - blank=True, help_text="Personal notes about the experience" - ) - display_order = models.PositiveIntegerField( - default=0, help_text="User-defined display order for drag-drop sorting" - ) + first_ridden_at = models.DateField(null=True, blank=True, help_text="Date of first ride") + last_ridden_at = models.DateField(null=True, blank=True, help_text="Date of most recent ride") + notes = models.TextField(blank=True, help_text="Personal notes about the experience") + display_order = models.PositiveIntegerField(default=0, help_text="User-defined display order for drag-drop sorting") class Meta(TrackedModel.Meta): verbose_name = "Ride Credit" diff --git a/backend/apps/rides/models/location.py b/backend/apps/rides/models/location.py index 19b60d0a..72d0bbab 100644 --- a/backend/apps/rides/models/location.py +++ b/backend/apps/rides/models/location.py @@ -12,9 +12,7 @@ class RideLocation(models.Model): """ # Relationships - ride = models.OneToOneField( - "rides.Ride", on_delete=models.CASCADE, related_name="ride_location" - ) + ride = models.OneToOneField("rides.Ride", on_delete=models.CASCADE, related_name="ride_location") # Optional Spatial Data - keep it simple with single point point = gis_models.PointField( @@ -29,9 +27,7 @@ class RideLocation(models.Model): max_length=100, blank=True, db_index=True, - help_text=( - "Themed area or land within the park (e.g., 'Frontierland', 'Tomorrowland')" - ), + help_text=("Themed area or land within the park (e.g., 'Frontierland', 'Tomorrowland')"), ) # General notes field to match database schema diff --git a/backend/apps/rides/models/media.py b/backend/apps/rides/models/media.py index 5a603c01..6eb6fc0a 100644 --- a/backend/apps/rides/models/media.py +++ b/backend/apps/rides/models/media.py @@ -35,14 +35,12 @@ def ride_photo_upload_path(instance: models.Model, filename: str) -> str: class RidePhoto(TrackedModel): """Photo model specific to rides.""" - ride = models.ForeignKey( - "rides.Ride", on_delete=models.CASCADE, related_name="photos" - ) + ride = models.ForeignKey("rides.Ride", on_delete=models.CASCADE, related_name="photos") image = models.ForeignKey( - 'django_cloudflareimages_toolkit.CloudflareImage', + "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.CASCADE, - help_text="Ride photo stored on Cloudflare Images" + help_text="Ride photo stored on Cloudflare Images", ) caption = models.CharField(max_length=255, blank=True) @@ -56,7 +54,7 @@ class RidePhoto(TrackedModel): domain="rides", max_length=50, default="exterior", - help_text="Type of photo for categorization and display purposes" + help_text="Type of photo for categorization and display purposes", ) # Metadata @@ -100,9 +98,7 @@ class RidePhoto(TrackedModel): # Set default caption if not provided if not self.caption and self.uploaded_by: - self.caption = MediaService.generate_default_caption( - self.uploaded_by.username - ) + self.caption = MediaService.generate_default_caption(self.uploaded_by.username) # If this is marked as primary, unmark other primary photos for this ride if self.is_primary: diff --git a/backend/apps/rides/models/rankings.py b/backend/apps/rides/models/rankings.py index 0c82a99d..dc357c03 100644 --- a/backend/apps/rides/models/rankings.py +++ b/backend/apps/rides/models/rankings.py @@ -22,17 +22,12 @@ class RideRanking(models.Model): """ ride = models.OneToOneField( - "rides.Ride", on_delete=models.CASCADE, related_name="ranking", - help_text="Ride this ranking entry describes" + "rides.Ride", on_delete=models.CASCADE, related_name="ranking", help_text="Ride this ranking entry describes" ) # Core ranking metrics - rank = models.PositiveIntegerField( - db_index=True, help_text="Overall rank position (1 = best)" - ) - wins = models.PositiveIntegerField( - default=0, help_text="Number of rides this ride beats in pairwise comparisons" - ) + rank = models.PositiveIntegerField(db_index=True, help_text="Overall rank position (1 = best)") + wins = models.PositiveIntegerField(default=0, help_text="Number of rides this ride beats in pairwise comparisons") losses = models.PositiveIntegerField( default=0, help_text="Number of rides that beat this ride in pairwise comparisons", @@ -66,9 +61,7 @@ class RideRanking(models.Model): ) # Metadata - last_calculated = models.DateTimeField( - default=timezone.now, help_text="When this ranking was last calculated" - ) + last_calculated = models.DateTimeField(default=timezone.now, help_text="When this ranking was last calculated") calculation_version = models.CharField( max_length=10, default="1.0", help_text="Algorithm version used for calculation" ) @@ -85,8 +78,7 @@ class RideRanking(models.Model): constraints = [ models.CheckConstraint( name="rideranking_winning_percentage_range", - check=models.Q(winning_percentage__gte=0) - & models.Q(winning_percentage__lte=1), + check=models.Q(winning_percentage__gte=0) & models.Q(winning_percentage__lte=1), violation_error_message="Winning percentage must be between 0 and 1", ), models.CheckConstraint( @@ -115,23 +107,13 @@ class RidePairComparison(models.Model): (users who have rated both rides). It's used to speed up ranking calculations. """ - ride_a = models.ForeignKey( - "rides.Ride", on_delete=models.CASCADE, related_name="comparisons_as_a" - ) - ride_b = models.ForeignKey( - "rides.Ride", on_delete=models.CASCADE, related_name="comparisons_as_b" - ) + ride_a = models.ForeignKey("rides.Ride", on_delete=models.CASCADE, related_name="comparisons_as_a") + ride_b = models.ForeignKey("rides.Ride", on_delete=models.CASCADE, related_name="comparisons_as_b") # Comparison results - ride_a_wins = models.PositiveIntegerField( - default=0, help_text="Number of mutual riders who rated ride_a higher" - ) - ride_b_wins = models.PositiveIntegerField( - default=0, help_text="Number of mutual riders who rated ride_b higher" - ) - ties = models.PositiveIntegerField( - default=0, help_text="Number of mutual riders who rated both rides equally" - ) + ride_a_wins = models.PositiveIntegerField(default=0, help_text="Number of mutual riders who rated ride_a higher") + ride_b_wins = models.PositiveIntegerField(default=0, help_text="Number of mutual riders who rated ride_b higher") + ties = models.PositiveIntegerField(default=0, help_text="Number of mutual riders who rated both rides equally") # Metrics mutual_riders_count = models.PositiveIntegerField( @@ -153,9 +135,7 @@ class RidePairComparison(models.Model): ) # Metadata - last_calculated = models.DateTimeField( - auto_now=True, help_text="When this comparison was last calculated" - ) + last_calculated = models.DateTimeField(auto_now=True, help_text="When this comparison was last calculated") class Meta: verbose_name = "Ride Pair Comparison" @@ -197,14 +177,10 @@ class RankingSnapshot(models.Model): This allows us to show ranking trends and movements. """ - ride = models.ForeignKey( - "rides.Ride", on_delete=models.CASCADE, related_name="ranking_history" - ) + ride = models.ForeignKey("rides.Ride", on_delete=models.CASCADE, related_name="ranking_history") rank = models.PositiveIntegerField() winning_percentage = models.DecimalField(max_digits=5, decimal_places=4) - snapshot_date = models.DateField( - db_index=True, help_text="Date when this ranking snapshot was taken" - ) + snapshot_date = models.DateField(db_index=True, help_text="Date when this ranking snapshot was taken") class Meta: verbose_name = "Ranking Snapshot" diff --git a/backend/apps/rides/models/reviews.py b/backend/apps/rides/models/reviews.py index 2ba98d8b..cc886fba 100644 --- a/backend/apps/rides/models/reviews.py +++ b/backend/apps/rides/models/reviews.py @@ -12,15 +12,9 @@ class RideReview(TrackedModel): A review of a ride. """ - ride = models.ForeignKey( - "rides.Ride", on_delete=models.CASCADE, related_name="reviews" - ) - user = models.ForeignKey( - "accounts.User", on_delete=models.CASCADE, related_name="ride_reviews" - ) - rating = models.PositiveSmallIntegerField( - validators=[MinValueValidator(1), MaxValueValidator(10)] - ) + ride = models.ForeignKey("rides.Ride", on_delete=models.CASCADE, related_name="reviews") + user = models.ForeignKey("accounts.User", on_delete=models.CASCADE, related_name="ride_reviews") + rating = models.PositiveSmallIntegerField(validators=[MinValueValidator(1), MaxValueValidator(10)]) title = models.CharField(max_length=200) content = models.TextField() visit_date = models.DateField() @@ -63,10 +57,7 @@ class RideReview(TrackedModel): name="ride_review_moderation_consistency", check=models.Q(moderated_by__isnull=True, moderated_at__isnull=True) | models.Q(moderated_by__isnull=False, moderated_at__isnull=False), - violation_error_message=( - "Moderated reviews must have both moderator and moderation " - "timestamp" - ), + violation_error_message=("Moderated reviews must have both moderator and moderation " "timestamp"), ), ] diff --git a/backend/apps/rides/models/rides.py b/backend/apps/rides/models/rides.py index 3e90d75f..687071f5 100644 --- a/backend/apps/rides/models/rides.py +++ b/backend/apps/rides/models/rides.py @@ -1,5 +1,4 @@ import contextlib -from typing import TYPE_CHECKING import pghistory from django.contrib.auth.models import AbstractBaseUser @@ -14,10 +13,6 @@ from config.django import base as settings from .company import Company -if TYPE_CHECKING: - from .rides import RollerCoasterStats - - @pghistory.track() class RideModel(TrackedModel): @@ -30,9 +25,7 @@ class RideModel(TrackedModel): """ name = models.CharField(max_length=255, help_text="Name of the ride model") - slug = models.SlugField( - max_length=255, help_text="URL-friendly identifier (unique within manufacturer)" - ) + slug = models.SlugField(max_length=255, help_text="URL-friendly identifier (unique within manufacturer)") manufacturer = models.ForeignKey( Company, on_delete=models.SET_NULL, @@ -42,9 +35,7 @@ class RideModel(TrackedModel): limit_choices_to={"roles__contains": ["MANUFACTURER"]}, help_text="Primary manufacturer of this ride model", ) - description = models.TextField( - blank=True, help_text="Detailed description of the ride model" - ) + description = models.TextField(blank=True, help_text="Detailed description of the ride model") category = RichChoiceField( choice_group="categories", domain="rides", @@ -125,9 +116,7 @@ class RideModel(TrackedModel): blank=True, help_text="Year of last installation of this model (if discontinued)", ) - is_discontinued = models.BooleanField( - default=False, help_text="Whether this model is no longer being manufactured" - ) + is_discontinued = models.BooleanField(default=False, help_text="Whether this model is no longer being manufactured") total_installations = models.PositiveIntegerField( default=0, help_text="Total number of installations worldwide (auto-calculated)" ) @@ -156,9 +145,7 @@ class RideModel(TrackedModel): ) # SEO and metadata - meta_title = models.CharField( - max_length=60, blank=True, help_text="SEO meta title (auto-generated if blank)" - ) + meta_title = models.CharField(max_length=60, blank=True, help_text="SEO meta title (auto-generated if blank)") meta_description = models.CharField( max_length=160, blank=True, @@ -175,25 +162,21 @@ class RideModel(TrackedModel): constraints = [ # Unique constraints (replacing unique_together for better error messages) models.UniqueConstraint( - fields=['manufacturer', 'name'], - name='ridemodel_manufacturer_name_unique', - violation_error_message='A ride model with this name already exists for this manufacturer' + fields=["manufacturer", "name"], + name="ridemodel_manufacturer_name_unique", + violation_error_message="A ride model with this name already exists for this manufacturer", ), models.UniqueConstraint( - fields=['manufacturer', 'slug'], - name='ridemodel_manufacturer_slug_unique', - violation_error_message='A ride model with this slug already exists for this manufacturer' + fields=["manufacturer", "slug"], + name="ridemodel_manufacturer_slug_unique", + violation_error_message="A ride model with this slug already exists for this manufacturer", ), # Height range validation models.CheckConstraint( name="ride_model_height_range_logical", condition=models.Q(typical_height_range_min_ft__isnull=True) | models.Q(typical_height_range_max_ft__isnull=True) - | models.Q( - typical_height_range_min_ft__lte=models.F( - "typical_height_range_max_ft" - ) - ), + | models.Q(typical_height_range_min_ft__lte=models.F("typical_height_range_max_ft")), violation_error_message="Minimum height cannot exceed maximum height", ), # Speed range validation @@ -201,11 +184,7 @@ class RideModel(TrackedModel): name="ride_model_speed_range_logical", condition=models.Q(typical_speed_range_min_mph__isnull=True) | models.Q(typical_speed_range_max_mph__isnull=True) - | models.Q( - typical_speed_range_min_mph__lte=models.F( - "typical_speed_range_max_mph" - ) - ), + | models.Q(typical_speed_range_min_mph__lte=models.F("typical_speed_range_max_mph")), violation_error_message="Minimum speed cannot exceed maximum speed", ), # Capacity range validation @@ -213,11 +192,7 @@ class RideModel(TrackedModel): name="ride_model_capacity_range_logical", condition=models.Q(typical_capacity_range_min__isnull=True) | models.Q(typical_capacity_range_max__isnull=True) - | models.Q( - typical_capacity_range_min__lte=models.F( - "typical_capacity_range_max" - ) - ), + | models.Q(typical_capacity_range_min__lte=models.F("typical_capacity_range_max")), violation_error_message="Minimum capacity cannot exceed maximum capacity", ), # Installation years validation @@ -225,27 +200,19 @@ class RideModel(TrackedModel): name="ride_model_installation_years_logical", condition=models.Q(first_installation_year__isnull=True) | models.Q(last_installation_year__isnull=True) - | models.Q( - first_installation_year__lte=models.F("last_installation_year") - ), + | models.Q(first_installation_year__lte=models.F("last_installation_year")), violation_error_message="First installation year cannot be after last installation year", ), ] def __str__(self) -> str: - return ( - self.name - if not self.manufacturer - else f"{self.manufacturer.name} {self.name}" - ) + return self.name if not self.manufacturer else f"{self.manufacturer.name} {self.name}" def clean(self) -> None: """Validate RideModel business rules.""" super().clean() if self.is_discontinued and not self.last_installation_year: - raise ValidationError({ - 'last_installation_year': 'Discontinued models must have a last installation year' - }) + raise ValidationError({"last_installation_year": "Discontinued models must have a last installation year"}) def save(self, *args, **kwargs) -> None: if not self.slug: @@ -257,11 +224,7 @@ class RideModel(TrackedModel): # Ensure uniqueness within the same manufacturer counter = 1 - while ( - RideModel.objects.filter(manufacturer=self.manufacturer, slug=self.slug) - .exclude(pk=self.pk) - .exists() - ): + while RideModel.objects.filter(manufacturer=self.manufacturer, slug=self.slug).exclude(pk=self.pk).exists(): self.slug = f"{base_slug}-{counter}" counter += 1 @@ -269,16 +232,12 @@ class RideModel(TrackedModel): if not self.meta_title: self.meta_title = str(self)[:60] if not self.meta_description: - desc = ( - f"{self} - {self.description[:100]}" if self.description else str(self) - ) + desc = f"{self} - {self.description[:100]}" if self.description else str(self) self.meta_description = desc[:160] # Generate frontend URL if self.manufacturer: - frontend_domain = getattr( - settings, "FRONTEND_DOMAIN", "https://thrillwiki.com" - ) + frontend_domain = getattr(settings, "FRONTEND_DOMAIN", "https://thrillwiki.com") self.url = f"{frontend_domain}/rides/manufacturers/{self.manufacturer.slug}/{self.slug}/" super().save(*args, **kwargs) @@ -342,9 +301,7 @@ class RideModelVariant(TrackedModel): help_text="Base ride model this variant belongs to", ) name = models.CharField(max_length=255, help_text="Name of this variant") - description = models.TextField( - blank=True, help_text="Description of variant differences" - ) + description = models.TextField(blank=True, help_text="Description of variant differences") # Variant-specific specifications min_height_ft = models.DecimalField( @@ -402,16 +359,12 @@ class RideModelPhoto(TrackedModel): help_text="Ride model this photo belongs to", ) image = models.ForeignKey( - 'django_cloudflareimages_toolkit.CloudflareImage', + "django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.CASCADE, - help_text="Photo of the ride model stored on Cloudflare Images" - ) - caption = models.CharField( - max_length=500, blank=True, help_text="Photo caption or description" - ) - alt_text = models.CharField( - max_length=255, blank=True, help_text="Alternative text for accessibility" + help_text="Photo of the ride model stored on Cloudflare Images", ) + caption = models.CharField(max_length=500, blank=True, help_text="Photo caption or description") + alt_text = models.CharField(max_length=255, blank=True, help_text="Alternative text for accessibility") # Photo metadata photo_type = RichChoiceField( @@ -422,18 +375,12 @@ class RideModelPhoto(TrackedModel): help_text="Type of photo for categorization and display purposes", ) - is_primary = models.BooleanField( - default=False, help_text="Whether this is the primary photo for the ride model" - ) + is_primary = models.BooleanField(default=False, help_text="Whether this is the primary photo for the ride model") # Attribution - photographer = models.CharField( - max_length=255, blank=True, help_text="Name of the photographer" - ) + photographer = models.CharField(max_length=255, blank=True, help_text="Name of the photographer") source = models.CharField(max_length=255, blank=True, help_text="Source of the photo") - copyright_info = models.CharField( - max_length=255, blank=True, help_text="Copyright information" - ) + copyright_info = models.CharField(max_length=255, blank=True, help_text="Copyright information") class Meta(TrackedModel.Meta): verbose_name = "Ride Model Photo" @@ -446,9 +393,9 @@ class RideModelPhoto(TrackedModel): def save(self, *args, **kwargs) -> None: # Ensure only one primary photo per ride model if self.is_primary: - RideModelPhoto.objects.filter( - ride_model=self.ride_model, is_primary=True - ).exclude(pk=self.pk).update(is_primary=False) + RideModelPhoto.objects.filter(ride_model=self.ride_model, is_primary=True).exclude(pk=self.pk).update( + is_primary=False + ) super().save(*args, **kwargs) @@ -474,15 +421,9 @@ class RideModelTechnicalSpec(TrackedModel): ) spec_name = models.CharField(max_length=100, help_text="Name of the specification") - spec_value = models.CharField( - max_length=255, help_text="Value of the specification" - ) - spec_unit = models.CharField( - max_length=20, blank=True, help_text="Unit of measurement" - ) - notes = models.TextField( - blank=True, help_text="Additional notes about this specification" - ) + spec_value = models.CharField(max_length=255, help_text="Value of the specification") + spec_unit = models.CharField(max_length=20, blank=True, help_text="Unit of measurement") + notes = models.TextField(blank=True, help_text="Additional notes about this specification") class Meta(TrackedModel.Meta): verbose_name = "Ride Model Technical Specification" @@ -503,17 +444,15 @@ class Ride(StateMachineMixin, TrackedModel): jobs. Use selectors or annotations for real-time calculations if needed. """ - if TYPE_CHECKING: - coaster_stats: 'RollerCoasterStats' + # Type hint for the reverse relation from RollerCoasterStats + coaster_stats: "RollerCoasterStats" state_field_name = "status" name = models.CharField(max_length=255) slug = models.SlugField(max_length=255) description = models.TextField(blank=True) - park = models.ForeignKey( - "parks.Park", on_delete=models.CASCADE, related_name="rides" - ) + park = models.ForeignKey("parks.Park", on_delete=models.CASCADE, related_name="rides") park_area = models.ForeignKey( "parks.ParkArea", on_delete=models.SET_NULL, @@ -527,7 +466,7 @@ class Ride(StateMachineMixin, TrackedModel): max_length=2, default="", blank=True, - help_text="Ride category classification" + help_text="Ride category classification", ) manufacturer = models.ForeignKey( Company, @@ -558,7 +497,7 @@ class Ride(StateMachineMixin, TrackedModel): domain="rides", max_length=20, default="OPERATING", - help_text="Current operational status of the ride" + help_text="Current operational status of the ride", ) post_closing_status = RichChoiceField( choice_group="post_closing_statuses", @@ -575,9 +514,7 @@ class Ride(StateMachineMixin, TrackedModel): max_height_in = models.PositiveIntegerField(null=True, blank=True) capacity_per_hour = models.PositiveIntegerField(null=True, blank=True) ride_duration_seconds = models.PositiveIntegerField(null=True, blank=True) - average_rating = models.DecimalField( - max_digits=3, decimal_places=2, null=True, blank=True - ) + average_rating = models.DecimalField(max_digits=3, decimal_places=2, null=True, blank=True) # Computed fields for hybrid filtering opening_year = models.IntegerField(null=True, blank=True, db_index=True) @@ -603,9 +540,7 @@ class Ride(StateMachineMixin, TrackedModel): # Frontend URL url = models.URLField(blank=True, help_text="Frontend URL for this ride") - park_url = models.URLField( - blank=True, help_text="Frontend URL for this ride's park" - ) + park_url = models.URLField(blank=True, help_text="Frontend URL for this ride's park") class Meta(TrackedModel.Meta): verbose_name = "Ride" @@ -635,17 +570,13 @@ class Ride(StateMachineMixin, TrackedModel): name="ride_min_height_reasonable", condition=models.Q(min_height_in__isnull=True) | (models.Q(min_height_in__gte=30) & models.Q(min_height_in__lte=90)), - violation_error_message=( - "Minimum height must be between 30 and 90 inches" - ), + violation_error_message=("Minimum height must be between 30 and 90 inches"), ), models.CheckConstraint( name="ride_max_height_reasonable", condition=models.Q(max_height_in__isnull=True) | (models.Q(max_height_in__gte=30) & models.Q(max_height_in__lte=90)), - violation_error_message=( - "Maximum height must be between 30 and 90 inches" - ), + violation_error_message=("Maximum height must be between 30 and 90 inches"), ), # Business rule: Rating must be between 1 and 10 models.CheckConstraint( @@ -657,14 +588,12 @@ class Ride(StateMachineMixin, TrackedModel): # Business rule: Capacity and duration must be positive models.CheckConstraint( name="ride_capacity_positive", - condition=models.Q(capacity_per_hour__isnull=True) - | models.Q(capacity_per_hour__gt=0), + condition=models.Q(capacity_per_hour__isnull=True) | models.Q(capacity_per_hour__gt=0), violation_error_message="Hourly capacity must be positive", ), models.CheckConstraint( name="ride_duration_positive", - condition=models.Q(ride_duration_seconds__isnull=True) - | models.Q(ride_duration_seconds__gt=0), + condition=models.Q(ride_duration_seconds__isnull=True) | models.Q(ride_duration_seconds__gt=0), violation_error_message="Ride duration must be positive", ), ] @@ -699,9 +628,7 @@ class Ride(StateMachineMixin, TrackedModel): from django.core.exceptions import ValidationError if not post_closing_status: - raise ValidationError( - "post_closing_status must be set when entering CLOSING status" - ) + raise ValidationError("post_closing_status must be set when entering CLOSING status") self.transition_to_closing(user=user) self.closing_date = closing_date self.post_closing_status = post_closing_status @@ -770,7 +697,7 @@ class Ride(StateMachineMixin, TrackedModel): self._ensure_unique_slug_in_park() # Handle park area validation when park changes - if park_changed and self.park_area: + if park_changed and self.park_area: # noqa: SIM102 # Check if park_area belongs to the new park if self.park_area.park.id != self.park.id: # Clear park_area if it doesn't belong to the new park @@ -786,9 +713,7 @@ class Ride(StateMachineMixin, TrackedModel): # Generate frontend URLs if self.park: - frontend_domain = getattr( - settings, "FRONTEND_DOMAIN", "https://thrillwiki.com" - ) + frontend_domain = getattr(settings, "FRONTEND_DOMAIN", "https://thrillwiki.com") self.url = f"{frontend_domain}/parks/{self.park.slug}/rides/{self.slug}/" self.park_url = f"{frontend_domain}/parks/{self.park.slug}/" @@ -817,7 +742,7 @@ class Ride(StateMachineMixin, TrackedModel): # Park info if self.park: search_parts.append(self.park.name) - if hasattr(self.park, 'location') and self.park.location: + if hasattr(self.park, "location") and self.park.location: if self.park.location.city: search_parts.append(self.park.location.city) if self.park.location.state: @@ -855,7 +780,7 @@ class Ride(StateMachineMixin, TrackedModel): # Roller coaster stats if available try: - if hasattr(self, 'coaster_stats') and self.coaster_stats: + if hasattr(self, "coaster_stats") and self.coaster_stats: stats = self.coaster_stats if stats.track_type: search_parts.append(stats.track_type) @@ -877,7 +802,7 @@ class Ride(StateMachineMixin, TrackedModel): # Ignore if coaster_stats doesn't exist or has issues pass - self.search_text = ' '.join(filter(None, search_parts)).lower() + self.search_text = " ".join(filter(None, search_parts)).lower() def _ensure_unique_slug_in_park(self) -> None: """Ensure the ride's slug is unique within its park.""" @@ -885,11 +810,7 @@ class Ride(StateMachineMixin, TrackedModel): self.slug = base_slug counter = 1 - while ( - Ride.objects.filter(park=self.park, slug=self.slug) - .exclude(pk=self.pk) - .exists() - ): + while Ride.objects.filter(park=self.park, slug=self.slug).exclude(pk=self.pk).exists(): self.slug = f"{base_slug}-{counter}" counter += 1 @@ -921,26 +842,15 @@ class Ride(StateMachineMixin, TrackedModel): # Return summary of changes changes = { - 'old_park': { - 'id': old_park.id, - 'name': old_park.name, - 'slug': old_park.slug - }, - 'new_park': { - 'id': new_park.id, - 'name': new_park.name, - 'slug': new_park.slug - }, - 'url_changed': old_url != self.url, - 'old_url': old_url, - 'new_url': self.url, - 'park_area_cleared': clear_park_area and old_park_area is not None, - 'old_park_area': { - 'id': old_park_area.id, - 'name': old_park_area.name - } if old_park_area else None, - 'slug_changed': self.slug != slugify(self.name), - 'final_slug': self.slug + "old_park": {"id": old_park.id, "name": old_park.name, "slug": old_park.slug}, + "new_park": {"id": new_park.id, "name": new_park.name, "slug": new_park.slug}, + "url_changed": old_url != self.url, + "old_url": old_url, + "new_url": self.url, + "park_area_cleared": clear_park_area and old_park_area is not None, + "old_park_area": {"id": old_park_area.id, "name": old_park_area.name} if old_park_area else None, + "slug_changed": self.slug != slugify(self.name), + "final_slug": self.slug, } return changes @@ -963,9 +873,9 @@ class Ride(StateMachineMixin, TrackedModel): except cls.DoesNotExist: # Try historical slugs in HistoricalSlug model content_type = ContentType.objects.get_for_model(cls) - historical_query = HistoricalSlug.objects.filter( - content_type=content_type, slug=slug - ).order_by("-created_at") + historical_query = HistoricalSlug.objects.filter(content_type=content_type, slug=slug).order_by( + "-created_at" + ) for historical in historical_query: try: @@ -986,14 +896,13 @@ class Ride(StateMachineMixin, TrackedModel): except cls.DoesNotExist: continue - raise cls.DoesNotExist("No ride found with this slug") + raise cls.DoesNotExist("No ride found with this slug") from None @pghistory.track() class RollerCoasterStats(models.Model): """Model for tracking roller coaster specific statistics""" - ride = models.OneToOneField( Ride, on_delete=models.CASCADE, @@ -1021,22 +930,16 @@ class RollerCoasterStats(models.Model): blank=True, help_text="Maximum speed in mph", ) - inversions = models.PositiveIntegerField( - default=0, help_text="Number of inversions" - ) - ride_time_seconds = models.PositiveIntegerField( - null=True, blank=True, help_text="Duration of the ride in seconds" - ) - track_type = models.CharField( - max_length=255, blank=True, help_text="Type of track (e.g., tubular steel, wooden)" - ) + inversions = models.PositiveIntegerField(default=0, help_text="Number of inversions") + ride_time_seconds = models.PositiveIntegerField(null=True, blank=True, help_text="Duration of the ride in seconds") + track_type = models.CharField(max_length=255, blank=True, help_text="Type of track (e.g., tubular steel, wooden)") track_material = RichChoiceField( choice_group="track_materials", domain="rides", max_length=20, default="STEEL", blank=True, - help_text="Track construction material type" + help_text="Track construction material type", ) roller_coaster_type = RichChoiceField( choice_group="coaster_types", @@ -1044,7 +947,7 @@ class RollerCoasterStats(models.Model): max_length=20, default="SITDOWN", blank=True, - help_text="Roller coaster type classification" + help_text="Roller coaster type classification", ) max_drop_height_ft = models.DecimalField( max_digits=6, @@ -1058,20 +961,12 @@ class RollerCoasterStats(models.Model): domain="rides", max_length=20, default="CHAIN", - help_text="Propulsion or lift system type" - ) - train_style = models.CharField( - max_length=255, blank=True, help_text="Style of train (e.g., floorless, inverted)" - ) - trains_count = models.PositiveIntegerField( - null=True, blank=True, help_text="Number of trains" - ) - cars_per_train = models.PositiveIntegerField( - null=True, blank=True, help_text="Number of cars per train" - ) - seats_per_car = models.PositiveIntegerField( - null=True, blank=True, help_text="Number of seats per car" + help_text="Propulsion or lift system type", ) + train_style = models.CharField(max_length=255, blank=True, help_text="Style of train (e.g., floorless, inverted)") + trains_count = models.PositiveIntegerField(null=True, blank=True, help_text="Number of trains") + cars_per_train = models.PositiveIntegerField(null=True, blank=True, help_text="Number of cars per train") + seats_per_car = models.PositiveIntegerField(null=True, blank=True, help_text="Number of seats per car") class Meta: verbose_name = "Roller Coaster Statistics" diff --git a/backend/apps/rides/selectors.py b/backend/apps/rides/selectors.py index 1e778a31..e40a9164 100644 --- a/backend/apps/rides/selectors.py +++ b/backend/apps/rides/selectors.py @@ -13,9 +13,7 @@ from .choices import RIDE_CATEGORIES from .models import Ride, RideModel, RideReview -def ride_list_for_display( - *, filters: dict[str, Any] | None = None -) -> QuerySet[Ride]: +def ride_list_for_display(*, filters: dict[str, Any] | None = None) -> QuerySet[Ride]: """ Get rides optimized for list display with related data. @@ -85,9 +83,7 @@ def ride_detail_optimized(*, slug: str, park_slug: str) -> Ride: "park__location", Prefetch( "reviews", - queryset=RideReview.objects.select_related("user").filter( - is_published=True - ), + queryset=RideReview.objects.select_related("user").filter(is_published=True), ), "photos", ) @@ -171,9 +167,7 @@ def rides_in_park(*, park_slug: str) -> QuerySet[Ride]: ) -def rides_near_location( - *, point: Point, distance_km: float = 50, limit: int = 10 -) -> QuerySet[Ride]: +def rides_near_location(*, point: Point, distance_km: float = 50, limit: int = 10) -> QuerySet[Ride]: """ Get rides near a specific geographic location. @@ -227,9 +221,7 @@ def ride_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet[Ride]: """ return ( Ride.objects.filter( - Q(name__icontains=query) - | Q(park__name__icontains=query) - | Q(manufacturer__name__icontains=query) + Q(name__icontains=query) | Q(park__name__icontains=query) | Q(manufacturer__name__icontains=query) ) .select_related("park", "manufacturer") .prefetch_related("park__location") @@ -254,16 +246,10 @@ def rides_with_recent_reviews(*, days: int = 30) -> QuerySet[Ride]: cutoff_date = timezone.now() - timedelta(days=days) return ( - Ride.objects.filter( - reviews__created_at__gte=cutoff_date, reviews__is_published=True - ) + Ride.objects.filter(reviews__created_at__gte=cutoff_date, reviews__is_published=True) .select_related("park", "manufacturer") .prefetch_related("park__location") - .annotate( - recent_review_count=Count( - "reviews", filter=Q(reviews__created_at__gte=cutoff_date) - ) - ) + .annotate(recent_review_count=Count("reviews", filter=Q(reviews__created_at__gte=cutoff_date))) .order_by("-recent_review_count") .distinct() ) diff --git a/backend/apps/rides/services/__init__.py b/backend/apps/rides/services/__init__.py index fcae25c3..2bc4fe19 100644 --- a/backend/apps/rides/services/__init__.py +++ b/backend/apps/rides/services/__init__.py @@ -4,4 +4,3 @@ from .location_service import RideLocationService from .media_service import RideMediaService __all__ = ["RideLocationService", "RideMediaService", "RideService"] - diff --git a/backend/apps/rides/services/hybrid_loader.py b/backend/apps/rides/services/hybrid_loader.py index 86147ba5..4f1f6de7 100644 --- a/backend/apps/rides/services/hybrid_loader.py +++ b/backend/apps/rides/services/hybrid_loader.py @@ -95,13 +95,13 @@ class SmartRideLoader: total_count = queryset.count() # Get progressive batch - rides = list(queryset[offset:offset + self.PROGRESSIVE_LOAD_SIZE]) + rides = list(queryset[offset : offset + self.PROGRESSIVE_LOAD_SIZE]) return { - 'rides': self._serialize_rides(rides), - 'total_count': total_count, - 'has_more': len(rides) == self.PROGRESSIVE_LOAD_SIZE, - 'next_offset': offset + len(rides) if len(rides) == self.PROGRESSIVE_LOAD_SIZE else None + "rides": self._serialize_rides(rides), + "total_count": total_count, + "has_more": len(rides) == self.PROGRESSIVE_LOAD_SIZE, + "next_offset": offset + len(rides) if len(rides) == self.PROGRESSIVE_LOAD_SIZE else None, } def get_filter_metadata(self, filters: dict[str, Any] | None = None) -> dict[str, Any]: @@ -148,8 +148,7 @@ class SmartRideLoader: return count - def _get_client_side_data(self, filters: dict[str, Any] | None, - total_count: int) -> dict[str, Any]: + def _get_client_side_data(self, filters: dict[str, Any] | None, total_count: int) -> dict[str, Any]: """Get all data for client-side filtering.""" cache_key = f"{self.cache_prefix}client_side_all" cached_data = cache.get(cache_key) @@ -158,45 +157,46 @@ class SmartRideLoader: from apps.rides.models import Ride # Load all rides with optimized query - queryset = Ride.objects.select_related( - 'park', - 'park__location', - 'park_area', - 'manufacturer', - 'designer', - 'ride_model', - 'ride_model__manufacturer' - ).prefetch_related( - 'coaster_stats' - ).order_by('name') + queryset = ( + Ride.objects.select_related( + "park", + "park__location", + "park_area", + "manufacturer", + "designer", + "ride_model", + "ride_model__manufacturer", + ) + .prefetch_related("coaster_stats") + .order_by("name") + ) rides = list(queryset) cached_data = self._serialize_rides(rides) cache.set(cache_key, cached_data, self.CACHE_TIMEOUT) return { - 'strategy': 'client_side', - 'rides': cached_data, - 'total_count': total_count, - 'has_more': False, - 'filter_metadata': self.get_filter_metadata(filters) + "strategy": "client_side", + "rides": cached_data, + "total_count": total_count, + "has_more": False, + "filter_metadata": self.get_filter_metadata(filters), } - def _get_server_side_data(self, filters: dict[str, Any] | None, - total_count: int) -> dict[str, Any]: + def _get_server_side_data(self, filters: dict[str, Any] | None, total_count: int) -> dict[str, Any]: """Get initial batch for server-side filtering.""" # Build filtered queryset queryset = self._build_filtered_queryset(filters) # Get initial batch - rides = list(queryset[:self.INITIAL_LOAD_SIZE]) + rides = list(queryset[: self.INITIAL_LOAD_SIZE]) return { - 'strategy': 'server_side', - 'rides': self._serialize_rides(rides), - 'total_count': total_count, - 'has_more': len(rides) == self.INITIAL_LOAD_SIZE, - 'next_offset': len(rides) if len(rides) == self.INITIAL_LOAD_SIZE else None + "strategy": "server_side", + "rides": self._serialize_rides(rides), + "total_count": total_count, + "has_more": len(rides) == self.INITIAL_LOAD_SIZE, + "next_offset": len(rides) if len(rides) == self.INITIAL_LOAD_SIZE else None, } def _build_filtered_queryset(self, filters: dict[str, Any] | None): @@ -205,118 +205,110 @@ class SmartRideLoader: # Start with optimized base queryset queryset = Ride.objects.select_related( - 'park', - 'park__location', - 'park_area', - 'manufacturer', - 'designer', - 'ride_model', - 'ride_model__manufacturer' - ).prefetch_related( - 'coaster_stats' - ) + "park", "park__location", "park_area", "manufacturer", "designer", "ride_model", "ride_model__manufacturer" + ).prefetch_related("coaster_stats") if not filters: - return queryset.order_by('name') + return queryset.order_by("name") # Apply filters q_objects = Q() # Text search using computed search_text field - if 'search' in filters and filters['search']: - search_term = filters['search'].lower() + if "search" in filters and filters["search"]: + search_term = filters["search"].lower() q_objects &= Q(search_text__icontains=search_term) # Park filters - if 'park_slug' in filters and filters['park_slug']: - q_objects &= Q(park__slug=filters['park_slug']) + if "park_slug" in filters and filters["park_slug"]: + q_objects &= Q(park__slug=filters["park_slug"]) - if 'park_id' in filters and filters['park_id']: - q_objects &= Q(park_id=filters['park_id']) + if "park_id" in filters and filters["park_id"]: + q_objects &= Q(park_id=filters["park_id"]) # Category filters - if 'category' in filters and filters['category']: - q_objects &= Q(category__in=filters['category']) + if "category" in filters and filters["category"]: + q_objects &= Q(category__in=filters["category"]) # Status filters - if 'status' in filters and filters['status']: - q_objects &= Q(status__in=filters['status']) + if "status" in filters and filters["status"]: + q_objects &= Q(status__in=filters["status"]) # Company filters - if 'manufacturer_ids' in filters and filters['manufacturer_ids']: - q_objects &= Q(manufacturer_id__in=filters['manufacturer_ids']) + if "manufacturer_ids" in filters and filters["manufacturer_ids"]: + q_objects &= Q(manufacturer_id__in=filters["manufacturer_ids"]) - if 'designer_ids' in filters and filters['designer_ids']: - q_objects &= Q(designer_id__in=filters['designer_ids']) + if "designer_ids" in filters and filters["designer_ids"]: + q_objects &= Q(designer_id__in=filters["designer_ids"]) # Ride model filters - if 'ride_model_ids' in filters and filters['ride_model_ids']: - q_objects &= Q(ride_model_id__in=filters['ride_model_ids']) + if "ride_model_ids" in filters and filters["ride_model_ids"]: + q_objects &= Q(ride_model_id__in=filters["ride_model_ids"]) # Opening year filters using computed opening_year field - if 'opening_year' in filters and filters['opening_year']: - q_objects &= Q(opening_year=filters['opening_year']) + if "opening_year" in filters and filters["opening_year"]: + q_objects &= Q(opening_year=filters["opening_year"]) - if 'min_opening_year' in filters and filters['min_opening_year']: - q_objects &= Q(opening_year__gte=filters['min_opening_year']) + if "min_opening_year" in filters and filters["min_opening_year"]: + q_objects &= Q(opening_year__gte=filters["min_opening_year"]) - if 'max_opening_year' in filters and filters['max_opening_year']: - q_objects &= Q(opening_year__lte=filters['max_opening_year']) + if "max_opening_year" in filters and filters["max_opening_year"]: + q_objects &= Q(opening_year__lte=filters["max_opening_year"]) # Rating filters - if 'min_rating' in filters and filters['min_rating']: - q_objects &= Q(average_rating__gte=filters['min_rating']) + if "min_rating" in filters and filters["min_rating"]: + q_objects &= Q(average_rating__gte=filters["min_rating"]) - if 'max_rating' in filters and filters['max_rating']: - q_objects &= Q(average_rating__lte=filters['max_rating']) + if "max_rating" in filters and filters["max_rating"]: + q_objects &= Q(average_rating__lte=filters["max_rating"]) # Height requirement filters - if 'min_height_requirement' in filters and filters['min_height_requirement']: - q_objects &= Q(min_height_in__gte=filters['min_height_requirement']) + if "min_height_requirement" in filters and filters["min_height_requirement"]: + q_objects &= Q(min_height_in__gte=filters["min_height_requirement"]) - if 'max_height_requirement' in filters and filters['max_height_requirement']: - q_objects &= Q(max_height_in__lte=filters['max_height_requirement']) + if "max_height_requirement" in filters and filters["max_height_requirement"]: + q_objects &= Q(max_height_in__lte=filters["max_height_requirement"]) # Capacity filters - if 'min_capacity' in filters and filters['min_capacity']: - q_objects &= Q(capacity_per_hour__gte=filters['min_capacity']) + if "min_capacity" in filters and filters["min_capacity"]: + q_objects &= Q(capacity_per_hour__gte=filters["min_capacity"]) - if 'max_capacity' in filters and filters['max_capacity']: - q_objects &= Q(capacity_per_hour__lte=filters['max_capacity']) + if "max_capacity" in filters and filters["max_capacity"]: + q_objects &= Q(capacity_per_hour__lte=filters["max_capacity"]) # Roller coaster specific filters - if 'roller_coaster_type' in filters and filters['roller_coaster_type']: - q_objects &= Q(coaster_stats__roller_coaster_type__in=filters['roller_coaster_type']) + if "roller_coaster_type" in filters and filters["roller_coaster_type"]: + q_objects &= Q(coaster_stats__roller_coaster_type__in=filters["roller_coaster_type"]) - if 'track_material' in filters and filters['track_material']: - q_objects &= Q(coaster_stats__track_material__in=filters['track_material']) + if "track_material" in filters and filters["track_material"]: + q_objects &= Q(coaster_stats__track_material__in=filters["track_material"]) - if 'propulsion_system' in filters and filters['propulsion_system']: - q_objects &= Q(coaster_stats__propulsion_system__in=filters['propulsion_system']) + if "propulsion_system" in filters and filters["propulsion_system"]: + q_objects &= Q(coaster_stats__propulsion_system__in=filters["propulsion_system"]) # Roller coaster height filters - if 'min_height_ft' in filters and filters['min_height_ft']: - q_objects &= Q(coaster_stats__height_ft__gte=filters['min_height_ft']) + if "min_height_ft" in filters and filters["min_height_ft"]: + q_objects &= Q(coaster_stats__height_ft__gte=filters["min_height_ft"]) - if 'max_height_ft' in filters and filters['max_height_ft']: - q_objects &= Q(coaster_stats__height_ft__lte=filters['max_height_ft']) + if "max_height_ft" in filters and filters["max_height_ft"]: + q_objects &= Q(coaster_stats__height_ft__lte=filters["max_height_ft"]) # Roller coaster speed filters - if 'min_speed_mph' in filters and filters['min_speed_mph']: - q_objects &= Q(coaster_stats__speed_mph__gte=filters['min_speed_mph']) + if "min_speed_mph" in filters and filters["min_speed_mph"]: + q_objects &= Q(coaster_stats__speed_mph__gte=filters["min_speed_mph"]) - if 'max_speed_mph' in filters and filters['max_speed_mph']: - q_objects &= Q(coaster_stats__speed_mph__lte=filters['max_speed_mph']) + if "max_speed_mph" in filters and filters["max_speed_mph"]: + q_objects &= Q(coaster_stats__speed_mph__lte=filters["max_speed_mph"]) # Inversion filters - if 'min_inversions' in filters and filters['min_inversions']: - q_objects &= Q(coaster_stats__inversions__gte=filters['min_inversions']) + if "min_inversions" in filters and filters["min_inversions"]: + q_objects &= Q(coaster_stats__inversions__gte=filters["min_inversions"]) - if 'max_inversions' in filters and filters['max_inversions']: - q_objects &= Q(coaster_stats__inversions__lte=filters['max_inversions']) + if "max_inversions" in filters and filters["max_inversions"]: + q_objects &= Q(coaster_stats__inversions__lte=filters["max_inversions"]) - if 'has_inversions' in filters and filters['has_inversions'] is not None: - if filters['has_inversions']: + if "has_inversions" in filters and filters["has_inversions"] is not None: + if filters["has_inversions"]: q_objects &= Q(coaster_stats__inversions__gt=0) else: q_objects &= Q(coaster_stats__inversions=0) @@ -325,10 +317,12 @@ class SmartRideLoader: queryset = queryset.filter(q_objects) # Apply ordering - ordering = filters.get('ordering', 'name') - if ordering in ['height_ft', '-height_ft', 'speed_mph', '-speed_mph']: + ordering = filters.get("ordering", "name") + if ordering in ["height_ft", "-height_ft", "speed_mph", "-speed_mph"]: # For coaster stats ordering, we need to join and order by the stats - ordering_field = ordering.replace('height_ft', 'coaster_stats__height_ft').replace('speed_mph', 'coaster_stats__speed_mph') + ordering_field = ordering.replace("height_ft", "coaster_stats__height_ft").replace( + "speed_mph", "coaster_stats__speed_mph" + ) queryset = queryset.order_by(ordering_field) else: queryset = queryset.order_by(ordering) @@ -342,99 +336,99 @@ class SmartRideLoader: for ride in rides: # Basic ride data ride_data = { - 'id': ride.id, - 'name': ride.name, - 'slug': ride.slug, - 'description': ride.description, - 'category': ride.category, - 'status': ride.status, - 'opening_date': ride.opening_date.isoformat() if ride.opening_date else None, - 'closing_date': ride.closing_date.isoformat() if ride.closing_date else None, - 'opening_year': ride.opening_year, - 'min_height_in': ride.min_height_in, - 'max_height_in': ride.max_height_in, - 'capacity_per_hour': ride.capacity_per_hour, - 'ride_duration_seconds': ride.ride_duration_seconds, - 'average_rating': float(ride.average_rating) if ride.average_rating else None, - 'url': ride.url, - 'park_url': ride.park_url, - 'created_at': ride.created_at.isoformat(), - 'updated_at': ride.updated_at.isoformat(), + "id": ride.id, + "name": ride.name, + "slug": ride.slug, + "description": ride.description, + "category": ride.category, + "status": ride.status, + "opening_date": ride.opening_date.isoformat() if ride.opening_date else None, + "closing_date": ride.closing_date.isoformat() if ride.closing_date else None, + "opening_year": ride.opening_year, + "min_height_in": ride.min_height_in, + "max_height_in": ride.max_height_in, + "capacity_per_hour": ride.capacity_per_hour, + "ride_duration_seconds": ride.ride_duration_seconds, + "average_rating": float(ride.average_rating) if ride.average_rating else None, + "url": ride.url, + "park_url": ride.park_url, + "created_at": ride.created_at.isoformat(), + "updated_at": ride.updated_at.isoformat(), } # Park data if ride.park: - ride_data['park'] = { - 'id': ride.park.id, - 'name': ride.park.name, - 'slug': ride.park.slug, + ride_data["park"] = { + "id": ride.park.id, + "name": ride.park.name, + "slug": ride.park.slug, } # Park location data - if hasattr(ride.park, 'location') and ride.park.location: - ride_data['park']['location'] = { - 'city': ride.park.location.city, - 'state': ride.park.location.state, - 'country': ride.park.location.country, + if hasattr(ride.park, "location") and ride.park.location: + ride_data["park"]["location"] = { + "city": ride.park.location.city, + "state": ride.park.location.state, + "country": ride.park.location.country, } # Park area data if ride.park_area: - ride_data['park_area'] = { - 'id': ride.park_area.id, - 'name': ride.park_area.name, - 'slug': ride.park_area.slug, + ride_data["park_area"] = { + "id": ride.park_area.id, + "name": ride.park_area.name, + "slug": ride.park_area.slug, } # Company data if ride.manufacturer: - ride_data['manufacturer'] = { - 'id': ride.manufacturer.id, - 'name': ride.manufacturer.name, - 'slug': ride.manufacturer.slug, + ride_data["manufacturer"] = { + "id": ride.manufacturer.id, + "name": ride.manufacturer.name, + "slug": ride.manufacturer.slug, } if ride.designer: - ride_data['designer'] = { - 'id': ride.designer.id, - 'name': ride.designer.name, - 'slug': ride.designer.slug, + ride_data["designer"] = { + "id": ride.designer.id, + "name": ride.designer.name, + "slug": ride.designer.slug, } # Ride model data if ride.ride_model: - ride_data['ride_model'] = { - 'id': ride.ride_model.id, - 'name': ride.ride_model.name, - 'slug': ride.ride_model.slug, - 'category': ride.ride_model.category, + ride_data["ride_model"] = { + "id": ride.ride_model.id, + "name": ride.ride_model.name, + "slug": ride.ride_model.slug, + "category": ride.ride_model.category, } if ride.ride_model.manufacturer: - ride_data['ride_model']['manufacturer'] = { - 'id': ride.ride_model.manufacturer.id, - 'name': ride.ride_model.manufacturer.name, - 'slug': ride.ride_model.manufacturer.slug, + ride_data["ride_model"]["manufacturer"] = { + "id": ride.ride_model.manufacturer.id, + "name": ride.ride_model.manufacturer.name, + "slug": ride.ride_model.manufacturer.slug, } # Roller coaster stats - if hasattr(ride, 'coaster_stats') and ride.coaster_stats: + if hasattr(ride, "coaster_stats") and ride.coaster_stats: stats = ride.coaster_stats - ride_data['coaster_stats'] = { - 'height_ft': float(stats.height_ft) if stats.height_ft else None, - 'length_ft': float(stats.length_ft) if stats.length_ft else None, - 'speed_mph': float(stats.speed_mph) if stats.speed_mph else None, - 'inversions': stats.inversions, - 'ride_time_seconds': stats.ride_time_seconds, - 'track_type': stats.track_type, - 'track_material': stats.track_material, - 'roller_coaster_type': stats.roller_coaster_type, - 'max_drop_height_ft': float(stats.max_drop_height_ft) if stats.max_drop_height_ft else None, - 'propulsion_system': stats.propulsion_system, - 'train_style': stats.train_style, - 'trains_count': stats.trains_count, - 'cars_per_train': stats.cars_per_train, - 'seats_per_car': stats.seats_per_car, + ride_data["coaster_stats"] = { + "height_ft": float(stats.height_ft) if stats.height_ft else None, + "length_ft": float(stats.length_ft) if stats.length_ft else None, + "speed_mph": float(stats.speed_mph) if stats.speed_mph else None, + "inversions": stats.inversions, + "ride_time_seconds": stats.ride_time_seconds, + "track_type": stats.track_type, + "track_material": stats.track_material, + "roller_coaster_type": stats.roller_coaster_type, + "max_drop_height_ft": float(stats.max_drop_height_ft) if stats.max_drop_height_ft else None, + "propulsion_system": stats.propulsion_system, + "train_style": stats.train_style, + "trains_count": stats.trains_count, + "cars_per_train": stats.cars_per_train, + "seats_per_car": stats.seats_per_car, } serialized.append(ride_data) @@ -448,267 +442,250 @@ class SmartRideLoader: from apps.rides.models.rides import RollerCoasterStats # Get unique values from database with counts - parks_data = list(Ride.objects.exclude( - park__isnull=True - ).select_related('park').values( - 'park__id', 'park__name', 'park__slug' - ).annotate(count=models.Count('id')).distinct().order_by('park__name')) + parks_data = list( + Ride.objects.exclude(park__isnull=True) + .select_related("park") + .values("park__id", "park__name", "park__slug") + .annotate(count=models.Count("id")) + .distinct() + .order_by("park__name") + ) - park_areas_data = list(Ride.objects.exclude( - park_area__isnull=True - ).select_related('park_area').values( - 'park_area__id', 'park_area__name', 'park_area__slug' - ).annotate(count=models.Count('id')).distinct().order_by('park_area__name')) + park_areas_data = list( + Ride.objects.exclude(park_area__isnull=True) + .select_related("park_area") + .values("park_area__id", "park_area__name", "park_area__slug") + .annotate(count=models.Count("id")) + .distinct() + .order_by("park_area__name") + ) - manufacturers_data = list(Company.objects.filter( - roles__contains=['MANUFACTURER'] - ).values('id', 'name', 'slug').annotate( - count=models.Count('manufactured_rides') - ).order_by('name')) + manufacturers_data = list( + Company.objects.filter(roles__contains=["MANUFACTURER"]) + .values("id", "name", "slug") + .annotate(count=models.Count("manufactured_rides")) + .order_by("name") + ) - designers_data = list(Company.objects.filter( - roles__contains=['DESIGNER'] - ).values('id', 'name', 'slug').annotate( - count=models.Count('designed_rides') - ).order_by('name')) + designers_data = list( + Company.objects.filter(roles__contains=["DESIGNER"]) + .values("id", "name", "slug") + .annotate(count=models.Count("designed_rides")) + .order_by("name") + ) - ride_models_data = list(RideModel.objects.select_related( - 'manufacturer' - ).values( - 'id', 'name', 'slug', 'manufacturer__name', 'manufacturer__slug', 'category' - ).annotate(count=models.Count('rides')).order_by('manufacturer__name', 'name')) + ride_models_data = list( + RideModel.objects.select_related("manufacturer") + .values("id", "name", "slug", "manufacturer__name", "manufacturer__slug", "category") + .annotate(count=models.Count("rides")) + .order_by("manufacturer__name", "name") + ) # Get categories and statuses with counts - categories_data = list(Ride.objects.values('category').annotate( - count=models.Count('id') - ).order_by('category')) + categories_data = list(Ride.objects.values("category").annotate(count=models.Count("id")).order_by("category")) - statuses_data = list(Ride.objects.values('status').annotate( - count=models.Count('id') - ).order_by('status')) + statuses_data = list(Ride.objects.values("status").annotate(count=models.Count("id")).order_by("status")) # Get roller coaster specific data with counts - rc_types_data = list(RollerCoasterStats.objects.values('roller_coaster_type').annotate( - count=models.Count('ride') - ).exclude(roller_coaster_type__isnull=True).order_by('roller_coaster_type')) + rc_types_data = list( + RollerCoasterStats.objects.values("roller_coaster_type") + .annotate(count=models.Count("ride")) + .exclude(roller_coaster_type__isnull=True) + .order_by("roller_coaster_type") + ) - track_materials_data = list(RollerCoasterStats.objects.values('track_material').annotate( - count=models.Count('ride') - ).exclude(track_material__isnull=True).order_by('track_material')) + track_materials_data = list( + RollerCoasterStats.objects.values("track_material") + .annotate(count=models.Count("ride")) + .exclude(track_material__isnull=True) + .order_by("track_material") + ) - propulsion_systems_data = list(RollerCoasterStats.objects.values('propulsion_system').annotate( - count=models.Count('ride') - ).exclude(propulsion_system__isnull=True).order_by('propulsion_system')) + propulsion_systems_data = list( + RollerCoasterStats.objects.values("propulsion_system") + .annotate(count=models.Count("ride")) + .exclude(propulsion_system__isnull=True) + .order_by("propulsion_system") + ) # Convert to frontend-expected format with value/label/count categories = [ - { - 'value': item['category'], - 'label': self._get_category_label(item['category']), - 'count': item['count'] - } + {"value": item["category"], "label": self._get_category_label(item["category"]), "count": item["count"]} for item in categories_data ] statuses = [ - { - 'value': item['status'], - 'label': self._get_status_label(item['status']), - 'count': item['count'] - } + {"value": item["status"], "label": self._get_status_label(item["status"]), "count": item["count"]} for item in statuses_data ] roller_coaster_types = [ { - 'value': item['roller_coaster_type'], - 'label': self._get_rc_type_label(item['roller_coaster_type']), - 'count': item['count'] + "value": item["roller_coaster_type"], + "label": self._get_rc_type_label(item["roller_coaster_type"]), + "count": item["count"], } for item in rc_types_data ] track_materials = [ { - 'value': item['track_material'], - 'label': self._get_track_material_label(item['track_material']), - 'count': item['count'] + "value": item["track_material"], + "label": self._get_track_material_label(item["track_material"]), + "count": item["count"], } for item in track_materials_data ] propulsion_systems = [ { - 'value': item['propulsion_system'], - 'label': self._get_propulsion_system_label(item['propulsion_system']), - 'count': item['count'] + "value": item["propulsion_system"], + "label": self._get_propulsion_system_label(item["propulsion_system"]), + "count": item["count"], } for item in propulsion_systems_data ] # Convert other data to expected format parks = [ - { - 'value': str(item['park__id']), - 'label': item['park__name'], - 'count': item['count'] - } - for item in parks_data + {"value": str(item["park__id"]), "label": item["park__name"], "count": item["count"]} for item in parks_data ] park_areas = [ - { - 'value': str(item['park_area__id']), - 'label': item['park_area__name'], - 'count': item['count'] - } + {"value": str(item["park_area__id"]), "label": item["park_area__name"], "count": item["count"]} for item in park_areas_data ] manufacturers = [ - { - 'value': str(item['id']), - 'label': item['name'], - 'count': item['count'] - } - for item in manufacturers_data + {"value": str(item["id"]), "label": item["name"], "count": item["count"]} for item in manufacturers_data ] designers = [ - { - 'value': str(item['id']), - 'label': item['name'], - 'count': item['count'] - } - for item in designers_data + {"value": str(item["id"]), "label": item["name"], "count": item["count"]} for item in designers_data ] ride_models = [ - { - 'value': str(item['id']), - 'label': f"{item['manufacturer__name']} {item['name']}", - 'count': item['count'] - } + {"value": str(item["id"]), "label": f"{item['manufacturer__name']} {item['name']}", "count": item["count"]} for item in ride_models_data ] # Calculate ranges from actual data ride_stats = Ride.objects.aggregate( - min_rating=Min('average_rating'), - max_rating=Max('average_rating'), - min_height_req=Min('min_height_in'), - max_height_req=Max('max_height_in'), - min_capacity=Min('capacity_per_hour'), - max_capacity=Max('capacity_per_hour'), - min_duration=Min('ride_duration_seconds'), - max_duration=Max('ride_duration_seconds'), - min_year=Min('opening_year'), - max_year=Max('opening_year'), + min_rating=Min("average_rating"), + max_rating=Max("average_rating"), + min_height_req=Min("min_height_in"), + max_height_req=Max("max_height_in"), + min_capacity=Min("capacity_per_hour"), + max_capacity=Max("capacity_per_hour"), + min_duration=Min("ride_duration_seconds"), + max_duration=Max("ride_duration_seconds"), + min_year=Min("opening_year"), + max_year=Max("opening_year"), ) # Calculate roller coaster specific ranges coaster_stats = RollerCoasterStats.objects.aggregate( - min_height_ft=Min('height_ft'), - max_height_ft=Max('height_ft'), - min_length_ft=Min('length_ft'), - max_length_ft=Max('length_ft'), - min_speed_mph=Min('speed_mph'), - max_speed_mph=Max('speed_mph'), - min_inversions=Min('inversions'), - max_inversions=Max('inversions'), - min_ride_time=Min('ride_time_seconds'), - max_ride_time=Max('ride_time_seconds'), - min_drop_height=Min('max_drop_height_ft'), - max_drop_height=Max('max_drop_height_ft'), - min_trains=Min('trains_count'), - max_trains=Max('trains_count'), - min_cars=Min('cars_per_train'), - max_cars=Max('cars_per_train'), - min_seats=Min('seats_per_car'), - max_seats=Max('seats_per_car'), + min_height_ft=Min("height_ft"), + max_height_ft=Max("height_ft"), + min_length_ft=Min("length_ft"), + max_length_ft=Max("length_ft"), + min_speed_mph=Min("speed_mph"), + max_speed_mph=Max("speed_mph"), + min_inversions=Min("inversions"), + max_inversions=Max("inversions"), + min_ride_time=Min("ride_time_seconds"), + max_ride_time=Max("ride_time_seconds"), + min_drop_height=Min("max_drop_height_ft"), + max_drop_height=Max("max_drop_height_ft"), + min_trains=Min("trains_count"), + max_trains=Max("trains_count"), + min_cars=Min("cars_per_train"), + max_cars=Max("cars_per_train"), + min_seats=Min("seats_per_car"), + max_seats=Max("seats_per_car"), ) return { - 'categorical': { - 'categories': categories, - 'statuses': statuses, - 'roller_coaster_types': roller_coaster_types, - 'track_materials': track_materials, - 'propulsion_systems': propulsion_systems, - 'parks': parks, - 'park_areas': park_areas, - 'manufacturers': manufacturers, - 'designers': designers, - 'ride_models': ride_models, + "categorical": { + "categories": categories, + "statuses": statuses, + "roller_coaster_types": roller_coaster_types, + "track_materials": track_materials, + "propulsion_systems": propulsion_systems, + "parks": parks, + "park_areas": park_areas, + "manufacturers": manufacturers, + "designers": designers, + "ride_models": ride_models, }, - 'ranges': { - 'rating': { - 'min': float(ride_stats['min_rating'] or 1), - 'max': float(ride_stats['max_rating'] or 10), - 'step': 0.1, - 'unit': 'stars' + "ranges": { + "rating": { + "min": float(ride_stats["min_rating"] or 1), + "max": float(ride_stats["max_rating"] or 10), + "step": 0.1, + "unit": "stars", }, - 'height_requirement': { - 'min': ride_stats['min_height_req'] or 30, - 'max': ride_stats['max_height_req'] or 90, - 'step': 1, - 'unit': 'inches' + "height_requirement": { + "min": ride_stats["min_height_req"] or 30, + "max": ride_stats["max_height_req"] or 90, + "step": 1, + "unit": "inches", }, - 'capacity': { - 'min': ride_stats['min_capacity'] or 0, - 'max': ride_stats['max_capacity'] or 5000, - 'step': 50, - 'unit': 'riders/hour' + "capacity": { + "min": ride_stats["min_capacity"] or 0, + "max": ride_stats["max_capacity"] or 5000, + "step": 50, + "unit": "riders/hour", }, - 'ride_duration': { - 'min': ride_stats['min_duration'] or 0, - 'max': ride_stats['max_duration'] or 600, - 'step': 10, - 'unit': 'seconds' + "ride_duration": { + "min": ride_stats["min_duration"] or 0, + "max": ride_stats["max_duration"] or 600, + "step": 10, + "unit": "seconds", }, - 'opening_year': { - 'min': ride_stats['min_year'] or 1800, - 'max': ride_stats['max_year'] or 2030, - 'step': 1, - 'unit': 'year' + "opening_year": { + "min": ride_stats["min_year"] or 1800, + "max": ride_stats["max_year"] or 2030, + "step": 1, + "unit": "year", }, - 'height_ft': { - 'min': float(coaster_stats['min_height_ft'] or 0), - 'max': float(coaster_stats['max_height_ft'] or 500), - 'step': 5, - 'unit': 'feet' + "height_ft": { + "min": float(coaster_stats["min_height_ft"] or 0), + "max": float(coaster_stats["max_height_ft"] or 500), + "step": 5, + "unit": "feet", }, - 'length_ft': { - 'min': float(coaster_stats['min_length_ft'] or 0), - 'max': float(coaster_stats['max_length_ft'] or 10000), - 'step': 100, - 'unit': 'feet' + "length_ft": { + "min": float(coaster_stats["min_length_ft"] or 0), + "max": float(coaster_stats["max_length_ft"] or 10000), + "step": 100, + "unit": "feet", }, - 'speed_mph': { - 'min': float(coaster_stats['min_speed_mph'] or 0), - 'max': float(coaster_stats['max_speed_mph'] or 150), - 'step': 5, - 'unit': 'mph' + "speed_mph": { + "min": float(coaster_stats["min_speed_mph"] or 0), + "max": float(coaster_stats["max_speed_mph"] or 150), + "step": 5, + "unit": "mph", }, - 'inversions': { - 'min': coaster_stats['min_inversions'] or 0, - 'max': coaster_stats['max_inversions'] or 20, - 'step': 1, - 'unit': 'inversions' + "inversions": { + "min": coaster_stats["min_inversions"] or 0, + "max": coaster_stats["max_inversions"] or 20, + "step": 1, + "unit": "inversions", }, }, - 'total_count': Ride.objects.count(), + "total_count": Ride.objects.count(), } def _get_category_label(self, category: str) -> str: """Convert category code to human-readable label.""" category_labels = { - 'RC': 'Roller Coaster', - 'DR': 'Dark Ride', - 'FR': 'Flat Ride', - 'WR': 'Water Ride', - 'TR': 'Transport Ride', - 'OT': 'Other', + "RC": "Roller Coaster", + "DR": "Dark Ride", + "FR": "Flat Ride", + "WR": "Water Ride", + "TR": "Transport Ride", + "OT": "Other", } if category in category_labels: return category_labels[category] @@ -718,14 +695,14 @@ class SmartRideLoader: def _get_status_label(self, status: str) -> str: """Convert status code to human-readable label.""" status_labels = { - 'OPERATING': 'Operating', - 'CLOSED_TEMP': 'Temporarily Closed', - 'SBNO': 'Standing But Not Operating', - 'CLOSING': 'Closing Soon', - 'CLOSED_PERM': 'Permanently Closed', - 'UNDER_CONSTRUCTION': 'Under Construction', - 'DEMOLISHED': 'Demolished', - 'RELOCATED': 'Relocated', + "OPERATING": "Operating", + "CLOSED_TEMP": "Temporarily Closed", + "SBNO": "Standing But Not Operating", + "CLOSING": "Closing Soon", + "CLOSED_PERM": "Permanently Closed", + "UNDER_CONSTRUCTION": "Under Construction", + "DEMOLISHED": "Demolished", + "RELOCATED": "Relocated", } if status in status_labels: return status_labels[status] @@ -735,19 +712,19 @@ class SmartRideLoader: def _get_rc_type_label(self, rc_type: str) -> str: """Convert roller coaster type to human-readable label.""" rc_type_labels = { - 'SITDOWN': 'Sit Down', - 'INVERTED': 'Inverted', - 'SUSPENDED': 'Suspended', - 'FLOORLESS': 'Floorless', - 'FLYING': 'Flying', - 'WING': 'Wing', - 'DIVE': 'Dive', - 'SPINNING': 'Spinning', - 'WILD_MOUSE': 'Wild Mouse', - 'BOBSLED': 'Bobsled', - 'PIPELINE': 'Pipeline', - 'FOURTH_DIMENSION': '4th Dimension', - 'FAMILY': 'Family', + "SITDOWN": "Sit Down", + "INVERTED": "Inverted", + "SUSPENDED": "Suspended", + "FLOORLESS": "Floorless", + "FLYING": "Flying", + "WING": "Wing", + "DIVE": "Dive", + "SPINNING": "Spinning", + "WILD_MOUSE": "Wild Mouse", + "BOBSLED": "Bobsled", + "PIPELINE": "Pipeline", + "FOURTH_DIMENSION": "4th Dimension", + "FAMILY": "Family", } if rc_type in rc_type_labels: return rc_type_labels[rc_type] @@ -757,9 +734,9 @@ class SmartRideLoader: def _get_track_material_label(self, material: str) -> str: """Convert track material to human-readable label.""" material_labels = { - 'STEEL': 'Steel', - 'WOOD': 'Wood', - 'HYBRID': 'Hybrid (Steel/Wood)', + "STEEL": "Steel", + "WOOD": "Wood", + "HYBRID": "Hybrid (Steel/Wood)", } if material in material_labels: return material_labels[material] @@ -769,15 +746,15 @@ class SmartRideLoader: def _get_propulsion_system_label(self, propulsion_system: str) -> str: """Convert propulsion system to human-readable label.""" propulsion_labels = { - 'CHAIN': 'Chain Lift', - 'LSM': 'Linear Synchronous Motor', - 'LIM': 'Linear Induction Motor', - 'HYDRAULIC': 'Hydraulic Launch', - 'PNEUMATIC': 'Pneumatic Launch', - 'CABLE': 'Cable Lift', - 'FLYWHEEL': 'Flywheel Launch', - 'GRAVITY': 'Gravity', - 'NONE': 'No Propulsion System', + "CHAIN": "Chain Lift", + "LSM": "Linear Synchronous Motor", + "LIM": "Linear Induction Motor", + "HYDRAULIC": "Hydraulic Launch", + "PNEUMATIC": "Pneumatic Launch", + "CABLE": "Cable Lift", + "FLYWHEEL": "Flywheel Launch", + "GRAVITY": "Gravity", + "NONE": "No Propulsion System", } if propulsion_system in propulsion_labels: return propulsion_labels[propulsion_system] diff --git a/backend/apps/rides/services/location_service.py b/backend/apps/rides/services/location_service.py index 745cd0bb..a9688b5a 100644 --- a/backend/apps/rides/services/location_service.py +++ b/backend/apps/rides/services/location_service.py @@ -69,9 +69,7 @@ class RideLocationService: return ride_location @classmethod - def update_ride_location( - cls, ride_location: RideLocation, **updates - ) -> RideLocation: + def update_ride_location(cls, ride_location: RideLocation, **updates) -> RideLocation: """ Update ride location with validation. @@ -149,9 +147,7 @@ class RideLocationService: if park: queryset = queryset.filter(ride__park=park) - return list( - queryset.select_related("ride", "ride__park").order_by("point__distance") - ) + return list(queryset.select_related("ride", "ride__park").order_by("point__distance")) @classmethod def get_ride_navigation_info(cls, ride_location: RideLocation) -> dict[str, Any]: @@ -249,9 +245,7 @@ class RideLocationService: # Rough conversion: 1 degree latitude ≈ 111,000 meters # 1 degree longitude varies by latitude, but we'll use a rough approximation lat_offset = offset[0] / 111000 # North offset in degrees - lon_offset = offset[1] / ( - 111000 * abs(park_location.latitude) * 0.01 - ) # East offset + lon_offset = offset[1] / (111000 * abs(park_location.latitude) * 0.01) # East offset estimated_lat = park_location.latitude + lat_offset estimated_lon = park_location.longitude + lon_offset @@ -277,9 +271,7 @@ class RideLocationService: return updated_count # Get all rides in the park that don't have precise coordinates - ride_locations = RideLocation.objects.filter( - ride__park=park, point__isnull=True - ).select_related("ride") + ride_locations = RideLocation.objects.filter(ride__park=park, point__isnull=True).select_related("ride") for ride_location in ride_locations: # Try to search for the specific ride within the park area @@ -312,22 +304,15 @@ class RideLocationService: # Look for results that might be the ride for result in results: display_name = result.get("display_name", "").lower() - if ( - ride_location.ride.name.lower() in display_name - and park.name.lower() in display_name - ): + if ride_location.ride.name.lower() in display_name and park.name.lower() in display_name: # Update the ride location - ride_location.set_coordinates( - float(result["lat"]), float(result["lon"]) - ) + ride_location.set_coordinates(float(result["lat"]), float(result["lon"])) ride_location.save() updated_count += 1 break except Exception as e: - logger.warning( - f"Error updating ride location for {ride_location.ride.name}: {str(e)}" - ) + logger.warning(f"Error updating ride location for {ride_location.ride.name}: {str(e)}") continue return updated_count @@ -346,9 +331,7 @@ class RideLocationService: area_map = {} ride_locations = ( - RideLocation.objects.filter(ride__park=park) - .select_related("ride") - .order_by("park_area", "ride__name") + RideLocation.objects.filter(ride__park=park).select_related("ride").order_by("park_area", "ride__name") ) for ride_location in ride_locations: diff --git a/backend/apps/rides/services/media_service.py b/backend/apps/rides/services/media_service.py index ef61a823..2cebe014 100644 --- a/backend/apps/rides/services/media_service.py +++ b/backend/apps/rides/services/media_service.py @@ -143,11 +143,7 @@ class RideMediaService: Returns: List of RidePhoto instances """ - return list( - ride.photos.filter(photo_type=photo_type, is_approved=True).order_by( - "-created_at" - ) - ) + return list(ride.photos.filter(photo_type=photo_type, is_approved=True).order_by("-created_at")) @staticmethod def set_primary_photo(ride: Ride, photo: RidePhoto) -> bool: @@ -218,9 +214,7 @@ class RideMediaService: photo.image.delete(save=False) photo.delete() - logger.info( - f"Photo {photo_id} deleted from ride {ride_slug} by user {deleted_by.username}" - ) + logger.info(f"Photo {photo_id} deleted from ride {ride_slug} by user {deleted_by.username}") return True except Exception as e: logger.error(f"Failed to delete photo {photo.pk}: {str(e)}") @@ -272,9 +266,7 @@ class RideMediaService: if RideMediaService.approve_photo(photo, approved_by): approved_count += 1 - logger.info( - f"Bulk approved {approved_count} photos by user {approved_by.username}" - ) + logger.info(f"Bulk approved {approved_count} photos by user {approved_by.username}") return approved_count @staticmethod @@ -289,9 +281,7 @@ class RideMediaService: List of construction RidePhoto instances ordered by date taken """ return list( - ride.photos.filter(photo_type="construction", is_approved=True).order_by( - "date_taken", "created_at" - ) + ride.photos.filter(photo_type="construction", is_approved=True).order_by("date_taken", "created_at") ) @staticmethod diff --git a/backend/apps/rides/services/ranking_service.py b/backend/apps/rides/services/ranking_service.py index 9dba45b0..7d277afb 100644 --- a/backend/apps/rides/services/ranking_service.py +++ b/backend/apps/rides/services/ranking_service.py @@ -53,9 +53,7 @@ class RideRankingService: Dictionary with statistics about the ranking calculation """ start_time = timezone.now() - self.logger.info( - f"Starting ranking calculation for category: {category or 'ALL'}" - ) + self.logger.info(f"Starting ranking calculation for category: {category or 'ALL'}") try: with transaction.atomic(): @@ -87,9 +85,7 @@ class RideRankingService: self._cleanup_old_data() duration = (timezone.now() - start_time).total_seconds() - self.logger.info( - f"Ranking calculation completed in {duration:.2f} seconds" - ) + self.logger.info(f"Ranking calculation completed in {duration:.2f} seconds") return { "status": "success", @@ -113,9 +109,7 @@ class RideRankingService: """ queryset = ( Ride.objects.filter(status="OPERATING", reviews__is_published=True) - .annotate( - review_count=Count("reviews", filter=Q(reviews__is_published=True)) - ) + .annotate(review_count=Count("reviews", filter=Q(reviews__is_published=True))) .filter(review_count__gt=0) ) @@ -124,9 +118,7 @@ class RideRankingService: return list(queryset.distinct()) - def _calculate_all_comparisons( - self, rides: list[Ride] - ) -> dict[tuple[int, int], RidePairComparison]: + def _calculate_all_comparisons(self, rides: list[Ride]) -> dict[tuple[int, int], RidePairComparison]: """ Calculate pairwise comparisons for all ride pairs. @@ -146,15 +138,11 @@ class RideRankingService: processed += 1 if processed % 100 == 0: - self.logger.debug( - f"Processed {processed}/{total_pairs} comparisons" - ) + self.logger.debug(f"Processed {processed}/{total_pairs} comparisons") return comparisons - def _calculate_pairwise_comparison( - self, ride_a: Ride, ride_b: Ride - ) -> RidePairComparison | None: + def _calculate_pairwise_comparison(self, ride_a: Ride, ride_b: Ride) -> RidePairComparison | None: """ Calculate the pairwise comparison between two rides. @@ -163,15 +151,11 @@ class RideRankingService: """ # Get mutual riders (users who have rated both rides) ride_a_reviewers = set( - RideReview.objects.filter(ride=ride_a, is_published=True).values_list( - "user_id", flat=True - ) + RideReview.objects.filter(ride=ride_a, is_published=True).values_list("user_id", flat=True) ) ride_b_reviewers = set( - RideReview.objects.filter(ride=ride_b, is_published=True).values_list( - "user_id", flat=True - ) + RideReview.objects.filter(ride=ride_b, is_published=True).values_list("user_id", flat=True) ) mutual_riders = ride_a_reviewers & ride_b_reviewers @@ -183,16 +167,12 @@ class RideRankingService: # Get ratings from mutual riders ride_a_ratings = { review.user_id: review.rating - for review in RideReview.objects.filter( - ride=ride_a, user_id__in=mutual_riders, is_published=True - ) + for review in RideReview.objects.filter(ride=ride_a, user_id__in=mutual_riders, is_published=True) } ride_b_ratings = { review.user_id: review.rating - for review in RideReview.objects.filter( - ride=ride_b, user_id__in=mutual_riders, is_published=True - ) + for review in RideReview.objects.filter(ride=ride_b, user_id__in=mutual_riders, is_published=True) } # Count wins and ties @@ -212,12 +192,8 @@ class RideRankingService: ties += 1 # Calculate average ratings from mutual riders - ride_a_avg = ( - sum(ride_a_ratings.values()) / len(ride_a_ratings) if ride_a_ratings else 0 - ) - ride_b_avg = ( - sum(ride_b_ratings.values()) / len(ride_b_ratings) if ride_b_ratings else 0 - ) + ride_a_avg = sum(ride_a_ratings.values()) / len(ride_a_ratings) if ride_a_ratings else 0 + ride_b_avg = sum(ride_b_ratings.values()) / len(ride_b_ratings) if ride_b_ratings else 0 # Create or update comparison record comparison, created = RidePairComparison.objects.update_or_create( @@ -228,16 +204,8 @@ class RideRankingService: "ride_b_wins": ride_b_wins if ride_a.id < ride_b.id else ride_a_wins, "ties": ties, "mutual_riders_count": len(mutual_riders), - "ride_a_avg_rating": ( - Decimal(str(ride_a_avg)) - if ride_a.id < ride_b.id - else Decimal(str(ride_b_avg)) - ), - "ride_b_avg_rating": ( - Decimal(str(ride_b_avg)) - if ride_a.id < ride_b.id - else Decimal(str(ride_a_avg)) - ), + "ride_a_avg_rating": (Decimal(str(ride_a_avg)) if ride_a.id < ride_b.id else Decimal(str(ride_b_avg))), + "ride_b_avg_rating": (Decimal(str(ride_b_avg)) if ride_a.id < ride_b.id else Decimal(str(ride_a_avg))), }, ) @@ -294,16 +262,12 @@ class RideRankingService: # Calculate winning percentage (ties count as 0.5) total_comparisons = wins + losses + ties if total_comparisons > 0: - winning_percentage = Decimal( - str((wins + 0.5 * ties) / total_comparisons) - ) + winning_percentage = Decimal(str((wins + 0.5 * ties) / total_comparisons)) else: winning_percentage = Decimal("0.5") # Get average rating and reviewer count - ride_stats = RideReview.objects.filter( - ride=ride, is_published=True - ).aggregate( + ride_stats = RideReview.objects.filter(ride=ride, is_published=True).aggregate( avg_rating=Avg("rating"), reviewer_count=Count("user", distinct=True) ) @@ -356,11 +320,7 @@ class RideRankingService: tied_group = [rankings[i]] j = i + 1 - while ( - j < len(rankings) - and rankings[j]["winning_percentage"] - == rankings[i]["winning_percentage"] - ): + while j < len(rankings) and rankings[j]["winning_percentage"] == rankings[i]["winning_percentage"]: tied_group.append(rankings[j]) j += 1 @@ -462,9 +422,7 @@ class RideRankingService: cutoff_date = timezone.now() - timezone.timedelta(days=days_to_keep) # Delete old snapshots - deleted_snapshots = RankingSnapshot.objects.filter( - snapshot_date__lt=cutoff_date.date() - ).delete() + deleted_snapshots = RankingSnapshot.objects.filter(snapshot_date__lt=cutoff_date.date()).delete() if deleted_snapshots[0] > 0: self.logger.info(f"Deleted {deleted_snapshots[0]} old ranking snapshots") @@ -486,9 +444,7 @@ class RideRankingService: ) # Get ranking history - history = RankingSnapshot.objects.filter(ride=ride).order_by( - "-snapshot_date" - )[:30] + history = RankingSnapshot.objects.filter(ride=ride).order_by("-snapshot_date")[:30] return { "current_rank": ranking.rank, @@ -501,32 +457,18 @@ class RideRankingService: "last_calculated": ranking.last_calculated, "head_to_head": [ { - "opponent": ( - comp.ride_b if comp.ride_a_id == ride.id else comp.ride_a - ), + "opponent": (comp.ride_b if comp.ride_a_id == ride.id else comp.ride_a), "result": ( "win" if ( - ( - comp.ride_a_id == ride.id - and comp.ride_a_wins > comp.ride_b_wins - ) - or ( - comp.ride_b_id == ride.id - and comp.ride_b_wins > comp.ride_a_wins - ) + (comp.ride_a_id == ride.id and comp.ride_a_wins > comp.ride_b_wins) + or (comp.ride_b_id == ride.id and comp.ride_b_wins > comp.ride_a_wins) ) else ( "loss" if ( - ( - comp.ride_a_id == ride.id - and comp.ride_a_wins < comp.ride_b_wins - ) - or ( - comp.ride_b_id == ride.id - and comp.ride_b_wins < comp.ride_a_wins - ) + (comp.ride_a_id == ride.id and comp.ride_a_wins < comp.ride_b_wins) + or (comp.ride_b_id == ride.id and comp.ride_b_wins < comp.ride_a_wins) ) else "tie" ) diff --git a/backend/apps/rides/services/search.py b/backend/apps/rides/services/search.py index b8076881..c183b366 100644 --- a/backend/apps/rides/services/search.py +++ b/backend/apps/rides/services/search.py @@ -127,9 +127,7 @@ class RideSearchService: # Apply text search with ranking if filters.get("global_search"): - queryset, search_rank = self._apply_full_text_search( - queryset, filters["global_search"] - ) + queryset, search_rank = self._apply_full_text_search(queryset, filters["global_search"]) search_metadata["search_applied"] = True search_metadata["search_term"] = filters["global_search"] else: @@ -176,9 +174,7 @@ class RideSearchService: "applied_filters": self._get_applied_filters_summary(filters), } - def _apply_full_text_search( - self, queryset, search_term: str - ) -> tuple[models.QuerySet, models.Expression]: + def _apply_full_text_search(self, queryset, search_term: str) -> tuple[models.QuerySet, models.Expression]: """ Apply PostgreSQL full-text search with ranking and fuzzy matching. """ @@ -201,17 +197,14 @@ class RideSearchService: search_query = SearchQuery(search_term, config="english") # Calculate search rank - search_rank = SearchRank( - search_vector, search_query, weights=self.SEARCH_RANK_WEIGHTS - ) + search_rank = SearchRank(search_vector, search_query, weights=self.SEARCH_RANK_WEIGHTS) # Apply trigram similarity for fuzzy matching on name trigram_similarity = TrigramSimilarity("name", search_term) # Combine full-text search with trigram similarity queryset = queryset.annotate(trigram_similarity=trigram_similarity).filter( - Q(search_vector=search_query) - | Q(trigram_similarity__gte=self.TRIGRAM_SIMILARITY_THRESHOLD) + Q(search_vector=search_query) | Q(trigram_similarity__gte=self.TRIGRAM_SIMILARITY_THRESHOLD) ) # Use the greatest of search rank and trigram similarity for final ranking @@ -219,36 +212,22 @@ class RideSearchService: return queryset, final_rank - def _apply_basic_info_filters( - self, queryset, filters: dict[str, Any] - ) -> models.QuerySet: + def _apply_basic_info_filters(self, queryset, filters: dict[str, Any]) -> models.QuerySet: """Apply basic information filters.""" # Category filter (multi-select) if filters.get("category"): - categories = ( - filters["category"] - if isinstance(filters["category"], list) - else [filters["category"]] - ) + categories = filters["category"] if isinstance(filters["category"], list) else [filters["category"]] queryset = queryset.filter(category__in=categories) # Status filter (multi-select) if filters.get("status"): - statuses = ( - filters["status"] - if isinstance(filters["status"], list) - else [filters["status"]] - ) + statuses = filters["status"] if isinstance(filters["status"], list) else [filters["status"]] queryset = queryset.filter(status__in=statuses) # Park filter (multi-select) if filters.get("park"): - parks = ( - filters["park"] - if isinstance(filters["park"], list) - else [filters["park"]] - ) + parks = filters["park"] if isinstance(filters["park"], list) else [filters["park"]] if isinstance(parks[0], str): # If slugs provided queryset = queryset.filter(park__slug__in=parks) else: # If IDs provided @@ -256,11 +235,7 @@ class RideSearchService: # Park area filter (multi-select) if filters.get("park_area"): - areas = ( - filters["park_area"] - if isinstance(filters["park_area"], list) - else [filters["park_area"]] - ) + areas = filters["park_area"] if isinstance(filters["park_area"], list) else [filters["park_area"]] if isinstance(areas[0], str): # If slugs provided queryset = queryset.filter(park_area__slug__in=areas) else: # If IDs provided @@ -297,9 +272,7 @@ class RideSearchService: return queryset - def _apply_height_safety_filters( - self, queryset, filters: dict[str, Any] - ) -> models.QuerySet: + def _apply_height_safety_filters(self, queryset, filters: dict[str, Any]) -> models.QuerySet: """Apply height and safety requirement filters.""" # Minimum height range @@ -320,9 +293,7 @@ class RideSearchService: return queryset - def _apply_performance_filters( - self, queryset, filters: dict[str, Any] - ) -> models.QuerySet: + def _apply_performance_filters(self, queryset, filters: dict[str, Any]) -> models.QuerySet: """Apply performance metric filters.""" # Capacity range @@ -337,13 +308,9 @@ class RideSearchService: if filters.get("duration_range"): duration_range = filters["duration_range"] if duration_range.get("min") is not None: - queryset = queryset.filter( - ride_duration_seconds__gte=duration_range["min"] - ) + queryset = queryset.filter(ride_duration_seconds__gte=duration_range["min"]) if duration_range.get("max") is not None: - queryset = queryset.filter( - ride_duration_seconds__lte=duration_range["max"] - ) + queryset = queryset.filter(ride_duration_seconds__lte=duration_range["max"]) # Rating range if filters.get("rating_range"): @@ -355,17 +322,13 @@ class RideSearchService: return queryset - def _apply_relationship_filters( - self, queryset, filters: dict[str, Any] - ) -> models.QuerySet: + def _apply_relationship_filters(self, queryset, filters: dict[str, Any]) -> models.QuerySet: """Apply relationship filters (manufacturer, designer, ride model).""" # Manufacturer filter (multi-select) if filters.get("manufacturer"): manufacturers = ( - filters["manufacturer"] - if isinstance(filters["manufacturer"], list) - else [filters["manufacturer"]] + filters["manufacturer"] if isinstance(filters["manufacturer"], list) else [filters["manufacturer"]] ) if isinstance(manufacturers[0], str): # If slugs provided queryset = queryset.filter(manufacturer__slug__in=manufacturers) @@ -374,11 +337,7 @@ class RideSearchService: # Designer filter (multi-select) if filters.get("designer"): - designers = ( - filters["designer"] - if isinstance(filters["designer"], list) - else [filters["designer"]] - ) + designers = filters["designer"] if isinstance(filters["designer"], list) else [filters["designer"]] if isinstance(designers[0], str): # If slugs provided queryset = queryset.filter(designer__slug__in=designers) else: # If IDs provided @@ -386,11 +345,7 @@ class RideSearchService: # Ride model filter (multi-select) if filters.get("ride_model"): - models_list = ( - filters["ride_model"] - if isinstance(filters["ride_model"], list) - else [filters["ride_model"]] - ) + models_list = filters["ride_model"] if isinstance(filters["ride_model"], list) else [filters["ride_model"]] if isinstance(models_list[0], str): # If slugs provided queryset = queryset.filter(ride_model__slug__in=models_list) else: # If IDs provided @@ -398,9 +353,7 @@ class RideSearchService: return queryset - def _apply_roller_coaster_filters( - self, queryset, filters: dict[str, Any] - ) -> models.QuerySet: + def _apply_roller_coaster_filters(self, queryset, filters: dict[str, Any]) -> models.QuerySet: """Apply roller coaster specific filters.""" queryset = self._apply_numeric_range_filter( queryset, filters, "height_ft_range", "rollercoasterstats__height_ft" @@ -426,14 +379,8 @@ class RideSearchService: # Coaster type filter (multi-select) if filters.get("coaster_type"): - types = ( - filters["coaster_type"] - if isinstance(filters["coaster_type"], list) - else [filters["coaster_type"]] - ) - queryset = queryset.filter( - rollercoasterstats__roller_coaster_type__in=types - ) + types = filters["coaster_type"] if isinstance(filters["coaster_type"], list) else [filters["coaster_type"]] + queryset = queryset.filter(rollercoasterstats__roller_coaster_type__in=types) # Propulsion system filter (multi-select) if filters.get("propulsion_system"): @@ -457,18 +404,12 @@ class RideSearchService: if filters.get(filter_key): range_filter = filters[filter_key] if range_filter.get("min") is not None: - queryset = queryset.filter( - **{f"{field_name}__gte": range_filter["min"]} - ) + queryset = queryset.filter(**{f"{field_name}__gte": range_filter["min"]}) if range_filter.get("max") is not None: - queryset = queryset.filter( - **{f"{field_name}__lte": range_filter["max"]} - ) + queryset = queryset.filter(**{f"{field_name}__lte": range_filter["max"]}) return queryset - def _apply_company_filters( - self, queryset, filters: dict[str, Any] - ) -> models.QuerySet: + def _apply_company_filters(self, queryset, filters: dict[str, Any]) -> models.QuerySet: """Apply company-related filters.""" # Manufacturer roles filter @@ -518,13 +459,9 @@ class RideSearchService: return queryset.order_by("-search_rank", "name") # Apply the sorting - return queryset.order_by( - sort_field, "name" - ) # Always add name as secondary sort + return queryset.order_by(sort_field, "name") # Always add name as secondary sort - def _add_search_highlights( - self, results: list[Ride], search_term: str - ) -> list[Ride]: + def _add_search_highlights(self, results: list[Ride], search_term: str) -> list[Ride]: """Add search highlights to results using SearchHeadline.""" if not search_term or not results: @@ -601,9 +538,7 @@ class RideSearchService: else: raise ValueError(f"Unknown filter key: {filter_key}") - def get_search_suggestions( - self, query: str, limit: int = 10 - ) -> list[dict[str, Any]]: + def get_search_suggestions(self, query: str, limit: int = 10) -> list[dict[str, Any]]: """ Get search suggestions for autocomplete functionality. """ @@ -686,17 +621,11 @@ class RideSearchService: # Apply context filters to narrow down options if context_filters: temp_filters = context_filters.copy() - temp_filters.pop( - filter_type, None - ) # Remove the filter we're getting options for + temp_filters.pop(filter_type, None) # Remove the filter we're getting options for base_queryset = self._apply_all_filters(base_queryset, temp_filters) if filter_type == "park": - return list( - base_queryset.values("park__name", "park__slug") - .distinct() - .order_by("park__name") - ) + return list(base_queryset.values("park__name", "park__slug").distinct().order_by("park__name")) elif filter_type == "manufacturer": return list( diff --git a/backend/apps/rides/services/status_service.py b/backend/apps/rides/services/status_service.py index 2a299047..80401c38 100644 --- a/backend/apps/rides/services/status_service.py +++ b/backend/apps/rides/services/status_service.py @@ -3,7 +3,6 @@ Services for ride status transitions and management. Following Django styleguide pattern for business logic encapsulation. """ - from django.contrib.auth.models import AbstractBaseUser from django.db import transaction @@ -34,9 +33,7 @@ class RideStatusService: return ride @staticmethod - def close_ride_temporarily( - *, ride_id: int, user: AbstractBaseUser | None = None - ) -> Ride: + def close_ride_temporarily(*, ride_id: int, user: AbstractBaseUser | None = None) -> Ride: """ Temporarily close a ride. @@ -56,9 +53,7 @@ class RideStatusService: return ride @staticmethod - def mark_ride_sbno( - *, ride_id: int, user: AbstractBaseUser | None = None - ) -> Ride: + def mark_ride_sbno(*, ride_id: int, user: AbstractBaseUser | None = None) -> Ride: """ Mark a ride as SBNO (Standing But Not Operating). @@ -111,9 +106,7 @@ class RideStatusService: return ride @staticmethod - def close_ride_permanently( - *, ride_id: int, user: AbstractBaseUser | None = None - ) -> Ride: + def close_ride_permanently(*, ride_id: int, user: AbstractBaseUser | None = None) -> Ride: """ Permanently close a ride. diff --git a/backend/apps/rides/services_core.py b/backend/apps/rides/services_core.py index fb3e7ce1..8cbcbb64 100644 --- a/backend/apps/rides/services_core.py +++ b/backend/apps/rides/services_core.py @@ -139,9 +139,7 @@ class RideService: return ride @staticmethod - def close_ride_temporarily( - *, ride_id: int, user: UserType | None = None - ) -> Ride: + def close_ride_temporarily(*, ride_id: int, user: UserType | None = None) -> Ride: """ Temporarily close a ride. @@ -161,9 +159,7 @@ class RideService: return ride @staticmethod - def mark_ride_sbno( - *, ride_id: int, user: UserType | None = None - ) -> Ride: + def mark_ride_sbno(*, ride_id: int, user: UserType | None = None) -> Ride: """ Mark a ride as SBNO (Standing But Not Operating). @@ -216,9 +212,7 @@ class RideService: return ride @staticmethod - def close_ride_permanently( - *, ride_id: int, user: UserType | None = None - ) -> Ride: + def close_ride_permanently(*, ride_id: int, user: UserType | None = None) -> Ride: """ Permanently close a ride. @@ -258,9 +252,7 @@ class RideService: return ride @staticmethod - def relocate_ride( - *, ride_id: int, new_park_id: int, user: UserType | None = None - ) -> Ride: + def relocate_ride(*, ride_id: int, new_park_id: int, user: UserType | None = None) -> Ride: """ Relocate a ride to a new park. @@ -336,12 +328,7 @@ class RideService: """ from apps.moderation.services import ModerationService - result = { - 'manufacturers': [], - 'designers': [], - 'ride_models': [], - 'total_submissions': 0 - } + result = {"manufacturers": [], "designers": [], "ride_models": [], "total_submissions": 0} # Check for new manufacturer manufacturer_name = form_data.get("manufacturer_search") @@ -354,8 +341,8 @@ class RideService: reason=f"New manufacturer suggested: {manufacturer_name}", ) if submission: - result['manufacturers'].append(submission.id) - result['total_submissions'] += 1 + result["manufacturers"].append(submission.id) + result["total_submissions"] += 1 # Check for new designer designer_name = form_data.get("designer_search") @@ -368,8 +355,8 @@ class RideService: reason=f"New designer suggested: {designer_name}", ) if submission: - result['designers'].append(submission.id) - result['total_submissions'] += 1 + result["designers"].append(submission.id) + result["total_submissions"] += 1 # Check for new ride model ride_model_name = form_data.get("ride_model_search") @@ -386,7 +373,7 @@ class RideService: reason=f"New ride model suggested: {ride_model_name}", ) if submission: - result['ride_models'].append(submission.id) - result['total_submissions'] += 1 + result["ride_models"].append(submission.id) + result["total_submissions"] += 1 return result diff --git a/backend/apps/rides/signals.py b/backend/apps/rides/signals.py index 17c58d2c..9f1fa594 100644 --- a/backend/apps/rides/signals.py +++ b/backend/apps/rides/signals.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) # Computed Field Maintenance # ============================================================================= + def update_ride_search_text(ride): """ Update ride's search_text computed field. @@ -25,7 +26,7 @@ def update_ride_search_text(ride): try: ride._populate_computed_fields() - ride.save(update_fields=['search_text']) + ride.save(update_fields=["search_text"]) logger.debug(f"Updated search_text for ride {ride.pk}") except Exception as e: logger.exception(f"Failed to update search_text for ride {ride.pk}: {e}") @@ -47,22 +48,16 @@ def handle_ride_status(sender, instance, **kwargs): if today >= instance.closing_date and instance.status == "CLOSING": target_status = instance.post_closing_status or "SBNO" - logger.info( - f"Ride {instance.pk} closing date reached, " - f"transitioning to {target_status}" - ) + logger.info(f"Ride {instance.pk} closing date reached, " f"transitioning to {target_status}") # Try to use FSM transition method if available - transition_method_name = f'transition_to_{target_status.lower()}' + transition_method_name = f"transition_to_{target_status.lower()}" if hasattr(instance, transition_method_name): # Check if transition is allowed before attempting - if hasattr(instance, 'can_proceed'): - can_proceed = getattr(instance, f'can_transition_to_{target_status.lower()}', None) + if hasattr(instance, "can_proceed"): + can_proceed = getattr(instance, f"can_transition_to_{target_status.lower()}", None) if can_proceed and callable(can_proceed) and not can_proceed(): - logger.warning( - f"FSM transition to {target_status} not allowed " - f"for ride {instance.pk}" - ) + logger.warning(f"FSM transition to {target_status} not allowed " f"for ride {instance.pk}") # Fall back to direct status change instance.status = target_status instance.status_since = instance.closing_date @@ -72,13 +67,9 @@ def handle_ride_status(sender, instance, **kwargs): method = getattr(instance, transition_method_name) method() instance.status_since = instance.closing_date - logger.info( - f"Applied FSM transition to {target_status} for ride {instance.pk}" - ) + logger.info(f"Applied FSM transition to {target_status} for ride {instance.pk}") except Exception as e: - logger.exception( - f"Failed to apply FSM transition for ride {instance.pk}: {e}" - ) + logger.exception(f"Failed to apply FSM transition for ride {instance.pk}: {e}") # Fall back to direct status change instance.status = target_status instance.status_since = instance.closing_date @@ -101,23 +92,20 @@ def validate_closing_status(sender, instance, **kwargs): if instance.status == "CLOSING": # Ensure post_closing_status is set if not instance.post_closing_status: - logger.warning( - f"Ride {instance.pk} entering CLOSING without post_closing_status set" - ) + logger.warning(f"Ride {instance.pk} entering CLOSING without post_closing_status set") # Default to SBNO if not set instance.post_closing_status = "SBNO" # Ensure closing_date is set if not instance.closing_date: - logger.warning( - f"Ride {instance.pk} entering CLOSING without closing_date set" - ) + logger.warning(f"Ride {instance.pk} entering CLOSING without closing_date set") # Default to today's date instance.closing_date = timezone.now().date() # FSM transition signal handlers + def handle_ride_transition_to_closing(instance, source, target, user, **kwargs): """ Validate transition to CLOSING status. @@ -134,20 +122,15 @@ def handle_ride_transition_to_closing(instance, source, target, user, **kwargs): Returns: True if transition should proceed, False to abort. """ - if target != 'CLOSING': + if target != "CLOSING": return True if not instance.post_closing_status: - logger.error( - f"Cannot transition ride {instance.pk} to CLOSING: " - "post_closing_status not set" - ) + logger.error(f"Cannot transition ride {instance.pk} to CLOSING: " "post_closing_status not set") return False if not instance.closing_date: - logger.warning( - f"Ride {instance.pk} transitioning to CLOSING without closing_date" - ) + logger.warning(f"Ride {instance.pk} transitioning to CLOSING without closing_date") return True @@ -166,45 +149,35 @@ def apply_post_closing_status(instance, user=None): Returns: True if status was applied, False otherwise. """ - if instance.status != 'CLOSING': - logger.debug( - f"Ride {instance.pk} not in CLOSING state, skipping" - ) + if instance.status != "CLOSING": + logger.debug(f"Ride {instance.pk} not in CLOSING state, skipping") return False target_status = instance.post_closing_status if not target_status: - logger.warning( - f"Ride {instance.pk} in CLOSING but no post_closing_status set" - ) + logger.warning(f"Ride {instance.pk} in CLOSING but no post_closing_status set") return False # Try to use FSM transition - transition_method_name = f'transition_to_{target_status.lower()}' + transition_method_name = f"transition_to_{target_status.lower()}" if hasattr(instance, transition_method_name): try: method = getattr(instance, transition_method_name) method(user=user) instance.post_closing_status = None - instance.save(update_fields=['post_closing_status']) - logger.info( - f"Applied post_closing_status {target_status} to ride {instance.pk}" - ) + instance.save(update_fields=["post_closing_status"]) + logger.info(f"Applied post_closing_status {target_status} to ride {instance.pk}") return True except Exception as e: - logger.exception( - f"Failed to apply post_closing_status for ride {instance.pk}: {e}" - ) + logger.exception(f"Failed to apply post_closing_status for ride {instance.pk}: {e}") return False else: # Direct status change instance.status = target_status instance.post_closing_status = None instance.status_since = timezone.now().date() - instance.save(update_fields=['status', 'post_closing_status', 'status_since']) - logger.info( - f"Applied post_closing_status {target_status} to ride {instance.pk} (direct)" - ) + instance.save(update_fields=["status", "post_closing_status", "status_since"]) + logger.info(f"Applied post_closing_status {target_status} to ride {instance.pk} (direct)") return True @@ -212,7 +185,8 @@ def apply_post_closing_status(instance, user=None): # Computed Field Maintenance Signal Handlers # ============================================================================= -@receiver(post_save, sender='parks.Park') + +@receiver(post_save, sender="parks.Park") def update_ride_search_text_on_park_change(sender, instance, **kwargs): """ Update ride search_text when park name or location changes. @@ -227,7 +201,7 @@ def update_ride_search_text_on_park_change(sender, instance, **kwargs): logger.exception(f"Failed to update ride search_text on park change: {e}") -@receiver(post_save, sender='parks.Company') +@receiver(post_save, sender="parks.Company") def update_ride_search_text_on_company_change(sender, instance, **kwargs): """ Update ride search_text when manufacturer/designer name changes. @@ -248,7 +222,7 @@ def update_ride_search_text_on_company_change(sender, instance, **kwargs): logger.exception(f"Failed to update ride search_text on company change: {e}") -@receiver(post_save, sender='rides.RideModel') +@receiver(post_save, sender="rides.RideModel") def update_ride_search_text_on_ride_model_change(sender, instance, **kwargs): """ Update ride search_text when ride model name changes. diff --git a/backend/apps/rides/tasks.py b/backend/apps/rides/tasks.py index f758a21c..24321575 100644 --- a/backend/apps/rides/tasks.py +++ b/backend/apps/rides/tasks.py @@ -36,9 +36,7 @@ def check_overdue_closings() -> dict: # Query rides that need transition today = timezone.now().date() - overdue_rides = Ride.objects.filter( - status="CLOSING", closing_date__lte=today - ).select_for_update() + overdue_rides = Ride.objects.filter(status="CLOSING", closing_date__lte=today).select_for_update() processed = 0 succeeded = 0 @@ -109,9 +107,7 @@ def _get_system_user(): logger.info("Created system user for automated tasks") except Exception as e: # If creation fails, try to get moderator or admin user - logger.warning( - "Failed to create system user, falling back to moderator: %s", str(e) - ) + logger.warning("Failed to create system user, falling back to moderator: %s", str(e)) try: system_user = User.objects.filter(is_staff=True).first() if not system_user: diff --git a/backend/apps/rides/tests.py b/backend/apps/rides/tests.py index d2857f72..94a6d59f 100644 --- a/backend/apps/rides/tests.py +++ b/backend/apps/rides/tests.py @@ -36,53 +36,40 @@ class RideTransitionTests(TestCase): def setUp(self): """Set up test fixtures.""" self.user = User.objects.create_user( - username='testuser', - email='test@example.com', - password='testpass123', - role='USER' + username="testuser", email="test@example.com", password="testpass123", role="USER" ) self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.admin = User.objects.create_user( - username='admin', - email='admin@example.com', - password='testpass123', - role='ADMIN' + username="admin", email="admin@example.com", password="testpass123", role="ADMIN" ) # Create operator and park self.operator = ParkCompany.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) self.park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) # Create manufacturer self.manufacturer = Company.objects.create( - name='Test Manufacturer', - description='Test manufacturer company', - roles=['MANUFACTURER'] + name="Test Manufacturer", description="Test manufacturer company", roles=["MANUFACTURER"] ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): """Helper to create a Ride with specified status.""" defaults = { - 'name': 'Test Ride', - 'slug': 'test-ride', - 'description': 'A test ride', - 'park': self.park, - 'manufacturer': self.manufacturer + "name": "Test Ride", + "slug": "test-ride", + "description": "A test ride", + "park": self.park, + "manufacturer": self.manufacturer, } defaults.update(kwargs) return Ride.objects.create(status=status, **defaults) @@ -93,34 +80,34 @@ class RideTransitionTests(TestCase): def test_operating_to_closed_temp_transition(self): """Test transition from OPERATING to CLOSED_TEMP.""" - ride = self._create_ride(status='OPERATING') - self.assertEqual(ride.status, 'OPERATING') + ride = self._create_ride(status="OPERATING") + self.assertEqual(ride.status, "OPERATING") ride.transition_to_closed_temp(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_TEMP') + self.assertEqual(ride.status, "CLOSED_TEMP") def test_operating_to_sbno_transition(self): """Test transition from OPERATING to SBNO.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride.transition_to_sbno(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") def test_operating_to_closing_transition(self): """Test transition from OPERATING to CLOSING.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride.transition_to_closing(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSING') + self.assertEqual(ride.status, "CLOSING") # ------------------------------------------------------------------------- # Under construction transitions @@ -128,14 +115,14 @@ class RideTransitionTests(TestCase): def test_under_construction_to_operating_transition(self): """Test transition from UNDER_CONSTRUCTION to OPERATING.""" - ride = self._create_ride(status='UNDER_CONSTRUCTION') - self.assertEqual(ride.status, 'UNDER_CONSTRUCTION') + ride = self._create_ride(status="UNDER_CONSTRUCTION") + self.assertEqual(ride.status, "UNDER_CONSTRUCTION") ride.transition_to_operating(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") # ------------------------------------------------------------------------- # Closed temp transitions @@ -143,34 +130,34 @@ class RideTransitionTests(TestCase): def test_closed_temp_to_operating_transition(self): """Test transition from CLOSED_TEMP to OPERATING (reopen).""" - ride = self._create_ride(status='CLOSED_TEMP') + ride = self._create_ride(status="CLOSED_TEMP") ride.transition_to_operating(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") def test_closed_temp_to_sbno_transition(self): """Test transition from CLOSED_TEMP to SBNO.""" - ride = self._create_ride(status='CLOSED_TEMP') + ride = self._create_ride(status="CLOSED_TEMP") ride.transition_to_sbno(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") def test_closed_temp_to_closed_perm_transition(self): """Test transition from CLOSED_TEMP to CLOSED_PERM.""" - ride = self._create_ride(status='CLOSED_TEMP') + ride = self._create_ride(status="CLOSED_TEMP") ride.transition_to_closed_perm(user=self.moderator) ride.closing_date = date.today() ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") # ------------------------------------------------------------------------- # SBNO transitions @@ -178,23 +165,23 @@ class RideTransitionTests(TestCase): def test_sbno_to_operating_transition(self): """Test transition from SBNO to OPERATING (revival).""" - ride = self._create_ride(status='SBNO') + ride = self._create_ride(status="SBNO") ride.transition_to_operating(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") def test_sbno_to_closed_perm_transition(self): """Test transition from SBNO to CLOSED_PERM.""" - ride = self._create_ride(status='SBNO') + ride = self._create_ride(status="SBNO") ride.transition_to_closed_perm(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") # ------------------------------------------------------------------------- # Closing transitions @@ -202,23 +189,23 @@ class RideTransitionTests(TestCase): def test_closing_to_closed_perm_transition(self): """Test transition from CLOSING to CLOSED_PERM.""" - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.transition_to_closed_perm(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") def test_closing_to_sbno_transition(self): """Test transition from CLOSING to SBNO.""" - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.transition_to_sbno(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") # ------------------------------------------------------------------------- # Closed perm transitions (to final states) @@ -226,23 +213,23 @@ class RideTransitionTests(TestCase): def test_closed_perm_to_demolished_transition(self): """Test transition from CLOSED_PERM to DEMOLISHED.""" - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") ride.transition_to_demolished(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'DEMOLISHED') + self.assertEqual(ride.status, "DEMOLISHED") def test_closed_perm_to_relocated_transition(self): """Test transition from CLOSED_PERM to RELOCATED.""" - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") ride.transition_to_relocated(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'RELOCATED') + self.assertEqual(ride.status, "RELOCATED") # ------------------------------------------------------------------------- # Invalid transitions (final states) @@ -250,28 +237,28 @@ class RideTransitionTests(TestCase): def test_demolished_cannot_transition(self): """Test that DEMOLISHED state cannot transition further.""" - ride = self._create_ride(status='DEMOLISHED') + ride = self._create_ride(status="DEMOLISHED") with self.assertRaises(TransitionNotAllowed): ride.transition_to_operating(user=self.moderator) def test_relocated_cannot_transition(self): """Test that RELOCATED state cannot transition further.""" - ride = self._create_ride(status='RELOCATED') + ride = self._create_ride(status="RELOCATED") with self.assertRaises(TransitionNotAllowed): ride.transition_to_operating(user=self.moderator) def test_operating_cannot_directly_demolish(self): """Test that OPERATING cannot directly transition to DEMOLISHED.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") with self.assertRaises(TransitionNotAllowed): ride.transition_to_demolished(user=self.moderator) def test_operating_cannot_directly_relocate(self): """Test that OPERATING cannot directly transition to RELOCATED.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") with self.assertRaises(TransitionNotAllowed): ride.transition_to_relocated(user=self.moderator) @@ -282,84 +269,76 @@ class RideTransitionTests(TestCase): def test_open_wrapper_method(self): """Test the open() wrapper method.""" - ride = self._create_ride(status='CLOSED_TEMP') + ride = self._create_ride(status="CLOSED_TEMP") ride.open(user=self.user) ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") def test_close_temporarily_wrapper_method(self): """Test the close_temporarily() wrapper method.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride.close_temporarily(user=self.user) ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_TEMP') + self.assertEqual(ride.status, "CLOSED_TEMP") def test_mark_sbno_wrapper_method(self): """Test the mark_sbno() wrapper method.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride.mark_sbno(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") def test_mark_closing_wrapper_method(self): """Test the mark_closing() wrapper method.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") closing = date(2025, 12, 31) - ride.mark_closing( - closing_date=closing, - post_closing_status='DEMOLISHED', - user=self.moderator - ) + ride.mark_closing(closing_date=closing, post_closing_status="DEMOLISHED", user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSING') + self.assertEqual(ride.status, "CLOSING") self.assertEqual(ride.closing_date, closing) - self.assertEqual(ride.post_closing_status, 'DEMOLISHED') + self.assertEqual(ride.post_closing_status, "DEMOLISHED") def test_mark_closing_requires_post_closing_status(self): """Test that mark_closing() requires post_closing_status.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") with self.assertRaises(ValidationError): - ride.mark_closing( - closing_date=date(2025, 12, 31), - post_closing_status='', - user=self.moderator - ) + ride.mark_closing(closing_date=date(2025, 12, 31), post_closing_status="", user=self.moderator) def test_close_permanently_wrapper_method(self): """Test the close_permanently() wrapper method.""" - ride = self._create_ride(status='SBNO') + ride = self._create_ride(status="SBNO") ride.close_permanently(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") def test_demolish_wrapper_method(self): """Test the demolish() wrapper method.""" - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") ride.demolish(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'DEMOLISHED') + self.assertEqual(ride.status, "DEMOLISHED") def test_relocate_wrapper_method(self): """Test the relocate() wrapper method.""" - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") ride.relocate(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'RELOCATED') + self.assertEqual(ride.status, "RELOCATED") # ============================================================================ @@ -373,37 +352,30 @@ class RidePostClosingTests(TestCase): def setUp(self): """Set up test fixtures.""" self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.operator = ParkCompany.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) self.park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) self.manufacturer = Company.objects.create( - name='Test Manufacturer', - description='Test manufacturer company', - roles=['MANUFACTURER'] + name="Test Manufacturer", description="Test manufacturer company", roles=["MANUFACTURER"] ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): """Helper to create a Ride.""" defaults = { - 'name': 'Test Ride', - 'slug': 'test-ride', - 'description': 'A test ride', - 'park': self.park, - 'manufacturer': self.manufacturer + "name": "Test Ride", + "slug": "test-ride", + "description": "A test ride", + "park": self.park, + "manufacturer": self.manufacturer, } defaults.update(kwargs) return Ride.objects.create(status=status, **defaults) @@ -411,111 +383,85 @@ class RidePostClosingTests(TestCase): def test_apply_post_closing_status_to_demolished(self): """Test apply_post_closing_status transitions to DEMOLISHED.""" yesterday = date.today() - timedelta(days=1) - ride = self._create_ride( - status='CLOSING', - closing_date=yesterday, - post_closing_status='DEMOLISHED' - ) + ride = self._create_ride(status="CLOSING", closing_date=yesterday, post_closing_status="DEMOLISHED") ride.apply_post_closing_status(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'DEMOLISHED') + self.assertEqual(ride.status, "DEMOLISHED") def test_apply_post_closing_status_to_relocated(self): """Test apply_post_closing_status transitions to RELOCATED.""" yesterday = date.today() - timedelta(days=1) - ride = self._create_ride( - status='CLOSING', - closing_date=yesterday, - post_closing_status='RELOCATED' - ) + ride = self._create_ride(status="CLOSING", closing_date=yesterday, post_closing_status="RELOCATED") ride.apply_post_closing_status(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'RELOCATED') + self.assertEqual(ride.status, "RELOCATED") def test_apply_post_closing_status_to_sbno(self): """Test apply_post_closing_status transitions to SBNO.""" yesterday = date.today() - timedelta(days=1) - ride = self._create_ride( - status='CLOSING', - closing_date=yesterday, - post_closing_status='SBNO' - ) + ride = self._create_ride(status="CLOSING", closing_date=yesterday, post_closing_status="SBNO") ride.apply_post_closing_status(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") def test_apply_post_closing_status_to_closed_perm(self): """Test apply_post_closing_status transitions to CLOSED_PERM.""" yesterday = date.today() - timedelta(days=1) - ride = self._create_ride( - status='CLOSING', - closing_date=yesterday, - post_closing_status='CLOSED_PERM' - ) + ride = self._create_ride(status="CLOSING", closing_date=yesterday, post_closing_status="CLOSED_PERM") ride.apply_post_closing_status(user=self.moderator) ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") def test_apply_post_closing_status_not_yet_reached(self): """Test apply_post_closing_status does nothing if date not reached.""" tomorrow = date.today() + timedelta(days=1) - ride = self._create_ride( - status='CLOSING', - closing_date=tomorrow, - post_closing_status='DEMOLISHED' - ) + ride = self._create_ride(status="CLOSING", closing_date=tomorrow, post_closing_status="DEMOLISHED") ride.apply_post_closing_status(user=self.moderator) ride.refresh_from_db() # Status should remain CLOSING since date hasn't been reached - self.assertEqual(ride.status, 'CLOSING') + self.assertEqual(ride.status, "CLOSING") def test_apply_post_closing_status_requires_closing_status(self): """Test apply_post_closing_status requires CLOSING status.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") with self.assertRaises(ValidationError) as ctx: ride.apply_post_closing_status(user=self.moderator) - self.assertIn('CLOSING', str(ctx.exception)) + self.assertIn("CLOSING", str(ctx.exception)) def test_apply_post_closing_status_requires_closing_date(self): """Test apply_post_closing_status requires closing_date.""" - ride = self._create_ride( - status='CLOSING', - post_closing_status='DEMOLISHED' - ) + ride = self._create_ride(status="CLOSING", post_closing_status="DEMOLISHED") ride.closing_date = None ride.save() with self.assertRaises(ValidationError) as ctx: ride.apply_post_closing_status(user=self.moderator) - self.assertIn('closing_date', str(ctx.exception)) + self.assertIn("closing_date", str(ctx.exception)) def test_apply_post_closing_status_requires_post_closing_status(self): """Test apply_post_closing_status requires post_closing_status.""" yesterday = date.today() - timedelta(days=1) - ride = self._create_ride( - status='CLOSING', - closing_date=yesterday - ) + ride = self._create_ride(status="CLOSING", closing_date=yesterday) ride.post_closing_status = None ride.save() with self.assertRaises(ValidationError) as ctx: ride.apply_post_closing_status(user=self.moderator) - self.assertIn('post_closing_status', str(ctx.exception)) + self.assertIn("post_closing_status", str(ctx.exception)) # ============================================================================ @@ -529,64 +475,54 @@ class RideTransitionHistoryTests(TestCase): def setUp(self): """Set up test fixtures.""" self.moderator = User.objects.create_user( - username='moderator', - email='moderator@example.com', - password='testpass123', - role='MODERATOR' + username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR" ) self.operator = ParkCompany.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) self.park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) self.manufacturer = Company.objects.create( - name='Test Manufacturer', - description='Test manufacturer company', - roles=['MANUFACTURER'] + name="Test Manufacturer", description="Test manufacturer company", roles=["MANUFACTURER"] ) - def _create_ride(self, status='OPERATING'): + def _create_ride(self, status="OPERATING"): """Helper to create a Ride.""" return Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park, manufacturer=self.manufacturer, - status=status + status=status, ) def test_transition_creates_state_log(self): """Test that transitions create StateLog entries.""" from django_fsm_log.models import StateLog - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride.transition_to_closed_temp(user=self.moderator) ride.save() ride_ct = ContentType.objects.get_for_model(ride) - log = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).first() + log = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).first() self.assertIsNotNone(log) - self.assertEqual(log.state, 'CLOSED_TEMP') + self.assertEqual(log.state, "CLOSED_TEMP") self.assertEqual(log.by, self.moderator) def test_multiple_transitions_create_multiple_logs(self): """Test that multiple transitions create multiple log entries.""" from django_fsm_log.models import StateLog - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride_ct = ContentType.objects.get_for_model(ride) # First transition @@ -597,29 +533,23 @@ class RideTransitionHistoryTests(TestCase): ride.transition_to_operating(user=self.moderator) ride.save() - logs = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).order_by("timestamp") self.assertEqual(logs.count(), 2) - self.assertEqual(logs[0].state, 'CLOSED_TEMP') - self.assertEqual(logs[1].state, 'OPERATING') + self.assertEqual(logs[0].state, "CLOSED_TEMP") + self.assertEqual(logs[1].state, "OPERATING") def test_transition_log_includes_user(self): """Test that transition logs include the user who made the change.""" from django_fsm_log.models import StateLog - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride.transition_to_sbno(user=self.moderator) ride.save() ride_ct = ContentType.objects.get_for_model(ride) - log = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).first() + log = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).first() self.assertEqual(log.by, self.moderator) @@ -628,19 +558,15 @@ class RideTransitionHistoryTests(TestCase): from django_fsm_log.models import StateLog yesterday = date.today() - timedelta(days=1) - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.closing_date = yesterday - ride.post_closing_status = 'DEMOLISHED' + ride.post_closing_status = "DEMOLISHED" ride.save() ride.apply_post_closing_status(user=self.moderator) ride_ct = ContentType.objects.get_for_model(ride) - log = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id, - state='DEMOLISHED' - ).first() + log = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id, state="DEMOLISHED").first() self.assertIsNotNone(log) self.assertEqual(log.by, self.moderator) @@ -657,31 +583,27 @@ class RideBusinessLogicTests(TestCase): def setUp(self): """Set up test fixtures.""" self.operator = ParkCompany.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) self.park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) self.manufacturer = Company.objects.create( - name='Test Manufacturer', - description='Test manufacturer company', - roles=['MANUFACTURER'] + name="Test Manufacturer", description="Test manufacturer company", roles=["MANUFACTURER"] ) def test_ride_creates_with_valid_park(self): """Test ride can be created with valid park.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) self.assertEqual(ride.park, self.park) @@ -689,36 +611,33 @@ class RideBusinessLogicTests(TestCase): def test_ride_slug_auto_generated(self): """Test that ride slug is auto-generated from name.""" ride = Ride.objects.create( - name='My Amazing Roller Coaster', - description='A test ride', - park=self.park, - manufacturer=self.manufacturer + name="My Amazing Roller Coaster", description="A test ride", park=self.park, manufacturer=self.manufacturer ) - self.assertEqual(ride.slug, 'my-amazing-roller-coaster') + self.assertEqual(ride.slug, "my-amazing-roller-coaster") def test_ride_url_generated(self): """Test that frontend URL is generated on save.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) - self.assertIn('test-park', ride.url) - self.assertIn('test-ride', ride.url) + self.assertIn("test-park", ride.url) + self.assertIn("test-ride", ride.url) def test_opening_year_computed_from_opening_date(self): """Test that opening_year is computed from opening_date.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park, manufacturer=self.manufacturer, - opening_date=date(2020, 6, 15) + opening_date=date(2020, 6, 15), ) self.assertEqual(ride.opening_year, 2020) @@ -726,38 +645,31 @@ class RideBusinessLogicTests(TestCase): def test_search_text_populated(self): """Test that search_text is populated on save.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A thrilling roller coaster', + name="Test Ride", + slug="test-ride", + description="A thrilling roller coaster", park=self.park, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) - self.assertIn('test ride', ride.search_text) - self.assertIn('thrilling roller coaster', ride.search_text) - self.assertIn('test park', ride.search_text) - self.assertIn('test manufacturer', ride.search_text) + self.assertIn("test ride", ride.search_text) + self.assertIn("thrilling roller coaster", ride.search_text) + self.assertIn("test park", ride.search_text) + self.assertIn("test manufacturer", ride.search_text) def test_ride_slug_unique_within_park(self): """Test that ride slugs are unique within a park.""" Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='First ride', - park=self.park, - manufacturer=self.manufacturer + name="Test Ride", slug="test-ride", description="First ride", park=self.park, manufacturer=self.manufacturer ) # Creating another ride with same name should get different slug ride2 = Ride.objects.create( - name='Test Ride', - description='Second ride', - park=self.park, - manufacturer=self.manufacturer + name="Test Ride", description="Second ride", park=self.park, manufacturer=self.manufacturer ) - self.assertNotEqual(ride2.slug, 'test-ride') - self.assertTrue(ride2.slug.startswith('test-ride')) + self.assertNotEqual(ride2.slug, "test-ride") + self.assertTrue(ride2.slug.startswith("test-ride")) # ============================================================================ @@ -771,55 +683,51 @@ class RideMoveTests(TestCase): def setUp(self): """Set up test fixtures.""" self.operator = ParkCompany.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) self.park1 = Park.objects.create( - name='Park One', - slug='park-one', - description='First park', + name="Park One", + slug="park-one", + description="First park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) self.park2 = Park.objects.create( - name='Park Two', - slug='park-two', - description='Second park', + name="Park Two", + slug="park-two", + description="Second park", operator=self.operator, - timezone='America/Los_Angeles' + timezone="America/Los_Angeles", ) self.manufacturer = Company.objects.create( - name='Test Manufacturer', - description='Test manufacturer company', - roles=['MANUFACTURER'] + name="Test Manufacturer", description="Test manufacturer company", roles=["MANUFACTURER"] ) def test_move_ride_to_different_park(self): """Test moving a ride to a different park.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park1, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) changes = ride.move_to_park(self.park2) ride.refresh_from_db() self.assertEqual(ride.park, self.park2) - self.assertEqual(changes['old_park']['id'], self.park1.id) - self.assertEqual(changes['new_park']['id'], self.park2.id) + self.assertEqual(changes["old_park"]["id"], self.park1.id) + self.assertEqual(changes["new_park"]["id"], self.park2.id) def test_move_ride_updates_url(self): """Test that moving a ride updates the URL.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park1, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) old_url = ride.url @@ -827,27 +735,27 @@ class RideMoveTests(TestCase): ride.refresh_from_db() self.assertNotEqual(ride.url, old_url) - self.assertIn('park-two', ride.url) - self.assertTrue(changes['url_changed']) + self.assertIn("park-two", ride.url) + self.assertTrue(changes["url_changed"]) def test_move_ride_handles_slug_conflict(self): """Test that moving a ride handles slug conflicts in destination park.""" # Create ride in park1 ride1 = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park1, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) # Create ride with same slug in park2 Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='Another test ride', + name="Test Ride", + slug="test-ride", + description="Another test ride", park=self.park2, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) # Move ride1 to park2 @@ -856,8 +764,8 @@ class RideMoveTests(TestCase): ride1.refresh_from_db() self.assertEqual(ride1.park, self.park2) # Slug should have been modified to avoid conflict - self.assertNotEqual(ride1.slug, 'test-ride') - self.assertTrue(changes['slug_changed']) + self.assertNotEqual(ride1.slug, "test-ride") + self.assertTrue(changes["slug_changed"]) # ============================================================================ @@ -871,34 +779,30 @@ class RideSlugHistoryTests(TestCase): def setUp(self): """Set up test fixtures.""" self.operator = ParkCompany.objects.create( - name='Test Operator', - description='Test operator company', - roles=['OPERATOR'] + name="Test Operator", description="Test operator company", roles=["OPERATOR"] ) self.park = Park.objects.create( - name='Test Park', - slug='test-park', - description='A test park', + name="Test Park", + slug="test-park", + description="A test park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) self.manufacturer = Company.objects.create( - name='Test Manufacturer', - description='Test manufacturer company', - roles=['MANUFACTURER'] + name="Test Manufacturer", description="Test manufacturer company", roles=["MANUFACTURER"] ) def test_get_by_slug_finds_current_slug(self): """Test get_by_slug finds ride by current slug.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) - found_ride, is_historical = Ride.get_by_slug('test-ride', park=self.park) + found_ride, is_historical = Ride.get_by_slug("test-ride", park=self.park) self.assertEqual(found_ride, ride) self.assertFalse(is_historical) @@ -906,24 +810,24 @@ class RideSlugHistoryTests(TestCase): def test_get_by_slug_with_park_filter(self): """Test get_by_slug filters by park.""" ride = Ride.objects.create( - name='Test Ride', - slug='test-ride', - description='A test ride', + name="Test Ride", + slug="test-ride", + description="A test ride", park=self.park, - manufacturer=self.manufacturer + manufacturer=self.manufacturer, ) # Should find ride in correct park - found_ride, is_historical = Ride.get_by_slug('test-ride', park=self.park) + found_ride, is_historical = Ride.get_by_slug("test-ride", park=self.park) self.assertEqual(found_ride, ride) # Should not find ride in different park other_park = Park.objects.create( - name='Other Park', - slug='other-park', - description='Another park', + name="Other Park", + slug="other-park", + description="Another park", operator=self.operator, - timezone='America/New_York' + timezone="America/New_York", ) with self.assertRaises(Ride.DoesNotExist): - Ride.get_by_slug('test-ride', park=other_park) + Ride.get_by_slug("test-ride", park=other_park) diff --git a/backend/apps/rides/tests/test_ride_workflows.py b/backend/apps/rides/tests/test_ride_workflows.py index b6024fba..dc2f9648 100644 --- a/backend/apps/rides/tests/test_ride_workflows.py +++ b/backend/apps/rides/tests/test_ride_workflows.py @@ -25,42 +25,33 @@ class RideOpeningWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='ride_user', - email='ride_user@example.com', - password='testpass123', - role='USER' + username="ride_user", email="ride_user@example.com", password="testpass123", role="USER" ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): """Helper to create a ride with park.""" from apps.parks.models import Company, Park from apps.rides.models import Ride # Create manufacturer - manufacturer = Company.objects.create( - name=f'Manufacturer {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) + manufacturer = Company.objects.create(name=f"Manufacturer {timezone.now().timestamp()}", roles=["MANUFACTURER"]) # Create park with operator - operator = Company.objects.create( - name=f'Operator {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + operator = Company.objects.create(name=f"Operator {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Test Park {timezone.now().timestamp()}', - slug=f'test-park-{timezone.now().timestamp()}', + name=f"Test Park {timezone.now().timestamp()}", + slug=f"test-park-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Test Ride {timezone.now().timestamp()}', - 'slug': f'test-ride-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Test Ride {timezone.now().timestamp()}", + "slug": f"test-ride-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -71,16 +62,16 @@ class RideOpeningWorkflowTests(TestCase): Flow: UNDER_CONSTRUCTION → OPERATING """ - ride = self._create_ride(status='UNDER_CONSTRUCTION') + ride = self._create_ride(status="UNDER_CONSTRUCTION") - self.assertEqual(ride.status, 'UNDER_CONSTRUCTION') + self.assertEqual(ride.status, "UNDER_CONSTRUCTION") # Ride opens ride.transition_to_operating(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") class RideMaintenanceWorkflowTests(TestCase): @@ -89,38 +80,29 @@ class RideMaintenanceWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='maint_user', - email='maint@example.com', - password='testpass123', - role='USER' + username="maint_user", email="maint@example.com", password="testpass123", role="USER" ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr Maint {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op Maint {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr Maint {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op Maint {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park Maint {timezone.now().timestamp()}', - slug=f'park-maint-{timezone.now().timestamp()}', + name=f"Park Maint {timezone.now().timestamp()}", + slug=f"park-maint-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride Maint {timezone.now().timestamp()}', - 'slug': f'ride-maint-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride Maint {timezone.now().timestamp()}", + "slug": f"ride-maint-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -131,21 +113,21 @@ class RideMaintenanceWorkflowTests(TestCase): Flow: OPERATING → CLOSED_TEMP → OPERATING """ - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") # Close for maintenance ride.transition_to_closed_temp(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_TEMP') + self.assertEqual(ride.status, "CLOSED_TEMP") # Reopen after maintenance ride.transition_to_operating(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") class RideSBNOWorkflowTests(TestCase): @@ -154,38 +136,29 @@ class RideSBNOWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='sbno_mod', - email='sbno_mod@example.com', - password='testpass123', - role='MODERATOR' + username="sbno_mod", email="sbno_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr SBNO {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op SBNO {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr SBNO {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op SBNO {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park SBNO {timezone.now().timestamp()}', - slug=f'park-sbno-{timezone.now().timestamp()}', + name=f"Park SBNO {timezone.now().timestamp()}", + slug=f"park-sbno-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride SBNO {timezone.now().timestamp()}', - 'slug': f'ride-sbno-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride SBNO {timezone.now().timestamp()}", + "slug": f"ride-sbno-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -196,14 +169,14 @@ class RideSBNOWorkflowTests(TestCase): Flow: OPERATING → SBNO """ - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") # Mark as SBNO ride.transition_to_sbno(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") def test_ride_sbno_from_closed_temp(self): """ @@ -211,14 +184,14 @@ class RideSBNOWorkflowTests(TestCase): Flow: OPERATING → CLOSED_TEMP → SBNO """ - ride = self._create_ride(status='CLOSED_TEMP') + ride = self._create_ride(status="CLOSED_TEMP") # Extended to SBNO ride.transition_to_sbno(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") def test_ride_revival_from_sbno(self): """ @@ -226,14 +199,14 @@ class RideSBNOWorkflowTests(TestCase): Flow: SBNO → OPERATING """ - ride = self._create_ride(status='SBNO') + ride = self._create_ride(status="SBNO") # Revive the ride ride.transition_to_operating(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") def test_sbno_to_closed_perm(self): """ @@ -241,14 +214,14 @@ class RideSBNOWorkflowTests(TestCase): Flow: SBNO → CLOSED_PERM """ - ride = self._create_ride(status='SBNO') + ride = self._create_ride(status="SBNO") # Confirm permanent closure ride.transition_to_closed_perm(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") class RideScheduledClosureWorkflowTests(TestCase): @@ -257,38 +230,29 @@ class RideScheduledClosureWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='closing_mod', - email='closing_mod@example.com', - password='testpass123', - role='MODERATOR' + username="closing_mod", email="closing_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr Closing {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op Closing {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr Closing {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op Closing {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park Closing {timezone.now().timestamp()}', - slug=f'park-closing-{timezone.now().timestamp()}', + name=f"Park Closing {timezone.now().timestamp()}", + slug=f"park-closing-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride Closing {timezone.now().timestamp()}', - 'slug': f'ride-closing-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride Closing {timezone.now().timestamp()}", + "slug": f"ride-closing-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -299,19 +263,19 @@ class RideScheduledClosureWorkflowTests(TestCase): Flow: OPERATING → CLOSING (with closing_date and post_closing_status) """ - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") closing_date = (timezone.now() + timedelta(days=30)).date() # Mark as closing ride.transition_to_closing(user=self.moderator) ride.closing_date = closing_date - ride.post_closing_status = 'DEMOLISHED' + ride.post_closing_status = "DEMOLISHED" ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSING') + self.assertEqual(ride.status, "CLOSING") self.assertEqual(ride.closing_date, closing_date) - self.assertEqual(ride.post_closing_status, 'DEMOLISHED') + self.assertEqual(ride.post_closing_status, "DEMOLISHED") def test_closing_to_closed_perm(self): """ @@ -319,9 +283,9 @@ class RideScheduledClosureWorkflowTests(TestCase): Flow: CLOSING → CLOSED_PERM """ - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.closing_date = timezone.now().date() - ride.post_closing_status = 'CLOSED_PERM' + ride.post_closing_status = "CLOSED_PERM" ride.save() # Transition when closing date reached @@ -329,7 +293,7 @@ class RideScheduledClosureWorkflowTests(TestCase): ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") def test_closing_to_sbno(self): """ @@ -337,9 +301,9 @@ class RideScheduledClosureWorkflowTests(TestCase): Flow: CLOSING → SBNO """ - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.closing_date = timezone.now().date() - ride.post_closing_status = 'SBNO' + ride.post_closing_status = "SBNO" ride.save() # Transition to SBNO @@ -347,7 +311,7 @@ class RideScheduledClosureWorkflowTests(TestCase): ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") class RideDemolitionWorkflowTests(TestCase): @@ -356,38 +320,29 @@ class RideDemolitionWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='demo_ride_mod', - email='demo_ride_mod@example.com', - password='testpass123', - role='MODERATOR' + username="demo_ride_mod", email="demo_ride_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_ride(self, status='CLOSED_PERM', **kwargs): + def _create_ride(self, status="CLOSED_PERM", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr Demo {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op Demo {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr Demo {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op Demo {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park Demo {timezone.now().timestamp()}', - slug=f'park-demo-{timezone.now().timestamp()}', + name=f"Park Demo {timezone.now().timestamp()}", + slug=f"park-demo-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride Demo {timezone.now().timestamp()}', - 'slug': f'ride-demo-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride Demo {timezone.now().timestamp()}", + "slug": f"ride-demo-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -398,20 +353,20 @@ class RideDemolitionWorkflowTests(TestCase): Flow: CLOSED_PERM → DEMOLISHED """ - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") # Demolish ride.transition_to_demolished(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'DEMOLISHED') + self.assertEqual(ride.status, "DEMOLISHED") def test_demolished_is_final_state(self): """Test that demolished rides cannot transition further.""" from django_fsm import TransitionNotAllowed - ride = self._create_ride(status='DEMOLISHED') + ride = self._create_ride(status="DEMOLISHED") # Cannot transition from demolished with self.assertRaises(TransitionNotAllowed): @@ -424,38 +379,29 @@ class RideRelocationWorkflowTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='reloc_ride_mod', - email='reloc_ride_mod@example.com', - password='testpass123', - role='MODERATOR' + username="reloc_ride_mod", email="reloc_ride_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_ride(self, status='CLOSED_PERM', **kwargs): + def _create_ride(self, status="CLOSED_PERM", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr Reloc {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op Reloc {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr Reloc {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op Reloc {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park Reloc {timezone.now().timestamp()}', - slug=f'park-reloc-{timezone.now().timestamp()}', + name=f"Park Reloc {timezone.now().timestamp()}", + slug=f"park-reloc-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride Reloc {timezone.now().timestamp()}', - 'slug': f'ride-reloc-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride Reloc {timezone.now().timestamp()}", + "slug": f"ride-reloc-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -466,20 +412,20 @@ class RideRelocationWorkflowTests(TestCase): Flow: CLOSED_PERM → RELOCATED """ - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") # Relocate ride.transition_to_relocated(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'RELOCATED') + self.assertEqual(ride.status, "RELOCATED") def test_relocated_is_final_state(self): """Test that relocated rides cannot transition further.""" from django_fsm import TransitionNotAllowed - ride = self._create_ride(status='RELOCATED') + ride = self._create_ride(status="RELOCATED") # Cannot transition from relocated with self.assertRaises(TransitionNotAllowed): @@ -492,145 +438,129 @@ class RideWrapperMethodTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='wrapper_ride_user', - email='wrapper_ride@example.com', - password='testpass123', - role='USER' + username="wrapper_ride_user", email="wrapper_ride@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='wrapper_ride_mod', - email='wrapper_ride_mod@example.com', - password='testpass123', - role='MODERATOR' + username="wrapper_ride_mod", email="wrapper_ride_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr Wrapper {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op Wrapper {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr Wrapper {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op Wrapper {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park Wrapper {timezone.now().timestamp()}', - slug=f'park-wrapper-{timezone.now().timestamp()}', + name=f"Park Wrapper {timezone.now().timestamp()}", + slug=f"park-wrapper-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride Wrapper {timezone.now().timestamp()}', - 'slug': f'ride-wrapper-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride Wrapper {timezone.now().timestamp()}", + "slug": f"ride-wrapper-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) def test_close_temporarily_wrapper(self): """Test close_temporarily wrapper method.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") - if hasattr(ride, 'close_temporarily'): + if hasattr(ride, "close_temporarily"): ride.close_temporarily(user=self.user) else: ride.transition_to_closed_temp(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_TEMP') + self.assertEqual(ride.status, "CLOSED_TEMP") def test_mark_sbno_wrapper(self): """Test mark_sbno wrapper method.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") - if hasattr(ride, 'mark_sbno'): + if hasattr(ride, "mark_sbno"): ride.mark_sbno(user=self.moderator) else: ride.transition_to_sbno(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") def test_mark_closing_wrapper(self): """Test mark_closing wrapper method.""" - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") closing_date = (timezone.now() + timedelta(days=30)).date() - if hasattr(ride, 'mark_closing'): - ride.mark_closing( - closing_date=closing_date, - post_closing_status='DEMOLISHED', - user=self.moderator - ) + if hasattr(ride, "mark_closing"): + ride.mark_closing(closing_date=closing_date, post_closing_status="DEMOLISHED", user=self.moderator) else: ride.transition_to_closing(user=self.moderator) ride.closing_date = closing_date - ride.post_closing_status = 'DEMOLISHED' + ride.post_closing_status = "DEMOLISHED" ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSING') + self.assertEqual(ride.status, "CLOSING") def test_open_wrapper(self): """Test open wrapper method.""" - ride = self._create_ride(status='CLOSED_TEMP') + ride = self._create_ride(status="CLOSED_TEMP") - if hasattr(ride, 'open'): + if hasattr(ride, "open"): ride.open(user=self.user) else: ride.transition_to_operating(user=self.user) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'OPERATING') + self.assertEqual(ride.status, "OPERATING") def test_close_permanently_wrapper(self): """Test close_permanently wrapper method.""" - ride = self._create_ride(status='SBNO') + ride = self._create_ride(status="SBNO") - if hasattr(ride, 'close_permanently'): + if hasattr(ride, "close_permanently"): ride.close_permanently(user=self.moderator) else: ride.transition_to_closed_perm(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'CLOSED_PERM') + self.assertEqual(ride.status, "CLOSED_PERM") def test_demolish_wrapper(self): """Test demolish wrapper method.""" - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") - if hasattr(ride, 'demolish'): + if hasattr(ride, "demolish"): ride.demolish(user=self.moderator) else: ride.transition_to_demolished(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'DEMOLISHED') + self.assertEqual(ride.status, "DEMOLISHED") def test_relocate_wrapper(self): """Test relocate wrapper method.""" - ride = self._create_ride(status='CLOSED_PERM') + ride = self._create_ride(status="CLOSED_PERM") - if hasattr(ride, 'relocate'): + if hasattr(ride, "relocate"): ride.relocate(user=self.moderator) else: ride.transition_to_relocated(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'RELOCATED') + self.assertEqual(ride.status, "RELOCATED") class RidePostClosingStatusAutomationTests(TestCase): @@ -639,90 +569,81 @@ class RidePostClosingStatusAutomationTests(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username='auto_mod', - email='auto_mod@example.com', - password='testpass123', - role='MODERATOR' + username="auto_mod", email="auto_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_ride(self, status='CLOSING', **kwargs): + def _create_ride(self, status="CLOSING", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr Auto {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op Auto {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr Auto {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op Auto {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park Auto {timezone.now().timestamp()}', - slug=f'park-auto-{timezone.now().timestamp()}', + name=f"Park Auto {timezone.now().timestamp()}", + slug=f"park-auto-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride Auto {timezone.now().timestamp()}', - 'slug': f'ride-auto-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride Auto {timezone.now().timestamp()}", + "slug": f"ride-auto-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) def test_apply_post_closing_status_demolished(self): """Test apply_post_closing_status transitions to DEMOLISHED.""" - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.closing_date = timezone.now().date() - ride.post_closing_status = 'DEMOLISHED' + ride.post_closing_status = "DEMOLISHED" ride.save() # Apply post-closing status if method exists - if hasattr(ride, 'apply_post_closing_status'): + if hasattr(ride, "apply_post_closing_status"): ride.apply_post_closing_status(user=self.moderator) else: ride.transition_to_demolished(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'DEMOLISHED') + self.assertEqual(ride.status, "DEMOLISHED") def test_apply_post_closing_status_relocated(self): """Test apply_post_closing_status transitions to RELOCATED.""" - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.closing_date = timezone.now().date() - ride.post_closing_status = 'RELOCATED' + ride.post_closing_status = "RELOCATED" ride.save() - if hasattr(ride, 'apply_post_closing_status'): + if hasattr(ride, "apply_post_closing_status"): ride.apply_post_closing_status(user=self.moderator) else: ride.transition_to_relocated(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'RELOCATED') + self.assertEqual(ride.status, "RELOCATED") def test_apply_post_closing_status_sbno(self): """Test apply_post_closing_status transitions to SBNO.""" - ride = self._create_ride(status='CLOSING') + ride = self._create_ride(status="CLOSING") ride.closing_date = timezone.now().date() - ride.post_closing_status = 'SBNO' + ride.post_closing_status = "SBNO" ride.save() - if hasattr(ride, 'apply_post_closing_status'): + if hasattr(ride, "apply_post_closing_status"): ride.apply_post_closing_status(user=self.moderator) else: ride.transition_to_sbno(user=self.moderator) ride.save() ride.refresh_from_db() - self.assertEqual(ride.status, 'SBNO') + self.assertEqual(ride.status, "SBNO") class RideStateLogTests(TestCase): @@ -731,44 +652,32 @@ class RideStateLogTests(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username='ride_log_user', - email='ride_log_user@example.com', - password='testpass123', - role='USER' + username="ride_log_user", email="ride_log_user@example.com", password="testpass123", role="USER" ) cls.moderator = User.objects.create_user( - username='ride_log_mod', - email='ride_log_mod@example.com', - password='testpass123', - role='MODERATOR' + username="ride_log_mod", email="ride_log_mod@example.com", password="testpass123", role="MODERATOR" ) - def _create_ride(self, status='OPERATING', **kwargs): + def _create_ride(self, status="OPERATING", **kwargs): from apps.parks.models import Company, Park from apps.rides.models import Ride - manufacturer = Company.objects.create( - name=f'Mfr Log {timezone.now().timestamp()}', - roles=['MANUFACTURER'] - ) - operator = Company.objects.create( - name=f'Op Log {timezone.now().timestamp()}', - roles=['OPERATOR'] - ) + manufacturer = Company.objects.create(name=f"Mfr Log {timezone.now().timestamp()}", roles=["MANUFACTURER"]) + operator = Company.objects.create(name=f"Op Log {timezone.now().timestamp()}", roles=["OPERATOR"]) park = Park.objects.create( - name=f'Park Log {timezone.now().timestamp()}', - slug=f'park-log-{timezone.now().timestamp()}', + name=f"Park Log {timezone.now().timestamp()}", + slug=f"park-log-{timezone.now().timestamp()}", operator=operator, - status='OPERATING', - timezone='America/New_York' + status="OPERATING", + timezone="America/New_York", ) defaults = { - 'name': f'Ride Log {timezone.now().timestamp()}', - 'slug': f'ride-log-{timezone.now().timestamp()}', - 'park': park, - 'manufacturer': manufacturer, - 'status': status + "name": f"Ride Log {timezone.now().timestamp()}", + "slug": f"ride-log-{timezone.now().timestamp()}", + "park": park, + "manufacturer": manufacturer, + "status": status, } defaults.update(kwargs) return Ride.objects.create(**defaults) @@ -778,7 +687,7 @@ class RideStateLogTests(TestCase): from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride_ct = ContentType.objects.get_for_model(ride) # Perform transition @@ -786,13 +695,10 @@ class RideStateLogTests(TestCase): ride.save() # Check log was created - log = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).first() + log = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).first() self.assertIsNotNone(log, "StateLog entry should be created") - self.assertEqual(log.state, 'CLOSED_TEMP') + self.assertEqual(log.state, "CLOSED_TEMP") self.assertEqual(log.by, self.user) def test_multiple_transitions_logged(self): @@ -800,7 +706,7 @@ class RideStateLogTests(TestCase): from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride_ct = ContentType.objects.get_for_model(ride) # First transition: OPERATING -> SBNO @@ -812,21 +718,18 @@ class RideStateLogTests(TestCase): ride.save() # Check multiple logs created - logs = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).order_by("timestamp") self.assertEqual(logs.count(), 2, "Should have 2 log entries") - self.assertEqual(logs[0].state, 'SBNO') - self.assertEqual(logs[1].state, 'OPERATING') + self.assertEqual(logs[0].state, "SBNO") + self.assertEqual(logs[1].state, "OPERATING") def test_sbno_revival_workflow_logged(self): """Test that SBNO revival workflow is logged.""" from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - ride = self._create_ride(status='SBNO') + ride = self._create_ride(status="SBNO") ride_ct = ContentType.objects.get_for_model(ride) # Revival: SBNO -> OPERATING @@ -834,13 +737,10 @@ class RideStateLogTests(TestCase): ride.save() # Check log was created - log = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).first() + log = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).first() self.assertIsNotNone(log, "StateLog entry should be created") - self.assertEqual(log.state, 'OPERATING') + self.assertEqual(log.state, "OPERATING") self.assertEqual(log.by, self.moderator) def test_full_lifecycle_logged(self): @@ -848,7 +748,7 @@ class RideStateLogTests(TestCase): from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride_ct = ContentType.objects.get_for_model(ride) # Lifecycle: OPERATING -> CLOSED_TEMP -> SBNO -> CLOSED_PERM -> DEMOLISHED @@ -865,37 +765,31 @@ class RideStateLogTests(TestCase): ride.save() # Check all logs created - logs = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).order_by("timestamp") self.assertEqual(logs.count(), 4, "Should have 4 log entries") states = [log.state for log in logs] - self.assertEqual(states, ['CLOSED_TEMP', 'SBNO', 'CLOSED_PERM', 'DEMOLISHED']) + self.assertEqual(states, ["CLOSED_TEMP", "SBNO", "CLOSED_PERM", "DEMOLISHED"]) def test_scheduled_closing_workflow_logged(self): """Test that scheduled closing workflow creates logs.""" from django.contrib.contenttypes.models import ContentType from django_fsm_log.models import StateLog - ride = self._create_ride(status='OPERATING') + ride = self._create_ride(status="OPERATING") ride_ct = ContentType.objects.get_for_model(ride) # Scheduled closing workflow: OPERATING -> CLOSING -> CLOSED_PERM ride.transition_to_closing(user=self.moderator) ride.closing_date = (timezone.now() + timedelta(days=30)).date() - ride.post_closing_status = 'DEMOLISHED' + ride.post_closing_status = "DEMOLISHED" ride.save() ride.transition_to_closed_perm(user=self.moderator) ride.save() - logs = StateLog.objects.filter( - content_type=ride_ct, - object_id=ride.id - ).order_by('timestamp') + logs = StateLog.objects.filter(content_type=ride_ct, object_id=ride.id).order_by("timestamp") self.assertEqual(logs.count(), 2, "Should have 2 log entries") - self.assertEqual(logs[0].state, 'CLOSING') - self.assertEqual(logs[1].state, 'CLOSED_PERM') + self.assertEqual(logs[0].state, "CLOSING") + self.assertEqual(logs[1].state, "CLOSED_PERM") diff --git a/backend/apps/rides/views.py b/backend/apps/rides/views.py index f8970f76..66aef276 100644 --- a/backend/apps/rides/views.py +++ b/backend/apps/rides/views.py @@ -98,9 +98,7 @@ def show_coaster_fields(request: HttpRequest) -> HttpResponse: return render(request, "rides/partials/coaster_fields.html") -def ride_status_actions( - request: HttpRequest, park_slug: str, ride_slug: str -) -> HttpResponse: +def ride_status_actions(request: HttpRequest, park_slug: str, ride_slug: str) -> HttpResponse: """ Return FSM status actions for ride moderators. @@ -131,9 +129,7 @@ def ride_status_actions( ) -def ride_header_badge( - request: HttpRequest, park_slug: str, ride_slug: str -) -> HttpResponse: +def ride_header_badge(request: HttpRequest, park_slug: str, ride_slug: str) -> HttpResponse: """ Return the header status badge partial for a ride. @@ -205,9 +201,7 @@ class RideDetailView(HistoryMixin, DetailView): return context -class RideCreateView( - LoginRequiredMixin, ParkContextRequired, RideFormMixin, CreateView -): +class RideCreateView(LoginRequiredMixin, ParkContextRequired, RideFormMixin, CreateView): """ View for creating a new ride. @@ -389,9 +383,7 @@ class RideListView(ListView): from apps.core.choices.registry import get_choices choices = get_choices("categories", "rides") - context["category_choices"] = [ - (choice.value, choice.label) for choice in choices - ] + context["category_choices"] = [(choice.value, choice.label) for choice in choices] # Add filter summary for display if filter_form.is_valid(): @@ -512,10 +504,7 @@ def get_search_suggestions(request: HttpRequest) -> HttpResponse: if query: # Get common ride names matching_names = ( - Ride.objects.filter(name__icontains=query) - .values("name") - .annotate(count=Count("id")) - .order_by("-count")[:3] + Ride.objects.filter(name__icontains=query).values("name").annotate(count=Count("id")).order_by("-count")[:3] ) for match in matching_names: @@ -663,18 +652,14 @@ class RideRankingsView(ListView): from apps.core.choices.registry import get_choices choices = get_choices("categories", "rides") - context["category_choices"] = [ - (choice.value, choice.label) for choice in choices - ] + context["category_choices"] = [(choice.value, choice.label) for choice in choices] context["selected_category"] = self.request.GET.get("category", "all") context["min_riders"] = self.request.GET.get("min_riders", "") # Add statistics if self.object_list: context["total_ranked"] = RideRanking.objects.count() - context["last_updated"] = ( - self.object_list[0].last_calculated if self.object_list else None - ) + context["last_updated"] = self.object_list[0].last_calculated if self.object_list else None return context @@ -688,9 +673,9 @@ class RideRankingDetailView(DetailView): def get_queryset(self): """Get ride with ranking data.""" - return Ride.objects.select_related( - "park", "manufacturer", "ranking" - ).prefetch_related("comparisons_as_a", "comparisons_as_b", "ranking_history") + return Ride.objects.select_related("park", "manufacturer", "ranking").prefetch_related( + "comparisons_as_a", "comparisons_as_b", "ranking_history" + ) def get_context_data(self, **kwargs): """Add ranking details to context.""" @@ -704,14 +689,10 @@ class RideRankingDetailView(DetailView): context.update(ranking_details) # Get recent movement - recent_snapshots = RankingSnapshot.objects.filter( - ride=self.object - ).order_by("-snapshot_date")[:7] + recent_snapshots = RankingSnapshot.objects.filter(ride=self.object).order_by("-snapshot_date")[:7] if len(recent_snapshots) >= 2: - context["rank_change"] = ( - recent_snapshots[0].rank - recent_snapshots[1].rank - ) + context["rank_change"] = recent_snapshots[0].rank - recent_snapshots[1].rank context["previous_rank"] = recent_snapshots[1].rank else: context["not_ranked"] = True diff --git a/backend/apps/support/models.py b/backend/apps/support/models.py index b2817fb4..7d36861c 100644 --- a/backend/apps/support/models.py +++ b/backend/apps/support/models.py @@ -5,30 +5,30 @@ from apps.core.history import TrackedModel class Ticket(TrackedModel): - STATUS_OPEN = 'open' - STATUS_IN_PROGRESS = 'in_progress' - STATUS_CLOSED = 'closed' + STATUS_OPEN = "open" + STATUS_IN_PROGRESS = "in_progress" + STATUS_CLOSED = "closed" STATUS_CHOICES = [ - (STATUS_OPEN, 'Open'), - (STATUS_IN_PROGRESS, 'In Progress'), - (STATUS_CLOSED, 'Closed'), + (STATUS_OPEN, "Open"), + (STATUS_IN_PROGRESS, "In Progress"), + (STATUS_CLOSED, "Closed"), ] - CATEGORY_GENERAL = 'general' - CATEGORY_BUG = 'bug' - CATEGORY_PARTNERSHIP = 'partnership' - CATEGORY_PRESS = 'press' - CATEGORY_DATA = 'data' - CATEGORY_ACCOUNT = 'account' + CATEGORY_GENERAL = "general" + CATEGORY_BUG = "bug" + CATEGORY_PARTNERSHIP = "partnership" + CATEGORY_PRESS = "press" + CATEGORY_DATA = "data" + CATEGORY_ACCOUNT = "account" CATEGORY_CHOICES = [ - (CATEGORY_GENERAL, 'General Inquiry'), - (CATEGORY_BUG, 'Bug Report'), - (CATEGORY_PARTNERSHIP, 'Partnership'), - (CATEGORY_PRESS, 'Press/Media'), - (CATEGORY_DATA, 'Data Correction'), - (CATEGORY_ACCOUNT, 'Account Issue'), + (CATEGORY_GENERAL, "General Inquiry"), + (CATEGORY_BUG, "Bug Report"), + (CATEGORY_PARTNERSHIP, "Partnership"), + (CATEGORY_PRESS, "Press/Media"), + (CATEGORY_DATA, "Data Correction"), + (CATEGORY_ACCOUNT, "Account Issue"), ] user = models.ForeignKey( @@ -37,7 +37,7 @@ class Ticket(TrackedModel): null=True, blank=True, related_name="tickets", - help_text="User who submitted the ticket (optional)" + help_text="User who submitted the ticket (optional)", ) category = models.CharField( @@ -45,18 +45,13 @@ class Ticket(TrackedModel): choices=CATEGORY_CHOICES, default=CATEGORY_GENERAL, db_index=True, - help_text="Category of the ticket" + help_text="Category of the ticket", ) subject = models.CharField(max_length=255) message = models.TextField() email = models.EmailField(help_text="Contact email", blank=True) - status = models.CharField( - max_length=20, - choices=STATUS_CHOICES, - default=STATUS_OPEN, - db_index=True - ) + status = models.CharField(max_length=20, choices=STATUS_CHOICES, default=STATUS_OPEN, db_index=True) class Meta(TrackedModel.Meta): verbose_name = "Ticket" @@ -71,4 +66,3 @@ class Ticket(TrackedModel): if self.user and not self.email: self.email = self.user.email super().save(*args, **kwargs) - diff --git a/backend/apps/support/serializers.py b/backend/apps/support/serializers.py index 701c621a..9990745d 100644 --- a/backend/apps/support/serializers.py +++ b/backend/apps/support/serializers.py @@ -7,8 +7,8 @@ from .models import Ticket class TicketSerializer(serializers.ModelSerializer): user = UserSerializer(read_only=True) - category_display = serializers.CharField(source='get_category_display', read_only=True) - status_display = serializers.CharField(source='get_status_display', read_only=True) + category_display = serializers.CharField(source="get_category_display", read_only=True) + status_display = serializers.CharField(source="get_status_display", read_only=True) class Meta: model = Ticket @@ -29,8 +29,7 @@ class TicketSerializer(serializers.ModelSerializer): def validate(self, data): # Ensure email is provided if user is anonymous - request = self.context.get('request') - if request and not request.user.is_authenticated and not data.get('email'): + request = self.context.get("request") + if request and not request.user.is_authenticated and not data.get("email"): raise serializers.ValidationError({"email": "Email is required for guests."}) return data - diff --git a/backend/apps/support/views.py b/backend/apps/support/views.py index 94d028da..ff869ce4 100644 --- a/backend/apps/support/views.py +++ b/backend/apps/support/views.py @@ -11,9 +11,10 @@ class TicketViewSet(viewsets.ModelViewSet): Only Staff can LIST/RETRIEVE/UPDATE all. Users can LIST/RETRIEVE their own. """ + queryset = Ticket.objects.all() serializer_class = TicketSerializer - permission_classes = [permissions.AllowAny] # We handle granular perms in get_queryset/perform_create + permission_classes = [permissions.AllowAny] # We handle granular perms in get_queryset/perform_create filter_backends = [DjangoFilterBackend, filters.OrderingFilter] filterset_fields = ["status", "category"] ordering_fields = ["created_at", "status"] @@ -25,7 +26,7 @@ class TicketViewSet(viewsets.ModelViewSet): return Ticket.objects.all() if user.is_authenticated: return Ticket.objects.filter(user=user) - return Ticket.objects.none() # Guests can't list tickets + return Ticket.objects.none() # Guests can't list tickets def perform_create(self, serializer): if self.request.user.is_authenticated: diff --git a/backend/config/django/base.py b/backend/config/django/base.py index 1fc59729..a1083874 100644 --- a/backend/config/django/base.py +++ b/backend/config/django/base.py @@ -42,16 +42,12 @@ DEBUG = config("DEBUG", default=True, cast=bool) # Allowed hosts (comma-separated in .env) ALLOWED_HOSTS = config( - "ALLOWED_HOSTS", - default="localhost,127.0.0.1", - cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] + "ALLOWED_HOSTS", default="localhost,127.0.0.1", cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] ) # CSRF trusted origins (comma-separated in .env) CSRF_TRUSTED_ORIGINS = config( - "CSRF_TRUSTED_ORIGINS", - default="", - cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] + "CSRF_TRUSTED_ORIGINS", default="", cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] ) # ============================================================================= diff --git a/backend/config/django/local.py b/backend/config/django/local.py index f13c838b..f03e34ab 100644 --- a/backend/config/django/local.py +++ b/backend/config/django/local.py @@ -149,10 +149,7 @@ LOGGING = { }, "json": { "()": "pythonjsonlogger.jsonlogger.JsonFormatter", - "format": ( - "%(levelname)s %(asctime)s %(module)s %(process)d " - "%(thread)d %(message)s" - ), + "format": ("%(levelname)s %(asctime)s %(module)s %(process)d " "%(thread)d %(message)s"), }, }, "handlers": { diff --git a/backend/config/django/production.py b/backend/config/django/production.py index 9a0ba147..4281c280 100644 --- a/backend/config/django/production.py +++ b/backend/config/django/production.py @@ -20,16 +20,10 @@ from .base import * # noqa: F401,F403 DEBUG = False # Allowed hosts must be explicitly set in production -ALLOWED_HOSTS = config( - "ALLOWED_HOSTS", - cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] -) +ALLOWED_HOSTS = config("ALLOWED_HOSTS", cast=lambda v: [s.strip() for s in v.split(",") if s.strip()]) # CSRF trusted origins for production -CSRF_TRUSTED_ORIGINS = config( - "CSRF_TRUSTED_ORIGINS", - cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] -) +CSRF_TRUSTED_ORIGINS = config("CSRF_TRUSTED_ORIGINS", cast=lambda v: [s.strip() for s in v.split(",") if s.strip()]) # ============================================================================= # Security Settings for Production @@ -75,9 +69,7 @@ if redis_url: "PARSER_CLASS": "redis.connection.HiredisParser", "CONNECTION_POOL_CLASS": "redis.BlockingConnectionPool", "CONNECTION_POOL_CLASS_KWARGS": { - "max_connections": config( - "REDIS_MAX_CONNECTIONS", default=100, cast=int - ), + "max_connections": config("REDIS_MAX_CONNECTIONS", default=100, cast=int), "timeout": 20, "socket_keepalive": True, "retry_on_timeout": True, @@ -119,9 +111,7 @@ if redis_url: STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage" # Update STORAGES for Django 4.2+ -STORAGES["staticfiles"]["BACKEND"] = ( # noqa: F405 - "whitenoise.storage.CompressedManifestStaticFilesStorage" -) +STORAGES["staticfiles"]["BACKEND"] = "whitenoise.storage.CompressedManifestStaticFilesStorage" # noqa: F405 # ============================================================================= # Production REST Framework Settings @@ -148,8 +138,7 @@ LOGGING = { "json": { "()": "pythonjsonlogger.jsonlogger.JsonFormatter", "format": ( - "%(levelname)s %(asctime)s %(module)s %(process)d " - "%(thread)d %(message)s %(pathname)s %(lineno)d" + "%(levelname)s %(asctime)s %(module)s %(process)d " "%(thread)d %(message)s %(pathname)s %(lineno)d" ), }, "simple": { @@ -257,9 +246,7 @@ if SENTRY_DSN: RedisIntegration(), ], environment=config("SENTRY_ENVIRONMENT", default="production"), - traces_sample_rate=config( - "SENTRY_TRACES_SAMPLE_RATE", default=0.1, cast=float - ), + traces_sample_rate=config("SENTRY_TRACES_SAMPLE_RATE", default=0.1, cast=float), send_default_pii=False, # Don't send PII to Sentry attach_stacktrace=True, ) diff --git a/backend/config/settings/cache.py b/backend/config/settings/cache.py index d4934390..47c16e56 100644 --- a/backend/config/settings/cache.py +++ b/backend/config/settings/cache.py @@ -46,15 +46,13 @@ CACHES = { # Connection pooling for better performance "CONNECTION_POOL_CLASS": "redis.BlockingConnectionPool", "CONNECTION_POOL_CLASS_KWARGS": { - "max_connections": config( - "REDIS_MAX_CONNECTIONS", default=100, cast=int - ), + "max_connections": config("REDIS_MAX_CONNECTIONS", default=100, cast=int), "timeout": config("REDIS_CONNECTION_TIMEOUT", default=20, cast=int), "socket_keepalive": True, "socket_keepalive_options": { - 1: 1, # TCP_KEEPIDLE: Start keepalive after 1s idle - 2: 1, # TCP_KEEPINTVL: Send probes every 1s - 3: 3, # TCP_KEEPCNT: Close after 3 failed probes + 1: 1, # TCP_KEEPIDLE: Start keepalive after 1s idle + 2: 1, # TCP_KEEPINTVL: Send probes every 1s + 3: 3, # TCP_KEEPCNT: Close after 3 failed probes }, "retry_on_timeout": True, "health_check_interval": 30, @@ -62,14 +60,11 @@ CACHES = { # Compress cached data to save memory "COMPRESSOR": "django_redis.compressors.zlib.ZlibCompressor", # Graceful degradation if Redis is unavailable - "IGNORE_EXCEPTIONS": config( - "REDIS_IGNORE_EXCEPTIONS", default=True, cast=bool - ), + "IGNORE_EXCEPTIONS": config("REDIS_IGNORE_EXCEPTIONS", default=True, cast=bool), }, "KEY_PREFIX": config("CACHE_KEY_PREFIX", default="thrillwiki"), "VERSION": 1, }, - # Session cache - separate for security isolation # Uses a different Redis database (db 2) "sessions": { @@ -80,16 +75,13 @@ CACHES = { "PARSER_CLASS": "redis.connection.HiredisParser", "CONNECTION_POOL_CLASS": "redis.BlockingConnectionPool", "CONNECTION_POOL_CLASS_KWARGS": { - "max_connections": config( - "REDIS_SESSIONS_MAX_CONNECTIONS", default=50, cast=int - ), + "max_connections": config("REDIS_SESSIONS_MAX_CONNECTIONS", default=50, cast=int), "timeout": 10, "socket_keepalive": True, }, }, "KEY_PREFIX": "sessions", }, - # API cache - high concurrency for API responses # Uses a different Redis database (db 3) "api": { @@ -100,9 +92,7 @@ CACHES = { "PARSER_CLASS": "redis.connection.HiredisParser", "CONNECTION_POOL_CLASS": "redis.BlockingConnectionPool", "CONNECTION_POOL_CLASS_KWARGS": { - "max_connections": config( - "REDIS_API_MAX_CONNECTIONS", default=100, cast=int - ), + "max_connections": config("REDIS_API_MAX_CONNECTIONS", default=100, cast=int), "timeout": 15, "socket_keepalive": True, "retry_on_timeout": True, @@ -126,14 +116,10 @@ SESSION_CACHE_ALIAS = "sessions" SESSION_COOKIE_AGE = config("SESSION_COOKIE_AGE", default=3600, cast=int) # Update session on each request (sliding expiry) -SESSION_SAVE_EVERY_REQUEST = config( - "SESSION_SAVE_EVERY_REQUEST", default=True, cast=bool -) +SESSION_SAVE_EVERY_REQUEST = config("SESSION_SAVE_EVERY_REQUEST", default=True, cast=bool) # Session persists until cookie expires (not browser close) -SESSION_EXPIRE_AT_BROWSER_CLOSE = config( - "SESSION_EXPIRE_AT_BROWSER_CLOSE", default=False, cast=bool -) +SESSION_EXPIRE_AT_BROWSER_CLOSE = config("SESSION_EXPIRE_AT_BROWSER_CLOSE", default=False, cast=bool) # ============================================================================= # Cache Middleware Settings @@ -141,6 +127,4 @@ SESSION_EXPIRE_AT_BROWSER_CLOSE = config( # For Django's cache middleware (UpdateCacheMiddleware/FetchFromCacheMiddleware) CACHE_MIDDLEWARE_SECONDS = config("CACHE_MIDDLEWARE_SECONDS", default=300, cast=int) -CACHE_MIDDLEWARE_KEY_PREFIX = config( - "CACHE_MIDDLEWARE_KEY_PREFIX", default="thrillwiki" -) +CACHE_MIDDLEWARE_KEY_PREFIX = config("CACHE_MIDDLEWARE_KEY_PREFIX", default="thrillwiki") diff --git a/backend/config/settings/database.py b/backend/config/settings/database.py index a1d28c2e..4346b2b5 100644 --- a/backend/config/settings/database.py +++ b/backend/config/settings/database.py @@ -26,10 +26,7 @@ from decouple import config # ============================================================================= # Parse DATABASE_URL environment variable into Django database settings -DATABASE_URL = config( - "DATABASE_URL", - default="postgis://thrillwiki_user:thrillwiki@localhost:5432/thrillwiki_test_db" -) +DATABASE_URL = config("DATABASE_URL", default="postgis://thrillwiki_user:thrillwiki@localhost:5432/thrillwiki_test_db") # Parse the database URL db_config = dj_database_url.parse(DATABASE_URL) @@ -84,14 +81,8 @@ if "postgis" in DATABASE_URL or "postgresql" in DATABASE_URL: # macOS with Homebrew (default) # Linux: /usr/lib/x86_64-linux-gnu/libgdal.so # Docker: Usually handled by the image -GDAL_LIBRARY_PATH = config( - "GDAL_LIBRARY_PATH", - default="/opt/homebrew/lib/libgdal.dylib" -) -GEOS_LIBRARY_PATH = config( - "GEOS_LIBRARY_PATH", - default="/opt/homebrew/lib/libgeos_c.dylib" -) +GDAL_LIBRARY_PATH = config("GDAL_LIBRARY_PATH", default="/opt/homebrew/lib/libgdal.dylib") +GEOS_LIBRARY_PATH = config("GEOS_LIBRARY_PATH", default="/opt/homebrew/lib/libgeos_c.dylib") # ============================================================================= # Read Replica Configuration (Optional) diff --git a/backend/config/settings/email.py b/backend/config/settings/email.py index d5f37d68..b3e12f49 100644 --- a/backend/config/settings/email.py +++ b/backend/config/settings/email.py @@ -21,10 +21,7 @@ from decouple import config # - ForwardEmail: django_forwardemail.backends.ForwardEmailBackend (production) # - SMTP: django.core.mail.backends.smtp.EmailBackend (custom SMTP) -EMAIL_BACKEND = config( - "EMAIL_BACKEND", - default="django_forwardemail.backends.ForwardEmailBackend" -) +EMAIL_BACKEND = config("EMAIL_BACKEND", default="django_forwardemail.backends.ForwardEmailBackend") # ============================================================================= # ForwardEmail Configuration @@ -32,10 +29,7 @@ EMAIL_BACKEND = config( # ForwardEmail is a privacy-focused email service that supports custom domains # https://forwardemail.net/ -FORWARD_EMAIL_BASE_URL = config( - "FORWARD_EMAIL_BASE_URL", - default="https://api.forwardemail.net" -) +FORWARD_EMAIL_BASE_URL = config("FORWARD_EMAIL_BASE_URL", default="https://api.forwardemail.net") FORWARD_EMAIL_API_KEY = config("FORWARD_EMAIL_API_KEY", default="") FORWARD_EMAIL_DOMAIN = config("FORWARD_EMAIL_DOMAIN", default="") @@ -62,10 +56,7 @@ EMAIL_HOST_PASSWORD = config("EMAIL_HOST_PASSWORD", default="") EMAIL_TIMEOUT = config("EMAIL_TIMEOUT", default=30, cast=int) # Default from email address -DEFAULT_FROM_EMAIL = config( - "DEFAULT_FROM_EMAIL", - default="ThrillWiki " -) +DEFAULT_FROM_EMAIL = config("DEFAULT_FROM_EMAIL", default="ThrillWiki ") # ============================================================================= # Email Subject Prefix diff --git a/backend/config/settings/local.py b/backend/config/settings/local.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/config/settings/logging.py b/backend/config/settings/logging.py index 27001221..1238174c 100644 --- a/backend/config/settings/logging.py +++ b/backend/config/settings/logging.py @@ -46,10 +46,7 @@ LOGGING_FORMATTERS = { # JSON format for production - machine parseable for log aggregation "json": { "()": "pythonjsonlogger.jsonlogger.JsonFormatter", - "format": ( - "%(levelname)s %(asctime)s %(module)s %(process)d " - "%(thread)d %(message)s" - ), + "format": ("%(levelname)s %(asctime)s %(module)s %(process)d " "%(thread)d %(message)s"), }, # Simple format for console output "simple": { diff --git a/backend/config/settings/rest_framework.py b/backend/config/settings/rest_framework.py index 5e524718..95a67970 100644 --- a/backend/config/settings/rest_framework.py +++ b/backend/config/settings/rest_framework.py @@ -82,15 +82,11 @@ REST_FRAMEWORK = { CORS_ALLOW_CREDENTIALS = True # Allow all origins (not recommended for production) -CORS_ALLOW_ALL_ORIGINS = config( - "CORS_ALLOW_ALL_ORIGINS", default=False, cast=bool -) +CORS_ALLOW_ALL_ORIGINS = config("CORS_ALLOW_ALL_ORIGINS", default=False, cast=bool) # Specific allowed origins (comma-separated) CORS_ALLOWED_ORIGINS = config( - "CORS_ALLOWED_ORIGINS", - default="", - cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] + "CORS_ALLOWED_ORIGINS", default="", cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] ) # Allowed HTTP headers for CORS requests @@ -129,33 +125,27 @@ CORS_EXPOSE_HEADERS = [ # API Rate Limiting # ============================================================================= -API_RATE_LIMIT_PER_MINUTE = config( - "API_RATE_LIMIT_PER_MINUTE", default=60, cast=int -) -API_RATE_LIMIT_PER_HOUR = config( - "API_RATE_LIMIT_PER_HOUR", default=1000, cast=int -) +API_RATE_LIMIT_PER_MINUTE = config("API_RATE_LIMIT_PER_MINUTE", default=60, cast=int) +API_RATE_LIMIT_PER_HOUR = config("API_RATE_LIMIT_PER_HOUR", default=1000, cast=int) # ============================================================================= # SimpleJWT Settings # ============================================================================= # JWT token configuration for authentication + # Import SECRET_KEY for signing tokens # This will be set by base.py before this module is imported def get_secret_key(): """Get SECRET_KEY lazily to avoid circular imports.""" return config("SECRET_KEY") + SIMPLE_JWT = { # Token lifetimes # Short access tokens (15 min) provide better security - "ACCESS_TOKEN_LIFETIME": timedelta( - minutes=config("JWT_ACCESS_TOKEN_LIFETIME_MINUTES", default=15, cast=int) - ), - "REFRESH_TOKEN_LIFETIME": timedelta( - days=config("JWT_REFRESH_TOKEN_LIFETIME_DAYS", default=7, cast=int) - ), + "ACCESS_TOKEN_LIFETIME": timedelta(minutes=config("JWT_ACCESS_TOKEN_LIFETIME_MINUTES", default=15, cast=int)), + "REFRESH_TOKEN_LIFETIME": timedelta(days=config("JWT_REFRESH_TOKEN_LIFETIME_DAYS", default=7, cast=int)), # Token rotation and blacklisting # Rotate refresh tokens on each use and blacklist old ones "ROTATE_REFRESH_TOKENS": True, @@ -177,9 +167,7 @@ SIMPLE_JWT = { # User identification "USER_ID_FIELD": "id", "USER_ID_CLAIM": "user_id", - "USER_AUTHENTICATION_RULE": ( - "rest_framework_simplejwt.authentication.default_user_authentication_rule" - ), + "USER_AUTHENTICATION_RULE": ("rest_framework_simplejwt.authentication.default_user_authentication_rule"), # Token classes "AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",), "TOKEN_TYPE_CLAIM": "token_type", @@ -211,9 +199,7 @@ REST_AUTH = { # SameSite cookie attribute (Lax is compatible with OAuth flows) "JWT_AUTH_SAMESITE": "Lax", "JWT_AUTH_RETURN_EXPIRATION": True, - "JWT_TOKEN_CLAIMS_SERIALIZER": ( - "rest_framework_simplejwt.serializers.TokenObtainPairSerializer" - ), + "JWT_TOKEN_CLAIMS_SERIALIZER": ("rest_framework_simplejwt.serializers.TokenObtainPairSerializer"), } # ============================================================================= diff --git a/backend/config/settings/secrets.py b/backend/config/settings/secrets.py index 8336680b..634cc195 100644 --- a/backend/config/settings/secrets.py +++ b/backend/config/settings/secrets.py @@ -31,17 +31,13 @@ logger = logging.getLogger("security") # ============================================================================= # Enable secret rotation checking (set to True in production) -SECRET_ROTATION_ENABLED = config( - "SECRET_ROTATION_ENABLED", default=False, cast=bool -) +SECRET_ROTATION_ENABLED = config("SECRET_ROTATION_ENABLED", default=False, cast=bool) # Secret version for tracking rotations SECRET_KEY_VERSION = config("SECRET_KEY_VERSION", default="1") # Secret expiry warning threshold (days before expiry to start warning) -SECRET_EXPIRY_WARNING_DAYS = config( - "SECRET_EXPIRY_WARNING_DAYS", default=30, cast=int -) +SECRET_EXPIRY_WARNING_DAYS = config("SECRET_EXPIRY_WARNING_DAYS", default=30, cast=int) # ============================================================================= # Required Secrets Registry @@ -104,10 +100,7 @@ def validate_secret_strength(name: str, value: str, min_length: int = 10) -> boo return False if len(value) < min_length: - logger.error( - f"Secret '{name}' is too short ({len(value)} chars, " - f"minimum {min_length})" - ) + logger.error(f"Secret '{name}' is too short ({len(value)} chars, " f"minimum {min_length})") return False # Check for placeholder values @@ -123,9 +116,7 @@ def validate_secret_strength(name: str, value: str, min_length: int = 10) -> boo value_lower = value.lower() for pattern in placeholder_patterns: if pattern in value_lower: - logger.warning( - f"Secret '{name}' appears to contain a placeholder value" - ) + logger.warning(f"Secret '{name}' appears to contain a placeholder value") return False return True @@ -148,9 +139,7 @@ def validate_secret_key(secret_key: str) -> bool: bool: True if valid, False otherwise """ if len(secret_key) < 50: - logger.error( - f"SECRET_KEY is too short ({len(secret_key)} chars, minimum 50)" - ) + logger.error(f"SECRET_KEY is too short ({len(secret_key)} chars, minimum 50)") return False has_upper = any(c.isupper() for c in secret_key) @@ -159,10 +148,7 @@ def validate_secret_key(secret_key: str) -> bool: has_special = any(not c.isalnum() for c in secret_key) if not all([has_upper, has_lower, has_digit, has_special]): - logger.warning( - "SECRET_KEY should contain uppercase, lowercase, digits, " - "and special characters" - ) + logger.warning("SECRET_KEY should contain uppercase, lowercase, digits, " "and special characters") # Don't fail, just warn - some generated keys may not have all return True @@ -193,7 +179,7 @@ def get_secret( value = config(name, default=default) except UndefinedValueError: if required: - raise ValueError(f"Required secret '{name}' is not set") + raise ValueError(f"Required secret '{name}' is not set") from None return default if value and min_length > 0 and not validate_secret_strength(name, value, min_length): @@ -231,7 +217,7 @@ def validate_required_secrets(raise_on_error: bool = False) -> list[str]: msg = f"Required secret '{name}' is not set: {rules['description']}" errors.append(msg) if raise_on_error: - raise ValueError(msg) + raise ValueError(msg) from None return errors @@ -257,9 +243,7 @@ def check_secret_expiry() -> list[str]: version = int(SECRET_KEY_VERSION) # If version is very old, suggest rotation if version < 2: - warnings_list.append( - "SECRET_KEY version is old. Consider rotating secrets." - ) + warnings_list.append("SECRET_KEY version is old. Consider rotating secrets.") except ValueError: pass @@ -316,8 +300,7 @@ class EnvironmentSecretProvider(SecretProvider): def set_secret(self, name: str, value: str) -> bool: """Environment variables are read-only at runtime.""" logger.warning( - f"Cannot set secret '{name}' in environment provider. " - "Update your .env file or environment variables." + f"Cannot set secret '{name}' in environment provider. " "Update your .env file or environment variables." ) return False @@ -385,4 +368,4 @@ def run_startup_validation() -> None: raise ValueError("SECRET_KEY does not meet security requirements") except UndefinedValueError: if not debug_mode: - raise ValueError("SECRET_KEY is required in production") + raise ValueError("SECRET_KEY is required in production") from None diff --git a/backend/config/settings/security.py b/backend/config/settings/security.py index 9344993e..db3d42bd 100644 --- a/backend/config/settings/security.py +++ b/backend/config/settings/security.py @@ -35,15 +35,11 @@ TURNSTILE_VERIFY_URL = config( # X-XSS-Protection: Enables browser's built-in XSS filter # Note: Modern browsers are deprecating this in favor of CSP, but it's still # useful for older browsers -SECURE_BROWSER_XSS_FILTER = config( - "SECURE_BROWSER_XSS_FILTER", default=True, cast=bool -) +SECURE_BROWSER_XSS_FILTER = config("SECURE_BROWSER_XSS_FILTER", default=True, cast=bool) # X-Content-Type-Options: Prevents MIME type sniffing attacks # When True, adds "X-Content-Type-Options: nosniff" header -SECURE_CONTENT_TYPE_NOSNIFF = config( - "SECURE_CONTENT_TYPE_NOSNIFF", default=True, cast=bool -) +SECURE_CONTENT_TYPE_NOSNIFF = config("SECURE_CONTENT_TYPE_NOSNIFF", default=True, cast=bool) # X-Frame-Options: Protects against clickjacking attacks # DENY = Never allow framing (most secure) @@ -53,24 +49,18 @@ X_FRAME_OPTIONS = config("X_FRAME_OPTIONS", default="DENY") # Referrer-Policy: Controls how much referrer information is sent # strict-origin-when-cross-origin = Send full URL for same-origin, # only origin for cross-origin, nothing for downgrade -SECURE_REFERRER_POLICY = config( - "SECURE_REFERRER_POLICY", default="strict-origin-when-cross-origin" -) +SECURE_REFERRER_POLICY = config("SECURE_REFERRER_POLICY", default="strict-origin-when-cross-origin") # Cross-Origin-Opener-Policy: Prevents cross-origin attacks via window references # same-origin = Document can only be accessed by windows from same origin -SECURE_CROSS_ORIGIN_OPENER_POLICY = config( - "SECURE_CROSS_ORIGIN_OPENER_POLICY", default="same-origin" -) +SECURE_CROSS_ORIGIN_OPENER_POLICY = config("SECURE_CROSS_ORIGIN_OPENER_POLICY", default="same-origin") # ============================================================================= # HSTS (HTTP Strict Transport Security) Configuration # ============================================================================= # Include subdomains in HSTS policy -SECURE_HSTS_INCLUDE_SUBDOMAINS = config( - "SECURE_HSTS_INCLUDE_SUBDOMAINS", default=True, cast=bool -) +SECURE_HSTS_INCLUDE_SUBDOMAINS = config("SECURE_HSTS_INCLUDE_SUBDOMAINS", default=True, cast=bool) # HSTS max-age in seconds (31536000 = 1 year, recommended minimum) SECURE_HSTS_SECONDS = config("SECURE_HSTS_SECONDS", default=31536000, cast=int) @@ -82,9 +72,7 @@ SECURE_HSTS_PRELOAD = config("SECURE_HSTS_PRELOAD", default=False, cast=bool) # URLs exempt from SSL redirect (e.g., health checks) # Format: comma-separated list of URL patterns SECURE_REDIRECT_EXEMPT = config( - "SECURE_REDIRECT_EXEMPT", - default="", - cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] + "SECURE_REDIRECT_EXEMPT", default="", cast=lambda v: [s.strip() for s in v.split(",") if s.strip()] ) # Redirect all HTTP requests to HTTPS @@ -93,9 +81,7 @@ SECURE_SSL_REDIRECT = config("SECURE_SSL_REDIRECT", default=False, cast=bool) # Header used by proxy to indicate HTTPS # Common values: ('HTTP_X_FORWARDED_PROTO', 'https') _proxy_ssl_header = config("SECURE_PROXY_SSL_HEADER", default="") -SECURE_PROXY_SSL_HEADER = ( - tuple(_proxy_ssl_header.split(",")) if _proxy_ssl_header else None -) +SECURE_PROXY_SSL_HEADER = tuple(_proxy_ssl_header.split(",")) if _proxy_ssl_header else None # ============================================================================= # Session Cookie Security @@ -143,9 +129,7 @@ AUTHENTICATION_BACKENDS = [ AUTH_PASSWORD_VALIDATORS = [ { - "NAME": ( - "django.contrib.auth.password_validation.UserAttributeSimilarityValidator" - ), + "NAME": ("django.contrib.auth.password_validation.UserAttributeSimilarityValidator"), }, { "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", diff --git a/backend/config/settings/storage.py b/backend/config/settings/storage.py index dc33615d..a657c3ca 100644 --- a/backend/config/settings/storage.py +++ b/backend/config/settings/storage.py @@ -37,19 +37,13 @@ STATIC_ROOT = BASE_DIR / "staticfiles" # WhiteNoise serves static files efficiently without a separate web server # Compression quality for Brotli/Gzip (1-100, higher = better but slower) -WHITENOISE_COMPRESSION_QUALITY = config( - "WHITENOISE_COMPRESSION_QUALITY", default=90, cast=int -) +WHITENOISE_COMPRESSION_QUALITY = config("WHITENOISE_COMPRESSION_QUALITY", default=90, cast=int) # Cache max-age for static files (1 year for immutable content) -WHITENOISE_MAX_AGE = config( - "WHITENOISE_MAX_AGE", default=31536000, cast=int -) +WHITENOISE_MAX_AGE = config("WHITENOISE_MAX_AGE", default=31536000, cast=int) # Don't fail on missing manifest entries (graceful degradation) -WHITENOISE_MANIFEST_STRICT = config( - "WHITENOISE_MANIFEST_STRICT", default=False, cast=bool -) +WHITENOISE_MANIFEST_STRICT = config("WHITENOISE_MANIFEST_STRICT", default=False, cast=bool) # Additional MIME types WHITENOISE_MIMETYPES = { @@ -59,11 +53,26 @@ WHITENOISE_MIMETYPES = { # Skip compressing already compressed formats WHITENOISE_SKIP_COMPRESS_EXTENSIONS = [ - "jpg", "jpeg", "png", "gif", "webp", # Images - "zip", "gz", "tgz", "bz2", "tbz", "xz", "br", # Archives - "swf", "flv", # Flash - "woff", "woff2", # Fonts - "mp3", "mp4", "ogg", "webm", # Media + "jpg", + "jpeg", + "png", + "gif", + "webp", # Images + "zip", + "gz", + "tgz", + "bz2", + "tbz", + "xz", + "br", # Archives + "swf", + "flv", # Flash + "woff", + "woff2", # Fonts + "mp3", + "mp4", + "ogg", + "webm", # Media ] # ============================================================================= @@ -103,20 +112,14 @@ STORAGES = { # Maximum size (in bytes) of file to upload into memory (2.5MB) # Files larger than this are written to disk -FILE_UPLOAD_MAX_MEMORY_SIZE = config( - "FILE_UPLOAD_MAX_MEMORY_SIZE", default=2621440, cast=int -) +FILE_UPLOAD_MAX_MEMORY_SIZE = config("FILE_UPLOAD_MAX_MEMORY_SIZE", default=2621440, cast=int) # Maximum size (in bytes) of request data (10MB) # This limits the total size of POST request body -DATA_UPLOAD_MAX_MEMORY_SIZE = config( - "DATA_UPLOAD_MAX_MEMORY_SIZE", default=10485760, cast=int -) +DATA_UPLOAD_MAX_MEMORY_SIZE = config("DATA_UPLOAD_MAX_MEMORY_SIZE", default=10485760, cast=int) # Maximum number of GET/POST parameters (1000) -DATA_UPLOAD_MAX_NUMBER_FIELDS = config( - "DATA_UPLOAD_MAX_NUMBER_FIELDS", default=1000, cast=int -) +DATA_UPLOAD_MAX_NUMBER_FIELDS = config("DATA_UPLOAD_MAX_NUMBER_FIELDS", default=1000, cast=int) # File upload permissions (0o644 = rw-r--r--) FILE_UPLOAD_PERMISSIONS = 0o644 diff --git a/backend/config/settings/third_party.py b/backend/config/settings/third_party.py index ee053b16..e74453e0 100644 --- a/backend/config/settings/third_party.py +++ b/backend/config/settings/third_party.py @@ -33,9 +33,7 @@ ACCOUNT_SIGNUP_FIELDS = ["email*", "username*", "password1*", "password2*"] ACCOUNT_LOGIN_METHODS = {"email", "username"} # Email verification settings -ACCOUNT_EMAIL_VERIFICATION = config( - "ACCOUNT_EMAIL_VERIFICATION", default="mandatory" -) +ACCOUNT_EMAIL_VERIFICATION = config("ACCOUNT_EMAIL_VERIFICATION", default="mandatory") ACCOUNT_EMAIL_REQUIRED = True ACCOUNT_EMAIL_VERIFICATION_SUPPORTS_CHANGE = True ACCOUNT_EMAIL_VERIFICATION_SUPPORTS_RESEND = True @@ -114,12 +112,8 @@ CELERY_BROKER_URL = config("REDIS_URL", default="redis://localhost:6379/1") CELERY_RESULT_BACKEND = config("REDIS_URL", default="redis://localhost:6379/1") # Task settings for test environments -CELERY_TASK_ALWAYS_EAGER = config( - "CELERY_TASK_ALWAYS_EAGER", default=False, cast=bool -) -CELERY_TASK_EAGER_PROPAGATES = config( - "CELERY_TASK_EAGER_PROPAGATES", default=False, cast=bool -) +CELERY_TASK_ALWAYS_EAGER = config("CELERY_TASK_ALWAYS_EAGER", default=False, cast=bool) +CELERY_TASK_EAGER_PROPAGATES = config("CELERY_TASK_EAGER_PROPAGATES", default=False, cast=bool) # ============================================================================= # Health Check Configuration @@ -165,16 +159,10 @@ CLOUDFLARE_IMAGES = { "DEFAULT_VARIANT": config("CLOUDFLARE_IMAGES_DEFAULT_VARIANT", default="public"), "UPLOAD_TIMEOUT": config("CLOUDFLARE_IMAGES_UPLOAD_TIMEOUT", default=300, cast=int), "WEBHOOK_SECRET": config("CLOUDFLARE_IMAGES_WEBHOOK_SECRET", default=""), - "CLEANUP_EXPIRED_HOURS": config( - "CLOUDFLARE_IMAGES_CLEANUP_HOURS", default=24, cast=int - ), - "MAX_FILE_SIZE": config( - "CLOUDFLARE_IMAGES_MAX_FILE_SIZE", default=10 * 1024 * 1024, cast=int - ), + "CLEANUP_EXPIRED_HOURS": config("CLOUDFLARE_IMAGES_CLEANUP_HOURS", default=24, cast=int), + "MAX_FILE_SIZE": config("CLOUDFLARE_IMAGES_MAX_FILE_SIZE", default=10 * 1024 * 1024, cast=int), "ALLOWED_FORMATS": ["jpeg", "png", "gif", "webp"], - "REQUIRE_SIGNED_URLS": config( - "CLOUDFLARE_IMAGES_REQUIRE_SIGNED_URLS", default=False, cast=bool - ), + "REQUIRE_SIGNED_URLS": config("CLOUDFLARE_IMAGES_REQUIRE_SIGNED_URLS", default=False, cast=bool), "DEFAULT_METADATA": {}, } @@ -183,21 +171,13 @@ CLOUDFLARE_IMAGES = { # ============================================================================= # Settings for the road trip planning service using OpenStreetMap -ROADTRIP_CACHE_TIMEOUT = config( - "ROADTRIP_CACHE_TIMEOUT", default=3600 * 24, cast=int -) # 24 hours for geocoding -ROADTRIP_ROUTE_CACHE_TIMEOUT = config( - "ROADTRIP_ROUTE_CACHE_TIMEOUT", default=3600 * 6, cast=int -) # 6 hours for routes +ROADTRIP_CACHE_TIMEOUT = config("ROADTRIP_CACHE_TIMEOUT", default=3600 * 24, cast=int) # 24 hours for geocoding +ROADTRIP_ROUTE_CACHE_TIMEOUT = config("ROADTRIP_ROUTE_CACHE_TIMEOUT", default=3600 * 6, cast=int) # 6 hours for routes ROADTRIP_MAX_REQUESTS_PER_SECOND = config( "ROADTRIP_MAX_REQUESTS_PER_SECOND", default=1, cast=int ) # Respect OSM rate limits -ROADTRIP_USER_AGENT = config( - "ROADTRIP_USER_AGENT", default="ThrillWiki/1.0 (https://thrillwiki.com)" -) -ROADTRIP_REQUEST_TIMEOUT = config( - "ROADTRIP_REQUEST_TIMEOUT", default=10, cast=int -) # seconds +ROADTRIP_USER_AGENT = config("ROADTRIP_USER_AGENT", default="ThrillWiki/1.0 (https://thrillwiki.com)") +ROADTRIP_REQUEST_TIMEOUT = config("ROADTRIP_REQUEST_TIMEOUT", default=10, cast=int) # seconds ROADTRIP_MAX_RETRIES = config("ROADTRIP_MAX_RETRIES", default=3, cast=int) ROADTRIP_BACKOFF_FACTOR = config("ROADTRIP_BACKOFF_FACTOR", default=2, cast=int) @@ -206,9 +186,7 @@ ROADTRIP_BACKOFF_FACTOR = config("ROADTRIP_BACKOFF_FACTOR", default=2, cast=int) # ============================================================================= # django-autocomplete-light settings -AUTOCOMPLETE_BLOCK_UNAUTHENTICATED = config( - "AUTOCOMPLETE_BLOCK_UNAUTHENTICATED", default=False, cast=bool -) +AUTOCOMPLETE_BLOCK_UNAUTHENTICATED = config("AUTOCOMPLETE_BLOCK_UNAUTHENTICATED", default=False, cast=bool) # ============================================================================= # Frontend Configuration @@ -226,7 +204,5 @@ TURNSTILE_SECRET = config("TURNSTILE_SECRET", default="") # Skip Turnstile validation in development if keys not set TURNSTILE_SKIP_VALIDATION = config( - "TURNSTILE_SKIP_VALIDATION", - default=not TURNSTILE_SECRET, # Skip if no secret - cast=bool + "TURNSTILE_SKIP_VALIDATION", default=not TURNSTILE_SECRET, cast=bool # Skip if no secret ) diff --git a/backend/config/settings/validation.py b/backend/config/settings/validation.py index 86a11ca7..6fb040a9 100644 --- a/backend/config/settings/validation.py +++ b/backend/config/settings/validation.py @@ -160,19 +160,13 @@ def validate_email(value: str) -> bool: def validate_type(value: Any, expected_type: type) -> bool: """Validate that a value is of the expected type.""" - if expected_type == bool: + if expected_type is bool: # Special handling for boolean strings - return isinstance(value, bool) or str(value).lower() in ( - "true", "false", "1", "0", "yes", "no" - ) + return isinstance(value, bool) or str(value).lower() in ("true", "false", "1", "0", "yes", "no") return isinstance(value, expected_type) -def validate_range( - value: Any, - min_value: Any | None = None, - max_value: Any | None = None -) -> bool: +def validate_range(value: Any, min_value: Any | None = None, max_value: Any | None = None) -> bool: """Validate that a value is within a specified range.""" if min_value is not None and value < min_value: return False @@ -215,11 +209,11 @@ def validate_variable(name: str, rules: dict) -> list[str]: var_type = rules.get("type", str) default = rules.get("default") - if var_type == bool: + if var_type is bool: value = config(name, default=default, cast=bool) - elif var_type == int: + elif var_type is int: value = config(name, default=default, cast=int) - elif var_type == float: + elif var_type is float: value = config(name, default=default, cast=float) else: value = config(name, default=default) @@ -233,29 +227,21 @@ def validate_variable(name: str, rules: dict) -> list[str]: # Type validation if not validate_type(value, rules.get("type", str)): - errors.append( - f"{name}: Expected type {rules['type'].__name__}, " - f"got {type(value).__name__}" - ) + errors.append(f"{name}: Expected type {rules['type'].__name__}, " f"got {type(value).__name__}") # Length validation (for strings) if isinstance(value, str): min_length = rules.get("min_length", 0) max_length = rules.get("max_length") if not validate_length(value, min_length, max_length): - errors.append( - f"{name}: Length must be between {min_length} and " - f"{max_length or 'unlimited'}" - ) + errors.append(f"{name}: Length must be between {min_length} and " f"{max_length or 'unlimited'}") # Range validation (for numbers) - if isinstance(value, (int, float)): + if isinstance(value, int | float): min_value = rules.get("min_value") max_value = rules.get("max_value") if not validate_range(value, min_value, max_value): - errors.append( - f"{name}: Value must be between {min_value} and {max_value}" - ) + errors.append(f"{name}: Value must be between {min_value} and {max_value}") # Custom validator validator_name = rules.get("validator") @@ -285,13 +271,9 @@ def validate_cross_rules() -> list[str]: try: value = config(var_name, default=None) if value is not None and not check_fn(value): - errors.append( - f"{rule['name']}: {var_name} {message}" - ) + errors.append(f"{rule['name']}: {var_name} {message}") except Exception: - errors.append( - f"{rule['name']}: Could not validate {var_name}" - ) + errors.append(f"{rule['name']}: Could not validate {var_name}") except Exception as e: errors.append(f"Cross-validation error for {rule['name']}: {e}") @@ -343,9 +325,7 @@ def validate_all_settings(raise_on_error: bool = False) -> dict: logger.error(f"Configuration error: {error}") if raise_on_error: - raise ValueError( - f"Configuration validation failed: {result['errors']}" - ) + raise ValueError(f"Configuration validation failed: {result['errors']}") # Log warnings for warning in result["warnings"]: @@ -372,9 +352,7 @@ def run_startup_validation() -> None: for error in result["errors"]: warnings.warn(f"Configuration error: {error}", stacklevel=2) else: - raise ValueError( - "Configuration validation failed. Check logs for details." - ) + raise ValueError("Configuration validation failed. Check logs for details.") # ============================================================================= diff --git a/backend/ensure_admin.py b/backend/ensure_admin.py index 6391f2ae..793a263e 100644 --- a/backend/ensure_admin.py +++ b/backend/ensure_admin.py @@ -7,10 +7,11 @@ sys.path.append(os.path.join(os.path.dirname(__file__))) os.environ.setdefault("DJANGO_SETTINGS_MODULE", "thrillwiki.settings") django.setup() -from django.contrib.auth import get_user_model +from django.contrib.auth import get_user_model # noqa: E402 User = get_user_model() + def ensure_admin(): username = "admin" email = "admin@example.com" @@ -23,12 +24,13 @@ def ensure_admin(): else: print(f"Superuser {username} already exists.") u = User.objects.get(username=username) - if not u.is_staff or not u.is_superuser or u.role != 'ADMIN': + if not u.is_staff or not u.is_superuser or u.role != "ADMIN": u.is_staff = True u.is_superuser = True - u.role = 'ADMIN' + u.role = "ADMIN" u.save() print("Updated existing user to ADMIN/Superuser.") + if __name__ == "__main__": ensure_admin() diff --git a/backend/scripts/benchmark_queries.py b/backend/scripts/benchmark_queries.py index 98f1e451..4f749c44 100644 --- a/backend/scripts/benchmark_queries.py +++ b/backend/scripts/benchmark_queries.py @@ -30,6 +30,7 @@ if not settings.DEBUG: def benchmark(name: str, iterations: int = 5): """Decorator to benchmark a function.""" + def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> dict[str, Any]: @@ -48,17 +49,19 @@ def benchmark(name: str, iterations: int = 5): query_counts.append(len(context.captured_queries)) return { - 'name': name, - 'avg_time_ms': statistics.mean(times), - 'min_time_ms': min(times), - 'max_time_ms': max(times), - 'std_dev_ms': statistics.stdev(times) if len(times) > 1 else 0, - 'avg_queries': statistics.mean(query_counts), - 'min_queries': min(query_counts), - 'max_queries': max(query_counts), - 'iterations': iterations, + "name": name, + "avg_time_ms": statistics.mean(times), + "min_time_ms": min(times), + "max_time_ms": max(times), + "std_dev_ms": statistics.stdev(times) if len(times) > 1 else 0, + "avg_queries": statistics.mean(query_counts), + "min_queries": min(query_counts), + "max_queries": max(query_counts), + "iterations": iterations, } + return wrapper + return decorator @@ -67,7 +70,9 @@ def print_benchmark_result(result: dict[str, Any]) -> None: print(f"\n{'='*60}") print(f"Benchmark: {result['name']}") print(f"{'='*60}") - print(f" Time (ms): avg={result['avg_time_ms']:.2f}, min={result['min_time_ms']:.2f}, max={result['max_time_ms']:.2f}") + print( + f" Time (ms): avg={result['avg_time_ms']:.2f}, min={result['min_time_ms']:.2f}, max={result['max_time_ms']:.2f}" + ) print(f" Std Dev (ms): {result['std_dev_ms']:.2f}") print(f" Queries: avg={result['avg_queries']:.1f}, min={result['min_queries']}, max={result['max_queries']}") print(f" Iterations: {result['iterations']}") @@ -86,7 +91,7 @@ def run_benchmarks() -> list[dict[str, Any]]: parks = Park.objects.optimized_for_list()[:50] for park in parks: _ = park.operator - _ = park.coaster_count_calculated if hasattr(park, 'coaster_count_calculated') else None + _ = park.coaster_count_calculated if hasattr(park, "coaster_count_calculated") else None return list(parks) results.append(bench_park_list_optimized()) @@ -167,22 +172,22 @@ def run_benchmarks() -> list[dict[str, Any]]: def print_summary(results: list[dict[str, Any]]) -> None: """Print a summary table of all benchmarks.""" - print("\n" + "="*80) + print("\n" + "=" * 80) print("BENCHMARK SUMMARY") - print("="*80) + print("=" * 80) print(f"{'Benchmark':<45} {'Avg Time (ms)':<15} {'Avg Queries':<15}") - print("-"*80) + print("-" * 80) for result in results: print(f"{result['name']:<45} {result['avg_time_ms']:<15.2f} {result['avg_queries']:<15.1f}") - print("="*80) + print("=" * 80) if True: # Always run when executed - print("\n" + "="*80) + print("\n" + "=" * 80) print("THRILLWIKI QUERY PERFORMANCE BENCHMARKS") - print("="*80) + print("=" * 80) print("\nRunning benchmarks...") try: @@ -200,4 +205,5 @@ if True: # Always run when executed except Exception as e: print(f"\nError running benchmarks: {e}") import traceback + traceback.print_exc() diff --git a/backend/stubs/environ.pyi b/backend/stubs/environ.pyi index 1d8b44bd..1df17f37 100644 --- a/backend/stubs/environ.pyi +++ b/backend/stubs/environ.pyi @@ -1,3 +1,4 @@ +# ruff: noqa: B008 """Type stubs for django-environ to fix Pylance type checking issues.""" import builtins diff --git a/backend/templates/base/base.html b/backend/templates/base/base.html index 2604f422..a71f31f1 100644 --- a/backend/templates/base/base.html +++ b/backend/templates/base/base.html @@ -54,7 +54,7 @@ - + @@ -64,7 +64,7 @@ {# Use title block directly #} - {% block page_title %}{% block title %}ThrillWiki{% endblock %}{% endblock %} + {% block title %}ThrillWiki{% endblock %} diff --git a/backend/test_avatar_upload.py b/backend/test_avatar_upload.py index 42830782..432f0d1d 100644 --- a/backend/test_avatar_upload.py +++ b/backend/test_avatar_upload.py @@ -29,13 +29,7 @@ def step1_get_upload_url(): print("Step 1: Requesting upload URL...") url = f"{API_BASE}/cloudflare-images/api/upload-url/" - data = { - "metadata": { - "type": "avatar", - "userId": "7627" # Replace with your user ID - }, - "require_signed_urls": False - } + data = {"metadata": {"type": "avatar", "userId": "7627"}, "require_signed_urls": False} # Replace with your user ID response = requests.post(url, json=data, headers=HEADERS) print(f"Status: {response.status_code}") @@ -54,11 +48,9 @@ def step2_upload_image(upload_url): # Create a simple test image (1x1 pixel PNG) # This is a minimal valid PNG file - png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\x12IDATx\x9cc```bPPP\x00\x02\xd2\x00\x00\x00\x05\x00\x01\r\n-\xdb\x00\x00\x00\x00IEND\xaeB`\x82' + png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\x12IDATx\x9cc```bPPP\x00\x02\xd2\x00\x00\x00\x05\x00\x01\r\n-\xdb\x00\x00\x00\x00IEND\xaeB`\x82" - files = { - 'file': ('test_avatar.png', png_data, 'image/png') - } + files = {"file": ("test_avatar.png", png_data, "image/png")} # Upload to Cloudflare (no auth headers needed for direct upload) response = requests.post(upload_url, files=files) @@ -76,9 +68,7 @@ def step3_save_avatar(cloudflare_id): print("\nStep 3: Saving avatar reference...") url = f"{API_BASE}/accounts/profile/avatar/save/" - data = { - "cloudflare_image_id": cloudflare_id - } + data = {"cloudflare_image_id": cloudflare_id} response = requests.post(url, json=data, headers=HEADERS) print(f"Status: {response.status_code}") diff --git a/backend/tests/accessibility/test_wcag_compliance.py b/backend/tests/accessibility/test_wcag_compliance.py index 0d096f81..e34471e5 100644 --- a/backend/tests/accessibility/test_wcag_compliance.py +++ b/backend/tests/accessibility/test_wcag_compliance.py @@ -33,12 +33,14 @@ try: from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.ui import WebDriverWait + HAS_SELENIUM = True except ImportError: HAS_SELENIUM = False try: from axe_selenium_python import Axe + HAS_AXE = True except ImportError: HAS_AXE = False @@ -53,7 +55,7 @@ def skip_if_no_browser(): return unittest.skip("Selenium not installed") if not HAS_AXE: return unittest.skip("axe-selenium-python not installed") - if os.environ.get('CI') and not os.environ.get('BROWSER_TESTS'): + if os.environ.get("CI") and not os.environ.get("BROWSER_TESTS"): return unittest.skip("Browser tests disabled in CI") return lambda func: func @@ -73,14 +75,12 @@ class AccessibilityTestMixin: dict: Axe results containing violations and passes """ if url_name: - url = f'{self.live_server_url}{reverse(url_name)}' + url = f"{self.live_server_url}{reverse(url_name)}" elif not url: raise ValueError("Either url_name or url must be provided") self.driver.get(url) - WebDriverWait(self.driver, 10).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(self.driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) axe = Axe(self.driver) axe.inject() @@ -96,20 +96,13 @@ class AccessibilityTestMixin: results: Axe audit results page_name: Name of page for error messages """ - critical_violations = [ - v for v in results.get('violations', []) - if v.get('impact') in ('critical', 'serious') - ] + critical_violations = [v for v in results.get("violations", []) if v.get("impact") in ("critical", "serious")] if critical_violations: - violation_details = "\n".join([ - f"- {v['id']}: {v['description']} (impact: {v['impact']})" - for v in critical_violations - ]) - self.fail( - f"Critical accessibility violations found on {page_name}:\n" - f"{violation_details}" + violation_details = "\n".join( + [f"- {v['id']}: {v['description']} (impact: {v['impact']})" for v in critical_violations] ) + self.fail(f"Critical accessibility violations found on {page_name}:\n" f"{violation_details}") def assert_wcag_aa_compliant(self, results, page_name="page"): """ @@ -119,17 +112,13 @@ class AccessibilityTestMixin: results: Axe audit results page_name: Name of page for error messages """ - violations = results.get('violations', []) + violations = results.get("violations", []) if violations: - violation_details = "\n".join([ - f"- {v['id']}: {v['description']} (impact: {v['impact']})" - for v in violations - ]) - self.fail( - f"WCAG 2.1 AA violations found on {page_name}:\n" - f"{violation_details}" + violation_details = "\n".join( + [f"- {v['id']}: {v['description']} (impact: {v['impact']})" for v in violations] ) + self.fail(f"WCAG 2.1 AA violations found on {page_name}:\n" f"{violation_details}") @skip_if_no_browser() @@ -148,11 +137,11 @@ class WCAGComplianceTests(AccessibilityTestMixin, LiveServerTestCase): # Configure Chrome for headless testing chrome_options = Options() - chrome_options.add_argument('--headless') - chrome_options.add_argument('--no-sandbox') - chrome_options.add_argument('--disable-dev-shm-usage') - chrome_options.add_argument('--disable-gpu') - chrome_options.add_argument('--window-size=1920,1080') + chrome_options.add_argument("--headless") + chrome_options.add_argument("--no-sandbox") + chrome_options.add_argument("--disable-dev-shm-usage") + chrome_options.add_argument("--disable-gpu") + chrome_options.add_argument("--window-size=1920,1080") try: cls.driver = webdriver.Chrome(options=chrome_options) @@ -162,38 +151,38 @@ class WCAGComplianceTests(AccessibilityTestMixin, LiveServerTestCase): @classmethod def tearDownClass(cls): - if hasattr(cls, 'driver'): + if hasattr(cls, "driver"): cls.driver.quit() super().tearDownClass() def test_homepage_accessibility(self): """Test homepage WCAG 2.1 AA compliance.""" - results = self.run_axe_audit(url_name='home') + results = self.run_axe_audit(url_name="home") self.assert_no_critical_violations(results, "homepage") def test_park_list_accessibility(self): """Test park list page WCAG 2.1 AA compliance.""" - results = self.run_axe_audit(url_name='parks:park_list') + results = self.run_axe_audit(url_name="parks:park_list") self.assert_no_critical_violations(results, "park list") def test_ride_list_accessibility(self): """Test ride list page WCAG 2.1 AA compliance.""" - results = self.run_axe_audit(url_name='rides:global_ride_list') + results = self.run_axe_audit(url_name="rides:global_ride_list") self.assert_no_critical_violations(results, "ride list") def test_manufacturer_list_accessibility(self): """Test manufacturer list page WCAG 2.1 AA compliance.""" - results = self.run_axe_audit(url_name='rides:manufacturer_list') + results = self.run_axe_audit(url_name="rides:manufacturer_list") self.assert_no_critical_violations(results, "manufacturer list") def test_login_page_accessibility(self): """Test login page WCAG 2.1 AA compliance.""" - results = self.run_axe_audit(url_name='account_login') + results = self.run_axe_audit(url_name="account_login") self.assert_no_critical_violations(results, "login page") def test_signup_page_accessibility(self): """Test signup page WCAG 2.1 AA compliance.""" - results = self.run_axe_audit(url_name='account_signup') + results = self.run_axe_audit(url_name="account_signup") self.assert_no_critical_violations(results, "signup page") @@ -207,77 +196,66 @@ class HTMLAccessibilityTests(TestCase): def test_homepage_has_main_landmark(self): """Verify homepage has a main landmark.""" - response = self.client.get(reverse('home')) - self.assertContains(response, ']*>', content) + + img_tags = re.findall(r"]*>", content) for img in img_tags: - self.assertIn( - 'alt=', - img, - f"Image missing alt attribute: {img[:100]}" - ) + self.assertIn("alt=", img, f"Image missing alt attribute: {img[:100]}") def test_form_fields_have_labels(self): """Verify form fields have associated labels.""" - response = self.client.get(reverse('account_login')) - content = response.content.decode('utf-8') + response = self.client.get(reverse("account_login")) + content = response.content.decode("utf-8") # Find input elements (excluding hidden and submit) import re - inputs = re.findall( - r']*type=["\'](?!hidden|submit)[^"\']*["\'][^>]*>', - content - ) + + inputs = re.findall(r']*type=["\'](?!hidden|submit)[^"\']*["\'][^>]*>', content) for inp in inputs: # Each input should have id attribute for label association - self.assertTrue( - 'id=' in inp or 'aria-label' in inp, - f"Input missing id or aria-label: {inp[:100]}" - ) + self.assertTrue("id=" in inp or "aria-label" in inp, f"Input missing id or aria-label: {inp[:100]}") def test_buttons_are_accessible(self): """Verify buttons have accessible names.""" - response = self.client.get(reverse('home')) - content = response.content.decode('utf-8') + response = self.client.get(reverse("home")) + content = response.content.decode("utf-8") import re + # Find button elements - buttons = re.findall(r']*>.*?', content, re.DOTALL) + buttons = re.findall(r"]*>.*?", content, re.DOTALL) for button in buttons: # Button should have text content or aria-label - has_text = bool(re.search(r'>([^<]+)<', button)) - has_aria = 'aria-label' in button + has_text = bool(re.search(r">([^<]+)<", button)) + has_aria = "aria-label" in button - self.assertTrue( - has_text or has_aria, - f"Button missing accessible name: {button[:100]}" - ) + self.assertTrue(has_text or has_aria, f"Button missing accessible name: {button[:100]}") class KeyboardNavigationTests(TestCase): @@ -289,49 +267,35 @@ class KeyboardNavigationTests(TestCase): def test_interactive_elements_are_focusable(self): """Verify interactive elements don't have tabindex=-1.""" - response = self.client.get(reverse('home')) - content = response.content.decode('utf-8') + response = self.client.get(reverse("home")) + content = response.content.decode("utf-8") # Links and buttons should not have tabindex=-1 (unless intentionally hidden) import re - problematic = re.findall( - r'<(a|button)[^>]*tabindex=["\']?-1["\']?[^>]*>', - content - ) + + problematic = re.findall(r'<(a|button)[^>]*tabindex=["\']?-1["\']?[^>]*>', content) # Filter out elements that are legitimately hidden for elem in problematic: - self.assertIn( - 'aria-hidden', - elem, - f"Interactive element has tabindex=-1 without aria-hidden: {elem}" - ) + self.assertIn("aria-hidden", elem, f"Interactive element has tabindex=-1 without aria-hidden: {elem}") def test_modals_have_escape_handler(self): """Verify modal templates include escape key handling.""" from django.template.loader import get_template - template = get_template('components/modals/modal_inner.html') + template = get_template("components/modals/modal_inner.html") source = template.template.source - self.assertIn( - 'escape', - source.lower(), - "Modal should handle Escape key" - ) + self.assertIn("escape", source.lower(), "Modal should handle Escape key") def test_dropdowns_have_keyboard_support(self): """Verify dropdown menus support keyboard navigation.""" - response = self.client.get(reverse('home')) - content = response.content.decode('utf-8') + response = self.client.get(reverse("home")) + content = response.content.decode("utf-8") # Check for aria-expanded on dropdown triggers - if 'dropdown' in content.lower() or 'menu' in content.lower(): - self.assertIn( - 'aria-expanded', - content, - "Dropdown should have aria-expanded attribute" - ) + if "dropdown" in content.lower() or "menu" in content.lower(): + self.assertIn("aria-expanded", content, "Dropdown should have aria-expanded attribute") class ARIAAttributeTests(TestCase): @@ -343,42 +307,26 @@ class ARIAAttributeTests(TestCase): """Verify modal has role=dialog.""" from django.template.loader import get_template - template = get_template('components/modals/modal_inner.html') + template = get_template("components/modals/modal_inner.html") source = template.template.source - self.assertIn( - 'role="dialog"', - source, - "Modal should have role=dialog" - ) + self.assertIn('role="dialog"', source, "Modal should have role=dialog") def test_modal_has_aria_modal(self): """Verify modal has aria-modal=true.""" from django.template.loader import get_template - template = get_template('components/modals/modal_inner.html') + template = get_template("components/modals/modal_inner.html") source = template.template.source - self.assertIn( - 'aria-modal="true"', - source, - "Modal should have aria-modal=true" - ) + self.assertIn('aria-modal="true"', source, "Modal should have aria-modal=true") def test_breadcrumb_has_navigation_role(self): """Verify breadcrumbs use nav element with aria-label.""" from django.template.loader import get_template - template = get_template('components/navigation/breadcrumbs.html') + template = get_template("components/navigation/breadcrumbs.html") source = template.template.source - self.assertIn( - ' {{ const response = await fetch('/core/fsm/parks/park/{park.pk}/transition/transition_to_operating/', {{ method: 'POST', @@ -58,21 +57,20 @@ class TestInvalidTransitionErrors: hxTrigger: response.headers.get('HX-Trigger') }}; }} - """) + """ + ) # Should return error status (400) if response: - assert response.get('status') in [400, 403] + assert response.get("status") in [400, 403] # Check for error toast in HX-Trigger header - hx_trigger = response.get('hxTrigger') + hx_trigger = response.get("hxTrigger") if hx_trigger: - assert 'showToast' in hx_trigger - assert 'error' in hx_trigger.lower() + assert "showToast" in hx_trigger + assert "error" in hx_trigger.lower() - def test_already_transitioned_shows_error( - self, mod_page: Page, live_server, db - ): + def test_already_transitioned_shows_error(self, mod_page: Page, live_server, db): """Test that trying to approve an already-approved submission shows error.""" from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -84,8 +82,7 @@ class TestInvalidTransitionErrors: # Create an already-approved submission user, _ = User.objects.get_or_create( - username="testsubmitter2", - defaults={"email": "testsubmitter2@example.com"} + username="testsubmitter2", defaults={"email": "testsubmitter2@example.com"} ) park = Park.objects.first() @@ -101,7 +98,7 @@ class TestInvalidTransitionErrors: submission_type="EDIT", changes={"description": "Already approved"}, reason="Already approved test", - status="APPROVED" # Already approved + status="APPROVED", # Already approved ) try: @@ -109,7 +106,8 @@ class TestInvalidTransitionErrors: mod_page.wait_for_load_state("networkidle") # Try to approve again via direct API call - response = mod_page.evaluate(f""" + response = mod_page.evaluate( + f""" async () => {{ const response = await fetch('/core/fsm/moderation/editsubmission/{submission.pk}/transition/transition_to_approved/', {{ method: 'POST', @@ -125,18 +123,17 @@ class TestInvalidTransitionErrors: hxTrigger: response.headers.get('HX-Trigger') }}; }} - """) + """ + ) # Should return error status if response: - assert response.get('status') in [400, 403] + assert response.get("status") in [400, 403] finally: submission.delete() - def test_nonexistent_transition_shows_error( - self, mod_page: Page, live_server, db - ): + def test_nonexistent_transition_shows_error(self, mod_page: Page, live_server, db): """Test that requesting a non-existent transition shows error.""" from apps.parks.models import Park @@ -148,7 +145,8 @@ class TestInvalidTransitionErrors: mod_page.wait_for_load_state("networkidle") # Try to call a non-existent transition - response = mod_page.evaluate(f""" + response = mod_page.evaluate( + f""" async () => {{ const response = await fetch('/core/fsm/parks/park/{park.pk}/transition/nonexistent_transition/', {{ method: 'POST', @@ -164,19 +162,18 @@ class TestInvalidTransitionErrors: hxTrigger: response.headers.get('HX-Trigger') }}; }} - """) + """ + ) # Should return error status (400 or 404) if response: - assert response.get('status') in [400, 404] + assert response.get("status") in [400, 404] class TestLoadingIndicators: """Tests for loading indicator visibility during transitions.""" - def test_loading_indicator_appears_during_transition( - self, mod_page: Page, live_server, db - ): + def test_loading_indicator_appears_during_transition(self, mod_page: Page, live_server, db): """Verify loading spinner appears during HTMX transition.""" from apps.parks.models import Park @@ -187,19 +184,14 @@ class TestLoadingIndicators: mod_page.goto(f"{live_server.url}/parks/{park.slug}/") mod_page.wait_for_load_state("networkidle") - status_actions = mod_page.locator('[data-park-status-actions]') - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + status_actions = mod_page.locator("[data-park-status-actions]") + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") if not close_temp_btn.is_visible(): pytest.skip("Close Temporarily button not visible") # Add a route to slow down the request so we can see loading state - mod_page.route("**/core/fsm/**", lambda route: ( - mod_page.wait_for_timeout(500), - route.continue_() - )) + mod_page.route("**/core/fsm/**", lambda route: (mod_page.wait_for_timeout(500), route.continue_())) # Handle confirmation dialog mod_page.on("dialog", lambda dialog: dialog.accept()) @@ -212,19 +204,17 @@ class TestLoadingIndicators: # The loading indicator should appear (may be brief) # We wait a short time for it to appear - try: + try: # noqa: SIM105 expect(loading_indicator.first).to_be_visible(timeout=1000) except Exception: # Loading indicator may have already disappeared if response was fast pass # Wait for transition to complete - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) - def test_button_disabled_during_transition( - self, mod_page: Page, live_server, db - ): + def test_button_disabled_during_transition(self, mod_page: Page, live_server, db): """Test that transition button is disabled during request.""" from apps.parks.models import Park @@ -235,19 +225,14 @@ class TestLoadingIndicators: mod_page.goto(f"{live_server.url}/parks/{park.slug}/") mod_page.wait_for_load_state("networkidle") - status_actions = mod_page.locator('[data-park-status-actions]') - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + status_actions = mod_page.locator("[data-park-status-actions]") + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") if not close_temp_btn.is_visible(): pytest.skip("Close Temporarily button not visible") # Add a route to slow down the request - mod_page.route("**/core/fsm/**", lambda route: ( - mod_page.wait_for_timeout(1000), - route.continue_() - )) + mod_page.route("**/core/fsm/**", lambda route: (mod_page.wait_for_timeout(1000), route.continue_())) mod_page.on("dialog", lambda dialog: dialog.accept()) @@ -261,9 +246,7 @@ class TestLoadingIndicators: class TestNetworkErrorHandling: """Tests for handling network errors during transitions.""" - def test_network_error_shows_error_toast( - self, mod_page: Page, live_server, db - ): + def test_network_error_shows_error_toast(self, mod_page: Page, live_server, db): """Test that network errors show appropriate error toast.""" from apps.parks.models import Park @@ -277,10 +260,8 @@ class TestNetworkErrorHandling: # Abort network requests to simulate network error mod_page.route("**/core/fsm/**", lambda route: route.abort("failed")) - status_actions = mod_page.locator('[data-park-status-actions]') - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + status_actions = mod_page.locator("[data-park-status-actions]") + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") if not close_temp_btn.is_visible(): pytest.skip("Close Temporarily button not visible") @@ -295,7 +276,7 @@ class TestNetworkErrorHandling: error_indicator = mod_page.locator('[data-toast].error, .htmx-error, [class*="error"]') # May show as toast or inline error - try: + try: # noqa: SIM105 expect(error_indicator.first).to_be_visible(timeout=5000) except Exception: # Error may be handled differently @@ -305,9 +286,7 @@ class TestNetworkErrorHandling: park.refresh_from_db() assert park.status == "OPERATING" - def test_server_error_shows_user_friendly_message( - self, mod_page: Page, live_server, db - ): + def test_server_error_shows_user_friendly_message(self, mod_page: Page, live_server, db): """Test that server errors show user-friendly messages.""" from apps.parks.models import Park @@ -319,17 +298,18 @@ class TestNetworkErrorHandling: mod_page.wait_for_load_state("networkidle") # Return 500 error to simulate server error - mod_page.route("**/core/fsm/**", lambda route: route.fulfill( - status=500, - headers={"HX-Trigger": '{"showToast": {"message": "An unexpected error occurred", "type": "error"}}'}, - body="" - )) - - status_actions = mod_page.locator('[data-park-status-actions]') - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" + mod_page.route( + "**/core/fsm/**", + lambda route: route.fulfill( + status=500, + headers={"HX-Trigger": '{"showToast": {"message": "An unexpected error occurred", "type": "error"}}'}, + body="", + ), ) + status_actions = mod_page.locator("[data-park-status-actions]") + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") + if not close_temp_btn.is_visible(): pytest.skip("Close Temporarily button not visible") @@ -338,7 +318,7 @@ class TestNetworkErrorHandling: close_temp_btn.click() # Should show user-friendly error message - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Should not show technical error details to user @@ -349,9 +329,7 @@ class TestNetworkErrorHandling: class TestConfirmationDialogs: """Tests for confirmation dialogs on dangerous transitions.""" - def test_confirm_dialog_appears_for_reject_transition( - self, mod_page: Page, live_server, db - ): + def test_confirm_dialog_appears_for_reject_transition(self, mod_page: Page, live_server, db): """Test that confirmation dialog appears for reject transition.""" from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -362,8 +340,7 @@ class TestConfirmationDialogs: User = get_user_model() user, _ = User.objects.get_or_create( - username="testsubmitter3", - defaults={"email": "testsubmitter3@example.com"} + username="testsubmitter3", defaults={"email": "testsubmitter3@example.com"} ) park = Park.objects.first() @@ -379,7 +356,7 @@ class TestConfirmationDialogs: submission_type="EDIT", changes={"description": "Confirm dialog test"}, reason="Confirm dialog test", - status="PENDING" + status="PENDING", ) dialog_shown = {"shown": False} @@ -395,9 +372,7 @@ class TestConfirmationDialogs: mod_page.on("dialog", handle_dialog) - submission_row = mod_page.locator( - f'[data-submission-id="{submission.pk}"]' - ) + submission_row = mod_page.locator(f'[data-submission-id="{submission.pk}"]') if submission_row.is_visible(): reject_btn = submission_row.get_by_role("button", name="Reject") @@ -413,9 +388,7 @@ class TestConfirmationDialogs: finally: submission.delete() - def test_cancel_confirm_dialog_prevents_transition( - self, mod_page: Page, live_server, db - ): + def test_cancel_confirm_dialog_prevents_transition(self, mod_page: Page, live_server, db): """Test that canceling the confirmation dialog prevents the transition.""" from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -426,8 +399,7 @@ class TestConfirmationDialogs: User = get_user_model() user, _ = User.objects.get_or_create( - username="testsubmitter4", - defaults={"email": "testsubmitter4@example.com"} + username="testsubmitter4", defaults={"email": "testsubmitter4@example.com"} ) park = Park.objects.first() @@ -443,7 +415,7 @@ class TestConfirmationDialogs: submission_type="EDIT", changes={"description": "Cancel confirm test"}, reason="Cancel confirm test", - status="PENDING" + status="PENDING", ) try: @@ -453,9 +425,7 @@ class TestConfirmationDialogs: # Dismiss (cancel) the dialog mod_page.on("dialog", lambda dialog: dialog.dismiss()) - submission_row = mod_page.locator( - f'[data-submission-id="{submission.pk}"]' - ) + submission_row = mod_page.locator(f'[data-submission-id="{submission.pk}"]') if submission_row.is_visible(): reject_btn = submission_row.get_by_role("button", name="Reject") @@ -472,9 +442,7 @@ class TestConfirmationDialogs: finally: submission.delete() - def test_accept_confirm_dialog_executes_transition( - self, mod_page: Page, live_server, db - ): + def test_accept_confirm_dialog_executes_transition(self, mod_page: Page, live_server, db): """Test that accepting the confirmation dialog executes the transition.""" from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -485,8 +453,7 @@ class TestConfirmationDialogs: User = get_user_model() user, _ = User.objects.get_or_create( - username="testsubmitter5", - defaults={"email": "testsubmitter5@example.com"} + username="testsubmitter5", defaults={"email": "testsubmitter5@example.com"} ) park = Park.objects.first() @@ -502,7 +469,7 @@ class TestConfirmationDialogs: submission_type="EDIT", changes={"description": "Accept confirm test"}, reason="Accept confirm test", - status="PENDING" + status="PENDING", ) try: @@ -512,9 +479,7 @@ class TestConfirmationDialogs: # Accept the dialog mod_page.on("dialog", lambda dialog: dialog.accept()) - submission_row = mod_page.locator( - f'[data-submission-id="{submission.pk}"]' - ) + submission_row = mod_page.locator(f'[data-submission-id="{submission.pk}"]') if submission_row.is_visible(): reject_btn = submission_row.get_by_role("button", name="Reject") @@ -522,7 +487,7 @@ class TestConfirmationDialogs: reject_btn.click() # Wait for transition to complete - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify submission status WAS changed @@ -536,9 +501,7 @@ class TestConfirmationDialogs: class TestValidationErrors: """Tests for validation error handling.""" - def test_validation_error_shows_specific_message( - self, mod_page: Page, live_server, db - ): + def test_validation_error_shows_specific_message(self, mod_page: Page, live_server, db): """Test that validation errors show specific error messages.""" # This test depends on having transitions that require additional data # For example, a transition that requires a reason field @@ -563,9 +526,7 @@ class TestValidationErrors: class TestToastNotificationBehavior: """Tests for toast notification appearance and behavior.""" - def test_success_toast_auto_dismisses( - self, mod_page: Page, live_server, db - ): + def test_success_toast_auto_dismisses(self, mod_page: Page, live_server, db): """Test that success toast auto-dismisses after timeout.""" from apps.parks.models import Park @@ -576,10 +537,8 @@ class TestToastNotificationBehavior: mod_page.goto(f"{live_server.url}/parks/{park.slug}/") mod_page.wait_for_load_state("networkidle") - status_actions = mod_page.locator('[data-park-status-actions]') - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + status_actions = mod_page.locator("[data-park-status-actions]") + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") if not close_temp_btn.is_visible(): pytest.skip("Close Temporarily button not visible") @@ -589,16 +548,14 @@ class TestToastNotificationBehavior: close_temp_btn.click() # Toast should appear - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Toast should auto-dismiss after timeout (typically 3-5 seconds) # Wait for auto-dismiss expect(toast).not_to_be_visible(timeout=10000) - def test_error_toast_has_correct_styling( - self, mod_page: Page, live_server, db - ): + def test_error_toast_has_correct_styling(self, mod_page: Page, live_server, db): """Test that error toast has correct red/danger styling.""" from apps.parks.models import Park @@ -610,19 +567,18 @@ class TestToastNotificationBehavior: mod_page.wait_for_load_state("networkidle") # Simulate an error response - mod_page.route("**/core/fsm/**", lambda route: route.fulfill( - status=400, - headers={ - "HX-Trigger": '{"showToast": {"message": "Test error message", "type": "error"}}' - }, - body="" - )) - - status_actions = mod_page.locator('[data-park-status-actions]') - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" + mod_page.route( + "**/core/fsm/**", + lambda route: route.fulfill( + status=400, + headers={"HX-Trigger": '{"showToast": {"message": "Test error message", "type": "error"}}'}, + body="", + ), ) + status_actions = mod_page.locator("[data-park-status-actions]") + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") + if not close_temp_btn.is_visible(): pytest.skip("Close Temporarily button not visible") @@ -631,15 +587,13 @@ class TestToastNotificationBehavior: close_temp_btn.click() # Error toast should appear with error styling - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Should have error/danger styling (red) expect(toast).to_have_class(re.compile(r"error|danger|bg-red|text-red")) - def test_success_toast_has_correct_styling( - self, mod_page: Page, live_server, db - ): + def test_success_toast_has_correct_styling(self, mod_page: Page, live_server, db): """Test that success toast has correct green/success styling.""" from apps.parks.models import Park @@ -654,10 +608,8 @@ class TestToastNotificationBehavior: mod_page.goto(f"{live_server.url}/parks/{park.slug}/") mod_page.wait_for_load_state("networkidle") - status_actions = mod_page.locator('[data-park-status-actions]') - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + status_actions = mod_page.locator("[data-park-status-actions]") + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") if not close_temp_btn.is_visible(): pytest.skip("Close Temporarily button not visible") @@ -667,7 +619,7 @@ class TestToastNotificationBehavior: close_temp_btn.click() # Success toast should appear with success styling - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Should have success styling (green) diff --git a/backend/tests/e2e/test_fsm_permissions.py b/backend/tests/e2e/test_fsm_permissions.py index 74063ac7..2a6599cc 100644 --- a/backend/tests/e2e/test_fsm_permissions.py +++ b/backend/tests/e2e/test_fsm_permissions.py @@ -22,9 +22,7 @@ from playwright.sync_api import Page, expect class TestUnauthenticatedUserPermissions: """Tests for unauthenticated user permission guards.""" - def test_unauthenticated_user_cannot_see_moderation_dashboard( - self, page: Page, live_server - ): + def test_unauthenticated_user_cannot_see_moderation_dashboard(self, page: Page, live_server): """Test that unauthenticated users are redirected from moderation dashboard.""" # Navigate to moderation dashboard without logging in response = page.goto(f"{live_server.url}/moderation/dashboard/") @@ -34,9 +32,7 @@ class TestUnauthenticatedUserPermissions: current_url = page.url assert "login" in current_url or "denied" in current_url or response.status == 403 - def test_unauthenticated_user_cannot_see_transition_buttons( - self, page: Page, live_server, db - ): + def test_unauthenticated_user_cannot_see_transition_buttons(self, page: Page, live_server, db): """Test that unauthenticated users cannot see transition buttons on park detail.""" from apps.parks.models import Park @@ -48,18 +44,14 @@ class TestUnauthenticatedUserPermissions: page.wait_for_load_state("networkidle") # Status action buttons should NOT be visible - status_actions = page.locator('[data-park-status-actions]') + status_actions = page.locator("[data-park-status-actions]") # Either the section doesn't exist or the buttons are not there if status_actions.is_visible(): - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") expect(close_temp_btn).not_to_be_visible() - def test_unauthenticated_direct_post_returns_403( - self, page: Page, live_server, db - ): + def test_unauthenticated_direct_post_returns_403(self, page: Page, live_server, db): """Test that direct POST to FSM endpoint returns 403 for unauthenticated user.""" from apps.parks.models import Park @@ -70,7 +62,7 @@ class TestUnauthenticatedUserPermissions: # Attempt to POST directly to FSM transition endpoint response = page.request.post( f"{live_server.url}/core/fsm/parks/park/{park.pk}/transition/transition_to_closed_temp/", - headers={"HX-Request": "true"} + headers={"HX-Request": "true"}, ) # Should get 403 Forbidden @@ -84,9 +76,7 @@ class TestUnauthenticatedUserPermissions: class TestRegularUserPermissions: """Tests for regular (non-moderator) user permission guards.""" - def test_regular_user_cannot_approve_submission( - self, auth_page: Page, live_server, db - ): + def test_regular_user_cannot_approve_submission(self, auth_page: Page, live_server, db): """Test that regular users cannot approve submissions.""" from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -114,7 +104,7 @@ class TestRegularUserPermissions: submission_type="EDIT", changes={"description": "Test change"}, reason="Permission test", - status="PENDING" + status="PENDING", ) try: @@ -126,9 +116,7 @@ class TestRegularUserPermissions: # If somehow on dashboard, verify no approve button if "dashboard" in current_url: - submission_row = auth_page.locator( - f'[data-submission-id="{submission.pk}"]' - ) + submission_row = auth_page.locator(f'[data-submission-id="{submission.pk}"]') if submission_row.is_visible(): approve_btn = submission_row.get_by_role("button", name="Approve") expect(approve_btn).not_to_be_visible() @@ -136,7 +124,7 @@ class TestRegularUserPermissions: # Try direct POST - should be denied response = auth_page.request.post( f"{live_server.url}/core/fsm/moderation/editsubmission/{submission.pk}/transition/transition_to_approved/", - headers={"HX-Request": "true"} + headers={"HX-Request": "true"}, ) # Should be denied (403 or 302 redirect) @@ -149,9 +137,7 @@ class TestRegularUserPermissions: finally: submission.delete() - def test_regular_user_cannot_change_park_status( - self, auth_page: Page, live_server, db - ): + def test_regular_user_cannot_change_park_status(self, auth_page: Page, live_server, db): """Test that regular users cannot change park status.""" from apps.parks.models import Park @@ -163,18 +149,16 @@ class TestRegularUserPermissions: auth_page.wait_for_load_state("networkidle") # Status action buttons should NOT be visible to regular user - status_actions = auth_page.locator('[data-park-status-actions]') + status_actions = auth_page.locator("[data-park-status-actions]") if status_actions.is_visible(): - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") expect(close_temp_btn).not_to_be_visible() # Try direct POST - should be denied response = auth_page.request.post( f"{live_server.url}/core/fsm/parks/park/{park.pk}/transition/transition_to_closed_temp/", - headers={"HX-Request": "true"} + headers={"HX-Request": "true"}, ) # Should be denied @@ -184,9 +168,7 @@ class TestRegularUserPermissions: park.refresh_from_db() assert park.status == "OPERATING" - def test_regular_user_cannot_change_ride_status( - self, auth_page: Page, live_server, db - ): + def test_regular_user_cannot_change_ride_status(self, auth_page: Page, live_server, db): """Test that regular users cannot change ride status.""" from apps.rides.models import Ride @@ -194,24 +176,20 @@ class TestRegularUserPermissions: if not ride: pytest.skip("No operating ride available") - auth_page.goto( - f"{live_server.url}/parks/{ride.park.slug}/rides/{ride.slug}/" - ) + auth_page.goto(f"{live_server.url}/parks/{ride.park.slug}/rides/{ride.slug}/") auth_page.wait_for_load_state("networkidle") # Status action buttons should NOT be visible to regular user - status_actions = auth_page.locator('[data-ride-status-actions]') + status_actions = auth_page.locator("[data-ride-status-actions]") if status_actions.is_visible(): - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") expect(close_temp_btn).not_to_be_visible() # Try direct POST - should be denied response = auth_page.request.post( f"{live_server.url}/core/fsm/rides/ride/{ride.pk}/transition/transition_to_closed_temp/", - headers={"HX-Request": "true"} + headers={"HX-Request": "true"}, ) # Should be denied @@ -225,9 +203,7 @@ class TestRegularUserPermissions: class TestModeratorPermissions: """Tests for moderator-specific permission guards.""" - def test_moderator_can_approve_submission( - self, mod_page: Page, live_server, db - ): + def test_moderator_can_approve_submission(self, mod_page: Page, live_server, db): """Test that moderators CAN see and use approve button.""" from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -240,11 +216,7 @@ class TestModeratorPermissions: # Create a pending submission user = User.objects.filter(username="testuser").first() if not user: - user = User.objects.create_user( - username="testuser", - email="testuser@example.com", - password="testpass123" - ) + user = User.objects.create_user(username="testuser", email="testuser@example.com", password="testpass123") park = Park.objects.first() if not park: @@ -259,7 +231,7 @@ class TestModeratorPermissions: submission_type="EDIT", changes={"description": "Test change for moderator"}, reason="Moderator permission test", - status="PENDING" + status="PENDING", ) try: @@ -267,9 +239,7 @@ class TestModeratorPermissions: mod_page.wait_for_load_state("networkidle") # Moderator should be able to see the submission - submission_row = mod_page.locator( - f'[data-submission-id="{submission.pk}"]' - ) + submission_row = mod_page.locator(f'[data-submission-id="{submission.pk}"]') if submission_row.is_visible(): # Should see approve button @@ -279,9 +249,7 @@ class TestModeratorPermissions: finally: submission.delete() - def test_moderator_can_change_park_status( - self, mod_page: Page, live_server, db - ): + def test_moderator_can_change_park_status(self, mod_page: Page, live_server, db): """Test that moderators CAN see and use park status change buttons.""" from apps.parks.models import Park @@ -293,18 +261,14 @@ class TestModeratorPermissions: mod_page.wait_for_load_state("networkidle") # Status action buttons SHOULD be visible to moderator - status_actions = mod_page.locator('[data-park-status-actions]') + status_actions = mod_page.locator("[data-park-status-actions]") if status_actions.is_visible(): # Should see close temporarily button - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") expect(close_temp_btn).to_be_visible() - def test_moderator_cannot_access_admin_only_transitions( - self, mod_page: Page, live_server, db - ): + def test_moderator_cannot_access_admin_only_transitions(self, mod_page: Page, live_server, db): """Test that moderators CANNOT access admin-only transitions.""" # This test verifies that certain transitions require admin privileges # Specific transitions depend on the FSM configuration @@ -327,22 +291,18 @@ class TestModeratorPermissions: # Check for admin-only buttons (if any are configured) # The specific buttons that should be hidden depend on the FSM configuration - status_actions = mod_page.locator('[data-park-status-actions]') + status_actions = mod_page.locator("[data-park-status-actions]") # If there are admin-only transitions, verify they're hidden # This is a placeholder - actual admin-only transitions depend on configuration - admin_only_btn = status_actions.get_by_role( - "button", name="Force Delete" # Example admin-only action - ) + admin_only_btn = status_actions.get_by_role("button", name="Force Delete") # Example admin-only action expect(admin_only_btn).not_to_be_visible() class TestPermissionDeniedErrorHandling: """Tests for error handling when permission is denied.""" - def test_permission_denied_shows_error_toast( - self, auth_page: Page, live_server, db - ): + def test_permission_denied_shows_error_toast(self, auth_page: Page, live_server, db): """Test that permission denied errors show appropriate toast.""" from apps.parks.models import Park @@ -355,9 +315,12 @@ class TestPermissionDeniedErrorHandling: auth_page.wait_for_load_state("networkidle") # Make the request programmatically with HTMX header - response = auth_page.evaluate(""" + response = auth_page.evaluate( + """ async () => { - const response = await fetch('/core/fsm/parks/park/""" + str(park.pk) + """/transition/transition_to_closed_temp/', { + const response = await fetch('/core/fsm/parks/park/""" + + str(park.pk) + + """/transition/transition_to_closed_temp/', { method: 'POST', headers: { 'HX-Request': 'true', @@ -370,18 +333,17 @@ class TestPermissionDeniedErrorHandling: hxTrigger: response.headers.get('HX-Trigger') }; } - """) + """ + ) # Check if error toast was triggered - if response and response.get('status') in [400, 403]: - hx_trigger = response.get('hxTrigger') + if response and response.get("status") in [400, 403]: + hx_trigger = response.get("hxTrigger") if hx_trigger: - assert 'showToast' in hx_trigger - assert 'error' in hx_trigger.lower() or 'denied' in hx_trigger.lower() + assert "showToast" in hx_trigger + assert "error" in hx_trigger.lower() or "denied" in hx_trigger.lower() - def test_database_state_unchanged_on_permission_denied( - self, auth_page: Page, live_server, db - ): + def test_database_state_unchanged_on_permission_denied(self, auth_page: Page, live_server, db): """Test that database state is unchanged when permission is denied.""" from apps.parks.models import Park @@ -395,9 +357,12 @@ class TestPermissionDeniedErrorHandling: auth_page.goto(f"{live_server.url}/parks/{park.slug}/") auth_page.wait_for_load_state("networkidle") - auth_page.evaluate(""" + auth_page.evaluate( + """ async () => { - await fetch('/core/fsm/parks/park/""" + str(park.pk) + """/transition/transition_to_closed_temp/', { + await fetch('/core/fsm/parks/park/""" + + str(park.pk) + + """/transition/transition_to_closed_temp/', { method: 'POST', headers: { 'HX-Request': 'true', @@ -406,7 +371,8 @@ class TestPermissionDeniedErrorHandling: credentials: 'include' }); } - """) + """ + ) # Verify database state did NOT change park.refresh_from_db() @@ -416,9 +382,7 @@ class TestPermissionDeniedErrorHandling: class TestTransitionButtonVisibility: """Tests for correct transition button visibility based on permissions and state.""" - def test_transition_button_hidden_when_state_invalid( - self, mod_page: Page, live_server, db - ): + def test_transition_button_hidden_when_state_invalid(self, mod_page: Page, live_server, db): """Test that transition buttons are hidden when the current state is invalid.""" from apps.parks.models import Park @@ -430,7 +394,7 @@ class TestTransitionButtonVisibility: mod_page.goto(f"{live_server.url}/parks/{park.slug}/") mod_page.wait_for_load_state("networkidle") - status_actions = mod_page.locator('[data-park-status-actions]') + status_actions = mod_page.locator("[data-park-status-actions]") # Reopen button should NOT be visible for operating park # (can't reopen something that's already operating) @@ -442,9 +406,7 @@ class TestTransitionButtonVisibility: demolish_btn = status_actions.get_by_role("button", name="Mark as Demolished") expect(demolish_btn).not_to_be_visible() - def test_correct_buttons_shown_for_closed_temp_state( - self, mod_page: Page, live_server, db - ): + def test_correct_buttons_shown_for_closed_temp_state(self, mod_page: Page, live_server, db): """Test that correct buttons are shown for temporarily closed state.""" from apps.parks.models import Park @@ -461,21 +423,17 @@ class TestTransitionButtonVisibility: mod_page.goto(f"{live_server.url}/parks/{park.slug}/") mod_page.wait_for_load_state("networkidle") - status_actions = mod_page.locator('[data-park-status-actions]') + status_actions = mod_page.locator("[data-park-status-actions]") # Reopen button SHOULD be visible reopen_btn = status_actions.get_by_role("button", name="Reopen") expect(reopen_btn).to_be_visible() # Close Temporarily should NOT be visible (already closed) - close_temp_btn = status_actions.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = status_actions.get_by_role("button", name="Close Temporarily") expect(close_temp_btn).not_to_be_visible() - def test_correct_buttons_shown_for_closed_perm_state( - self, mod_page: Page, live_server, db - ): + def test_correct_buttons_shown_for_closed_perm_state(self, mod_page: Page, live_server, db): """Test that correct buttons are shown for permanently closed state.""" from apps.parks.models import Park @@ -492,7 +450,7 @@ class TestTransitionButtonVisibility: mod_page.goto(f"{live_server.url}/parks/{park.slug}/") mod_page.wait_for_load_state("networkidle") - status_actions = mod_page.locator('[data-park-status-actions]') + status_actions = mod_page.locator("[data-park-status-actions]") # Demolish/Relocate buttons SHOULD be visible demolish_btn = status_actions.get_by_role("button", name="Mark as Demolished") diff --git a/backend/tests/e2e/test_moderation_fsm.py b/backend/tests/e2e/test_moderation_fsm.py index 66e60892..92fa5238 100644 --- a/backend/tests/e2e/test_moderation_fsm.py +++ b/backend/tests/e2e/test_moderation_fsm.py @@ -30,10 +30,7 @@ def pending_submission(db): User = get_user_model() # Get or create test user - user, _ = User.objects.get_or_create( - username="testsubmitter", - defaults={"email": "testsubmitter@example.com"} - ) + user, _ = User.objects.get_or_create(username="testsubmitter", defaults={"email": "testsubmitter@example.com"}) user.set_password("testpass123") user.save() @@ -51,7 +48,7 @@ def pending_submission(db): submission_type="EDIT", changes={"description": "Updated park description for testing"}, reason="E2E test submission", - status="PENDING" + status="PENDING", ) yield submission @@ -73,8 +70,7 @@ def pending_photo_submission(db): # Get or create test user user, _ = User.objects.get_or_create( - username="testphotosubmitter", - defaults={"email": "testphotosubmitter@example.com"} + username="testphotosubmitter", defaults={"email": "testphotosubmitter@example.com"} ) user.set_password("testpass123") user.save() @@ -89,6 +85,7 @@ def pending_photo_submission(db): # Check if CloudflareImage model exists and has entries try: from django_cloudflareimages_toolkit.models import CloudflareImage + photo = CloudflareImage.objects.first() if not photo: pytest.skip("No CloudflareImage available for testing") @@ -96,12 +93,7 @@ def pending_photo_submission(db): pytest.skip("CloudflareImage not available") submission = PhotoSubmission.objects.create( - user=user, - content_type=content_type, - object_id=park.pk, - photo=photo, - caption="E2E test photo", - status="PENDING" + user=user, content_type=content_type, object_id=park.pk, photo=photo, caption="E2E test photo", status="PENDING" ) yield submission @@ -113,9 +105,7 @@ def pending_photo_submission(db): class TestEditSubmissionTransitions: """Tests for EditSubmission FSM transitions via HTMX.""" - def test_submission_approve_transition_as_moderator( - self, mod_page: Page, pending_submission, live_server - ): + def test_submission_approve_transition_as_moderator(self, mod_page: Page, pending_submission, live_server): """Test approving an EditSubmission as a moderator.""" # Navigate to moderation dashboard mod_page.goto(f"{live_server.url}/moderation/dashboard/") @@ -127,7 +117,7 @@ class TestEditSubmissionTransitions: submission_row = mod_page.locator(f'[data-submission-id="{pending_submission.pk}"]') # Verify initial status is pending - status_badge = submission_row.locator('[data-status-badge]') + status_badge = submission_row.locator("[data-status-badge]") expect(status_badge).to_contain_text("Pending") # Click the approve button @@ -140,7 +130,7 @@ class TestEditSubmissionTransitions: approve_btn.click() # Wait for toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) expect(toast).to_contain_text("approved") @@ -151,9 +141,7 @@ class TestEditSubmissionTransitions: pending_submission.refresh_from_db() assert pending_submission.status == "APPROVED" - def test_submission_reject_transition_as_moderator( - self, mod_page: Page, pending_submission, live_server - ): + def test_submission_reject_transition_as_moderator(self, mod_page: Page, pending_submission, live_server): """Test rejecting an EditSubmission as a moderator.""" mod_page.goto(f"{live_server.url}/moderation/dashboard/") mod_page.wait_for_load_state("networkidle") @@ -161,7 +149,7 @@ class TestEditSubmissionTransitions: submission_row = mod_page.locator(f'[data-submission-id="{pending_submission.pk}"]') # Verify initial status - status_badge = submission_row.locator('[data-status-badge]') + status_badge = submission_row.locator("[data-status-badge]") expect(status_badge).to_contain_text("Pending") # Click reject button @@ -173,7 +161,7 @@ class TestEditSubmissionTransitions: reject_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) expect(toast).to_contain_text("rejected") @@ -184,9 +172,7 @@ class TestEditSubmissionTransitions: pending_submission.refresh_from_db() assert pending_submission.status == "REJECTED" - def test_submission_escalate_transition_as_moderator( - self, mod_page: Page, pending_submission, live_server - ): + def test_submission_escalate_transition_as_moderator(self, mod_page: Page, pending_submission, live_server): """Test escalating an EditSubmission as a moderator.""" mod_page.goto(f"{live_server.url}/moderation/dashboard/") mod_page.wait_for_load_state("networkidle") @@ -194,7 +180,7 @@ class TestEditSubmissionTransitions: submission_row = mod_page.locator(f'[data-submission-id="{pending_submission.pk}"]') # Verify initial status - status_badge = submission_row.locator('[data-status-badge]') + status_badge = submission_row.locator("[data-status-badge]") expect(status_badge).to_contain_text("Pending") # Click escalate button @@ -206,7 +192,7 @@ class TestEditSubmissionTransitions: escalate_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) expect(toast).to_contain_text("escalated") @@ -221,9 +207,7 @@ class TestEditSubmissionTransitions: class TestPhotoSubmissionTransitions: """Tests for PhotoSubmission FSM transitions via HTMX.""" - def test_photo_submission_approve_transition( - self, mod_page: Page, pending_photo_submission, live_server - ): + def test_photo_submission_approve_transition(self, mod_page: Page, pending_photo_submission, live_server): """Test approving a PhotoSubmission as a moderator.""" mod_page.goto(f"{live_server.url}/moderation/dashboard/") mod_page.wait_for_load_state("networkidle") @@ -234,9 +218,7 @@ class TestPhotoSubmissionTransitions: photos_tab.click() # Find the photo submission row - submission_row = mod_page.locator( - f'[data-photo-submission-id="{pending_photo_submission.pk}"]' - ) + submission_row = mod_page.locator(f'[data-photo-submission-id="{pending_photo_submission.pk}"]') if not submission_row.is_visible(): pytest.skip("Photo submission not visible in dashboard") @@ -248,7 +230,7 @@ class TestPhotoSubmissionTransitions: approve_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) expect(toast).to_contain_text("approved") @@ -256,9 +238,7 @@ class TestPhotoSubmissionTransitions: pending_photo_submission.refresh_from_db() assert pending_photo_submission.status == "APPROVED" - def test_photo_submission_reject_transition( - self, mod_page: Page, pending_photo_submission, live_server - ): + def test_photo_submission_reject_transition(self, mod_page: Page, pending_photo_submission, live_server): """Test rejecting a PhotoSubmission as a moderator.""" mod_page.goto(f"{live_server.url}/moderation/dashboard/") mod_page.wait_for_load_state("networkidle") @@ -269,9 +249,7 @@ class TestPhotoSubmissionTransitions: photos_tab.click() # Find the photo submission row - submission_row = mod_page.locator( - f'[data-photo-submission-id="{pending_photo_submission.pk}"]' - ) + submission_row = mod_page.locator(f'[data-photo-submission-id="{pending_photo_submission.pk}"]') if not submission_row.is_visible(): pytest.skip("Photo submission not visible in dashboard") @@ -283,7 +261,7 @@ class TestPhotoSubmissionTransitions: reject_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) expect(toast).to_contain_text("rejected") @@ -304,10 +282,7 @@ class TestModerationQueueTransitions: User = get_user_model() - user, _ = User.objects.get_or_create( - username="testflagger", - defaults={"email": "testflagger@example.com"} - ) + user, _ = User.objects.get_or_create(username="testflagger", defaults={"email": "testflagger@example.com"}) queue_item = ModerationQueue.objects.create( item_type="CONTENT_REVIEW", @@ -315,16 +290,14 @@ class TestModerationQueueTransitions: priority="MEDIUM", title="E2E Test Queue Item", description="Queue item for E2E testing", - flagged_by=user + flagged_by=user, ) yield queue_item queue_item.delete() - def test_moderation_queue_start_transition( - self, mod_page: Page, pending_queue_item, live_server - ): + def test_moderation_queue_start_transition(self, mod_page: Page, pending_queue_item, live_server): """Test starting work on a ModerationQueue item.""" mod_page.goto(f"{live_server.url}/moderation/queue/") mod_page.wait_for_load_state("networkidle") @@ -340,16 +313,14 @@ class TestModerationQueueTransitions: start_btn.click() # Verify status updated to IN_PROGRESS - status_badge = queue_row.locator('[data-status-badge]') + status_badge = queue_row.locator("[data-status-badge]") expect(status_badge).to_contain_text("In Progress", timeout=5000) # Verify database state pending_queue_item.refresh_from_db() assert pending_queue_item.status == "IN_PROGRESS" - def test_moderation_queue_complete_transition( - self, mod_page: Page, pending_queue_item, live_server - ): + def test_moderation_queue_complete_transition(self, mod_page: Page, pending_queue_item, live_server): """Test completing a ModerationQueue item.""" # First set status to IN_PROGRESS pending_queue_item.status = "IN_PROGRESS" @@ -370,7 +341,7 @@ class TestModerationQueueTransitions: complete_btn.click() # Verify toast and status - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) pending_queue_item.refresh_from_db() @@ -390,8 +361,7 @@ class TestBulkOperationTransitions: User = get_user_model() user, _ = User.objects.get_or_create( - username="testadmin", - defaults={"email": "testadmin@example.com", "is_staff": True} + username="testadmin", defaults={"email": "testadmin@example.com", "is_staff": True} ) operation = BulkOperation.objects.create( @@ -401,24 +371,20 @@ class TestBulkOperationTransitions: description="E2E Test Bulk Operation", parameters={"test": True}, created_by=user, - total_items=10 + total_items=10, ) yield operation operation.delete() - def test_bulk_operation_cancel_transition( - self, mod_page: Page, pending_bulk_operation, live_server - ): + def test_bulk_operation_cancel_transition(self, mod_page: Page, pending_bulk_operation, live_server): """Test canceling a BulkOperation.""" mod_page.goto(f"{live_server.url}/moderation/bulk-operations/") mod_page.wait_for_load_state("networkidle") # Find the operation row - operation_row = mod_page.locator( - f'[data-bulk-operation-id="{pending_bulk_operation.pk}"]' - ) + operation_row = mod_page.locator(f'[data-bulk-operation-id="{pending_bulk_operation.pk}"]') if not operation_row.is_visible(): pytest.skip("Bulk operation not visible") @@ -430,7 +396,7 @@ class TestBulkOperationTransitions: cancel_btn.click() # Verify toast - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) expect(toast).to_contain_text("cancel") @@ -442,16 +408,12 @@ class TestBulkOperationTransitions: class TestTransitionLoadingStates: """Tests for loading indicators during FSM transitions.""" - def test_loading_indicator_appears_during_transition( - self, mod_page: Page, pending_submission, live_server - ): + def test_loading_indicator_appears_during_transition(self, mod_page: Page, pending_submission, live_server): """Verify loading spinner appears during HTMX transition.""" mod_page.goto(f"{live_server.url}/moderation/dashboard/") mod_page.wait_for_load_state("networkidle") - submission_row = mod_page.locator( - f'[data-submission-id="{pending_submission.pk}"]' - ) + submission_row = mod_page.locator(f'[data-submission-id="{pending_submission.pk}"]') # Get approve button and associated loading indicator approve_btn = submission_row.get_by_role("button", name="Approve") @@ -466,8 +428,8 @@ class TestTransitionLoadingStates: # Check for htmx-indicator visibility (may be brief) # The indicator should become visible during the request - submission_row.locator('.htmx-indicator') + submission_row.locator(".htmx-indicator") # Wait for transition to complete - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) diff --git a/backend/tests/e2e/test_park_browsing.py b/backend/tests/e2e/test_park_browsing.py index c2690718..d8174d04 100644 --- a/backend/tests/e2e/test_park_browsing.py +++ b/backend/tests/e2e/test_park_browsing.py @@ -32,9 +32,7 @@ class TestParkListPage: first_park = parks_data[0] expect(page.get_by_text(first_park.name)).to_be_visible() - def test__park_list__click_park__navigates_to_detail( - self, page: Page, live_server, parks_data - ): + def test__park_list__click_park__navigates_to_detail(self, page: Page, live_server, parks_data): """Test clicking a park navigates to detail page.""" page.goto(f"{live_server.url}/parks/") @@ -51,9 +49,7 @@ class TestParkListPage: page.goto(f"{live_server.url}/parks/") # Find search input - search_input = page.locator( - "input[type='search'], input[name='q'], input[placeholder*='search' i]" - ) + search_input = page.locator("input[type='search'], input[name='q'], input[placeholder*='search' i]") if search_input.count() > 0: search_input.first.fill("E2E Test Park 0") @@ -83,9 +79,7 @@ class TestParkDetailPage: page.goto(f"{live_server.url}/parks/{park.slug}/") # Look for rides section/tab - page.locator( - "[data-testid='rides-section'], #rides, [role='tabpanel']" - ) + page.locator("[data-testid='rides-section'], #rides, [role='tabpanel']") # Or a rides tab rides_tab = page.get_by_role("tab", name="Rides") @@ -103,9 +97,7 @@ class TestParkDetailPage: page.goto(f"{live_server.url}/parks/{park.slug}/") # Status badge or indicator should be visible - status_indicator = page.locator( - ".status-badge, [data-testid='status'], .park-status" - ) + status_indicator = page.locator(".status-badge, [data-testid='status'], .park-status") expect(status_indicator.first).to_be_visible() @@ -118,9 +110,7 @@ class TestParkFiltering: page.goto(f"{live_server.url}/parks/") # Find status filter - status_filter = page.locator( - "select[name='status'], [data-testid='status-filter']" - ) + status_filter = page.locator("select[name='status'], [data-testid='status-filter']") if status_filter.count() > 0: status_filter.first.select_option("OPERATING") @@ -135,9 +125,7 @@ class TestParkFiltering: page.goto(f"{live_server.url}/parks/") # Find clear filters button - clear_btn = page.locator( - "[data-testid='clear-filters'], button:has-text('Clear')" - ) + clear_btn = page.locator("[data-testid='clear-filters'], button:has-text('Clear')") if clear_btn.count() > 0: clear_btn.first.click() @@ -164,9 +152,7 @@ class TestParkNavigation: expect(page).to_have_url("**/parks/**") - def test__back_button__returns_to_previous_page( - self, page: Page, live_server, parks_data - ): + def test__back_button__returns_to_previous_page(self, page: Page, live_server, parks_data): """Test browser back button returns to previous page.""" page.goto(f"{live_server.url}/parks/") diff --git a/backend/tests/e2e/test_park_ride_fsm.py b/backend/tests/e2e/test_park_ride_fsm.py index e67b1f43..25125616 100644 --- a/backend/tests/e2e/test_park_ride_fsm.py +++ b/backend/tests/e2e/test_park_ride_fsm.py @@ -26,11 +26,7 @@ def operating_park(db): from tests.factories import ParkFactory # Use factory to create a complete park - park = ParkFactory( - name="E2E Test Park", - slug="e2e-test-park", - status="OPERATING" - ) + park = ParkFactory(name="E2E Test Park", slug="e2e-test-park", status="OPERATING") yield park @@ -42,12 +38,7 @@ def operating_ride(db, operating_park): """Create an operating Ride for testing status transitions.""" from tests.factories import RideFactory - ride = RideFactory( - name="E2E Test Ride", - slug="e2e-test-ride", - park=operating_park, - status="OPERATING" - ) + ride = RideFactory(name="E2E Test Ride", slug="e2e-test-ride", park=operating_park, status="OPERATING") yield ride @@ -55,31 +46,25 @@ def operating_ride(db, operating_park): class TestParkStatusTransitions: """Tests for Park FSM status transitions via HTMX.""" - def test_park_close_temporarily_as_moderator( - self, mod_page: Page, operating_park, live_server - ): + def test_park_close_temporarily_as_moderator(self, mod_page: Page, operating_park, live_server): """Test closing a park temporarily as a moderator.""" mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") # Verify initial status badge shows Operating - status_section = mod_page.locator('[data-park-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-park-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_contain_text("Operating") # Find and click "Close Temporarily" button - close_temp_btn = status_section.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = status_section.get_by_role("button", name="Close Temporarily") if not close_temp_btn.is_visible(): # May be in a dropdown menu - actions_dropdown = status_section.locator('[data-actions-dropdown]') + actions_dropdown = status_section.locator("[data-actions-dropdown]") if actions_dropdown.is_visible(): actions_dropdown.click() - close_temp_btn = mod_page.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = mod_page.get_by_role("button", name="Close Temporarily") # Handle confirmation dialog mod_page.on("dialog", lambda dialog: dialog.accept()) @@ -87,7 +72,7 @@ class TestParkStatusTransitions: close_temp_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -97,9 +82,7 @@ class TestParkStatusTransitions: operating_park.refresh_from_db() assert operating_park.status == "CLOSED_TEMP" - def test_park_reopen_from_closed_temp( - self, mod_page: Page, operating_park, live_server - ): + def test_park_reopen_from_closed_temp(self, mod_page: Page, operating_park, live_server): """Test reopening a temporarily closed park.""" # First close the park temporarily operating_park.status = "CLOSED_TEMP" @@ -109,18 +92,18 @@ class TestParkStatusTransitions: mod_page.wait_for_load_state("networkidle") # Verify initial status badge shows Temporarily Closed - status_badge = mod_page.locator('[data-status-badge]') + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_contain_text("Temporarily Closed") # Find and click "Reopen" button - status_section = mod_page.locator('[data-park-status-actions]') + status_section = mod_page.locator("[data-park-status-actions]") reopen_btn = status_section.get_by_role("button", name="Reopen") mod_page.on("dialog", lambda dialog: dialog.accept()) reopen_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -130,34 +113,28 @@ class TestParkStatusTransitions: operating_park.refresh_from_db() assert operating_park.status == "OPERATING" - def test_park_close_permanently_as_moderator( - self, mod_page: Page, operating_park, live_server - ): + def test_park_close_permanently_as_moderator(self, mod_page: Page, operating_park, live_server): """Test closing a park permanently as a moderator.""" mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-park-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-park-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") # Find and click "Close Permanently" button - close_perm_btn = status_section.get_by_role( - "button", name="Close Permanently" - ) + close_perm_btn = status_section.get_by_role("button", name="Close Permanently") if not close_perm_btn.is_visible(): - actions_dropdown = status_section.locator('[data-actions-dropdown]') + actions_dropdown = status_section.locator("[data-actions-dropdown]") if actions_dropdown.is_visible(): actions_dropdown.click() - close_perm_btn = mod_page.get_by_role( - "button", name="Close Permanently" - ) + close_perm_btn = mod_page.get_by_role("button", name="Close Permanently") mod_page.on("dialog", lambda dialog: dialog.accept()) close_perm_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -167,9 +144,7 @@ class TestParkStatusTransitions: operating_park.refresh_from_db() assert operating_park.status == "CLOSED_PERM" - def test_park_demolish_from_closed_perm( - self, mod_page: Page, operating_park, live_server - ): + def test_park_demolish_from_closed_perm(self, mod_page: Page, operating_park, live_server): """Test transitioning a permanently closed park to demolished.""" # Set park to permanently closed operating_park.status = "CLOSED_PERM" @@ -179,25 +154,23 @@ class TestParkStatusTransitions: mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-park-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-park-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") # Find and click "Mark as Demolished" button demolish_btn = status_section.get_by_role("button", name="Mark as Demolished") if not demolish_btn.is_visible(): - actions_dropdown = status_section.locator('[data-actions-dropdown]') + actions_dropdown = status_section.locator("[data-actions-dropdown]") if actions_dropdown.is_visible(): actions_dropdown.click() - demolish_btn = mod_page.get_by_role( - "button", name="Mark as Demolished" - ) + demolish_btn = mod_page.get_by_role("button", name="Mark as Demolished") mod_page.on("dialog", lambda dialog: dialog.accept()) demolish_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -207,24 +180,18 @@ class TestParkStatusTransitions: operating_park.refresh_from_db() assert operating_park.status == "DEMOLISHED" - def test_park_available_transitions_update( - self, mod_page: Page, operating_park, live_server - ): + def test_park_available_transitions_update(self, mod_page: Page, operating_park, live_server): """Test that available transitions update based on current state.""" mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-park-status-actions]') + status_section = mod_page.locator("[data-park-status-actions]") # Operating park should have Close Temporarily and Close Permanently - expect( - status_section.get_by_role("button", name="Close Temporarily") - ).to_be_visible() + expect(status_section.get_by_role("button", name="Close Temporarily")).to_be_visible() # Should NOT have Reopen (not applicable for Operating state) - expect( - status_section.get_by_role("button", name="Reopen") - ).not_to_be_visible() + expect(status_section.get_by_role("button", name="Reopen")).not_to_be_visible() # Now close temporarily and verify buttons change operating_park.status = "CLOSED_TEMP" @@ -234,37 +201,29 @@ class TestParkStatusTransitions: mod_page.wait_for_load_state("networkidle") # Now should have Reopen button - expect( - status_section.get_by_role("button", name="Reopen") - ).to_be_visible() + expect(status_section.get_by_role("button", name="Reopen")).to_be_visible() class TestRideStatusTransitions: """Tests for Ride FSM status transitions via HTMX.""" - def test_ride_close_temporarily_as_moderator( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_close_temporarily_as_moderator(self, mod_page: Page, operating_ride, live_server): """Test closing a ride temporarily as a moderator.""" - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-ride-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-ride-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_contain_text("Operating") # Find and click "Close Temporarily" button - close_temp_btn = status_section.get_by_role( - "button", name="Close Temporarily" - ) + close_temp_btn = status_section.get_by_role("button", name="Close Temporarily") mod_page.on("dialog", lambda dialog: dialog.accept()) close_temp_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -274,23 +233,19 @@ class TestRideStatusTransitions: operating_ride.refresh_from_db() assert operating_ride.status == "CLOSED_TEMP" - def test_ride_mark_sbno_as_moderator( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_mark_sbno_as_moderator(self, mod_page: Page, operating_ride, live_server): """Test marking a ride as Standing But Not Operating (SBNO).""" - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-ride-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-ride-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") # Find and click "Mark SBNO" button sbno_btn = status_section.get_by_role("button", name="Mark SBNO") if not sbno_btn.is_visible(): - actions_dropdown = status_section.locator('[data-actions-dropdown]') + actions_dropdown = status_section.locator("[data-actions-dropdown]") if actions_dropdown.is_visible(): actions_dropdown.click() sbno_btn = mod_page.get_by_role("button", name="Mark SBNO") @@ -299,7 +254,7 @@ class TestRideStatusTransitions: sbno_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -309,21 +264,17 @@ class TestRideStatusTransitions: operating_ride.refresh_from_db() assert operating_ride.status == "SBNO" - def test_ride_reopen_from_closed_temp( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_reopen_from_closed_temp(self, mod_page: Page, operating_ride, live_server): """Test reopening a temporarily closed ride.""" # First close the ride temporarily operating_ride.status = "CLOSED_TEMP" operating_ride.save() - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-ride-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-ride-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") # Find and click "Reopen" button reopen_btn = status_section.get_by_role("button", name="Reopen") @@ -332,7 +283,7 @@ class TestRideStatusTransitions: reopen_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -342,36 +293,28 @@ class TestRideStatusTransitions: operating_ride.refresh_from_db() assert operating_ride.status == "OPERATING" - def test_ride_close_permanently_as_moderator( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_close_permanently_as_moderator(self, mod_page: Page, operating_ride, live_server): """Test closing a ride permanently as a moderator.""" - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-ride-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-ride-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") # Find and click "Close Permanently" button - close_perm_btn = status_section.get_by_role( - "button", name="Close Permanently" - ) + close_perm_btn = status_section.get_by_role("button", name="Close Permanently") if not close_perm_btn.is_visible(): - actions_dropdown = status_section.locator('[data-actions-dropdown]') + actions_dropdown = status_section.locator("[data-actions-dropdown]") if actions_dropdown.is_visible(): actions_dropdown.click() - close_perm_btn = mod_page.get_by_role( - "button", name="Close Permanently" - ) + close_perm_btn = mod_page.get_by_role("button", name="Close Permanently") mod_page.on("dialog", lambda dialog: dialog.accept()) close_perm_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -381,41 +324,33 @@ class TestRideStatusTransitions: operating_ride.refresh_from_db() assert operating_ride.status == "CLOSED_PERM" - def test_ride_demolish_from_closed_perm( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_demolish_from_closed_perm(self, mod_page: Page, operating_ride, live_server): """Test transitioning a permanently closed ride to demolished.""" # Set ride to permanently closed operating_ride.status = "CLOSED_PERM" operating_ride.closing_date = date.today() - timedelta(days=365) operating_ride.save() - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-ride-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-ride-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") # Find and click "Mark as Demolished" button - demolish_btn = status_section.get_by_role( - "button", name="Mark as Demolished" - ) + demolish_btn = status_section.get_by_role("button", name="Mark as Demolished") if not demolish_btn.is_visible(): - actions_dropdown = status_section.locator('[data-actions-dropdown]') + actions_dropdown = status_section.locator("[data-actions-dropdown]") if actions_dropdown.is_visible(): actions_dropdown.click() - demolish_btn = mod_page.get_by_role( - "button", name="Mark as Demolished" - ) + demolish_btn = mod_page.get_by_role("button", name="Mark as Demolished") mod_page.on("dialog", lambda dialog: dialog.accept()) demolish_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -425,41 +360,33 @@ class TestRideStatusTransitions: operating_ride.refresh_from_db() assert operating_ride.status == "DEMOLISHED" - def test_ride_relocate_from_closed_perm( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_relocate_from_closed_perm(self, mod_page: Page, operating_ride, live_server): """Test transitioning a permanently closed ride to relocated.""" # Set ride to permanently closed operating_ride.status = "CLOSED_PERM" operating_ride.closing_date = date.today() - timedelta(days=365) operating_ride.save() - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-ride-status-actions]') - status_badge = mod_page.locator('[data-status-badge]') + status_section = mod_page.locator("[data-ride-status-actions]") + status_badge = mod_page.locator("[data-status-badge]") # Find and click "Mark as Relocated" button - relocate_btn = status_section.get_by_role( - "button", name="Mark as Relocated" - ) + relocate_btn = status_section.get_by_role("button", name="Mark as Relocated") if not relocate_btn.is_visible(): - actions_dropdown = status_section.locator('[data-actions-dropdown]') + actions_dropdown = status_section.locator("[data-actions-dropdown]") if actions_dropdown.is_visible(): actions_dropdown.click() - relocate_btn = mod_page.get_by_role( - "button", name="Mark as Relocated" - ) + relocate_btn = mod_page.get_by_role("button", name="Mark as Relocated") mod_page.on("dialog", lambda dialog: dialog.accept()) relocate_btn.click() # Verify toast notification - toast = mod_page.locator('[data-toast]') + toast = mod_page.locator("[data-toast]") expect(toast).to_be_visible(timeout=5000) # Verify status badge updated @@ -473,28 +400,22 @@ class TestRideStatusTransitions: class TestRideClosingWorkflow: """Tests for the special CLOSING status workflow with automatic transitions.""" - def test_ride_set_closing_with_future_date( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_set_closing_with_future_date(self, mod_page: Page, operating_ride, live_server): """Test setting a ride to CLOSING status with a future closing date.""" - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") - status_section = mod_page.locator('[data-ride-status-actions]') + status_section = mod_page.locator("[data-ride-status-actions]") # Find and click "Set Closing" button - set_closing_btn = status_section.get_by_role( - "button", name="Set Closing" - ) + set_closing_btn = status_section.get_by_role("button", name="Set Closing") if set_closing_btn.is_visible(): mod_page.on("dialog", lambda dialog: dialog.accept()) set_closing_btn.click() # Verify status badge updated - status_badge = mod_page.locator('[data-status-badge]') + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_contain_text("Closing", timeout=5000) # Verify database state @@ -503,9 +424,7 @@ class TestRideClosingWorkflow: else: pytest.skip("Set Closing button not available") - def test_ride_closing_shows_countdown( - self, mod_page: Page, operating_ride, live_server - ): + def test_ride_closing_shows_countdown(self, mod_page: Page, operating_ride, live_server): """Test that a ride in CLOSING status shows a countdown to closing date.""" # Set ride to CLOSING with future date future_date = date.today() + timedelta(days=30) @@ -513,37 +432,31 @@ class TestRideClosingWorkflow: operating_ride.closing_date = future_date operating_ride.save() - mod_page.goto( - f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/" - ) + mod_page.goto(f"{live_server.url}/parks/{operating_ride.park.slug}/rides/{operating_ride.slug}/") mod_page.wait_for_load_state("networkidle") # Verify closing countdown is displayed - closing_info = mod_page.locator('[data-closing-countdown]') + closing_info = mod_page.locator("[data-closing-countdown]") if closing_info.is_visible(): expect(closing_info).to_contain_text("30") else: # May just show the status badge - status_badge = mod_page.locator('[data-status-badge]') + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_contain_text("Closing") class TestStatusBadgeStyling: """Tests for correct status badge styling based on state.""" - def test_operating_status_badge_style( - self, mod_page: Page, operating_park, live_server - ): + def test_operating_status_badge_style(self, mod_page: Page, operating_park, live_server): """Test that Operating status has correct green styling.""" mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") - status_badge = mod_page.locator('[data-status-badge]') + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_have_class(re.compile(r"bg-green|text-green|success")) - def test_closed_temp_status_badge_style( - self, mod_page: Page, operating_park, live_server - ): + def test_closed_temp_status_badge_style(self, mod_page: Page, operating_park, live_server): """Test that Temporarily Closed status has correct yellow/warning styling.""" operating_park.status = "CLOSED_TEMP" operating_park.save() @@ -551,12 +464,10 @@ class TestStatusBadgeStyling: mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") - status_badge = mod_page.locator('[data-status-badge]') + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_have_class(re.compile(r"bg-yellow|text-yellow|warning")) - def test_closed_perm_status_badge_style( - self, mod_page: Page, operating_park, live_server - ): + def test_closed_perm_status_badge_style(self, mod_page: Page, operating_park, live_server): """Test that Permanently Closed status has correct red/danger styling.""" operating_park.status = "CLOSED_PERM" operating_park.save() @@ -564,12 +475,10 @@ class TestStatusBadgeStyling: mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") - status_badge = mod_page.locator('[data-status-badge]') + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_have_class(re.compile(r"bg-red|text-red|danger")) - def test_demolished_status_badge_style( - self, mod_page: Page, operating_park, live_server - ): + def test_demolished_status_badge_style(self, mod_page: Page, operating_park, live_server): """Test that Demolished status has correct gray styling.""" operating_park.status = "DEMOLISHED" operating_park.save() @@ -577,5 +486,5 @@ class TestStatusBadgeStyling: mod_page.goto(f"{live_server.url}/parks/{operating_park.slug}/") mod_page.wait_for_load_state("networkidle") - status_badge = mod_page.locator('[data-status-badge]') + status_badge = mod_page.locator("[data-status-badge]") expect(status_badge).to_have_class(re.compile(r"bg-gray|text-gray|muted")) diff --git a/backend/tests/e2e/test_review_submission.py b/backend/tests/e2e/test_review_submission.py index 36afd674..37b79fe8 100644 --- a/backend/tests/e2e/test_review_submission.py +++ b/backend/tests/e2e/test_review_submission.py @@ -24,9 +24,7 @@ class TestReviewSubmission: reviews_tab.click() # Click write review button - write_review = auth_page.locator( - "button:has-text('Write Review'), a:has-text('Write Review')" - ) + write_review = auth_page.locator("button:has-text('Write Review'), a:has-text('Write Review')") if write_review.count() > 0: write_review.first.click() @@ -36,9 +34,7 @@ class TestReviewSubmission: expect(auth_page.locator("input[name='title'], textarea[name='title']").first).to_be_visible() expect(auth_page.locator("textarea[name='content'], textarea[name='review']").first).to_be_visible() - def test__review_submission__valid_data__creates_review( - self, auth_page: Page, live_server, parks_data - ): + def test__review_submission__valid_data__creates_review(self, auth_page: Page, live_server, parks_data): """Test submitting a valid review creates it.""" park = parks_data[0] auth_page.goto(f"{live_server.url}/parks/{park.slug}/") @@ -48,9 +44,7 @@ class TestReviewSubmission: if reviews_tab.count() > 0: reviews_tab.click() - write_review = auth_page.locator( - "button:has-text('Write Review'), a:has-text('Write Review')" - ) + write_review = auth_page.locator("button:has-text('Write Review'), a:has-text('Write Review')") if write_review.count() > 0: write_review.first.click() @@ -63,9 +57,7 @@ class TestReviewSubmission: # May be radio buttons or stars auth_page.locator("input[name='rating'][value='5']").click() - auth_page.locator("input[name='title'], textarea[name='title']").first.fill( - "E2E Test Review Title" - ) + auth_page.locator("input[name='title'], textarea[name='title']").first.fill("E2E Test Review Title") auth_page.locator("textarea[name='content'], textarea[name='review']").first.fill( "This is an E2E test review content." ) @@ -75,9 +67,7 @@ class TestReviewSubmission: # Should show success or redirect auth_page.wait_for_timeout(500) - def test__review_submission__missing_rating__shows_error( - self, auth_page: Page, live_server, parks_data - ): + def test__review_submission__missing_rating__shows_error(self, auth_page: Page, live_server, parks_data): """Test submitting review without rating shows error.""" park = parks_data[0] auth_page.goto(f"{live_server.url}/parks/{park.slug}/") @@ -86,20 +76,14 @@ class TestReviewSubmission: if reviews_tab.count() > 0: reviews_tab.click() - write_review = auth_page.locator( - "button:has-text('Write Review'), a:has-text('Write Review')" - ) + write_review = auth_page.locator("button:has-text('Write Review'), a:has-text('Write Review')") if write_review.count() > 0: write_review.first.click() # Fill only title and content, skip rating - auth_page.locator("input[name='title'], textarea[name='title']").first.fill( - "Missing Rating Review" - ) - auth_page.locator("textarea[name='content'], textarea[name='review']").first.fill( - "Review without rating" - ) + auth_page.locator("input[name='title'], textarea[name='title']").first.fill("Missing Rating Review") + auth_page.locator("textarea[name='content'], textarea[name='review']").first.fill("Review without rating") auth_page.get_by_role("button", name="Submit").click() @@ -123,9 +107,7 @@ class TestReviewDisplay: reviews_tab.click() # Reviews should be displayed - reviews_section = page.locator( - "[data-testid='reviews-list'], .reviews-list, .review-item" - ) + reviews_section = page.locator("[data-testid='reviews-list'], .reviews-list, .review-item") if reviews_section.count() > 0: expect(reviews_section.first).to_be_visible() @@ -136,9 +118,7 @@ class TestReviewDisplay: page.goto(f"{page.url}") # Stay on current page after fixture # Rating should be visible (stars, number, etc.) - rating = page.locator( - ".rating, .stars, [data-testid='rating']" - ) + rating = page.locator(".rating, .stars, [data-testid='rating']") if rating.count() > 0: expect(rating.first).to_be_visible() @@ -153,9 +133,7 @@ class TestReviewDisplay: reviews_tab.click() # Author name should be visible in review - author = page.locator( - ".review-author, .author, [data-testid='author']" - ) + author = page.locator(".review-author, .author, [data-testid='author']") if author.count() > 0: expect(author.first).to_be_visible() @@ -170,9 +148,7 @@ class TestReviewEditing: # Navigate to reviews after creating one # Look for edit button on own review - edit_button = auth_page.locator( - "button:has-text('Edit'), a:has-text('Edit Review')" - ) + edit_button = auth_page.locator("button:has-text('Edit'), a:has-text('Edit Review')") if edit_button.count() > 0: expect(edit_button.first).to_be_visible() @@ -180,17 +156,13 @@ class TestReviewEditing: def test__edit_review__updates_content(self, auth_page: Page, live_server, test_review): """Test editing review updates the content.""" # Find and click edit - edit_button = auth_page.locator( - "button:has-text('Edit'), a:has-text('Edit Review')" - ) + edit_button = auth_page.locator("button:has-text('Edit'), a:has-text('Edit Review')") if edit_button.count() > 0: edit_button.first.click() # Update content - content_field = auth_page.locator( - "textarea[name='content'], textarea[name='review']" - ) + content_field = auth_page.locator("textarea[name='content'], textarea[name='review']") content_field.first.fill("Updated review content from E2E test") auth_page.get_by_role("button", name="Save").click() @@ -204,9 +176,7 @@ class TestReviewEditing: class TestReviewModeration: """E2E tests for review moderation.""" - def test__moderator__sees_moderation_actions( - self, mod_page: Page, live_server, parks_data - ): + def test__moderator__sees_moderation_actions(self, mod_page: Page, live_server, parks_data): """Test moderator sees moderation actions on reviews.""" park = parks_data[0] mod_page.goto(f"{live_server.url}/parks/{park.slug}/") @@ -216,9 +186,7 @@ class TestReviewModeration: reviews_tab.click() # Moderator should see moderation buttons - mod_actions = mod_page.locator( - "button:has-text('Remove'), button:has-text('Flag'), [data-testid='mod-action']" - ) + mod_actions = mod_page.locator("button:has-text('Remove'), button:has-text('Flag'), [data-testid='mod-action']") if mod_actions.count() > 0: expect(mod_actions.first).to_be_visible() @@ -259,16 +227,12 @@ class TestReviewVoting: reviews_tab.click() # Look for helpful/upvote buttons - vote_buttons = page.locator( - "button:has-text('Helpful'), button[aria-label*='helpful'], .vote-button" - ) + vote_buttons = page.locator("button:has-text('Helpful'), button[aria-label*='helpful'], .vote-button") if vote_buttons.count() > 0: expect(vote_buttons.first).to_be_visible() - def test__vote__authenticated__registers_vote( - self, auth_page: Page, live_server, parks_data - ): + def test__vote__authenticated__registers_vote(self, auth_page: Page, live_server, parks_data): """Test authenticated user can vote on review.""" park = parks_data[0] auth_page.goto(f"{live_server.url}/parks/{park.slug}/") @@ -277,9 +241,7 @@ class TestReviewVoting: if reviews_tab.count() > 0: reviews_tab.click() - helpful_button = auth_page.locator( - "button:has-text('Helpful'), button[aria-label*='helpful']" - ) + helpful_button = auth_page.locator("button:has-text('Helpful'), button[aria-label*='helpful']") if helpful_button.count() > 0: helpful_button.first.click() @@ -298,34 +260,24 @@ class TestRideReviews: page.goto(f"{live_server.url}/rides/{ride.slug}/") # Reviews section should be present - reviews_section = page.locator( - "[data-testid='reviews'], #reviews, .reviews-section" - ) + reviews_section = page.locator("[data-testid='reviews'], #reviews, .reviews-section") if reviews_section.count() > 0: expect(reviews_section.first).to_be_visible() - def test__ride_review__includes_ride_experience_fields( - self, auth_page: Page, live_server, rides_data - ): + def test__ride_review__includes_ride_experience_fields(self, auth_page: Page, live_server, rides_data): """Test ride review form includes experience fields.""" ride = rides_data[0] auth_page.goto(f"{live_server.url}/rides/{ride.slug}/") - write_review = auth_page.locator( - "button:has-text('Write Review'), a:has-text('Write Review')" - ) + write_review = auth_page.locator("button:has-text('Write Review'), a:has-text('Write Review')") if write_review.count() > 0: write_review.first.click() # Ride-specific fields - intensity_field = auth_page.locator( - "select[name='intensity'], input[name='intensity']" - ) - auth_page.locator( - "input[name='wait_time'], select[name='wait_time']" - ) + intensity_field = auth_page.locator("select[name='intensity'], input[name='intensity']") + auth_page.locator("input[name='wait_time'], select[name='wait_time']") # At least one experience field should be present if intensity_field.count() > 0: @@ -345,9 +297,7 @@ class TestReviewFiltering: if reviews_tab.count() > 0: reviews_tab.click() - sort_select = page.locator( - "select[name='sort'], [data-testid='sort-reviews']" - ) + sort_select = page.locator("select[name='sort'], [data-testid='sort-reviews']") if sort_select.count() > 0: sort_select.first.select_option("date") @@ -362,9 +312,7 @@ class TestReviewFiltering: if reviews_tab.count() > 0: reviews_tab.click() - rating_filter = page.locator( - "select[name='rating'], [data-testid='rating-filter']" - ) + rating_filter = page.locator("select[name='rating'], [data-testid='rating-filter']") if rating_filter.count() > 0: rating_filter.first.select_option("5") diff --git a/backend/tests/e2e/test_user_registration.py b/backend/tests/e2e/test_user_registration.py index 7bd5d980..38020a38 100644 --- a/backend/tests/e2e/test_user_registration.py +++ b/backend/tests/e2e/test_user_registration.py @@ -44,9 +44,7 @@ class TestUserRegistration: # Should redirect to success page or login page.wait_for_url("**/*", timeout=5000) - def test__registration__duplicate_username__shows_error( - self, page: Page, live_server, regular_user - ): + def test__registration__duplicate_username__shows_error(self, page: Page, live_server, regular_user): """Test registration with duplicate username shows error.""" page.goto(f"{live_server.url}/accounts/signup/") @@ -100,9 +98,7 @@ class TestUserLogin: expect(page.get_by_label("Password")).to_be_visible() expect(page.get_by_role("button", name="Sign In")).to_be_visible() - def test__login__valid_credentials__authenticates( - self, page: Page, live_server, regular_user - ): + def test__login__valid_credentials__authenticates(self, page: Page, live_server, regular_user): """Test login with valid credentials authenticates user.""" page.goto(f"{live_server.url}/accounts/login/") @@ -130,9 +126,7 @@ class TestUserLogin: """Test login page has remember me checkbox.""" page.goto(f"{live_server.url}/accounts/login/") - remember_me = page.locator( - "input[name='remember'], input[type='checkbox'][id*='remember']" - ) + remember_me = page.locator("input[name='remember'], input[type='checkbox'][id*='remember']") if remember_me.count() > 0: expect(remember_me.first).to_be_visible() @@ -147,9 +141,7 @@ class TestUserLogout: # User is already logged in via auth_page fixture # Find and click logout button/link - logout = auth_page.locator( - "a[href*='logout'], button:has-text('Log Out'), button:has-text('Sign Out')" - ) + logout = auth_page.locator("a[href*='logout'], button:has-text('Log Out'), button:has-text('Sign Out')") if logout.count() > 0: logout.first.click() @@ -172,14 +164,10 @@ class TestPasswordReset: """Test password reset page displays the form.""" page.goto(f"{live_server.url}/accounts/password/reset/") - email_input = page.locator( - "input[type='email'], input[name='email']" - ) + email_input = page.locator("input[type='email'], input[name='email']") expect(email_input.first).to_be_visible() - def test__password_reset__valid_email__shows_confirmation( - self, page: Page, live_server, regular_user - ): + def test__password_reset__valid_email__shows_confirmation(self, page: Page, live_server, regular_user): """Test password reset with valid email shows confirmation.""" page.goto(f"{live_server.url}/accounts/password/reset/") @@ -192,9 +180,7 @@ class TestPasswordReset: page.wait_for_timeout(500) # Look for success message or confirmation page - success = page.locator( - ".success, .alert-success, [role='alert']" - ) + success = page.locator(".success, .alert-success, [role='alert']") # Or check URL changed to done page if success.count() == 0: @@ -216,9 +202,7 @@ class TestUserProfile: """Test profile page has edit profile link/button.""" auth_page.goto(f"{live_server.url}/accounts/profile/") - edit_link = auth_page.locator( - "a[href*='edit'], button:has-text('Edit')" - ) + edit_link = auth_page.locator("a[href*='edit'], button:has-text('Edit')") if edit_link.count() > 0: expect(edit_link.first).to_be_visible() @@ -228,9 +212,7 @@ class TestUserProfile: auth_page.goto(f"{live_server.url}/accounts/profile/edit/") # Find bio/about field if present - bio_field = auth_page.locator( - "textarea[name='bio'], textarea[name='about']" - ) + bio_field = auth_page.locator("textarea[name='bio'], textarea[name='about']") if bio_field.count() > 0: bio_field.first.fill("Updated bio from E2E test") @@ -245,18 +227,14 @@ class TestUserProfile: class TestProtectedRoutes: """E2E tests for protected route access.""" - def test__protected_route__unauthenticated__redirects_to_login( - self, page: Page, live_server - ): + def test__protected_route__unauthenticated__redirects_to_login(self, page: Page, live_server): """Test accessing protected route redirects to login.""" page.goto(f"{live_server.url}/accounts/profile/") # Should redirect to login expect(page).to_have_url("**/login/**") - def test__protected_route__authenticated__allows_access( - self, auth_page: Page, live_server - ): + def test__protected_route__authenticated__allows_access(self, auth_page: Page, live_server): """Test authenticated user can access protected routes.""" auth_page.goto(f"{live_server.url}/accounts/profile/") @@ -270,9 +248,7 @@ class TestProtectedRoutes: # Should show login or forbidden # Admin login page or 403 - def test__moderator_route__moderator__allows_access( - self, mod_page: Page, live_server - ): + def test__moderator_route__moderator__allows_access(self, mod_page: Page, live_server): """Test moderator can access moderation routes.""" mod_page.goto(f"{live_server.url}/moderation/") diff --git a/backend/tests/factories.py b/backend/tests/factories.py index b97d8dc1..0f436c2b 100644 --- a/backend/tests/factories.py +++ b/backend/tests/factories.py @@ -202,9 +202,7 @@ class RideFactory(DjangoModelFactory): manufacturer = factory.SubFactory(ManufacturerCompanyFactory) designer = factory.SubFactory(DesignerCompanyFactory) ride_model = factory.SubFactory(RideModelFactory) - park_area = factory.SubFactory( - ParkAreaFactory, park=factory.SelfAttribute("..park") - ) + park_area = factory.SubFactory(ParkAreaFactory, park=factory.SelfAttribute("..park")) @factory.post_generation def create_location(obj, create, extracted, **kwargs): @@ -297,9 +295,7 @@ class Traits: """Trait for closed parks.""" return { "status": "CLOSED_PERM", - "closing_date": factory.Faker( - "date_between", start_date="-10y", end_date="today" - ), + "closing_date": factory.Faker("date_between", start_date="-10y", end_date="today"), } @staticmethod @@ -310,11 +306,7 @@ class Traits: @staticmethod def recent_submission(): """Trait for recent submissions.""" - return { - "submitted_at": factory.Faker( - "date_time_between", start_date="-7d", end_date="now" - ) - } + return {"submitted_at": factory.Faker("date_time_between", start_date="-7d", end_date="now")} # Specialized factories for testing scenarios @@ -378,11 +370,13 @@ class CloudflareImageFactory(DjangoModelFactory): @factory.lazy_attribute def expires_at(self): from django.utils import timezone + return timezone.now() + timezone.timedelta(days=365) @factory.lazy_attribute def uploaded_at(self): from django.utils import timezone + return timezone.now() diff --git a/backend/tests/forms/test_park_forms.py b/backend/tests/forms/test_park_forms.py index 680e66ed..b14208de 100644 --- a/backend/tests/forms/test_park_forms.py +++ b/backend/tests/forms/test_park_forms.py @@ -4,7 +4,6 @@ Tests for Park forms. Following Django styleguide pattern: test______ """ - import pytest from django.test import TestCase diff --git a/backend/tests/forms/test_ride_forms.py b/backend/tests/forms/test_ride_forms.py index da34eb96..c448e985 100644 --- a/backend/tests/forms/test_ride_forms.py +++ b/backend/tests/forms/test_ride_forms.py @@ -4,7 +4,6 @@ Tests for Ride forms. Following Django styleguide pattern: test______ """ - import pytest from django.test import TestCase diff --git a/backend/tests/integration/test_fsm_transition_view.py b/backend/tests/integration/test_fsm_transition_view.py index b2991f96..9efe80b4 100644 --- a/backend/tests/integration/test_fsm_transition_view.py +++ b/backend/tests/integration/test_fsm_transition_view.py @@ -28,27 +28,16 @@ class TestFSMTransitionViewHTMX(TestCase): def setUpTestData(cls): """Set up test data for all tests in this class.""" # Create regular user - cls.user = User.objects.create_user( - username="testuser", - email="testuser@example.com", - password="testpass123" - ) + cls.user = User.objects.create_user(username="testuser", email="testuser@example.com", password="testpass123") # Create moderator user cls.moderator = User.objects.create_user( - username="moderator", - email="moderator@example.com", - password="modpass123", - is_staff=True + username="moderator", email="moderator@example.com", password="modpass123", is_staff=True ) # Create admin user cls.admin = User.objects.create_user( - username="admin", - email="admin@example.com", - password="adminpass123", - is_staff=True, - is_superuser=True + username="admin", email="admin@example.com", password="adminpass123", is_staff=True, is_superuser=True ) def setUp(self): @@ -76,7 +65,7 @@ class TestFSMTransitionViewHTMX(TestCase): submission_type="EDIT", changes={"description": "Test change"}, reason="Integration test", - status="PENDING" + status="PENDING", ) url = reverse( @@ -85,15 +74,12 @@ class TestFSMTransitionViewHTMX(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_approved" - } + "transition_name": "transition_to_approved", + }, ) # Make request with HTMX header - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") # Should return 200 OK self.assertEqual(response.status_code, 200) @@ -129,7 +115,7 @@ class TestFSMTransitionViewHTMX(TestCase): submission_type="EDIT", changes={"description": "Test change non-htmx"}, reason="Integration test non-htmx", - status="PENDING" + status="PENDING", ) url = reverse( @@ -138,8 +124,8 @@ class TestFSMTransitionViewHTMX(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_approved" - } + "transition_name": "transition_to_approved", + }, ) # Make request WITHOUT HTMX header @@ -177,7 +163,7 @@ class TestFSMTransitionViewHTMX(TestCase): submission_type="EDIT", changes={"description": "Test partial"}, reason="Partial test", - status="PENDING" + status="PENDING", ) url = reverse( @@ -186,14 +172,11 @@ class TestFSMTransitionViewHTMX(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_approved" - } + "transition_name": "transition_to_approved", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") # Response should contain HTML (partial template) self.assertIn("text/html", response["Content-Type"]) @@ -217,14 +200,11 @@ class TestFSMTransitionViewHTMX(TestCase): "app_label": "parks", "model_name": "park", "pk": park.pk, - "transition_name": "transition_to_closed_temp" - } + "transition_name": "transition_to_closed_temp", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") # Parse HX-Trigger header trigger_data = json.loads(response["HX-Trigger"]) @@ -249,14 +229,11 @@ class TestFSMTransitionViewHTMX(TestCase): "app_label": "nonexistent", "model_name": "fakemodel", "pk": 1, - "transition_name": "fake_transition" - } + "transition_name": "fake_transition", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") # Should return 404 self.assertEqual(response.status_code, 404) @@ -282,14 +259,11 @@ class TestFSMTransitionViewHTMX(TestCase): "app_label": "parks", "model_name": "park", "pk": park.pk, - "transition_name": "nonexistent_transition" - } + "transition_name": "nonexistent_transition", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") # Should return 400 Bad Request self.assertEqual(response.status_code, 400) @@ -320,7 +294,7 @@ class TestFSMTransitionViewHTMX(TestCase): submission_type="EDIT", changes={"description": "Permission test"}, reason="Permission test", - status="PENDING" + status="PENDING", ) url = reverse( @@ -329,14 +303,11 @@ class TestFSMTransitionViewHTMX(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_approved" - } + "transition_name": "transition_to_approved", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") # Should return 400 or 403 (permission denied) self.assertIn(response.status_code, [400, 403]) @@ -355,10 +326,7 @@ class TestFSMTransitionViewParkModel(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username="mod_park", - email="mod_park@example.com", - password="modpass123", - is_staff=True + username="mod_park", email="mod_park@example.com", password="modpass123", is_staff=True ) def setUp(self): @@ -379,14 +347,11 @@ class TestFSMTransitionViewParkModel(TestCase): "app_label": "parks", "model_name": "park", "pk": park.pk, - "transition_name": "transition_to_closed_temp" - } + "transition_name": "transition_to_closed_temp", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -416,14 +381,11 @@ class TestFSMTransitionViewParkModel(TestCase): "app_label": "parks", "model_name": "park", "pk": park.pk, - "transition_name": "transition_to_operating" - } + "transition_name": "transition_to_operating", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -445,14 +407,11 @@ class TestFSMTransitionViewParkModel(TestCase): "app_label": "parks", "model_name": "park", "slug": park.slug, - "transition_name": "transition_to_closed_temp" - } + "transition_name": "transition_to_closed_temp", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -471,10 +430,7 @@ class TestFSMTransitionViewRideModel(TestCase): @classmethod def setUpTestData(cls): cls.moderator = User.objects.create_user( - username="mod_ride", - email="mod_ride@example.com", - password="modpass123", - is_staff=True + username="mod_ride", email="mod_ride@example.com", password="modpass123", is_staff=True ) def setUp(self): @@ -495,14 +451,11 @@ class TestFSMTransitionViewRideModel(TestCase): "app_label": "rides", "model_name": "ride", "pk": ride.pk, - "transition_name": "transition_to_closed_temp" - } + "transition_name": "transition_to_closed_temp", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -524,18 +477,10 @@ class TestFSMTransitionViewRideModel(TestCase): url = reverse( "core:fsm_transition", - kwargs={ - "app_label": "rides", - "model_name": "ride", - "pk": ride.pk, - "transition_name": "transition_to_sbno" - } + kwargs={"app_label": "rides", "model_name": "ride", "pk": ride.pk, "transition_name": "transition_to_sbno"}, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -553,17 +498,10 @@ class TestFSMTransitionViewModerationModels(TestCase): @classmethod def setUpTestData(cls): - cls.user = User.objects.create_user( - username="submitter", - email="submitter@example.com", - password="testpass123" - ) + cls.user = User.objects.create_user(username="submitter", email="submitter@example.com", password="testpass123") cls.moderator = User.objects.create_user( - username="mod_moderation", - email="mod_moderation@example.com", - password="modpass123", - is_staff=True + username="mod_moderation", email="mod_moderation@example.com", password="modpass123", is_staff=True ) def setUp(self): @@ -588,7 +526,7 @@ class TestFSMTransitionViewModerationModels(TestCase): submission_type="EDIT", changes={"description": "Approve test"}, reason="Approve test", - status="PENDING" + status="PENDING", ) url = reverse( @@ -597,14 +535,11 @@ class TestFSMTransitionViewModerationModels(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_approved" - } + "transition_name": "transition_to_approved", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -633,7 +568,7 @@ class TestFSMTransitionViewModerationModels(TestCase): submission_type="EDIT", changes={"description": "Reject test"}, reason="Reject test", - status="PENDING" + status="PENDING", ) url = reverse( @@ -642,14 +577,11 @@ class TestFSMTransitionViewModerationModels(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_rejected" - } + "transition_name": "transition_to_rejected", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -678,7 +610,7 @@ class TestFSMTransitionViewModerationModels(TestCase): submission_type="EDIT", changes={"description": "Escalate test"}, reason="Escalate test", - status="PENDING" + status="PENDING", ) url = reverse( @@ -687,14 +619,11 @@ class TestFSMTransitionViewModerationModels(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_escalated" - } + "transition_name": "transition_to_escalated", + }, ) - response = self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + response = self.client.post(url, HTTP_HX_REQUEST="true") self.assertEqual(response.status_code, 200) @@ -712,16 +641,11 @@ class TestFSMTransitionViewStateLog(TestCase): @classmethod def setUpTestData(cls): cls.user = User.objects.create_user( - username="submitter_log", - email="submitter_log@example.com", - password="testpass123" + username="submitter_log", email="submitter_log@example.com", password="testpass123" ) cls.moderator = User.objects.create_user( - username="mod_log", - email="mod_log@example.com", - password="modpass123", - is_staff=True + username="mod_log", email="mod_log@example.com", password="modpass123", is_staff=True ) def setUp(self): @@ -748,13 +672,12 @@ class TestFSMTransitionViewStateLog(TestCase): submission_type="EDIT", changes={"description": "StateLog test"}, reason="StateLog test", - status="PENDING" + status="PENDING", ) # Count existing StateLog entries initial_log_count = StateLog.objects.filter( - content_type=ContentType.objects.get_for_model(EditSubmission), - object_id=submission.pk + content_type=ContentType.objects.get_for_model(EditSubmission), object_id=submission.pk ).count() url = reverse( @@ -763,28 +686,23 @@ class TestFSMTransitionViewStateLog(TestCase): "app_label": "moderation", "model_name": "editsubmission", "pk": submission.pk, - "transition_name": "transition_to_approved" - } + "transition_name": "transition_to_approved", + }, ) - self.client.post( - url, - HTTP_HX_REQUEST="true" - ) + self.client.post(url, HTTP_HX_REQUEST="true") # Check that a new StateLog entry was created new_log_count = StateLog.objects.filter( - content_type=ContentType.objects.get_for_model(EditSubmission), - object_id=submission.pk + content_type=ContentType.objects.get_for_model(EditSubmission), object_id=submission.pk ).count() self.assertEqual(new_log_count, initial_log_count + 1) # Verify the StateLog entry details latest_log = StateLog.objects.filter( - content_type=ContentType.objects.get_for_model(EditSubmission), - object_id=submission.pk - ).latest('timestamp') + content_type=ContentType.objects.get_for_model(EditSubmission), object_id=submission.pk + ).latest("timestamp") self.assertEqual(latest_log.state, "APPROVED") self.assertEqual(latest_log.by, self.moderator) diff --git a/backend/tests/integration/test_fsm_transition_workflow.py b/backend/tests/integration/test_fsm_transition_workflow.py index 88e86a3a..4518b513 100644 --- a/backend/tests/integration/test_fsm_transition_workflow.py +++ b/backend/tests/integration/test_fsm_transition_workflow.py @@ -9,6 +9,7 @@ from datetime import date, timedelta import pytest from django.test import TestCase +from django_fsm import TransitionNotAllowed from tests.factories import ( ParkAreaFactory, @@ -55,7 +56,7 @@ class TestParkFSMTransitions(TestCase): user = UserFactory() # This should fail - can't reopen permanently closed park - with pytest.raises(Exception): + with pytest.raises((TransitionNotAllowed, ValueError)): park.open(user=user) diff --git a/backend/tests/integration/test_park_creation_workflow.py b/backend/tests/integration/test_park_creation_workflow.py index 6cb72cdc..5e5c3d3e 100644 --- a/backend/tests/integration/test_park_creation_workflow.py +++ b/backend/tests/integration/test_park_creation_workflow.py @@ -138,9 +138,7 @@ class TestParkReviewWorkflow(TestCase): ParkReviewFactory(park=park, user=user2, rating=10, is_published=True) # Calculate average - avg = park.reviews.filter(is_published=True).values_list( - "rating", flat=True - ) + avg = park.reviews.filter(is_published=True).values_list("rating", flat=True) calculated_avg = sum(avg) / len(avg) assert calculated_avg == 9.0 diff --git a/backend/tests/integration/test_photo_upload_workflow.py b/backend/tests/integration/test_photo_upload_workflow.py index 12c91ec8..8a805fa2 100644 --- a/backend/tests/integration/test_photo_upload_workflow.py +++ b/backend/tests/integration/test_photo_upload_workflow.py @@ -31,9 +31,7 @@ class TestParkPhotoUploadWorkflow(TestCase): @patch("apps.parks.services.media_service.MediaService.process_image") @patch("apps.parks.services.media_service.MediaService.generate_default_caption") @patch("apps.parks.services.media_service.MediaService.extract_exif_date") - def test__upload_photo__creates_pending_photo( - self, mock_exif, mock_caption, mock_process, mock_validate - ): + def test__upload_photo__creates_pending_photo(self, mock_exif, mock_caption, mock_process, mock_validate): """Test uploading photo creates a pending photo.""" mock_validate.return_value = (True, None) mock_process.return_value = Mock() diff --git a/backend/tests/managers/test_core_managers.py b/backend/tests/managers/test_core_managers.py index dacbf904..cf3514b2 100644 --- a/backend/tests/managers/test_core_managers.py +++ b/backend/tests/managers/test_core_managers.py @@ -4,7 +4,6 @@ Tests for Core managers and querysets. Following Django styleguide pattern: test______ """ - import pytest from django.test import TestCase @@ -23,6 +22,7 @@ class TestBaseQuerySet(TestCase): """Test active filters by is_active field if present.""" # Using User model which has is_active from django.contrib.auth import get_user_model + User = get_user_model() active_user = User.objects.create_user( @@ -43,6 +43,7 @@ class TestBaseQuerySet(TestCase): # Created just now, should be in recent from apps.parks.models import Park + result = Park.objects.recent(days=30) assert park in result @@ -53,6 +54,7 @@ class TestBaseQuerySet(TestCase): park2 = ParkFactory(name="Kings Island") from apps.parks.models import Park + result = Park.objects.search(query="Cedar") assert park1 in result @@ -64,6 +66,7 @@ class TestBaseQuerySet(TestCase): park2 = ParkFactory() from apps.parks.models import Park + result = Park.objects.search(query="") assert park1 in result @@ -81,6 +84,7 @@ class TestLocationQuerySet(TestCase): # Location is created by factory post_generation from apps.parks.models import Park + # This tests the pattern - actual filtering depends on location setup result = Park.objects.all() @@ -259,11 +263,10 @@ class TestBaseManager(TestCase): def test__active__delegates_to_queryset(self): """Test active method delegates to queryset.""" from django.contrib.auth import get_user_model + User = get_user_model() - user = User.objects.create_user( - username="test", email="test@test.com", password="test", is_active=True - ) + user = User.objects.create_user(username="test", email="test@test.com", password="test", is_active=True) # BaseManager's active method should work result = User.objects.filter(is_active=True) diff --git a/backend/tests/managers/test_park_managers.py b/backend/tests/managers/test_park_managers.py index df55ce95..49f1d6be 100644 --- a/backend/tests/managers/test_park_managers.py +++ b/backend/tests/managers/test_park_managers.py @@ -4,7 +4,6 @@ Tests for Park managers and querysets. Following Django styleguide pattern: test______ """ - import pytest from django.test import TestCase diff --git a/backend/tests/middleware/test_contract_validation_middleware.py b/backend/tests/middleware/test_contract_validation_middleware.py index e0952c58..35e67049 100644 --- a/backend/tests/middleware/test_contract_validation_middleware.py +++ b/backend/tests/middleware/test_contract_validation_middleware.py @@ -93,9 +93,7 @@ class TestContractValidationMiddlewareFilterValidation(TestCase): self.middleware.enabled = True @patch.object(ContractValidationMiddleware, "_log_contract_violation") - def test__validate_filter_metadata__valid_categorical_filters__no_violation( - self, mock_log - ): + def test__validate_filter_metadata__valid_categorical_filters__no_violation(self, mock_log): """Test valid categorical filter format doesn't log violation.""" request = self.factory.get("/api/v1/parks/filter-options/") valid_data = { @@ -118,11 +116,7 @@ class TestContractValidationMiddlewareFilterValidation(TestCase): def test__validate_filter_metadata__string_options__logs_violation(self, mock_log): """Test string filter options logs contract violation.""" request = self.factory.get("/api/v1/parks/filter-options/") - invalid_data = { - "categorical": { - "status": ["OPERATING", "CLOSED"] # Strings instead of objects - } - } + invalid_data = {"categorical": {"status": ["OPERATING", "CLOSED"]}} # Strings instead of objects response = JsonResponse(invalid_data) self.middleware.process_response(request, response) @@ -133,18 +127,10 @@ class TestContractValidationMiddlewareFilterValidation(TestCase): assert any("CATEGORICAL_OPTION_IS_STRING" in arg for arg in call_args) @patch.object(ContractValidationMiddleware, "_log_contract_violation") - def test__validate_filter_metadata__missing_value_property__logs_violation( - self, mock_log - ): + def test__validate_filter_metadata__missing_value_property__logs_violation(self, mock_log): """Test filter option missing 'value' property logs violation.""" request = self.factory.get("/api/v1/parks/filter-options/") - invalid_data = { - "categorical": { - "status": [ - {"label": "Operating", "count": 10} # Missing 'value' - ] - } - } + invalid_data = {"categorical": {"status": [{"label": "Operating", "count": 10}]}} # Missing 'value' response = JsonResponse(invalid_data) self.middleware.process_response(request, response) @@ -154,18 +140,10 @@ class TestContractValidationMiddlewareFilterValidation(TestCase): assert any("MISSING_VALUE_PROPERTY" in arg for arg in call_args) @patch.object(ContractValidationMiddleware, "_log_contract_violation") - def test__validate_filter_metadata__missing_label_property__logs_violation( - self, mock_log - ): + def test__validate_filter_metadata__missing_label_property__logs_violation(self, mock_log): """Test filter option missing 'label' property logs violation.""" request = self.factory.get("/api/v1/parks/filter-options/") - invalid_data = { - "categorical": { - "status": [ - {"value": "OPERATING", "count": 10} # Missing 'label' - ] - } - } + invalid_data = {"categorical": {"status": [{"value": "OPERATING", "count": 10}]}} # Missing 'label' response = JsonResponse(invalid_data) self.middleware.process_response(request, response) @@ -188,11 +166,7 @@ class TestContractValidationMiddlewareRangeValidation(TestCase): def test__validate_range_filter__valid_range__no_violation(self, mock_log): """Test valid range filter format doesn't log violation.""" request = self.factory.get("/api/v1/rides/filter-options/") - valid_data = { - "ranges": { - "height": {"min": 0, "max": 500, "step": 10, "unit": "ft"} - } - } + valid_data = {"ranges": {"height": {"min": 0, "max": 500, "step": 10, "unit": "ft"}}} response = JsonResponse(valid_data) self.middleware.process_response(request, response) @@ -205,11 +179,7 @@ class TestContractValidationMiddlewareRangeValidation(TestCase): def test__validate_range_filter__missing_min_max__logs_violation(self, mock_log): """Test range filter missing min/max logs violation.""" request = self.factory.get("/api/v1/rides/filter-options/") - invalid_data = { - "ranges": { - "height": {"step": 10} # Missing 'min' and 'max' - } - } + invalid_data = {"ranges": {"height": {"step": 10}}} # Missing 'min' and 'max' response = JsonResponse(invalid_data) self.middleware.process_response(request, response) @@ -232,11 +202,7 @@ class TestContractValidationMiddlewareHybridValidation(TestCase): def test__validate_hybrid_response__valid_strategy__no_violation(self, mock_log): """Test valid hybrid response strategy doesn't log violation.""" request = self.factory.get("/api/v1/parks/hybrid/") - valid_data = { - "strategy": "client_side", - "data": [], - "filter_metadata": {} - } + valid_data = {"strategy": "client_side", "data": [], "filter_metadata": {}} response = JsonResponse(valid_data) self.middleware.process_response(request, response) @@ -246,15 +212,10 @@ class TestContractValidationMiddlewareHybridValidation(TestCase): assert "INVALID_STRATEGY_VALUE" not in str(call) @patch.object(ContractValidationMiddleware, "_log_contract_violation") - def test__validate_hybrid_response__invalid_strategy__logs_violation( - self, mock_log - ): + def test__validate_hybrid_response__invalid_strategy__logs_violation(self, mock_log): """Test invalid hybrid strategy logs violation.""" request = self.factory.get("/api/v1/parks/hybrid/") - invalid_data = { - "strategy": "invalid_strategy", # Not 'client_side' or 'server_side' - "data": [] - } + invalid_data = {"strategy": "invalid_strategy", "data": []} # Not 'client_side' or 'server_side' response = JsonResponse(invalid_data) self.middleware.process_response(request, response) @@ -277,12 +238,7 @@ class TestContractValidationMiddlewarePaginationValidation(TestCase): def test__validate_pagination__valid_response__no_violation(self, mock_log): """Test valid pagination response doesn't log violation.""" request = self.factory.get("/api/v1/parks/") - valid_data = { - "count": 10, - "next": None, - "previous": None, - "results": [{"id": 1}, {"id": 2}] - } + valid_data = {"count": 10, "next": None, "previous": None, "results": [{"id": 1}, {"id": 2}]} response = JsonResponse(valid_data) self.middleware.process_response(request, response) @@ -296,10 +252,7 @@ class TestContractValidationMiddlewarePaginationValidation(TestCase): def test__validate_pagination__results_not_array__logs_violation(self, mock_log): """Test pagination with non-array results logs violation.""" request = self.factory.get("/api/v1/parks/") - invalid_data = { - "count": 10, - "results": "not an array" # Should be array - } + invalid_data = {"count": 10, "results": "not an array"} # Should be array response = JsonResponse(invalid_data) self.middleware.process_response(request, response) @@ -346,9 +299,7 @@ class TestContractValidationMiddlewareViolationSuggestions(TestCase): def test__get_violation_suggestion__categorical_string__returns_suggestion(self): """Test get_violation_suggestion returns suggestion for CATEGORICAL_OPTION_IS_STRING.""" - suggestion = self.middleware._get_violation_suggestion( - "CATEGORICAL_OPTION_IS_STRING" - ) + suggestion = self.middleware._get_violation_suggestion("CATEGORICAL_OPTION_IS_STRING") assert "ensure_filter_option_format" in suggestion assert "object arrays" in suggestion diff --git a/backend/tests/serializers/test_account_serializers.py b/backend/tests/serializers/test_account_serializers.py index 627e2be2..f7d42513 100644 --- a/backend/tests/serializers/test_account_serializers.py +++ b/backend/tests/serializers/test_account_serializers.py @@ -470,6 +470,3 @@ class TestUserProfileUpdateInputSerializer(TestCase): """Test user field is read-only for updates.""" extra_kwargs = UserProfileUpdateInputSerializer.Meta.extra_kwargs assert extra_kwargs.get("user", {}).get("read_only") is True - - - diff --git a/backend/tests/serializers/test_park_serializers.py b/backend/tests/serializers/test_park_serializers.py index 40bc023a..b6ab45aa 100644 --- a/backend/tests/serializers/test_park_serializers.py +++ b/backend/tests/serializers/test_park_serializers.py @@ -221,10 +221,7 @@ class TestParkPhotoListOutputSerializer(TestCase): def test__meta__all_fields_read_only(self): """Test all fields are read-only for list serializer.""" - assert ( - ParkPhotoListOutputSerializer.Meta.read_only_fields - == ParkPhotoListOutputSerializer.Meta.fields - ) + assert ParkPhotoListOutputSerializer.Meta.read_only_fields == ParkPhotoListOutputSerializer.Meta.fields class TestParkPhotoApprovalInputSerializer(TestCase): @@ -331,7 +328,7 @@ class TestHybridParkSerializer(TestCase): """Test serializing park without location returns null for location fields.""" park = ParkFactory() # Remove location if it exists - if hasattr(park, 'location') and park.location: + if hasattr(park, "location") and park.location: park.location.delete() serializer = HybridParkSerializer(park) @@ -413,10 +410,7 @@ class TestHybridParkSerializer(TestCase): def test__meta__all_fields_read_only(self): """Test all fields in HybridParkSerializer are read-only.""" - assert ( - HybridParkSerializer.Meta.read_only_fields - == HybridParkSerializer.Meta.fields - ) + assert HybridParkSerializer.Meta.read_only_fields == HybridParkSerializer.Meta.fields @pytest.mark.django_db diff --git a/backend/tests/serializers/test_ride_serializers.py b/backend/tests/serializers/test_ride_serializers.py index 8f6deb94..eeca0820 100644 --- a/backend/tests/serializers/test_ride_serializers.py +++ b/backend/tests/serializers/test_ride_serializers.py @@ -219,10 +219,7 @@ class TestRidePhotoListOutputSerializer(TestCase): def test__meta__all_fields_read_only(self): """Test all fields are read-only for list serializer.""" - assert ( - RidePhotoListOutputSerializer.Meta.read_only_fields - == RidePhotoListOutputSerializer.Meta.fields - ) + assert RidePhotoListOutputSerializer.Meta.read_only_fields == RidePhotoListOutputSerializer.Meta.fields class TestRidePhotoApprovalInputSerializer(TestCase): @@ -477,10 +474,7 @@ class TestHybridRideSerializer(TestCase): def test__meta__all_fields_read_only(self): """Test all fields in HybridRideSerializer are read-only.""" - assert ( - HybridRideSerializer.Meta.read_only_fields - == HybridRideSerializer.Meta.fields - ) + assert HybridRideSerializer.Meta.read_only_fields == HybridRideSerializer.Meta.fields def test__serialize__includes_ride_model_fields(self): """Test serializing includes ride model information.""" diff --git a/backend/tests/services/test_park_media_service.py b/backend/tests/services/test_park_media_service.py index 0da5d1a5..8ff1d35a 100644 --- a/backend/tests/services/test_park_media_service.py +++ b/backend/tests/services/test_park_media_service.py @@ -43,9 +43,7 @@ class TestParkMediaServiceUploadPhoto(TestCase): park = ParkFactory() user = UserFactory() - image_file = SimpleUploadedFile( - "test.jpg", b"fake image content", content_type="image/jpeg" - ) + image_file = SimpleUploadedFile("test.jpg", b"fake image content", content_type="image/jpeg") photo = ParkMediaService.upload_photo( park=park, @@ -70,9 +68,7 @@ class TestParkMediaServiceUploadPhoto(TestCase): park = ParkFactory() user = UserFactory() - image_file = SimpleUploadedFile( - "test.txt", b"not an image", content_type="text/plain" - ) + image_file = SimpleUploadedFile("test.txt", b"not an image", content_type="text/plain") with pytest.raises(ValueError) as exc_info: ParkMediaService.upload_photo( diff --git a/backend/tests/services/test_ride_service.py b/backend/tests/services/test_ride_service.py index 9cb02e88..664ca229 100644 --- a/backend/tests/services/test_ride_service.py +++ b/backend/tests/services/test_ride_service.py @@ -104,7 +104,9 @@ class TestRideServiceCreateRide(TestCase): def test__create_ride__invalid_park__raises_exception(self): """Test create_ride raises exception for invalid park.""" - with pytest.raises(Exception): + from apps.parks.models import Park + + with pytest.raises(Park.DoesNotExist): RideService.create_ride( name="Test Ride", park_id=99999, # Non-existent @@ -274,9 +276,7 @@ class TestRideServiceHandleNewEntitySuggestions(TestCase): """Tests for RideService.handle_new_entity_suggestions.""" @patch("apps.rides.services.ModerationService.create_edit_submission_with_queue") - def test__handle_new_entity_suggestions__new_manufacturer__creates_submission( - self, mock_create_submission - ): + def test__handle_new_entity_suggestions__new_manufacturer__creates_submission(self, mock_create_submission): """Test handle_new_entity_suggestions creates submission for new manufacturer.""" mock_submission = Mock() mock_submission.id = 1 @@ -302,9 +302,7 @@ class TestRideServiceHandleNewEntitySuggestions(TestCase): mock_create_submission.assert_called_once() @patch("apps.rides.services.ModerationService.create_edit_submission_with_queue") - def test__handle_new_entity_suggestions__new_designer__creates_submission( - self, mock_create_submission - ): + def test__handle_new_entity_suggestions__new_designer__creates_submission(self, mock_create_submission): """Test handle_new_entity_suggestions creates submission for new designer.""" mock_submission = Mock() mock_submission.id = 2 @@ -329,9 +327,7 @@ class TestRideServiceHandleNewEntitySuggestions(TestCase): assert 2 in result["designers"] @patch("apps.rides.services.ModerationService.create_edit_submission_with_queue") - def test__handle_new_entity_suggestions__new_ride_model__creates_submission( - self, mock_create_submission - ): + def test__handle_new_entity_suggestions__new_ride_model__creates_submission(self, mock_create_submission): """Test handle_new_entity_suggestions creates submission for new ride model.""" mock_submission = Mock() mock_submission.id = 3 diff --git a/backend/tests/test_factories.py b/backend/tests/test_factories.py index 45bbe1ac..7a4ede89 100644 --- a/backend/tests/test_factories.py +++ b/backend/tests/test_factories.py @@ -196,9 +196,7 @@ class FactoryValidationTestCase(TestCase): from datetime import date # Valid dates - park = ParkFactory.build( - opening_date=date(2020, 1, 1), closing_date=date(2023, 12, 31) - ) + park = ParkFactory.build(opening_date=date(2020, 1, 1), closing_date=date(2023, 12, 31)) # Verify opening is before closing if park.opening_date and park.closing_date: diff --git a/backend/tests/test_parks_api.py b/backend/tests/test_parks_api.py index b2f22a4a..86b76cc7 100644 --- a/backend/tests/test_parks_api.py +++ b/backend/tests/test_parks_api.py @@ -75,9 +75,7 @@ class TestParkListApi(APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) # Should return only operating parks (2 out of 3) - operating_parks = [ - p for p in response.data["data"] if p["status"] == "OPERATING" - ] + operating_parks = [p for p in response.data["data"] if p["status"] == "OPERATING"] self.assertEqual(len(operating_parks), 2) def test__park_list_api__with_search_query__returns_matching_results(self): @@ -315,9 +313,7 @@ class TestParkApiErrorHandling(APITestCase): """Test that malformed JSON returns proper error.""" url = reverse("parks_api:park-list") - response = self.client.post( - url, data='{"invalid": json}', content_type="application/json" - ) + response = self.client.post(url, data='{"invalid": json}', content_type="application/json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.data["status"], "error") @@ -358,9 +354,7 @@ class TestParkApiIntegration(APITestCase): retrieve_response = self.client.get(detail_url) self.assertEqual(retrieve_response.status_code, status.HTTP_200_OK) - self.assertEqual( - retrieve_response.data["data"]["name"], "Integration Test Park" - ) + self.assertEqual(retrieve_response.data["data"]["name"], "Integration Test Park") # 3. Update park update_data = {"description": "Updated integration test description"} diff --git a/backend/tests/test_utils.py b/backend/tests/test_utils.py index 46dd2ae7..8c41317e 100644 --- a/backend/tests/test_utils.py +++ b/backend/tests/test_utils.py @@ -81,9 +81,7 @@ class ApiTestMixin: error_code: Expected error code in response message_contains: String that should be in error message """ - self.assertApiResponse( - response, status_code=status_code, response_status="error" - ) + self.assertApiResponse(response, status_code=status_code, response_status="error") if error_code: self.assertEqual(response.data["error"]["code"], error_code) @@ -289,9 +287,7 @@ class GeographyTestMixin: self.assertGreaterEqual(longitude, -180, "Longitude below valid range") self.assertLessEqual(longitude, 180, "Longitude above valid range") - def assertCoordinateDistance( - self, point1: tuple, point2: tuple, max_distance_km: float - ): + def assertCoordinateDistance(self, point1: tuple, point2: tuple, max_distance_km: float): """ Assert two geographic points are within specified distance. diff --git a/backend/tests/utils/fsm_test_helpers.py b/backend/tests/utils/fsm_test_helpers.py index 3496f2d1..bbe6d75f 100644 --- a/backend/tests/utils/fsm_test_helpers.py +++ b/backend/tests/utils/fsm_test_helpers.py @@ -8,10 +8,17 @@ Reusable utility functions for testing FSM transitions: - Toast notification verification utilities """ +from __future__ import annotations + import json -from typing import Any +from typing import TYPE_CHECKING, Any from django.contrib.auth import get_user_model + +if TYPE_CHECKING: + from apps.moderation.models import EditSubmission, PhotoSubmission + from apps.parks.models import Park + from apps.rides.models import Ride from django.contrib.contenttypes.models import ContentType from django.db.models import Model from django.http import HttpResponse @@ -31,8 +38,8 @@ def create_test_submission( submission_type: str = "EDIT", changes: dict[str, Any] | None = None, reason: str = "Test submission", - **kwargs -) -> "EditSubmission": + **kwargs, +) -> EditSubmission: """ Create a test EditSubmission with the given status. @@ -54,8 +61,7 @@ def create_test_submission( # Get or create user if user is None: user, _ = User.objects.get_or_create( - username="test_submitter", - defaults={"email": "test_submitter@example.com"} + username="test_submitter", defaults={"email": "test_submitter@example.com"} ) user.set_password("testpass123") user.save() @@ -80,18 +86,13 @@ def create_test_submission( changes=changes, reason=reason, status=status, - **kwargs + **kwargs, ) return submission -def create_test_park( - status: str = "OPERATING", - name: str | None = None, - slug: str | None = None, - **kwargs -) -> "Park": +def create_test_park(status: str = "OPERATING", name: str | None = None, slug: str | None = None, **kwargs) -> Park: """ Create a test Park with the given status. @@ -108,29 +109,22 @@ def create_test_park( if name is None: import random + name = f"Test Park {random.randint(1000, 9999)}" if slug is None: from django.utils.text import slugify + slug = slugify(name) - park = ParkFactory( - name=name, - slug=slug, - status=status, - **kwargs - ) + park = ParkFactory(name=name, slug=slug, status=status, **kwargs) return park def create_test_ride( - status: str = "OPERATING", - name: str | None = None, - slug: str | None = None, - park: Model | None = None, - **kwargs -) -> "Ride": + status: str = "OPERATING", name: str | None = None, slug: str | None = None, park: Model | None = None, **kwargs +) -> Ride: """ Create a test Ride with the given status. @@ -148,18 +142,15 @@ def create_test_ride( if name is None: import random + name = f"Test Ride {random.randint(1000, 9999)}" if slug is None: from django.utils.text import slugify + slug = slugify(name) - ride_kwargs = { - "name": name, - "slug": slug, - "status": status, - **kwargs - } + ride_kwargs = {"name": name, "slug": slug, "status": status, **kwargs} if park is not None: ride_kwargs["park"] = park @@ -170,11 +161,8 @@ def create_test_ride( def create_test_photo_submission( - status: str = "PENDING", - user: User | None = None, - park: Model | None = None, - **kwargs -) -> "PhotoSubmission": + status: str = "PENDING", user: User | None = None, park: Model | None = None, **kwargs +) -> PhotoSubmission: """ Create a test PhotoSubmission with the given status. @@ -193,8 +181,7 @@ def create_test_photo_submission( # Get or create user if user is None: user, _ = User.objects.get_or_create( - username="test_photo_submitter", - defaults={"email": "test_photo@example.com"} + username="test_photo_submitter", defaults={"email": "test_photo@example.com"} ) user.set_password("testpass123") user.save() @@ -210,11 +197,12 @@ def create_test_photo_submission( # Get a photo if available try: from django_cloudflareimages_toolkit.models import CloudflareImage + photo = CloudflareImage.objects.first() if not photo: raise ValueError("No CloudflareImage available for testing") except ImportError: - raise ValueError("CloudflareImage not available") + raise ValueError("CloudflareImage not available") from None submission = PhotoSubmission.objects.create( user=user, @@ -223,7 +211,7 @@ def create_test_photo_submission( photo=photo, caption="Test photo submission", status=status, - **kwargs + **kwargs, ) return submission @@ -247,16 +235,11 @@ def assert_status_changed(obj: Model, expected_status: str) -> None: """ obj.refresh_from_db() actual_status = getattr(obj, "status", None) - assert actual_status == expected_status, ( - f"Expected status '{expected_status}', got '{actual_status}'" - ) + assert actual_status == expected_status, f"Expected status '{expected_status}', got '{actual_status}'" def assert_state_log_created( - obj: Model, - transition_name: str, - user: User | None = None, - expected_state: str | None = None + obj: Model, transition_name: str, user: User | None = None, expected_state: str | None = None ) -> None: """ Assert that a StateLog entry was created for a transition. @@ -274,31 +257,20 @@ def assert_state_log_created( content_type = ContentType.objects.get_for_model(obj) - logs = StateLog.objects.filter( - content_type=content_type, - object_id=obj.pk - ).order_by('-timestamp') + logs = StateLog.objects.filter(content_type=content_type, object_id=obj.pk).order_by("-timestamp") assert logs.exists(), "No StateLog entries found for object" latest_log = logs.first() if expected_state is not None: - assert latest_log.state == expected_state, ( - f"Expected state '{expected_state}' in log, got '{latest_log.state}'" - ) + assert latest_log.state == expected_state, f"Expected state '{expected_state}' in log, got '{latest_log.state}'" if user is not None: - assert latest_log.by == user, ( - f"Expected log by user '{user}', got '{latest_log.by}'" - ) + assert latest_log.by == user, f"Expected log by user '{user}', got '{latest_log.by}'" -def assert_toast_triggered( - response: HttpResponse, - message: str | None = None, - toast_type: str = "success" -) -> None: +def assert_toast_triggered(response: HttpResponse, message: str | None = None, toast_type: str = "success") -> None: """ Assert that the response contains an HX-Trigger header with toast data. @@ -316,14 +288,12 @@ def assert_toast_triggered( assert "showToast" in trigger_data, "HX-Trigger missing showToast event" toast_data = trigger_data["showToast"] - assert toast_data.get("type") == toast_type, ( - f"Expected toast type '{toast_type}', got '{toast_data.get('type')}'" - ) + assert toast_data.get("type") == toast_type, f"Expected toast type '{toast_type}', got '{toast_data.get('type')}'" if message is not None: - assert message in toast_data.get("message", ""), ( - f"Expected '{message}' in toast message, got '{toast_data.get('message')}'" - ) + assert message in toast_data.get( + "message", "" + ), f"Expected '{message}' in toast message, got '{toast_data.get('message')}'" def assert_no_status_change(obj: Model, original_status: str) -> None: @@ -339,9 +309,9 @@ def assert_no_status_change(obj: Model, original_status: str) -> None: """ obj.refresh_from_db() actual_status = getattr(obj, "status", None) - assert actual_status == original_status, ( - f"Status should not have changed from '{original_status}', but is now '{actual_status}'" - ) + assert ( + actual_status == original_status + ), f"Status should not have changed from '{original_status}', but is now '{actual_status}'" # ============================================================================= @@ -349,11 +319,7 @@ def assert_no_status_change(obj: Model, original_status: str) -> None: # ============================================================================= -def wait_for_htmx_swap( - page, - target_selector: str, - timeout: int = 5000 -) -> None: +def wait_for_htmx_swap(page, target_selector: str, timeout: int = 5000) -> None: """ Wait for an HTMX swap to complete on a target element. @@ -370,14 +336,12 @@ def wait_for_htmx_swap( return el && !el.classList.contains('htmx-request'); }} """, - timeout=timeout + timeout=timeout, ) def verify_transition_buttons_visible( - page, - transitions: list[str], - container_selector: str = "[data-status-actions]" + page, transitions: list[str], container_selector: str = "[data-status-actions]" ) -> dict[str, bool]: """ Verify which transition buttons are visible on the page. @@ -447,11 +411,7 @@ def wait_for_toast(page, toast_selector: str = "[data-toast]", timeout: int = 50 return toast -def wait_for_toast_dismiss( - page, - toast_selector: str = "[data-toast]", - timeout: int = 10000 -) -> None: +def wait_for_toast_dismiss(page, toast_selector: str = "[data-toast]", timeout: int = 10000) -> None: """ Wait for a toast notification to be dismissed. @@ -473,6 +433,7 @@ def click_and_confirm(page, button_locator, accept: bool = True) -> None: button_locator: The button locator to click accept: Whether to accept (True) or dismiss (False) the dialog """ + def handle_dialog(dialog): if accept: dialog.accept() @@ -500,11 +461,7 @@ def make_htmx_post(client, url: str, data: dict | None = None) -> HttpResponse: Returns: HttpResponse from the request """ - return client.post( - url, - data=data or {}, - HTTP_HX_REQUEST="true" - ) + return client.post(url, data=data or {}, HTTP_HX_REQUEST="true") def make_htmx_get(client, url: str) -> HttpResponse: @@ -518,19 +475,11 @@ def make_htmx_get(client, url: str) -> HttpResponse: Returns: HttpResponse from the request """ - return client.get( - url, - HTTP_HX_REQUEST="true" - ) + return client.get(url, HTTP_HX_REQUEST="true") def get_fsm_transition_url( - app_label: str, - model_name: str, - pk: int, - transition_name: str, - use_slug: bool = False, - slug: str | None = None + app_label: str, model_name: str, pk: int, transition_name: str, use_slug: bool = False, slug: str | None = None ) -> str: """ Generate the URL for an FSM transition. @@ -553,20 +502,10 @@ def get_fsm_transition_url( raise ValueError("slug is required when use_slug is True") return reverse( "core:fsm_transition_by_slug", - kwargs={ - "app_label": app_label, - "model_name": model_name, - "slug": slug, - "transition_name": transition_name - } + kwargs={"app_label": app_label, "model_name": model_name, "slug": slug, "transition_name": transition_name}, ) else: return reverse( "core:fsm_transition", - kwargs={ - "app_label": app_label, - "model_name": model_name, - "pk": pk, - "transition_name": transition_name - } + kwargs={"app_label": app_label, "model_name": model_name, "pk": pk, "transition_name": transition_name}, ) diff --git a/backend/tests/ux/test_breadcrumbs.py b/backend/tests/ux/test_breadcrumbs.py index a0283792..9f929912 100644 --- a/backend/tests/ux/test_breadcrumbs.py +++ b/backend/tests/ux/test_breadcrumbs.py @@ -142,9 +142,7 @@ class TestBreadcrumbBuilder: def test_schema_positions_auto_assigned(self): """Should auto-assign schema positions.""" builder = BreadcrumbBuilder() - crumbs = ( - builder.add_home().add("Parks", "/parks/").add_current("Test").build() - ) + crumbs = builder.add_home().add("Parks", "/parks/").add_current("Test").build() assert crumbs[0].schema_position == 1 assert crumbs[1].schema_position == 2 diff --git a/backend/tests/ux/test_messages.py b/backend/tests/ux/test_messages.py index 9a615398..c3593b19 100644 --- a/backend/tests/ux/test_messages.py +++ b/backend/tests/ux/test_messages.py @@ -5,7 +5,6 @@ These tests verify that message helper functions generate consistent, user-friendly messages. """ - from apps.core.utils.messages import ( confirm_delete, error_not_found, @@ -131,8 +130,4 @@ class TestConfirmMessages: def test_confirm_delete_warning(self): """Should include warning about irreversibility.""" message = confirm_delete("Park") - assert ( - "cannot be undone" in message.lower() - or "permanent" in message.lower() - or "sure" in message.lower() - ) + assert "cannot be undone" in message.lower() or "permanent" in message.lower() or "sure" in message.lower() diff --git a/backend/thrillwiki/views.py b/backend/thrillwiki/views.py index 4b3ea3c8..5ecd9131 100644 --- a/backend/thrillwiki/views.py +++ b/backend/thrillwiki/views.py @@ -44,18 +44,12 @@ class HomeView(TemplateView): # If not in cache, get them directly and cache them if trending_parks is None: try: - trending_parks = list( - PageView.get_trending_items(Park, hours=24, limit=10) - ) + trending_parks = list(PageView.get_trending_items(Park, hours=24, limit=10)) if trending_parks: - cache.set( - "trending_parks", trending_parks, 3600 - ) # Cache for 1 hour + cache.set("trending_parks", trending_parks, 3600) # Cache for 1 hour else: # Fallback to highest rated parks if no trending data - trending_parks = Park.objects.exclude( - average_rating__isnull=True - ).order_by("-average_rating")[:10] + trending_parks = Park.objects.exclude(average_rating__isnull=True).order_by("-average_rating")[:10] except Exception as e: log_exception( logger, @@ -64,24 +58,16 @@ class HomeView(TemplateView): request=self.request, ) # Fallback to highest rated parks if trending calculation fails - trending_parks = Park.objects.exclude( - average_rating__isnull=True - ).order_by("-average_rating")[:10] + trending_parks = Park.objects.exclude(average_rating__isnull=True).order_by("-average_rating")[:10] if trending_rides is None: try: - trending_rides = list( - PageView.get_trending_items(Ride, hours=24, limit=10) - ) + trending_rides = list(PageView.get_trending_items(Ride, hours=24, limit=10)) if trending_rides: - cache.set( - "trending_rides", trending_rides, 3600 - ) # Cache for 1 hour + cache.set("trending_rides", trending_rides, 3600) # Cache for 1 hour else: # Fallback to highest rated rides if no trending data - trending_rides = Ride.objects.exclude( - average_rating__isnull=True - ).order_by("-average_rating")[:10] + trending_rides = Ride.objects.exclude(average_rating__isnull=True).order_by("-average_rating")[:10] except Exception as e: log_exception( logger, @@ -90,21 +76,15 @@ class HomeView(TemplateView): request=self.request, ) # Fallback to highest rated rides if trending calculation fails - trending_rides = Ride.objects.exclude( - average_rating__isnull=True - ).order_by("-average_rating")[:10] + trending_rides = Ride.objects.exclude(average_rating__isnull=True).order_by("-average_rating")[:10] # Get highest rated items (mix of parks and rides) highest_rated_parks = list( - Park.objects.exclude(average_rating__isnull=True).order_by( - "-average_rating" - )[:20] + Park.objects.exclude(average_rating__isnull=True).order_by("-average_rating")[:20] ) # Get more items to randomly select from highest_rated_rides = list( - Ride.objects.exclude(average_rating__isnull=True).order_by( - "-average_rating" - )[:20] + Ride.objects.exclude(average_rating__isnull=True).order_by("-average_rating")[:20] ) # Get more items to randomly select from # Combine and shuffle highest rated items @@ -114,9 +94,7 @@ class HomeView(TemplateView): # Keep the same context variable names for template compatibility context["popular_parks"] = trending_parks context["popular_rides"] = trending_rides - context["highest_rated"] = all_highest_rated[ - :10 - ] # Take first 10 after shuffling + context["highest_rated"] = all_highest_rated[:10] # Take first 10 after shuffling return context @@ -131,9 +109,7 @@ class SearchView(TemplateView): # Search parks context["parks"] = ( Park.objects.filter( - Q(name__icontains=query) - | Q(location__icontains=query) - | Q(description__icontains=query) + Q(name__icontains=query) | Q(location__icontains=query) | Q(description__icontains=query) ) .select_related("operating_company") .prefetch_related("photos")[:10] @@ -142,9 +118,7 @@ class SearchView(TemplateView): # Search rides context["rides"] = ( Ride.objects.filter( - Q(name__icontains=query) - | Q(description__icontains=query) - | Q(manufacturer__name__icontains=query) + Q(name__icontains=query) | Q(description__icontains=query) | Q(manufacturer__name__icontains=query) ) .select_related("park", "coaster_stats") .prefetch_related("photos")[:10] @@ -163,11 +137,7 @@ class SearchView(TemplateView): "parks_count": len(context["parks"]), "rides_count": len(context["rides"]), "companies_count": len(context["companies"]), - "user_id": ( - self.request.user.id - if self.request.user.is_authenticated - else None - ), + "user_id": (self.request.user.id if self.request.user.is_authenticated else None), }, ) @@ -179,11 +149,7 @@ def environment_and_settings_view(request): env_vars = dict(os.environ) # Get all Django settings as a dictionary - settings_vars = { - setting: getattr(settings, setting) - for setting in dir(settings) - if setting.isupper() - } + settings_vars = {setting: getattr(settings, setting) for setting in dir(settings) if setting.isupper()} return render( request, diff --git a/backend/verify_backend.py b/backend/verify_backend.py index 2457d24c..ca2c4af7 100644 --- a/backend/verify_backend.py +++ b/backend/verify_backend.py @@ -4,17 +4,18 @@ import sys import django # Setup Django environment -sys.path.append('/Volumes/macminissd/Projects/thrillwiki_django_no_react/backend') +sys.path.append("/Volumes/macminissd/Projects/thrillwiki_django_no_react/backend") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.django.local") django.setup() -from django.contrib.auth import get_user_model -from rest_framework.test import APIClient +from django.contrib.auth import get_user_model # noqa: E402 +from rest_framework.test import APIClient # noqa: E402 -from apps.parks.models import Park +from apps.parks.models import Park # noqa: E402 User = get_user_model() + def run_verification(): print("Starting Backend Verification...") @@ -36,19 +37,16 @@ def run_verification(): # 3. Verify Profile Update (Unit System) # Endpoint: /api/v1/auth/user/ or /api/v1/accounts/me/ (depending on dj-rest-auth) # Let's try updating profile via PATCH /api/v1/auth/user/ - update_data = { - "unit_system": "imperial", - "location": "Test City, TS" - } + update_data = {"unit_system": "imperial", "location": "Test City, TS"} # Note: unit_system expects 'metric', 'imperial'. # Check if 'imperial' is valid key in RichChoiceField. # Assuming it is based on implementation plan. - response = client.patch('/api/v1/accounts/profile/update/', update_data, format='json') + response = client.patch("/api/v1/accounts/profile/update/", update_data, format="json") if response.status_code == 200: print(f"Profile updated successfully: {response.data.get('unit_system')}") - if response.data.get('unit_system') != 'imperial': - print(f"WARNING: unit_system mismatch. Got {response.data.get('unit_system')}") + if response.data.get("unit_system") != "imperial": + print(f"WARNING: unit_system mismatch. Got {response.data.get('unit_system')}") else: print(f"FAILED to update profile: {response.status_code} {response.data}") @@ -56,13 +54,13 @@ def run_verification(): # Create List list_data = { "title": "My Favorite Coasters", - "category": "RC", # Roller Coaster + "category": "RC", # Roller Coaster "description": "Best rides ever", - "is_public": True + "is_public": True, } - response = client.post('/api/v1/lists/lists/', list_data, format='json') + response = client.post("/api/v1/lists/lists/", list_data, format="json") if response.status_code == 201: - list_id = response.data['id'] + list_id = response.data["id"] print(f"UserList created: {list_id} - {response.data['title']}") else: print(f"FAILED to create UserList: {response.status_code} {response.data}") @@ -81,7 +79,7 @@ def run_verification(): # Alternatively, use specialized endpoint or just test UserList creation for now. # Actually, let's just check if we can GET the list - response = client.get(f'/api/v1/lists/lists/{list_id}/') + response = client.get(f"/api/v1/lists/lists/{list_id}/") if response.status_code == 200: print(f"UserList retrieved: {response.data['title']}") else: @@ -89,5 +87,6 @@ def run_verification(): print("Verification Complete.") + if __name__ == "__main__": run_verification() diff --git a/backend/verify_no_tuple_fallbacks.py b/backend/verify_no_tuple_fallbacks.py index dae7b4ea..b288cf1e 100644 --- a/backend/verify_no_tuple_fallbacks.py +++ b/backend/verify_no_tuple_fallbacks.py @@ -17,16 +17,16 @@ def search_for_tuple_fallbacks(): # Patterns that indicate tuple fallbacks choice_fallback_patterns = [ - r'choices\.get\([^,]+,\s*[^)]+\)', # choices.get(value, fallback) - r'status_.*\.get\([^,]+,\s*[^)]+\)', # status_colors.get(value, fallback) - r'category_.*\.get\([^,]+,\s*[^)]+\)', # category_images.get(value, fallback) - r'sla_hours\.get\([^,]+,\s*[^)]+\)', # sla_hours.get(priority, fallback) - r'get_tuple_choices\(', # get_tuple_choices function - r'from_tuple\(', # from_tuple function - r'convert_tuple_choices\(', # convert_tuple_choices function + r"choices\.get\([^,]+,\s*[^)]+\)", # choices.get(value, fallback) + r"status_.*\.get\([^,]+,\s*[^)]+\)", # status_colors.get(value, fallback) + r"category_.*\.get\([^,]+,\s*[^)]+\)", # category_images.get(value, fallback) + r"sla_hours\.get\([^,]+,\s*[^)]+\)", # sla_hours.get(priority, fallback) + r"get_tuple_choices\(", # get_tuple_choices function + r"from_tuple\(", # from_tuple function + r"convert_tuple_choices\(", # convert_tuple_choices function ] - apps_dir = Path('apps') + apps_dir = Path("apps") if not apps_dir.exists(): print("❌ Error: apps directory not found") return False @@ -34,24 +34,21 @@ def search_for_tuple_fallbacks(): found_fallbacks = [] # Search all Python files in apps directory - for py_file in apps_dir.rglob('*.py'): + for py_file in apps_dir.rglob("*.py"): # Skip migrations (they're supposed to have hardcoded values) - if 'migration' in str(py_file): + if "migration" in str(py_file): continue try: - with open(py_file, encoding='utf-8') as f: + with open(py_file, encoding="utf-8") as f: content = f.read() - for line_num, line in enumerate(content.split('\n'), 1): + for line_num, line in enumerate(content.split("\n"), 1): for pattern in choice_fallback_patterns: if re.search(pattern, line): - found_fallbacks.append({ - 'file': py_file, - 'line': line_num, - 'content': line.strip(), - 'pattern': pattern - }) + found_fallbacks.append( + {"file": py_file, "line": line_num, "content": line.strip(), "pattern": pattern} + ) except Exception as e: print(f"❌ Error reading {py_file}: {e}") continue @@ -66,12 +63,14 @@ def search_for_tuple_fallbacks(): print("✅ NO TUPLE FALLBACKS FOUND - All eliminated!") return True + def verify_tuple_functions_removed(): """Verify that tuple fallback functions have been removed.""" # Check that get_tuple_choices is not importable try: from apps.core.choices.registry import get_tuple_choices # noqa: F401 + print("❌ ERROR: get_tuple_choices function still exists!") return False except ImportError: @@ -80,18 +79,20 @@ def verify_tuple_functions_removed(): # Check that Rich Choice objects work as primary source try: from apps.core.choices.registry import get_choices # noqa: F401 + print("✅ get_choices function (Rich Choice objects) works as primary source") return True except ImportError: print("❌ ERROR: get_choices function missing!") return False + def main(): """Main verification function.""" print("=== TUPLE FALLBACK ELIMINATION VERIFICATION ===\n") # Change to backend directory if needed - if 'backend' not in os.getcwd(): + if "backend" not in os.getcwd(): backend_dir = Path(__file__).parent os.chdir(backend_dir) print(f"Changed directory to: {os.getcwd()}") @@ -110,5 +111,6 @@ def main(): print("❌ FAILURE: Tuple fallbacks still exist!") return 1 + if __name__ == "__main__": sys.exit(main())