Compare commits

..

3 Commits

5 changed files with 148 additions and 106 deletions

View File

@@ -1,64 +1,95 @@
from django.conf import settings from django.conf import settings
from allauth.account.adapter import DefaultAccountAdapter from django.http import HttpRequest
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter from typing import Optional, Any, Dict, Literal, TYPE_CHECKING, cast
from allauth.account.adapter import DefaultAccountAdapter # type: ignore[import]
from allauth.account.models import EmailConfirmation, EmailAddress # type: ignore[import]
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter # type: ignore[import]
from allauth.socialaccount.models import SocialLogin # type: ignore[import]
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.sites.shortcuts import get_current_site from django.contrib.sites.shortcuts import get_current_site
if TYPE_CHECKING:
from django.contrib.auth.models import AbstractUser
User = get_user_model() User = get_user_model()
class CustomAccountAdapter(DefaultAccountAdapter): class CustomAccountAdapter(DefaultAccountAdapter):
def is_open_for_signup(self, request): def is_open_for_signup(self, request: HttpRequest) -> Literal[True]:
""" """
Whether to allow sign ups. Whether to allow sign ups.
""" """
return True return True
def get_email_confirmation_url(self, request, emailconfirmation): def get_email_confirmation_url(self, request: HttpRequest, emailconfirmation: EmailConfirmation) -> str:
""" """
Constructs the email confirmation (activation) url. Constructs the email confirmation (activation) url.
""" """
get_current_site(request) get_current_site(request)
return f"{settings.LOGIN_REDIRECT_URL}verify-email?key={emailconfirmation.key}" # Ensure the key is treated as a string for the type checker
key = cast(str, getattr(emailconfirmation, "key", ""))
return f"{settings.LOGIN_REDIRECT_URL}verify-email?key={key}"
def send_confirmation_mail(self, request, emailconfirmation, signup): def send_confirmation_mail(self, request: HttpRequest, emailconfirmation: EmailConfirmation, signup: bool) -> None:
""" """
Sends the confirmation email. Sends the confirmation email.
""" """
current_site = get_current_site(request) current_site = get_current_site(request)
activate_url = self.get_email_confirmation_url(request, emailconfirmation) activate_url = self.get_email_confirmation_url(request, emailconfirmation)
ctx = { # Cast key to str for typing consistency and template context
"user": emailconfirmation.email_address.user, key = cast(str, getattr(emailconfirmation, "key", ""))
"activate_url": activate_url,
"current_site": current_site, # Determine template early
"key": emailconfirmation.key,
}
if signup: if signup:
email_template = "account/email/email_confirmation_signup" email_template = "account/email/email_confirmation_signup"
else: else:
email_template = "account/email/email_confirmation" email_template = "account/email/email_confirmation"
self.send_mail(email_template, emailconfirmation.email_address.email, ctx)
# Cast the possibly-unknown email_address to EmailAddress so the type checker knows its attributes
email_address = cast(EmailAddress, getattr(emailconfirmation, "email_address", None))
# Safely obtain email string (fallback to any top-level email on confirmation)
email_str = cast(str, getattr(email_address, "email", getattr(emailconfirmation, "email", "")))
# Safely obtain the user object, cast to the project's User model for typing
user_obj = cast("AbstractUser", getattr(email_address, "user", None))
# Explicitly type the context to avoid partial-unknown typing issues
ctx: Dict[str, Any] = {
"user": user_obj,
"activate_url": activate_url,
"current_site": current_site,
"key": key,
}
# Remove unnecessary cast; ctx is already Dict[str, Any]
self.send_mail(email_template, email_str, ctx) # type: ignore
class CustomSocialAccountAdapter(DefaultSocialAccountAdapter): class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
def is_open_for_signup(self, request, sociallogin): def is_open_for_signup(self, request: HttpRequest, sociallogin: SocialLogin) -> Literal[True]:
""" """
Whether to allow social account sign ups. Whether to allow social account sign ups.
""" """
return True return True
def populate_user(self, request, sociallogin, data): def populate_user(
self, request: HttpRequest, sociallogin: SocialLogin, data: Dict[str, Any]
) -> "AbstractUser": # type: ignore[override]
""" """
Hook that can be used to further populate the user instance. Hook that can be used to further populate the user instance.
""" """
user = super().populate_user(request, sociallogin, data) user = super().populate_user(request, sociallogin, data) # type: ignore
if sociallogin.account.provider == "discord": if getattr(sociallogin.account, "provider", None) == "discord": # type: ignore
user.discord_id = sociallogin.account.uid user.discord_id = getattr(sociallogin.account, "uid", None) # type: ignore
return user return cast("AbstractUser", user) # Ensure return type is explicit
def save_user(self, request, sociallogin, form=None): def save_user(
self, request: HttpRequest, sociallogin: SocialLogin, form: Optional[Any] = None
) -> "AbstractUser": # type: ignore[override]
""" """
Save the newly signed up social login. Save the newly signed up social login.
""" """
user = super().save_user(request, sociallogin, form) user = super().save_user(request, sociallogin, form) # type: ignore
return user if user is None:
raise ValueError("User creation failed")
return cast("AbstractUser", user) # Ensure return type is explicit

