Refactor account adapters and admin classes; enhance type hinting for better clarity and maintainability, ensuring consistent typing across methods and improving overall code quality.

This commit is contained in:
pacnpal
2025-09-27 11:59:29 -04:00
parent 31a2d84f9f
commit 679de16e4f
2 changed files with 119 additions and 79 deletions

View File

@@ -1,64 +1,95 @@
from django.conf import settings from django.conf import settings
from allauth.account.adapter import DefaultAccountAdapter from django.http import HttpRequest
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter from typing import Optional, Any, Dict, Literal, TYPE_CHECKING, cast
from allauth.account.adapter import DefaultAccountAdapter # type: ignore[import]
from allauth.account.models import EmailConfirmation, EmailAddress # type: ignore[import]
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter # type: ignore[import]
from allauth.socialaccount.models import SocialLogin # type: ignore[import]
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.sites.shortcuts import get_current_site from django.contrib.sites.shortcuts import get_current_site
if TYPE_CHECKING:
from django.contrib.auth.models import AbstractUser
User = get_user_model() User = get_user_model()
class CustomAccountAdapter(DefaultAccountAdapter): class CustomAccountAdapter(DefaultAccountAdapter):
def is_open_for_signup(self, request): def is_open_for_signup(self, request: HttpRequest) -> Literal[True]:
""" """
Whether to allow sign ups. Whether to allow sign ups.
""" """
return True return True
def get_email_confirmation_url(self, request, emailconfirmation): def get_email_confirmation_url(self, request: HttpRequest, emailconfirmation: EmailConfirmation) -> str:
""" """
Constructs the email confirmation (activation) url. Constructs the email confirmation (activation) url.
""" """
get_current_site(request) get_current_site(request)
return f"{settings.LOGIN_REDIRECT_URL}verify-email?key={emailconfirmation.key}" # Ensure the key is treated as a string for the type checker
key = cast(str, getattr(emailconfirmation, "key", ""))
return f"{settings.LOGIN_REDIRECT_URL}verify-email?key={key}"
def send_confirmation_mail(self, request, emailconfirmation, signup): def send_confirmation_mail(self, request: HttpRequest, emailconfirmation: EmailConfirmation, signup: bool) -> None:
""" """
Sends the confirmation email. Sends the confirmation email.
""" """
current_site = get_current_site(request) current_site = get_current_site(request)
activate_url = self.get_email_confirmation_url(request, emailconfirmation) activate_url = self.get_email_confirmation_url(request, emailconfirmation)
ctx = { # Cast key to str for typing consistency and template context
"user": emailconfirmation.email_address.user, key = cast(str, getattr(emailconfirmation, "key", ""))
"activate_url": activate_url,
"current_site": current_site, # Determine template early
"key": emailconfirmation.key,
}
if signup: if signup:
email_template = "account/email/email_confirmation_signup" email_template = "account/email/email_confirmation_signup"
else: else:
email_template = "account/email/email_confirmation" email_template = "account/email/email_confirmation"
self.send_mail(email_template, emailconfirmation.email_address.email, ctx)
# Cast the possibly-unknown email_address to EmailAddress so the type checker knows its attributes
email_address = cast(EmailAddress, getattr(emailconfirmation, "email_address", None))
# Safely obtain email string (fallback to any top-level email on confirmation)
email_str = cast(str, getattr(email_address, "email", getattr(emailconfirmation, "email", "")))
# Safely obtain the user object, cast to the project's User model for typing
user_obj = cast("AbstractUser", getattr(email_address, "user", None))
# Explicitly type the context to avoid partial-unknown typing issues
ctx: Dict[str, Any] = {
"user": user_obj,
"activate_url": activate_url,
"current_site": current_site,
"key": key,
}
# Remove unnecessary cast; ctx is already Dict[str, Any]
self.send_mail(email_template, email_str, ctx) # type: ignore
class CustomSocialAccountAdapter(DefaultSocialAccountAdapter): class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
def is_open_for_signup(self, request, sociallogin): def is_open_for_signup(self, request: HttpRequest, sociallogin: SocialLogin) -> Literal[True]:
""" """
Whether to allow social account sign ups. Whether to allow social account sign ups.
""" """
return True return True
def populate_user(self, request, sociallogin, data): def populate_user(
self, request: HttpRequest, sociallogin: SocialLogin, data: Dict[str, Any]
) -> "AbstractUser": # type: ignore[override]
""" """
Hook that can be used to further populate the user instance. Hook that can be used to further populate the user instance.
""" """
user = super().populate_user(request, sociallogin, data) user = super().populate_user(request, sociallogin, data) # type: ignore
if sociallogin.account.provider == "discord": if getattr(sociallogin.account, "provider", None) == "discord": # type: ignore
user.discord_id = sociallogin.account.uid user.discord_id = getattr(sociallogin.account, "uid", None) # type: ignore
return user return cast("AbstractUser", user) # Ensure return type is explicit
def save_user(self, request, sociallogin, form=None): def save_user(
self, request: HttpRequest, sociallogin: SocialLogin, form: Optional[Any] = None
) -> "AbstractUser": # type: ignore[override]
""" """
Save the newly signed up social login. Save the newly signed up social login.
""" """
user = super().save_user(request, sociallogin, form) user = super().save_user(request, sociallogin, form) # type: ignore
return user if user is None:
raise ValueError("User creation failed")
return cast("AbstractUser", user) # Ensure return type is explicit