View File

@@ -1,7 +1,10 @@
from typing import Any
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.admin import UserAdmin from django.contrib.auth.admin import UserAdmin as DjangoUserAdmin
from django.utils.html import format_html from django.utils.html import format_html
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.http import HttpRequest
from django.db.models import QuerySet
from .models import ( from .models import (
User, User,
UserProfile, UserProfile,
@@ -12,7 +15,7 @@ from .models import (
) )
class UserProfileInline(admin.StackedInline): class UserProfileInline(admin.StackedInline[UserProfile, admin.options.AdminSite]):
model = UserProfile model = UserProfile
can_delete = False can_delete = False
verbose_name_plural = "Profile" verbose_name_plural = "Profile"
@@ -39,7 +42,7 @@ class UserProfileInline(admin.StackedInline):
) )
class TopListItemInline(admin.TabularInline): class TopListItemInline(admin.TabularInline[TopListItem]):
model = TopListItem model = TopListItem
extra = 1 extra = 1
fields = ("content_type", "object_id", "rank", "notes") fields = ("content_type", "object_id", "rank", "notes")
@@ -47,7 +50,7 @@ class TopListItemInline(admin.TabularInline):
@admin.register(User) @admin.register(User)
class CustomUserAdmin(UserAdmin): class CustomUserAdmin(DjangoUserAdmin[User]):
list_display = ( list_display = (
"username", "username",
"email", "email",
@@ -74,7 +77,7 @@ class CustomUserAdmin(UserAdmin):
"ban_users", "ban_users",
"unban_users", "unban_users",
] ]
inlines = [UserProfileInline] inlines: list[type[admin.StackedInline[UserProfile]]] = [UserProfileInline]
fieldsets = ( fieldsets = (
(None, {"fields": ("username", "password")}), (None, {"fields": ("username", "password")}),
@@ -126,75 +129,82 @@ class CustomUserAdmin(UserAdmin):
) )
@admin.display(description="Avatar") @admin.display(description="Avatar")
def get_avatar(self, obj): def get_avatar(self, obj: User) -> str:
if obj.profile.avatar: profile = getattr(obj, "profile", None)
if profile and getattr(profile, "avatar", None):
return format_html( return format_html(
'<img src="{}" width="30" height="30" style="border-radius:50%;" />', '<img src="{0}" width="30" height="30" style="border-radius:50%;" />',
obj.profile.avatar.url, getattr(profile.avatar, "url", ""), # type: ignore
) )
return format_html( return format_html(
'<div style="width:30px; height:30px; border-radius:50%; ' '<div style="width:30px; height:30px; border-radius:50%; '
"background-color:#007bff; color:white; display:flex; " "background-color:#007bff; color:white; display:flex; "
'align-items:center; justify-content:center;">{}</div>', 'align-items:center; justify-content:center;">{0}</div>',
obj.username[0].upper(), getattr(obj, "username", "?")[0].upper(), # type: ignore
) )
@admin.display(description="Status") @admin.display(description="Status")
def get_status(self, obj): def get_status(self, obj: User) -> str:
if obj.is_banned: if getattr(obj, "is_banned", False):
return format_html('<span style="color: red;">Banned</span>') return format_html('<span style="color: red;">{}</span>', "Banned")
if not obj.is_active: if not getattr(obj, "is_active", True):
return format_html('<span style="color: orange;">Inactive</span>') return format_html('<span style="color: orange;">{}</span>', "Inactive")
if obj.is_superuser: if getattr(obj, "is_superuser", False):
return format_html('<span style="color: purple;">Superuser</span>') return format_html('<span style="color: purple;">{}</span>', "Superuser")
if obj.is_staff: if getattr(obj, "is_staff", False):
return format_html('<span style="color: blue;">Staff</span>') return format_html('<span style="color: blue;">{}</span>', "Staff")
return format_html('<span style="color: green;">Active</span>') return format_html('<span style="color: green;">{}</span>', "Active")
@admin.display(description="Ride Credits") @admin.display(description="Ride Credits")
def get_credits(self, obj): def get_credits(self, obj: User) -> str:
try: try:
profile = obj.profile profile = getattr(obj, "profile", None)
if not profile:
return "-"
return format_html( return format_html(
"RC: {}<br>DR: {}<br>FR: {}<br>WR: {}", "RC: {0}<br>DR: {1}<br>FR: {2}<br>WR: {3}",
profile.coaster_credits, getattr(profile, "coaster_credits", 0),
profile.dark_ride_credits, getattr(profile, "dark_ride_credits", 0),
profile.flat_ride_credits, getattr(profile, "flat_ride_credits", 0),
profile.water_ride_credits, getattr(profile, "water_ride_credits", 0),
) )
except UserProfile.DoesNotExist: except UserProfile.DoesNotExist:
return "-" return "-"
@admin.action(description="Activate selected users") @admin.action(description="Activate selected users")
def activate_users(self, request, queryset): def activate_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
queryset.update(is_active=True) queryset.update(is_active=True)
@admin.action(description="Deactivate selected users") @admin.action(description="Deactivate selected users")
def deactivate_users(self, request, queryset): def deactivate_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
queryset.update(is_active=False) queryset.update(is_active=False)
@admin.action(description="Ban selected users") @admin.action(description="Ban selected users")
def ban_users(self, request, queryset): def ban_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
from django.utils import timezone from django.utils import timezone
queryset.update(is_banned=True, ban_date=timezone.now()) queryset.update(is_banned=True, ban_date=timezone.now())
@admin.action(description="Unban selected users") @admin.action(description="Unban selected users")
def unban_users(self, request, queryset): def unban_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
queryset.update(is_banned=False, ban_date=None, ban_reason="") queryset.update(is_banned=False, ban_date=None, ban_reason="")
def save_model(self, request, obj, form, change): def save_model(
self,
request: HttpRequest,
obj: User,
form: Any,
change: bool
) -> None:
creating = not obj.pk creating = not obj.pk
super().save_model(request, obj, form, change) super().save_model(request, obj, form, change)
if creating and obj.role != "USER": if creating and getattr(obj, "role", "USER") != "USER":
# Ensure new user with role gets added to appropriate group group = Group.objects.filter(name=getattr(obj, "role", None)).first()
group = Group.objects.filter(name=obj.role).first()
if group: if group:
obj.groups.add(group) obj.groups.add(group) # type: ignore[attr-defined]
@admin.register(UserProfile) @admin.register(UserProfile)
class UserProfileAdmin(admin.ModelAdmin): class UserProfileAdmin(admin.ModelAdmin[UserProfile]):
list_display = ( list_display = (
"user", "user",
"display_name", "display_name",
@@ -235,7 +245,7 @@ class UserProfileAdmin(admin.ModelAdmin):
@admin.register(EmailVerification) @admin.register(EmailVerification)
class EmailVerificationAdmin(admin.ModelAdmin): class EmailVerificationAdmin(admin.ModelAdmin[EmailVerification]):
list_display = ("user", "created_at", "last_sent", "is_expired") list_display = ("user", "created_at", "last_sent", "is_expired")
list_filter = ("created_at", "last_sent") list_filter = ("created_at", "last_sent")
search_fields = ("user__username", "user__email", "token") search_fields = ("user__username", "user__email", "token")
@@ -247,21 +257,21 @@ class EmailVerificationAdmin(admin.ModelAdmin):
) )
@admin.display(description="Status") @admin.display(description="Status")
def is_expired(self, obj): def is_expired(self, obj: EmailVerification) -> str:
from django.utils import timezone from django.utils import timezone
from datetime import timedelta from datetime import timedelta
if timezone.now() - obj.last_sent > timedelta(days=1): if timezone.now() - getattr(obj, "last_sent", timezone.now()) > timedelta(days=1):
return format_html('<span style="color: red;">Expired</span>') return format_html('<span style="color: red;">{}</span>', "Expired")
return format_html('<span style="color: green;">Valid</span>') return format_html('<span style="color: green;">{}</span>', "Valid")
@admin.register(TopList) @admin.register(TopList)
class TopListAdmin(admin.ModelAdmin): class TopListAdmin(admin.ModelAdmin[TopList]):
list_display = ("title", "user", "category", "created_at", "updated_at") list_display = ("title", "user", "category", "created_at", "updated_at")
list_filter = ("category", "created_at", "updated_at") list_filter = ("category", "created_at", "updated_at")
search_fields = ("title", "user__username", "description") search_fields = ("title", "user__username", "description")
inlines = [TopListItemInline] inlines: list[type[admin.TabularInline[TopListItem]]] = [TopListItemInline]
fieldsets = ( fieldsets = (
( (
@@ -277,7 +287,7 @@ class TopListAdmin(admin.ModelAdmin):
@admin.register(TopListItem) @admin.register(TopListItem)
class TopListItemAdmin(admin.ModelAdmin): class TopListItemAdmin(admin.ModelAdmin[TopListItem]):
list_display = ("top_list", "content_type", "object_id", "rank") list_display = ("top_list", "content_type", "object_id", "rank")
list_filter = ("top_list__category", "rank") list_filter = ("top_list__category", "rank")
search_fields = ("top_list__title", "notes") search_fields = ("top_list__title", "notes")
@@ -290,7 +300,7 @@ class TopListItemAdmin(admin.ModelAdmin):
@admin.register(PasswordReset) @admin.register(PasswordReset)
class PasswordResetAdmin(admin.ModelAdmin): class PasswordResetAdmin(admin.ModelAdmin[PasswordReset]):
"""Admin interface for password reset tokens""" """Admin interface for password reset tokens"""
list_display = ( list_display = (
@@ -341,20 +351,19 @@ class PasswordResetAdmin(admin.ModelAdmin):
) )
@admin.display(description="Status", boolean=True) @admin.display(description="Status", boolean=True)
def is_expired(self, obj): def is_expired(self, obj: PasswordReset) -> str:
"""Display expiration status with color coding"""
from django.utils import timezone from django.utils import timezone
if obj.used: if getattr(obj, "used", False):
return format_html('<span style="color: blue;">Used</span>') return format_html('<span style="color: blue;">{}</span>', "Used")
elif timezone.now() > obj.expires_at: elif timezone.now() > getattr(obj, "expires_at", timezone.now()):
return format_html('<span style="color: red;">Expired</span>') return format_html('<span style="color: red;">{}</span>', "Expired")
return format_html('<span style="color: green;">Valid</span>') return format_html('<span style="color: green;">{}</span>', "Valid")
def has_add_permission(self, request): def has_add_permission(self, request: HttpRequest) -> bool:
"""Disable manual creation of password reset tokens""" """Disable manual creation of password reset tokens"""
return False return False
def has_change_permission(self, request, obj=None): def has_change_permission(self, request: HttpRequest, obj: Any = None) -> bool:
"""Allow viewing but restrict editing of password reset tokens""" """Allow viewing but restrict editing of password reset tokens"""
return getattr(request.user, "is_superuser", False) return getattr(request.user, "is_superuser", False)

View File

@@ -7,6 +7,7 @@ from datetime import timedelta
import sys import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import List
from decouple import config from decouple import config
# Suppress django-allauth deprecation warnings for dj_rest_auth compatibility # Suppress django-allauth deprecation warnings for dj_rest_auth compatibility
@@ -19,14 +20,14 @@ warnings.filterwarnings(
# Initialize environment variables with better defaults # Initialize environment variables with better defaults
DEBUG = config("DEBUG", default=True) DEBUG = config("DEBUG", default=True, cast=bool)
SECRET_KEY = config("SECRET_KEY") SECRET_KEY = config("SECRET_KEY")
ALLOWED_HOSTS = config("ALLOWED_HOSTS", default="localhost,127.0.0.1", cast=lambda v: [s.strip() for s in v.split(',') if s.strip()]) ALLOWED_HOSTS = config("ALLOWED_HOSTS", default="localhost,127.0.0.1", cast=lambda v: [s.strip() for s in str(v).split(',') if s.strip()])
DATABASE_URL = config("DATABASE_URL") DATABASE_URL = config("DATABASE_URL")
CACHE_URL = config("CACHE_URL", default="locmem://") CACHE_URL = config("CACHE_URL", default="locmem://")
EMAIL_URL = config("EMAIL_URL", default="console://") EMAIL_URL = config("EMAIL_URL", default="console://")
REDIS_URL = config("REDIS_URL", default="redis://127.0.0.1:6379/1") REDIS_URL = config("REDIS_URL", default="redis://127.0.0.1:6379/1")
CORS_ALLOWED_ORIGINS = config("CORS_ALLOWED_ORIGINS", default="", cast=lambda v: [s.strip() for s in v.split(',') if s.strip()]) CORS_ALLOWED_ORIGINS = config("CORS_ALLOWED_ORIGINS", default="", cast=lambda v: [s.strip() for s in str(v).split(',') if s.strip()])
API_RATE_LIMIT_PER_MINUTE = config("API_RATE_LIMIT_PER_MINUTE", default=60) API_RATE_LIMIT_PER_MINUTE = config("API_RATE_LIMIT_PER_MINUTE", default=60)
API_RATE_LIMIT_PER_HOUR = config("API_RATE_LIMIT_PER_HOUR", default=1000) API_RATE_LIMIT_PER_HOUR = config("API_RATE_LIMIT_PER_HOUR", default=1000)
CACHE_MIDDLEWARE_SECONDS = config("CACHE_MIDDLEWARE_SECONDS", default=300) CACHE_MIDDLEWARE_SECONDS = config("CACHE_MIDDLEWARE_SECONDS", default=300)
@@ -55,7 +56,7 @@ SECRET_KEY = config("SECRET_KEY")
# CSRF trusted origins # CSRF trusted origins
CSRF_TRUSTED_ORIGINS = config( CSRF_TRUSTED_ORIGINS = config(
"CSRF_TRUSTED_ORIGINS", default="", cast=lambda v: [s.strip() for s in v.split(',') if s.strip()] "CSRF_TRUSTED_ORIGINS", default="", cast=lambda v: [s.strip() for s in str(v).split(',') if s.strip()]
) )
# Application definition # Application definition

View File

@@ -76,10 +76,33 @@ dev = [
[tool.pyright] [tool.pyright]
stubPath = "stubs" stubPath = "stubs"
typeCheckingMode = "basic" include = ["."]
exclude = [
"**/node_modules",
"**/__pycache__",
"**/migrations",
"**/.venv",
"**/venv",
"**/.git",
"**/.hg",
"**/.tox",
"**/.nox",
]
typeCheckingMode = "strict"
reportIncompatibleMethodOverride = "error"
reportIncompatibleVariableOverride = "error"
reportGeneralTypeIssues = "error"
reportReturnType = "error"
reportMissingImports = "error"
reportMissingTypeStubs = "warning"
reportUndefinedVariable = "error"
reportUnusedImport = "warning"
reportUnusedVariable = "warning"
pythonVersion = "3.13"
[tool.pylance] [tool.pylance]
stubPath = "stubs" stubPath = "stubs"
[tool.uv.sources] [tool.uv.sources]
python-json-logger = { url = "https://github.com/nhairs/python-json-logger/releases/download/v3.0.0/python_json_logger-3.0.0-py3-none-any.whl" } python-json-logger = { url = "https://github.com/nhairs/python-json-logger/releases/download/v3.0.0/python_json_logger-3.0.0-py3-none-any.whl" }

View File

@@ -1,22 +0,0 @@
{
"include": [
"."
],
"exclude": [
"**/node_modules",
"**/__pycache__",
"**/migrations"
],
"stubPath": "stubs",
"typeCheckingMode": "strict",
"reportIncompatibleMethodOverride": "error",
"reportIncompatibleVariableOverride": "error",
"reportGeneralTypeIssues": "error",
"reportReturnType": "error",
"reportMissingImports": "error",
"reportMissingTypeStubs": "warning",
"reportUndefinedVariable": "error",
"reportUnusedImport": "warning",
"reportUnusedVariable": "warning",
"pythonVersion": "3.13"
}