View File

@@ -1,7 +1,10 @@
from typing import Any
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.admin import UserAdmin from django.contrib.auth.admin import UserAdmin as DjangoUserAdmin
from django.utils.html import format_html from django.utils.html import format_html
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.http import HttpRequest
from django.db.models import QuerySet
from .models import ( from .models import (
User, User,
UserProfile, UserProfile,
@@ -12,7 +15,7 @@ from .models import (
) )
class UserProfileInline(admin.StackedInline): class UserProfileInline(admin.StackedInline[UserProfile, admin.options.AdminSite]):
model = UserProfile model = UserProfile
can_delete = False can_delete = False
verbose_name_plural = "Profile" verbose_name_plural = "Profile"
@@ -39,7 +42,7 @@ class UserProfileInline(admin.StackedInline):
) )
class TopListItemInline(admin.TabularInline): class TopListItemInline(admin.TabularInline[TopListItem]):
model = TopListItem model = TopListItem
extra = 1 extra = 1
fields = ("content_type", "object_id", "rank", "notes") fields = ("content_type", "object_id", "rank", "notes")
@@ -47,7 +50,7 @@ class TopListItemInline(admin.TabularInline):
@admin.register(User) @admin.register(User)
class CustomUserAdmin(UserAdmin): class CustomUserAdmin(DjangoUserAdmin[User]):
list_display = ( list_display = (
"username", "username",
"email", "email",
@@ -74,7 +77,7 @@ class CustomUserAdmin(UserAdmin):
"ban_users", "ban_users",
"unban_users", "unban_users",
] ]
inlines = [UserProfileInline] inlines: list[type[admin.StackedInline[UserProfile]]] = [UserProfileInline]
fieldsets = ( fieldsets = (
(None, {"fields": ("username", "password")}), (None, {"fields": ("username", "password")}),
@@ -126,75 +129,82 @@ class CustomUserAdmin(UserAdmin):
) )
@admin.display(description="Avatar") @admin.display(description="Avatar")
def get_avatar(self, obj): def get_avatar(self, obj: User) -> str:
if obj.profile.avatar: profile = getattr(obj, "profile", None)
if profile and getattr(profile, "avatar", None):
return format_html( return format_html(
'<img src="{}" width="30" height="30" style="border-radius:50%;" />', '<img src="{0}" width="30" height="30" style="border-radius:50%;" />',
obj.profile.avatar.url, getattr(profile.avatar, "url", ""), # type: ignore
) )
return format_html( return format_html(
'<div style="width:30px; height:30px; border-radius:50%; ' '<div style="width:30px; height:30px; border-radius:50%; '
"background-color:#007bff; color:white; display:flex; " "background-color:#007bff; color:white; display:flex; "
'align-items:center; justify-content:center;">{}</div>', 'align-items:center; justify-content:center;">{0}</div>',
obj.username[0].upper(), getattr(obj, "username", "?")[0].upper(), # type: ignore
) )
@admin.display(description="Status") @admin.display(description="Status")
def get_status(self, obj): def get_status(self, obj: User) -> str:
if obj.is_banned: if getattr(obj, "is_banned", False):
return format_html('<span style="color: red;">Banned</span>') return format_html('<span style="color: red;">{}</span>', "Banned")
if not obj.is_active: if not getattr(obj, "is_active", True):
return format_html('<span style="color: orange;">Inactive</span>') return format_html('<span style="color: orange;">{}</span>', "Inactive")
if obj.is_superuser: if getattr(obj, "is_superuser", False):
return format_html('<span style="color: purple;">Superuser</span>') return format_html('<span style="color: purple;">{}</span>', "Superuser")
if obj.is_staff: if getattr(obj, "is_staff", False):
return format_html('<span style="color: blue;">Staff</span>') return format_html('<span style="color: blue;">{}</span>', "Staff")
return format_html('<span style="color: green;">Active</span>') return format_html('<span style="color: green;">{}</span>', "Active")
@admin.display(description="Ride Credits") @admin.display(description="Ride Credits")
def get_credits(self, obj): def get_credits(self, obj: User) -> str:
try: try:
profile = obj.profile profile = getattr(obj, "profile", None)
if not profile:
return "-"
return format_html( return format_html(
"RC: {}<br>DR: {}<br>FR: {}<br>WR: {}", "RC: {0}<br>DR: {1}<br>FR: {2}<br>WR: {3}",
profile.coaster_credits, getattr(profile, "coaster_credits", 0),
profile.dark_ride_credits, getattr(profile, "dark_ride_credits", 0),
profile.flat_ride_credits, getattr(profile, "flat_ride_credits", 0),
profile.water_ride_credits, getattr(profile, "water_ride_credits", 0),
) )
except UserProfile.DoesNotExist: except UserProfile.DoesNotExist:
return "-" return "-"
@admin.action(description="Activate selected users") @admin.action(description="Activate selected users")
def activate_users(self, request, queryset): def activate_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
queryset.update(is_active=True) queryset.update(is_active=True)
@admin.action(description="Deactivate selected users") @admin.action(description="Deactivate selected users")
def deactivate_users(self, request, queryset): def deactivate_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
queryset.update(is_active=False) queryset.update(is_active=False)
@admin.action(description="Ban selected users") @admin.action(description="Ban selected users")
def ban_users(self, request, queryset): def ban_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
from django.utils import timezone from django.utils import timezone
queryset.update(is_banned=True, ban_date=timezone.now()) queryset.update(is_banned=True, ban_date=timezone.now())
@admin.action(description="Unban selected users") @admin.action(description="Unban selected users")
def unban_users(self, request, queryset): def unban_users(self, request: HttpRequest, queryset: QuerySet[User]) -> None:
queryset.update(is_banned=False, ban_date=None, ban_reason="") queryset.update(is_banned=False, ban_date=None, ban_reason="")
def save_model(self, request, obj, form, change): def save_model(
self,
request: HttpRequest,
obj: User,
form: Any,
change: bool
) -> None:
creating = not obj.pk creating = not obj.pk
super().save_model(request, obj, form, change) super().save_model(request, obj, form, change)
if creating and obj.role != "USER": if creating and getattr(obj, "role", "USER") != "USER":
# Ensure new user with role gets added to appropriate group group = Group.objects.filter(name=getattr(obj, "role", None)).first()
group = Group.objects.filter(name=obj.role).first()
if group: if group:
obj.groups.add(group) obj.groups.add(group) # type: ignore[attr-defined]
@admin.register(UserProfile) @admin.register(UserProfile)
class UserProfileAdmin(admin.ModelAdmin): class UserProfileAdmin(admin.ModelAdmin[UserProfile]):
list_display = ( list_display = (
"user", "user",
"display_name", "display_name",
@@ -235,7 +245,7 @@ class UserProfileAdmin(admin.ModelAdmin):
@admin.register(EmailVerification) @admin.register(EmailVerification)
class EmailVerificationAdmin(admin.ModelAdmin): class EmailVerificationAdmin(admin.ModelAdmin[EmailVerification]):
list_display = ("user", "created_at", "last_sent", "is_expired") list_display = ("user", "created_at", "last_sent", "is_expired")
list_filter = ("created_at", "last_sent") list_filter = ("created_at", "last_sent")
search_fields = ("user__username", "user__email", "token") search_fields = ("user__username", "user__email", "token")
@@ -247,21 +257,21 @@ class EmailVerificationAdmin(admin.ModelAdmin):
) )
@admin.display(description="Status") @admin.display(description="Status")
def is_expired(self, obj): def is_expired(self, obj: EmailVerification) -> str:
from django.utils import timezone from django.utils import timezone
from datetime import timedelta from datetime import timedelta
if timezone.now() - obj.last_sent > timedelta(days=1): if timezone.now() - getattr(obj, "last_sent", timezone.now()) > timedelta(days=1):
return format_html('<span style="color: red;">Expired</span>') return format_html('<span style="color: red;">{}</span>', "Expired")
return format_html('<span style="color: green;">Valid</span>') return format_html('<span style="color: green;">{}</span>', "Valid")
@admin.register(TopList) @admin.register(TopList)
class TopListAdmin(admin.ModelAdmin): class TopListAdmin(admin.ModelAdmin[TopList]):
list_display = ("title", "user", "category", "created_at", "updated_at") list_display = ("title", "user", "category", "created_at", "updated_at")
list_filter = ("category", "created_at", "updated_at") list_filter = ("category", "created_at", "updated_at")
search_fields = ("title", "user__username", "description") search_fields = ("title", "user__username", "description")
inlines = [TopListItemInline] inlines: list[type[admin.TabularInline[TopListItem]]] = [TopListItemInline]
fieldsets = ( fieldsets = (
( (
@@ -277,7 +287,7 @@ class TopListAdmin(admin.ModelAdmin):
@admin.register(TopListItem) @admin.register(TopListItem)
class TopListItemAdmin(admin.ModelAdmin): class TopListItemAdmin(admin.ModelAdmin[TopListItem]):
list_display = ("top_list", "content_type", "object_id", "rank") list_display = ("top_list", "content_type", "object_id", "rank")
list_filter = ("top_list__category", "rank") list_filter = ("top_list__category", "rank")
search_fields = ("top_list__title", "notes") search_fields = ("top_list__title", "notes")
@@ -290,7 +300,7 @@ class TopListItemAdmin(admin.ModelAdmin):
@admin.register(PasswordReset) @admin.register(PasswordReset)
class PasswordResetAdmin(admin.ModelAdmin): class PasswordResetAdmin(admin.ModelAdmin[PasswordReset]):
"""Admin interface for password reset tokens""" """Admin interface for password reset tokens"""
list_display = ( list_display = (
@@ -341,20 +351,19 @@ class PasswordResetAdmin(admin.ModelAdmin):
) )
@admin.display(description="Status", boolean=True) @admin.display(description="Status", boolean=True)
def is_expired(self, obj): def is_expired(self, obj: PasswordReset) -> str:
"""Display expiration status with color coding"""
from django.utils import timezone from django.utils import timezone
if obj.used: if getattr(obj, "used", False):
return format_html('<span style="color: blue;">Used</span>') return format_html('<span style="color: blue;">{}</span>', "Used")
elif timezone.now() > obj.expires_at: elif timezone.now() > getattr(obj, "expires_at", timezone.now()):
return format_html('<span style="color: red;">Expired</span>') return format_html('<span style="color: red;">{}</span>', "Expired")
return format_html('<span style="color: green;">Valid</span>') return format_html('<span style="color: green;">{}</span>', "Valid")
def has_add_permission(self, request): def has_add_permission(self, request: HttpRequest) -> bool:
"""Disable manual creation of password reset tokens""" """Disable manual creation of password reset tokens"""
return False return False
def has_change_permission(self, request, obj=None): def has_change_permission(self, request: HttpRequest, obj: Any = None) -> bool:
"""Allow viewing but restrict editing of password reset tokens""" """Allow viewing but restrict editing of password reset tokens"""
return getattr(request.user, "is_superuser", False) return getattr(request.user, "is_superuser", False)