mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-22 02:51:08 -05:00
first commit
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,218 @@
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.validators import validate_email
|
||||
|
||||
from allauth.account import app_settings as account_app_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.forms import (
|
||||
AddEmailForm,
|
||||
BaseSignupForm,
|
||||
ConfirmLoginCodeForm,
|
||||
ReauthenticateForm,
|
||||
RequestLoginCodeForm,
|
||||
ResetPasswordForm,
|
||||
UserTokenForm,
|
||||
)
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.models import (
|
||||
EmailAddress,
|
||||
Login,
|
||||
get_emailconfirmation_model,
|
||||
)
|
||||
from allauth.core import context
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.headless.internal.restkit import inputs
|
||||
|
||||
|
||||
class SignupInput(BaseSignupForm, inputs.Input):
|
||||
password = inputs.CharField()
|
||||
|
||||
def clean_password(self):
|
||||
password = self.cleaned_data["password"]
|
||||
return get_account_adapter().clean_password(password)
|
||||
|
||||
|
||||
class LoginInput(inputs.Input):
|
||||
username = inputs.CharField(required=False)
|
||||
email = inputs.EmailField(required=False)
|
||||
password = inputs.CharField()
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
username = None
|
||||
email = None
|
||||
|
||||
if (
|
||||
account_app_settings.AUTHENTICATION_METHOD
|
||||
== account_app_settings.AuthenticationMethod.USERNAME
|
||||
):
|
||||
username = cleaned_data.get("username")
|
||||
missing_field = "username"
|
||||
elif (
|
||||
account_app_settings.AUTHENTICATION_METHOD
|
||||
== account_app_settings.AuthenticationMethod.EMAIL
|
||||
):
|
||||
email = cleaned_data.get("email")
|
||||
missing_field = "email"
|
||||
elif (
|
||||
account_app_settings.AUTHENTICATION_METHOD
|
||||
== account_app_settings.AuthenticationMethod.USERNAME_EMAIL
|
||||
):
|
||||
username = cleaned_data.get("username")
|
||||
email = cleaned_data.get("email")
|
||||
missing_field = "email"
|
||||
if email and username:
|
||||
raise get_adapter().validation_error("email_or_username")
|
||||
else:
|
||||
raise ImproperlyConfigured(account_app_settings.AUTHENTICATION_METHOD)
|
||||
|
||||
if not email and not username:
|
||||
if not self.errors.get(missing_field):
|
||||
self.add_error(
|
||||
missing_field, get_adapter().validation_error("required")
|
||||
)
|
||||
|
||||
password = cleaned_data.get("password")
|
||||
if password and (username or email):
|
||||
credentials = {"password": password}
|
||||
if email:
|
||||
credentials["email"] = email
|
||||
auth_method = account_app_settings.AuthenticationMethod.EMAIL
|
||||
else:
|
||||
credentials["username"] = username
|
||||
auth_method = account_app_settings.AuthenticationMethod.USERNAME
|
||||
user = get_account_adapter().authenticate(context.request, **credentials)
|
||||
if user:
|
||||
self.login = Login(user=user, email=credentials.get("email"))
|
||||
if flows.login.is_login_rate_limited(context.request, self.login):
|
||||
raise get_account_adapter().validation_error(
|
||||
"too_many_login_attempts"
|
||||
)
|
||||
else:
|
||||
error_code = "%s_password_mismatch" % auth_method.value
|
||||
self.add_error(
|
||||
"password", get_account_adapter().validation_error(error_code)
|
||||
)
|
||||
return cleaned_data
|
||||
|
||||
|
||||
class VerifyEmailInput(inputs.Input):
|
||||
key = inputs.CharField()
|
||||
|
||||
def clean_key(self):
|
||||
key = self.cleaned_data["key"]
|
||||
model = get_emailconfirmation_model()
|
||||
confirmation = model.from_key(key)
|
||||
valid = confirmation and not confirmation.key_expired()
|
||||
if not valid:
|
||||
raise get_account_adapter().validation_error(
|
||||
"incorrect_code"
|
||||
if account_app_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED
|
||||
else "invalid_or_expired_key"
|
||||
)
|
||||
if valid and not confirmation.email_address.can_set_verified():
|
||||
raise get_account_adapter().validation_error("email_taken")
|
||||
return confirmation
|
||||
|
||||
|
||||
class RequestPasswordResetInput(ResetPasswordForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class ResetPasswordKeyInput(inputs.Input):
|
||||
key = inputs.CharField()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def clean_key(self):
|
||||
key = self.cleaned_data["key"]
|
||||
uidb36, _, subkey = key.partition("-")
|
||||
token_form = UserTokenForm(data={"uidb36": uidb36, "key": subkey})
|
||||
if not token_form.is_valid():
|
||||
raise get_account_adapter().validation_error("invalid_password_reset")
|
||||
self.user = token_form.reset_user
|
||||
return key
|
||||
|
||||
|
||||
class ResetPasswordInput(ResetPasswordKeyInput):
|
||||
password = inputs.CharField()
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
password = self.cleaned_data.get("password")
|
||||
if self.user and password is not None:
|
||||
get_account_adapter().clean_password(password, user=self.user)
|
||||
return cleaned_data
|
||||
|
||||
|
||||
class ChangePasswordInput(inputs.Input):
|
||||
current_password = inputs.CharField(required=False)
|
||||
new_password = inputs.CharField()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["current_password"].required = self.user.has_usable_password()
|
||||
|
||||
def clean_current_password(self):
|
||||
current_password = self.cleaned_data["current_password"]
|
||||
if current_password:
|
||||
if not self.user.check_password(current_password):
|
||||
raise get_account_adapter().validation_error("enter_current_password")
|
||||
return current_password
|
||||
|
||||
def clean_new_password(self):
|
||||
new_password = self.cleaned_data["new_password"]
|
||||
adapter = get_account_adapter()
|
||||
return adapter.clean_password(new_password, user=self.user)
|
||||
|
||||
|
||||
class AddEmailInput(AddEmailForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class SelectEmailInput(inputs.Input):
|
||||
email = inputs.CharField()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def clean_email(self):
|
||||
email = self.cleaned_data["email"]
|
||||
validate_email(email)
|
||||
try:
|
||||
return EmailAddress.objects.get_for_user(user=self.user, email=email)
|
||||
except EmailAddress.DoesNotExist:
|
||||
raise get_adapter().validation_error("unknown_email")
|
||||
|
||||
|
||||
class DeleteEmailInput(SelectEmailInput):
|
||||
def clean_email(self):
|
||||
email = super().clean_email()
|
||||
if not flows.manage_email.can_delete_email(email):
|
||||
raise get_account_adapter().validation_error("cannot_remove_primary_email")
|
||||
return email
|
||||
|
||||
|
||||
class MarkAsPrimaryEmailInput(SelectEmailInput):
|
||||
primary = inputs.BooleanField(required=True)
|
||||
|
||||
def clean_email(self):
|
||||
email = super().clean_email()
|
||||
if not flows.manage_email.can_mark_as_primary(email):
|
||||
raise get_account_adapter().validation_error("unverified_primary_email")
|
||||
return email
|
||||
|
||||
|
||||
class ReauthenticateInput(ReauthenticateForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class RequestLoginCodeInput(RequestLoginCodeForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class ConfirmLoginCodeInput(ConfirmLoginCodeForm, inputs.Input):
|
||||
pass
|
||||
@@ -0,0 +1,44 @@
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.headless.base.response import APIResponse
|
||||
|
||||
|
||||
class RequestEmailVerificationResponse(APIResponse):
|
||||
def __init__(self, request, verification_sent):
|
||||
super().__init__(request, status=200 if verification_sent else 403)
|
||||
|
||||
|
||||
class VerifyEmailResponse(APIResponse):
|
||||
def __init__(self, request, verification, stage):
|
||||
adapter = get_adapter()
|
||||
data = {
|
||||
"email": verification.email_address.email,
|
||||
"user": adapter.serialize_user(verification.email_address.user),
|
||||
}
|
||||
meta = {
|
||||
"is_authenticating": stage is not None,
|
||||
}
|
||||
super().__init__(request, data=data, meta=meta)
|
||||
|
||||
|
||||
class EmailAddressesResponse(APIResponse):
|
||||
def __init__(self, request, email_addresses):
|
||||
data = [
|
||||
{
|
||||
"email": addr.email,
|
||||
"verified": addr.verified,
|
||||
"primary": addr.primary,
|
||||
}
|
||||
for addr in email_addresses
|
||||
]
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class RequestPasswordResponse(APIResponse):
|
||||
pass
|
||||
|
||||
|
||||
class PasswordResetKeyResponse(APIResponse):
|
||||
def __init__(self, request, user):
|
||||
adapter = get_adapter()
|
||||
data = {"user": adapter.serialize_user(user)}
|
||||
super().__init__(request, data=data)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,99 @@
|
||||
from allauth.account.models import EmailAddress
|
||||
|
||||
|
||||
def test_list_email(auth_client, user, headless_reverse):
|
||||
resp = auth_client.get(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
)
|
||||
assert len(resp.json()["data"]) == 1
|
||||
|
||||
|
||||
def test_remove_email(auth_client, user, email_factory, headless_reverse):
|
||||
addr = EmailAddress.objects.create(email=email_factory(), user=user)
|
||||
assert EmailAddress.objects.filter(user=user).count() == 2
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 1
|
||||
assert not EmailAddress.objects.filter(pk=addr.pk).exists()
|
||||
|
||||
|
||||
def test_add_email(auth_client, user, email_factory, headless_reverse):
|
||||
new_email = email_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": new_email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 2
|
||||
assert EmailAddress.objects.filter(email=new_email, verified=False).exists()
|
||||
|
||||
|
||||
def test_change_primary(auth_client, user, email_factory, headless_reverse):
|
||||
addr = EmailAddress.objects.create(
|
||||
email=email_factory(), user=user, verified=True, primary=False
|
||||
)
|
||||
resp = auth_client.patch(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email, "primary": True},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 2
|
||||
assert EmailAddress.objects.filter(pk=addr.pk, primary=True).exists()
|
||||
|
||||
|
||||
def test_resend_verification(
|
||||
auth_client, user, email_factory, headless_reverse, mailoutbox
|
||||
):
|
||||
addr = EmailAddress.objects.create(email=email_factory(), user=user, verified=False)
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(mailoutbox) == 1
|
||||
|
||||
|
||||
def test_email_rate_limit(
|
||||
auth_client, user, email_factory, headless_reverse, settings, enable_cache
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"manage_email": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
new_email = email_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": new_email},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 200 if attempt == 0 else 429
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
|
||||
|
||||
def test_resend_verification_rate_limit(
|
||||
auth_client,
|
||||
user,
|
||||
email_factory,
|
||||
headless_reverse,
|
||||
settings,
|
||||
enable_cache,
|
||||
mailoutbox,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"confirm_email": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
addr = EmailAddress.objects.create(
|
||||
email=email_factory(), user=user, verified=False
|
||||
)
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 403 if attempt else 200
|
||||
assert len(mailoutbox) == 1
|
||||
@@ -0,0 +1,220 @@
|
||||
import copy
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"has_password,request_data,response_data,status_code",
|
||||
[
|
||||
# Wrong current password
|
||||
(
|
||||
True,
|
||||
{"current_password": "wrong", "new_password": "{password_factory}"},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "current_password",
|
||||
"message": "Please type your current password.",
|
||||
"code": "enter_current_password",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Happy flow, regular password change
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "{user_password}",
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 200,
|
||||
"meta": {"is_authenticated": True},
|
||||
"data": {
|
||||
"user": ANY,
|
||||
"methods": [],
|
||||
},
|
||||
},
|
||||
200,
|
||||
),
|
||||
# New password does not match constraints
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "{user_password}",
|
||||
"new_password": "a",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "new_password",
|
||||
"code": "password_too_short",
|
||||
"message": "This password is too short. It must contain at least 6 characters.",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# New password not empty
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "{user_password}",
|
||||
"new_password": "",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "new_password",
|
||||
"code": "required",
|
||||
"message": "This field is required.",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Current password not blank
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "",
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "current_password",
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Current password missing
|
||||
(
|
||||
True,
|
||||
{
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "current_password",
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Current password not set, happy flow
|
||||
(
|
||||
False,
|
||||
{
|
||||
"current_password": "",
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 200,
|
||||
"meta": {"is_authenticated": True},
|
||||
"data": {
|
||||
"user": ANY,
|
||||
"methods": [],
|
||||
},
|
||||
},
|
||||
200,
|
||||
),
|
||||
# Current password not set, current_password absent
|
||||
(
|
||||
False,
|
||||
{
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 200,
|
||||
"meta": {"is_authenticated": True},
|
||||
"data": {
|
||||
"user": ANY,
|
||||
"methods": [],
|
||||
},
|
||||
},
|
||||
200,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_change_password(
|
||||
auth_client,
|
||||
user,
|
||||
request_data,
|
||||
response_data,
|
||||
status_code,
|
||||
has_password,
|
||||
user_password,
|
||||
password_factory,
|
||||
settings,
|
||||
mailoutbox,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
request_data = copy.deepcopy(request_data)
|
||||
response_data = copy.deepcopy(response_data)
|
||||
settings.ACCOUNT_EMAIL_NOTIFICATIONS = True
|
||||
if not has_password:
|
||||
user.set_unusable_password()
|
||||
user.save(update_fields=["password"])
|
||||
auth_client.force_login(user)
|
||||
if request_data.get("current_password") == "{user_password}":
|
||||
request_data["current_password"] = user_password
|
||||
if request_data.get("new_password") == "{password_factory}":
|
||||
request_data["new_password"] = password_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:change_password"),
|
||||
data=request_data,
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == status_code
|
||||
resp_json = resp.json()
|
||||
if headless_client == "app" and resp.status_code == 200:
|
||||
response_data["meta"]["session_token"] = ANY
|
||||
assert resp_json == response_data
|
||||
user.refresh_from_db()
|
||||
if resp.status_code == 200:
|
||||
assert user.check_password(request_data["new_password"])
|
||||
assert len(mailoutbox) == 1
|
||||
else:
|
||||
assert user.check_password(user_password)
|
||||
assert len(mailoutbox) == 0
|
||||
|
||||
|
||||
def test_change_password_rate_limit(
|
||||
enable_cache,
|
||||
auth_client,
|
||||
user,
|
||||
user_password,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"change_password": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
new_password = password_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:change_password"),
|
||||
data={
|
||||
"current_password": user_password,
|
||||
"new_password": new_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
user_password = new_password
|
||||
expected_status = 200 if attempt == 0 else 429
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
@@ -0,0 +1,86 @@
|
||||
from allauth.account.models import (
|
||||
EmailAddress,
|
||||
EmailConfirmationHMAC,
|
||||
get_emailconfirmation_model,
|
||||
)
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_verify_email_other_user(auth_client, user, user_factory, headless_reverse):
|
||||
other_user = user_factory(email_verified=False)
|
||||
email_address = EmailAddress.objects.get(user=other_user, verified=False)
|
||||
assert not email_address.verified
|
||||
key = EmailConfirmationHMAC(email_address).key
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# We're still authenticated as the user originally logged in, not the
|
||||
# other_user.
|
||||
assert data["data"]["user"]["id"] == user.pk
|
||||
|
||||
|
||||
def test_auth_unverified_email(
|
||||
client, user_factory, password_factory, settings, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
password = password_factory()
|
||||
user = user_factory(email_verified=False, password=password)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
flows = data["data"]["flows"]
|
||||
assert [f for f in flows if f["id"] == Flow.VERIFY_EMAIL][0]["is_pending"]
|
||||
emailaddress = EmailAddress.objects.filter(user=user, verified=False).get()
|
||||
key = get_emailconfirmation_model().create(emailaddress).key
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_verify_email_bad_key(
|
||||
client, settings, password_factory, user_factory, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
password = password_factory()
|
||||
user = user_factory(email_verified=False, password=password)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": "bad"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"code": "invalid_or_expired_key",
|
||||
"param": "key",
|
||||
"message": "Invalid or expired key.",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_email_verification_rate_limits(
|
||||
client,
|
||||
db,
|
||||
user_password,
|
||||
settings,
|
||||
user_factory,
|
||||
password_factory,
|
||||
enable_cache,
|
||||
headless_reverse,
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION_BY_CODE_ENABLED = True
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_RATE_LIMITS = {"confirm_email": "1/m/key"}
|
||||
email = "user@email.org"
|
||||
user_factory(email=email, email_verified=False, password=user_password)
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
if attempt == 0:
|
||||
assert resp.status_code == 401
|
||||
flow = [
|
||||
flow for flow in resp.json()["data"]["flows"] if flow.get("is_pending")
|
||||
][0]
|
||||
assert flow["id"] == Flow.VERIFY_EMAIL
|
||||
else:
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"message": "Too many failed login attempts. Try again later.",
|
||||
"code": "too_many_login_attempts",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_auth_password_input_error(headless_reverse, client):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
"param": "password",
|
||||
},
|
||||
{
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
"param": "username",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_auth_password_bad_password(headless_reverse, client, user, settings):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": "wrong",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "password",
|
||||
"message": "The email address and/or password you specified are not correct.",
|
||||
"code": "email_password_mismatch",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_auth_password_success(
|
||||
client, user, user_password, settings, headless_reverse, headless_client
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
login_resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
session_resp = client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert session_resp.status_code == 200
|
||||
for resp in [login_resp, session_resp]:
|
||||
extra_meta = {}
|
||||
if headless_client == "app" and resp == login_resp:
|
||||
# The session is created on first login, and hence the token is
|
||||
# exposed only at that moment.
|
||||
extra_meta["session_token"] = ANY
|
||||
assert resp.json() == {
|
||||
"status": 200,
|
||||
"data": {
|
||||
"user": {
|
||||
"id": user.pk,
|
||||
"display": str(user),
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"has_usable_password": True,
|
||||
},
|
||||
"methods": [
|
||||
{
|
||||
"at": ANY,
|
||||
"email": user.email,
|
||||
"method": "password",
|
||||
}
|
||||
],
|
||||
},
|
||||
"meta": {"is_authenticated": True, **extra_meta},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_active,status_code", [(False, 401), (True, 200)])
|
||||
def test_auth_password_user_inactive(
|
||||
client, user, user_password, settings, status_code, is_active, headless_reverse
|
||||
):
|
||||
user.is_active = is_active
|
||||
user.save(update_fields=["is_active"])
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"username": user.username,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == status_code
|
||||
|
||||
|
||||
def test_login_failed_rate_limit(
|
||||
client,
|
||||
user,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"login_failed": "1/m/ip"}
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": "wrong",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["errors"] == [
|
||||
(
|
||||
{
|
||||
"code": "email_password_mismatch",
|
||||
"message": "The email address and/or password you specified are not correct.",
|
||||
"param": "password",
|
||||
}
|
||||
if attempt == 0
|
||||
else {
|
||||
"message": "Too many failed login attempts. Try again later.",
|
||||
"code": "too_many_login_attempts",
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_login_rate_limit(
|
||||
client,
|
||||
user,
|
||||
user_password,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"login": "1/m/ip"}
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 200
|
||||
assert resp.status_code == expected_status
|
||||
|
||||
|
||||
def test_login_already_logged_in(
|
||||
auth_client, user, user_password, settings, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
@@ -0,0 +1,146 @@
|
||||
import time
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_login_by_code(headless_reverse, user, client, mailoutbox):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.LOGIN_BY_CODE][0][
|
||||
"is_pending"
|
||||
]
|
||||
assert len(mailoutbox) == 1
|
||||
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": code},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
|
||||
|
||||
def test_login_by_code_rate_limit(
|
||||
headless_reverse, user, client, mailoutbox, settings, enable_cache
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"request_login_code": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_code = 400 if attempt else 401
|
||||
assert resp.status_code == expected_code
|
||||
data = resp.json()
|
||||
assert data["status"] == expected_code
|
||||
if expected_code == 400:
|
||||
assert data["errors"] == [
|
||||
{
|
||||
"code": "too_many_login_attempts",
|
||||
"message": "Too many failed login attempts. Try again later.",
|
||||
"param": "email",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_login_by_code_max_attemps(headless_reverse, user, client, settings):
|
||||
settings.ACCOUNT_LOGIN_BY_CODE_MAX_ATTEMPTS = 2
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
for i in range(3):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": "wrong"},
|
||||
content_type="application/json",
|
||||
)
|
||||
session_resp = client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
data={"code": "wrong"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert session_resp.status_code == 401
|
||||
pending_flows = [
|
||||
f for f in session_resp.json()["data"]["flows"] if f.get("is_pending")
|
||||
]
|
||||
if i >= 1:
|
||||
assert resp.status_code == 409 if i >= 2 else 400
|
||||
assert len(pending_flows) == 0
|
||||
else:
|
||||
assert resp.status_code == 400
|
||||
assert len(pending_flows) == 1
|
||||
|
||||
|
||||
def test_login_by_code_required(
|
||||
client, settings, user_factory, password_factory, headless_reverse, mailoutbox
|
||||
):
|
||||
settings.ACCOUNT_LOGIN_BY_CODE_REQUIRED = True
|
||||
password = password_factory()
|
||||
user = user_factory(password=password, email_verified=False)
|
||||
email_address = EmailAddress.objects.get(email=user.email)
|
||||
assert not email_address.verified
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"username": user.username,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0][
|
||||
"id"
|
||||
]
|
||||
assert pending_flow == Flow.LOGIN_BY_CODE
|
||||
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": code},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
email_address.refresh_from_db()
|
||||
assert email_address.verified
|
||||
|
||||
|
||||
def test_login_by_code_expired(headless_reverse, user, client, mailoutbox):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.LOGIN_BY_CODE][0][
|
||||
"is_pending"
|
||||
]
|
||||
assert len(mailoutbox) == 1
|
||||
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
|
||||
|
||||
# Expire code
|
||||
session = client.headless_session()
|
||||
login = session["account_login"]
|
||||
login["state"]["login_code"]["at"] = time.time() - 24 * 60 * 60
|
||||
session["account_login"] = login
|
||||
session.save()
|
||||
|
||||
# Post valid code
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": code},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
@@ -0,0 +1,52 @@
|
||||
def test_reauthenticate(
|
||||
auth_client, user, user_password, headless_reverse, headless_client
|
||||
):
|
||||
resp = auth_client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
method_count = len(data["data"]["methods"])
|
||||
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:reauthenticate"),
|
||||
data={
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = auth_client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["data"]["methods"]) == method_count + 1
|
||||
last_method = data["data"]["methods"][-1]
|
||||
assert last_method["method"] == "password"
|
||||
|
||||
|
||||
def test_reauthenticate_rate_limit(
|
||||
auth_client,
|
||||
user,
|
||||
user_password,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
settings,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"reauthenticate": "1/m/ip"}
|
||||
for attempt in range(4):
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:reauthenticate"),
|
||||
data={
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 200
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
@@ -0,0 +1,151 @@
|
||||
from django.urls import reverse
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_password_reset_flow(
|
||||
client, user, mailoutbox, password_factory, settings, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_EMAIL_NOTIFICATIONS = True
|
||||
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_password_reset"),
|
||||
data={
|
||||
"email": user.email,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(mailoutbox) == 1
|
||||
body = mailoutbox[0].body
|
||||
# Extract URL for `password_reset_from_key` view
|
||||
url = body[body.find("/password/reset/") :].split()[0]
|
||||
key = url.split("/")[-2]
|
||||
password = password_factory()
|
||||
|
||||
# Too simple password
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": key,
|
||||
"password": "a",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"code": "password_too_short",
|
||||
"message": "This password is too short. It must contain at least 6 characters.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert len(mailoutbox) == 1
|
||||
|
||||
# Success
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": key,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
user.refresh_from_db()
|
||||
assert user.check_password(password)
|
||||
assert len(mailoutbox) == 2 # The security notification
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["get", "post"])
|
||||
def test_password_reset_flow_wrong_key(
|
||||
client, password_factory, headless_reverse, method
|
||||
):
|
||||
password = password_factory()
|
||||
|
||||
if method == "get":
|
||||
resp = client.get(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
HTTP_X_PASSWORD_RESET_KEY="wrong",
|
||||
)
|
||||
else:
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": "wrong",
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "key",
|
||||
"code": "invalid_password_reset",
|
||||
"message": "The password reset token was invalid.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_password_reset_flow_unknown_user(
|
||||
client, db, mailoutbox, password_factory, settings, headless_reverse
|
||||
):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_password_reset"),
|
||||
data={
|
||||
"email": "not@registered.org",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(mailoutbox) == 1
|
||||
body = mailoutbox[0].body
|
||||
if getattr(settings, "HEADLESS_ONLY", False):
|
||||
assert settings.HEADLESS_FRONTEND_URLS["account_signup"] in body
|
||||
else:
|
||||
assert reverse("account_signup") in body
|
||||
|
||||
|
||||
def test_reset_password_rate_limit(
|
||||
auth_client, user, headless_reverse, settings, enable_cache
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"reset_password": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:request_password_reset"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 200 if attempt == 0 else 429
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
|
||||
|
||||
def test_password_reset_key_rate_limit(
|
||||
client,
|
||||
user,
|
||||
settings,
|
||||
headless_reverse,
|
||||
password_reset_key_generator,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"reset_password_from_key": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": password_reset_key_generator(user),
|
||||
"password": "a", # too short
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 400
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
@@ -0,0 +1,33 @@
|
||||
from django.test.client import Client
|
||||
from django.urls import reverse
|
||||
|
||||
|
||||
def test_app_session_gone(db, user):
|
||||
# intentionally use a vanilla Django test client
|
||||
client = Client()
|
||||
# Force login, creates a Django session
|
||||
client.force_login(user)
|
||||
# That Django session should not play any role.
|
||||
resp = client.get(
|
||||
reverse("headless:app:account:current_session"), HTTP_X_SESSION_TOKEN="gone"
|
||||
)
|
||||
assert resp.status_code == 410
|
||||
|
||||
|
||||
def test_logout(auth_client, headless_reverse):
|
||||
# That Django session should not play any role.
|
||||
resp = auth_client.get(headless_reverse("headless:account:current_session"))
|
||||
assert resp.status_code == 200
|
||||
resp = auth_client.delete(headless_reverse("headless:account:current_session"))
|
||||
assert resp.status_code == 401
|
||||
resp = auth_client.get(headless_reverse("headless:account:current_session"))
|
||||
assert resp.status_code in [401, 410]
|
||||
|
||||
|
||||
def test_logout_no_token(app_client, user):
|
||||
app_client.force_login(user)
|
||||
resp = app_client.get(reverse("headless:app:account:current_session"))
|
||||
assert resp.status_code == 200
|
||||
resp = app_client.delete(reverse("headless:app:account:current_session"))
|
||||
assert resp.status_code == 401
|
||||
assert "session_token" not in resp.json()["meta"]
|
||||
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from allauth.account.models import EmailAddress, EmailConfirmationHMAC
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_signup(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": "wizard",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert User.objects.filter(username="wizard").exists()
|
||||
|
||||
|
||||
def test_signup_with_email_verification(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_USERNAME_REQUIRED = False
|
||||
email = email_factory()
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"email": email,
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert User.objects.filter(email=email).exists()
|
||||
data = resp.json()
|
||||
flow = next((f for f in data["data"]["flows"] if f.get("is_pending")))
|
||||
assert flow["id"] == "verify_email"
|
||||
addr = EmailAddress.objects.get(email=email)
|
||||
key = EmailConfirmationHMAC(addr).key
|
||||
resp = client.get(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
HTTP_X_EMAIL_VERIFICATION_KEY=key,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {
|
||||
"data": {
|
||||
"email": email,
|
||||
"user": ANY,
|
||||
},
|
||||
"meta": {"is_authenticating": True},
|
||||
"status": 200,
|
||||
}
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
addr.refresh_from_db()
|
||||
assert addr.verified
|
||||
|
||||
|
||||
def test_signup_prevent_enumeration(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
user,
|
||||
mailoutbox,
|
||||
):
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_USERNAME_REQUIRED = False
|
||||
settings.ACCOUNT_PREVENT_ENUMERATION = True
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert len(mailoutbox) == 1
|
||||
assert "an account using that email address already exists" in mailoutbox[0].body
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
|
||||
"is_pending"
|
||||
]
|
||||
resp = client.get(headless_reverse("headless:account:current_session"))
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
|
||||
"is_pending"
|
||||
]
|
||||
|
||||
|
||||
def test_signup_rate_limit(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
enable_cache,
|
||||
headless_client,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"signup": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": f"wizard{attempt}",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 200
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
|
||||
|
||||
def test_signup_closed(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
with patch(
|
||||
"allauth.account.adapter.DefaultAccountAdapter.is_open_for_signup"
|
||||
) as iofs:
|
||||
iofs.return_value = False
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": "wizard",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert not User.objects.filter(username="wizard").exists()
|
||||
|
||||
|
||||
def test_signup_while_logged_in(
|
||||
db,
|
||||
auth_client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": "wizard",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert resp.json() == {"status": 409}
|
||||
@@ -0,0 +1,96 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.headless.account import views
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
account_patterns = []
|
||||
auth_patterns = [
|
||||
path(
|
||||
"session",
|
||||
views.SessionView.as_api_view(client=client),
|
||||
name="current_session",
|
||||
),
|
||||
path(
|
||||
"reauthenticate",
|
||||
views.ReauthenticateView.as_api_view(client=client),
|
||||
name="reauthenticate",
|
||||
),
|
||||
path(
|
||||
"code/confirm",
|
||||
views.ConfirmLoginCodeView.as_api_view(client=client),
|
||||
name="confirm_login_code",
|
||||
),
|
||||
]
|
||||
if not allauth_settings.SOCIALACCOUNT_ONLY:
|
||||
account_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"password/change",
|
||||
views.ChangePasswordView.as_api_view(client=client),
|
||||
name="change_password",
|
||||
),
|
||||
path(
|
||||
"email",
|
||||
views.ManageEmailView.as_api_view(client=client),
|
||||
name="manage_email",
|
||||
),
|
||||
]
|
||||
)
|
||||
auth_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"password/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"request",
|
||||
views.RequestPasswordResetView.as_api_view(
|
||||
client=client
|
||||
),
|
||||
name="request_password_reset",
|
||||
),
|
||||
path(
|
||||
"reset",
|
||||
views.ResetPasswordView.as_api_view(client=client),
|
||||
name="reset_password",
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
path(
|
||||
"login",
|
||||
views.LoginView.as_api_view(client=client),
|
||||
name="login",
|
||||
),
|
||||
path(
|
||||
"signup",
|
||||
views.SignupView.as_api_view(client=client),
|
||||
name="signup",
|
||||
),
|
||||
path(
|
||||
"email/verify",
|
||||
views.VerifyEmailView.as_api_view(client=client),
|
||||
name="verify_email",
|
||||
),
|
||||
]
|
||||
)
|
||||
if account_settings.LOGIN_BY_CODE_ENABLED:
|
||||
auth_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"code/request",
|
||||
views.RequestLoginCodeView.as_api_view(client=client),
|
||||
name="request_login_code",
|
||||
),
|
||||
]
|
||||
)
|
||||
return [
|
||||
path("auth/", include(auth_patterns)),
|
||||
path(
|
||||
"account/",
|
||||
include(account_patterns),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,256 @@
|
||||
from django.utils.decorators import method_decorator
|
||||
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.internal.flows import password_change, password_reset
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.account.stages import EmailVerificationStage, LoginStageController
|
||||
from allauth.account.utils import send_email_confirmation
|
||||
from allauth.core import ratelimit
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.decorators import rate_limit
|
||||
from allauth.headless.account import response
|
||||
from allauth.headless.account.inputs import (
|
||||
AddEmailInput,
|
||||
ChangePasswordInput,
|
||||
ConfirmLoginCodeInput,
|
||||
DeleteEmailInput,
|
||||
LoginInput,
|
||||
MarkAsPrimaryEmailInput,
|
||||
ReauthenticateInput,
|
||||
RequestLoginCodeInput,
|
||||
RequestPasswordResetInput,
|
||||
ResetPasswordInput,
|
||||
ResetPasswordKeyInput,
|
||||
SelectEmailInput,
|
||||
SignupInput,
|
||||
VerifyEmailInput,
|
||||
)
|
||||
from allauth.headless.base.response import (
|
||||
APIResponse,
|
||||
AuthenticationResponse,
|
||||
ConflictResponse,
|
||||
ForbiddenResponse,
|
||||
)
|
||||
from allauth.headless.base.views import APIView, AuthenticatedAPIView
|
||||
from allauth.headless.internal import authkit
|
||||
from allauth.headless.internal.restkit.response import ErrorResponse
|
||||
|
||||
|
||||
class RequestLoginCodeView(APIView):
|
||||
input_class = RequestLoginCodeInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.login_by_code.request_login_code(
|
||||
self.request, self.input.cleaned_data["email"]
|
||||
)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class ConfirmLoginCodeView(APIView):
|
||||
input_class = ConfirmLoginCodeInput
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
auth_status = authkit.AuthenticationStatus(request)
|
||||
self.stage = auth_status.get_pending_stage()
|
||||
if not self.stage:
|
||||
return ConflictResponse(request)
|
||||
self.user, self.pending_login = flows.login_by_code.get_pending_login(
|
||||
request, self.stage.login, peek=True
|
||||
)
|
||||
if not self.pending_login:
|
||||
return ConflictResponse(request)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.login_by_code.perform_login_by_code(self.request, self.stage, None)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
kwargs = super().get_input_kwargs()
|
||||
kwargs["code"] = (
|
||||
self.pending_login.get("code", "") if self.pending_login else ""
|
||||
)
|
||||
return kwargs
|
||||
|
||||
def handle_invalid_input(self, input):
|
||||
flows.login_by_code.record_invalid_attempt(self.request, self.stage.login)
|
||||
return super().handle_invalid_input(input)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="login"), name="handle")
|
||||
class LoginView(APIView):
|
||||
input_class = LoginInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
if request.user.is_authenticated:
|
||||
return ConflictResponse(request)
|
||||
credentials = self.input.cleaned_data
|
||||
flows.login.perform_password_login(request, credentials, self.input.login)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="signup"), name="handle")
|
||||
class SignupView(APIView):
|
||||
input_class = {"POST": SignupInput}
|
||||
by_passkey = False
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
if request.user.is_authenticated:
|
||||
return ConflictResponse(request)
|
||||
if not get_account_adapter().is_open_for_signup(request):
|
||||
return ForbiddenResponse(request)
|
||||
user, resp = self.input.try_save(request)
|
||||
if not resp:
|
||||
try:
|
||||
flows.signup.complete_signup(
|
||||
request, user=user, by_passkey=self.by_passkey
|
||||
)
|
||||
except ImmediateHttpResponse:
|
||||
pass
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
|
||||
class SessionView(APIView):
|
||||
def get(self, request, *args, **kwargs):
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
adapter = get_account_adapter()
|
||||
adapter.logout(request)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
|
||||
class VerifyEmailView(APIView):
|
||||
input_class = VerifyEmailInput
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
self.stage = LoginStageController.enter(request, EmailVerificationStage.key)
|
||||
return super().handle(request, *args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
key = request.headers.get("x-email-verification-key", "")
|
||||
input = self.input_class({"key": key})
|
||||
if not input.is_valid():
|
||||
return ErrorResponse(request, input=input)
|
||||
verification = input.cleaned_data["key"]
|
||||
return response.VerifyEmailResponse(request, verification, stage=self.stage)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
confirmation = self.input.cleaned_data["key"]
|
||||
email_address = confirmation.confirm(request)
|
||||
if not email_address:
|
||||
# Should not happen, VerifyInputInput should have verified all
|
||||
# preconditions.
|
||||
return APIResponse(status=500)
|
||||
if self.stage:
|
||||
# Verifying email as part of login/signup flow, so emit a
|
||||
# authentication status response.
|
||||
self.stage.exit()
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class RequestPasswordResetView(APIView):
|
||||
input_class = RequestPasswordResetInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
r429 = ratelimit.consume_or_429(
|
||||
self.request,
|
||||
action="reset_password",
|
||||
key=self.input.cleaned_data["email"].lower(),
|
||||
)
|
||||
if r429:
|
||||
return r429
|
||||
self.input.save(request)
|
||||
return response.RequestPasswordResponse(request)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="reset_password_from_key"), name="handle")
|
||||
class ResetPasswordView(APIView):
|
||||
input_class = ResetPasswordInput
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
key = request.headers.get("X-Password-Reset-Key", "")
|
||||
input = ResetPasswordKeyInput({"key": key})
|
||||
if not input.is_valid():
|
||||
return ErrorResponse(request, input=input)
|
||||
return response.PasswordResetKeyResponse(request, input.user)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.password_reset.reset_password(
|
||||
self.input.user, self.input.cleaned_data["password"]
|
||||
)
|
||||
password_reset.finalize_password_reset(request, self.input.user)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="change_password"), name="handle")
|
||||
class ChangePasswordView(AuthenticatedAPIView):
|
||||
input_class = ChangePasswordInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
password_change.change_password(
|
||||
self.request.user, self.input.cleaned_data["new_password"]
|
||||
)
|
||||
is_set = not self.input.cleaned_data.get("current_password")
|
||||
if is_set:
|
||||
password_change.finalize_password_set(request, request.user)
|
||||
else:
|
||||
password_change.finalize_password_change(request, request.user)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="manage_email"), name="handle")
|
||||
class ManageEmailView(AuthenticatedAPIView):
|
||||
input_class = {
|
||||
"POST": AddEmailInput,
|
||||
"PUT": SelectEmailInput,
|
||||
"DELETE": DeleteEmailInput,
|
||||
"PATCH": MarkAsPrimaryEmailInput,
|
||||
}
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self._respond_email_list()
|
||||
|
||||
def _respond_email_list(self):
|
||||
addrs = EmailAddress.objects.filter(user=self.request.user)
|
||||
return response.EmailAddressesResponse(self.request, addrs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.manage_email.add_email(request, self.input)
|
||||
return self._respond_email_list()
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
addr = self.input.cleaned_data["email"]
|
||||
flows.manage_email.delete_email(request, addr)
|
||||
return self._respond_email_list()
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
addr = self.input.cleaned_data["email"]
|
||||
flows.manage_email.mark_as_primary(request, addr)
|
||||
return self._respond_email_list()
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
addr = self.input.cleaned_data["email"]
|
||||
sent = send_email_confirmation(request, request.user, email=addr.email)
|
||||
return response.RequestEmailVerificationResponse(
|
||||
request, verification_sent=sent
|
||||
)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="reauthenticate"), name="handle")
|
||||
class ReauthenticateView(AuthenticatedAPIView):
|
||||
input_class = ReauthenticateInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.reauthentication.reauthenticate_by_password(self.request)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from django.forms.fields import Field
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.account.utils import user_display, user_username
|
||||
from allauth.core.internal.adapter import BaseAdapter
|
||||
from allauth.headless import app_settings
|
||||
from allauth.utils import import_attribute
|
||||
|
||||
|
||||
class DefaultHeadlessAdapter(BaseAdapter):
|
||||
"""The adapter class allows you to override various functionality of the
|
||||
``allauth.headless`` app. To do so, point ``settings.HEADLESS_ADAPTER`` to your own
|
||||
class that derives from ``DefaultHeadlessAdapter`` and override the behavior by
|
||||
altering the implementation of the methods according to your own need.
|
||||
"""
|
||||
|
||||
error_messages = {
|
||||
# For the following error messages i18n is not an issue as these should not be
|
||||
# showing up in a UI.
|
||||
"account_not_found": "Unknown account.",
|
||||
"client_id_required": "`client_id` required.",
|
||||
"email_or_username": "Pass only one of email or username, not both.",
|
||||
"invalid_token": "Invalid token.",
|
||||
"token_authentication_not_supported": "Provider does not support token authentication.",
|
||||
"token_required": "`id_token` and/or `access_token` required.",
|
||||
"required": Field.default_error_messages["required"],
|
||||
"unknown_email": "Unknown email address.",
|
||||
"invalid_url": "Invalid URL.",
|
||||
}
|
||||
|
||||
def serialize_user(self, user) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the basic user data. Note that this data is also exposed in
|
||||
partly authenticated scenario's (e.g. password reset, email
|
||||
verification).
|
||||
"""
|
||||
ret = {
|
||||
"id": user.pk,
|
||||
"display": user_display(user),
|
||||
"has_usable_password": user.has_usable_password(),
|
||||
}
|
||||
email = EmailAddress.objects.get_primary_email(user)
|
||||
if email:
|
||||
ret["email"] = email
|
||||
username = user_username(user)
|
||||
if username:
|
||||
ret["username"] = username
|
||||
return ret
|
||||
|
||||
|
||||
def get_adapter():
|
||||
return import_attribute(app_settings.ADAPTER)()
|
||||
@@ -0,0 +1,36 @@
|
||||
class AppSettings:
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
|
||||
def _setting(self, name, dflt):
|
||||
from allauth.utils import get_setting
|
||||
|
||||
return get_setting(self.prefix + name, dflt)
|
||||
|
||||
@property
|
||||
def ADAPTER(self):
|
||||
return self._setting(
|
||||
"ADAPTER", "allauth.headless.adapter.DefaultHeadlessAdapter"
|
||||
)
|
||||
|
||||
@property
|
||||
def TOKEN_STRATEGY(self):
|
||||
from allauth.utils import import_attribute
|
||||
|
||||
path = self._setting(
|
||||
"TOKEN_STRATEGY", "allauth.headless.tokens.sessions.SessionTokenStrategy"
|
||||
)
|
||||
cls = import_attribute(path)
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def FRONTEND_URLS(self):
|
||||
return self._setting("FRONTEND_URLS", {})
|
||||
|
||||
|
||||
_app_settings = AppSettings("HEADLESS_")
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
# See https://peps.python.org/pep-0562/
|
||||
return getattr(_app_settings, name)
|
||||
@@ -0,0 +1,7 @@
|
||||
from django.apps import AppConfig
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class HeadlessConfig(AppConfig):
|
||||
name = "allauth.headless"
|
||||
verbose_name = _("Headless")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,155 @@
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.authentication import get_authentication_records
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.internal.stagekit import LOGIN_SESSION_KEY
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.headless.constants import Flow
|
||||
from allauth.headless.internal import authkit
|
||||
from allauth.headless.internal.restkit.response import APIResponse
|
||||
from allauth.mfa import app_settings as mfa_settings
|
||||
|
||||
|
||||
class BaseAuthenticationResponse(APIResponse):
|
||||
def __init__(self, request, user=None, status=None):
|
||||
data = {}
|
||||
if user and user.is_authenticated:
|
||||
adapter = get_adapter()
|
||||
data["user"] = adapter.serialize_user(user)
|
||||
data["methods"] = get_authentication_records(request)
|
||||
status = status or 200
|
||||
else:
|
||||
status = status or 401
|
||||
if status != 200:
|
||||
data["flows"] = self._get_flows(request, user)
|
||||
meta = {
|
||||
"is_authenticated": user and user.is_authenticated,
|
||||
}
|
||||
super().__init__(
|
||||
request,
|
||||
data=data,
|
||||
meta=meta,
|
||||
status=status,
|
||||
)
|
||||
|
||||
def _get_flows(self, request, user):
|
||||
auth_status = authkit.AuthenticationStatus(request)
|
||||
ret = []
|
||||
if user and user.is_authenticated:
|
||||
ret.extend(flows.reauthentication.get_reauthentication_flows(user))
|
||||
else:
|
||||
if not allauth_settings.SOCIALACCOUNT_ONLY:
|
||||
ret.append({"id": Flow.LOGIN})
|
||||
if account_settings.LOGIN_BY_CODE_ENABLED:
|
||||
ret.append({"id": Flow.LOGIN_BY_CODE})
|
||||
if (
|
||||
get_account_adapter().is_open_for_signup(request)
|
||||
and not allauth_settings.SOCIALACCOUNT_ONLY
|
||||
):
|
||||
ret.append({"id": Flow.SIGNUP})
|
||||
if allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
from allauth.headless.socialaccount.response import (
|
||||
provider_flows,
|
||||
)
|
||||
|
||||
ret.extend(provider_flows(request))
|
||||
if allauth_settings.MFA_ENABLED:
|
||||
if mfa_settings.PASSKEY_LOGIN_ENABLED:
|
||||
ret.append({"id": Flow.MFA_LOGIN_WEBAUTHN})
|
||||
stage_key = None
|
||||
stage = auth_status.get_pending_stage()
|
||||
if stage:
|
||||
stage_key = stage.key
|
||||
else:
|
||||
lsk = request.session.get(LOGIN_SESSION_KEY)
|
||||
if isinstance(lsk, str):
|
||||
stage_key = lsk
|
||||
if stage_key:
|
||||
pending_flow = {"id": stage_key, "is_pending": True}
|
||||
if stage and stage_key == Flow.MFA_AUTHENTICATE:
|
||||
self._enrich_mfa_flow(stage, pending_flow)
|
||||
self._upsert_pending_flow(ret, pending_flow)
|
||||
return ret
|
||||
|
||||
def _upsert_pending_flow(self, flows, pending_flow):
|
||||
flow = next((flow for flow in flows if flow["id"] == pending_flow["id"]), None)
|
||||
if flow:
|
||||
flow.update(pending_flow)
|
||||
else:
|
||||
flows.append(pending_flow)
|
||||
|
||||
def _enrich_mfa_flow(self, stage, flow: dict) -> None:
|
||||
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
adapter = get_mfa_adapter()
|
||||
types = []
|
||||
for typ in Authenticator.Type:
|
||||
if adapter.is_mfa_enabled(stage.login.user, types=[typ]):
|
||||
types.append(typ)
|
||||
flow["types"] = types
|
||||
|
||||
|
||||
class AuthenticationResponse(BaseAuthenticationResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, user=request.user)
|
||||
|
||||
|
||||
class ReauthenticationResponse(BaseAuthenticationResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, user=request.user, status=401)
|
||||
|
||||
|
||||
class UnauthorizedResponse(BaseAuthenticationResponse):
|
||||
def __init__(self, request, status=401):
|
||||
super().__init__(request, user=None, status=status)
|
||||
|
||||
|
||||
class ForbiddenResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=403)
|
||||
|
||||
|
||||
class ConflictResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=409)
|
||||
|
||||
|
||||
def get_config_data(request):
|
||||
data = {
|
||||
"authentication_method": account_settings.AUTHENTICATION_METHOD,
|
||||
"is_open_for_signup": get_account_adapter().is_open_for_signup(request),
|
||||
"email_verification_by_code_enabled": account_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED,
|
||||
"login_by_code_enabled": account_settings.LOGIN_BY_CODE_ENABLED,
|
||||
}
|
||||
return {"account": data}
|
||||
|
||||
|
||||
class ConfigResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
data = get_config_data(request)
|
||||
if allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
from allauth.headless.socialaccount.response import (
|
||||
get_config_data as get_socialaccount_config_data,
|
||||
)
|
||||
|
||||
data.update(get_socialaccount_config_data(request))
|
||||
if allauth_settings.MFA_ENABLED:
|
||||
from allauth.headless.mfa.response import (
|
||||
get_config_data as get_mfa_config_data,
|
||||
)
|
||||
|
||||
data.update(get_mfa_config_data(request))
|
||||
if allauth_settings.USERSESSIONS_ENABLED:
|
||||
from allauth.headless.usersessions.response import (
|
||||
get_config_data as get_usersessions_config_data,
|
||||
)
|
||||
|
||||
data.update(get_usersessions_config_data(request))
|
||||
return super().__init__(request, data=data)
|
||||
|
||||
|
||||
class RateLimitResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=429)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,10 @@
|
||||
def test_config(db, client, headless_reverse):
|
||||
resp = client.get(headless_reverse("headless:config"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert set(data["data"].keys()) == {
|
||||
"account",
|
||||
"mfa",
|
||||
"socialaccount",
|
||||
"usersessions",
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
from django.urls import path
|
||||
|
||||
from allauth.headless.base import views
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
return [
|
||||
path(
|
||||
"config",
|
||||
views.ConfigView.as_api_view(client=client),
|
||||
name="config",
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from django.utils.decorators import classonlymethod
|
||||
|
||||
from allauth.account.stages import LoginStage, LoginStageController
|
||||
from allauth.core.exceptions import ReauthenticationRequired
|
||||
from allauth.headless.base import response
|
||||
from allauth.headless.constants import Client
|
||||
from allauth.headless.internal import decorators
|
||||
from allauth.headless.internal.restkit.views import RESTView
|
||||
|
||||
|
||||
class APIView(RESTView):
|
||||
client = None
|
||||
|
||||
@classonlymethod
|
||||
def as_api_view(cls, **initkwargs):
|
||||
view_func = cls.as_view(**initkwargs)
|
||||
if initkwargs["client"] == Client.APP:
|
||||
view_func = decorators.app_view(view_func)
|
||||
else:
|
||||
view_func = decorators.browser_view(view_func)
|
||||
return view_func
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
except ReauthenticationRequired:
|
||||
return response.ReauthenticationResponse(self.request)
|
||||
|
||||
|
||||
class AuthenticationStageAPIView(APIView):
|
||||
stage_class: Optional[Type[LoginStage]] = None
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
self.stage = LoginStageController.enter(request, self.stage_class.key)
|
||||
if not self.stage:
|
||||
return response.UnauthorizedResponse(request)
|
||||
return super().handle(request, *args, **kwargs)
|
||||
|
||||
def respond_stage_error(self):
|
||||
return response.UnauthorizedResponse(self.request)
|
||||
|
||||
def respond_next_stage(self):
|
||||
self.stage.exit()
|
||||
return response.AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class AuthenticatedAPIView(APIView):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if not request.user.is_authenticated:
|
||||
return response.AuthenticationResponse(request)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class ConfigView(APIView):
|
||||
def get(self, request, *args, **kwargs):
|
||||
"""
|
||||
The frontend queries (GET) this endpoint, expecting to receive
|
||||
either a 401 if no user is authenticated, or user information.
|
||||
"""
|
||||
return response.ConfigResponse(request)
|
||||
@@ -0,0 +1,63 @@
|
||||
from django.test.client import Client
|
||||
from django.urls import reverse
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(params=["app", "browser"])
|
||||
def headless_client(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headless_reverse(headless_client):
|
||||
def rev(viewname, **kwargs):
|
||||
viewname = viewname.replace("headless:", f"headless:{headless_client}:")
|
||||
return reverse(viewname, **kwargs)
|
||||
|
||||
return rev
|
||||
|
||||
|
||||
class AppClient(Client):
|
||||
session_token = None
|
||||
|
||||
def generic(self, *args, **kwargs):
|
||||
if self.session_token:
|
||||
kwargs["HTTP_X_SESSION_TOKEN"] = self.session_token
|
||||
resp = super().generic(*args, **kwargs)
|
||||
if resp["content-type"] == "application/json":
|
||||
data = resp.json()
|
||||
session_token = data.get("meta", {}).get("session_token")
|
||||
if session_token:
|
||||
self.session_token = session_token
|
||||
return resp
|
||||
|
||||
def force_login(self, user):
|
||||
ret = super().force_login(user)
|
||||
self.session_token = self.session.session_key
|
||||
return ret
|
||||
|
||||
def headless_session(self):
|
||||
from allauth.headless.internal import sessionkit
|
||||
|
||||
return sessionkit.session_store(self.session_token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_client():
|
||||
return AppClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(headless_client):
|
||||
if headless_client == "browser":
|
||||
client = Client()
|
||||
client.headless_session = lambda: client.session
|
||||
return client
|
||||
return AppClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(client, user):
|
||||
client.force_login(user)
|
||||
return client
|
||||
@@ -0,0 +1,23 @@
|
||||
from enum import Enum
|
||||
|
||||
from allauth.account.stages import EmailVerificationStage, LoginByCodeStage
|
||||
|
||||
|
||||
class Client(str, Enum):
|
||||
APP = "app"
|
||||
BROWSER = "browser"
|
||||
|
||||
|
||||
class Flow(str, Enum):
|
||||
VERIFY_EMAIL = EmailVerificationStage.key
|
||||
LOGIN = "login"
|
||||
LOGIN_BY_CODE = LoginByCodeStage.key
|
||||
SIGNUP = "signup"
|
||||
PROVIDER_REDIRECT = "provider_redirect"
|
||||
PROVIDER_SIGNUP = "provider_signup"
|
||||
PROVIDER_TOKEN = "provider_token"
|
||||
REAUTHENTICATE = "reauthenticate"
|
||||
MFA_REAUTHENTICATE = "mfa_reauthenticate"
|
||||
MFA_AUTHENTICATE = "mfa_authenticate" # NOTE: Equal to `allauth.mfa.stages.AuthenticationStage.key`
|
||||
MFA_LOGIN_WEBAUTHN = "mfa_login_webauthn"
|
||||
MFA_SIGNUP_WEBAUTHN = "mfa_signup_webauthn"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,85 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from django.utils.functional import SimpleLazyObject, empty
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account.internal.stagekit import get_pending_stage
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.headless import app_settings
|
||||
from allauth.headless.constants import Client
|
||||
from allauth.headless.internal import sessionkit
|
||||
|
||||
|
||||
class AuthenticationStatus:
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
||||
@property
|
||||
def is_authenticated(self):
|
||||
return self.request.user.is_authenticated
|
||||
|
||||
def get_pending_stage(self):
|
||||
return get_pending_stage(self.request)
|
||||
|
||||
@property
|
||||
def has_pending_signup(self):
|
||||
if not allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
return False
|
||||
from allauth.socialaccount.internal import flows
|
||||
|
||||
return bool(flows.signup.get_pending_signup(self.request))
|
||||
|
||||
|
||||
def purge_request_user_cache(request):
|
||||
for attr in ["_cached_user", "_acached_user"]:
|
||||
if hasattr(request, attr):
|
||||
delattr(request, attr)
|
||||
if isinstance(request.user, SimpleLazyObject):
|
||||
request.user._wrapped = empty
|
||||
|
||||
|
||||
@contextmanager
|
||||
def authentication_context(request):
|
||||
from allauth.headless.base.response import UnauthorizedResponse
|
||||
|
||||
old_user = request.user
|
||||
old_session = request.session
|
||||
try:
|
||||
request.session = sessionkit.new_session()
|
||||
purge_request_user_cache(request)
|
||||
|
||||
strategy = app_settings.TOKEN_STRATEGY
|
||||
session_token = strategy.get_session_token(request)
|
||||
if session_token:
|
||||
session = strategy.lookup_session(session_token)
|
||||
if not session:
|
||||
raise ImmediateHttpResponse(UnauthorizedResponse(request, status=410))
|
||||
request.session = session
|
||||
purge_request_user_cache(request)
|
||||
request.allauth.headless._pre_user = request.user
|
||||
# request.user is lazy -- force evaluation
|
||||
request.allauth.headless._pre_user.pk
|
||||
yield
|
||||
finally:
|
||||
if request.session.modified and not request.session.is_empty():
|
||||
request.session.save()
|
||||
request.user = old_user
|
||||
request.session = old_session
|
||||
# e.g. logging in calls csrf `rotate_token()` -- this prevents setting a new CSRF cookie.
|
||||
request.META["CSRF_COOKIE_NEEDS_UPDATE"] = False
|
||||
|
||||
|
||||
def expose_access_token(request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Determines if a new access token needs to be exposed.
|
||||
"""
|
||||
if request.allauth.headless.client != Client.APP:
|
||||
return None
|
||||
if not request.user.is_authenticated:
|
||||
return None
|
||||
pre_user = request.allauth.headless._pre_user
|
||||
if pre_user.is_authenticated and pre_user.pk == request.user.pk:
|
||||
return None
|
||||
strategy = app_settings.TOKEN_STRATEGY
|
||||
return strategy.create_access_token_payload(request)
|
||||
@@ -0,0 +1,54 @@
|
||||
from functools import wraps
|
||||
from types import SimpleNamespace
|
||||
|
||||
from django.middleware.csrf import get_token
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from allauth.account.internal.decorators import login_not_required
|
||||
from allauth.headless.constants import Client
|
||||
from allauth.headless.internal import authkit
|
||||
|
||||
|
||||
def mark_request_as_headless(request, client):
|
||||
request.allauth.headless = SimpleNamespace()
|
||||
request.allauth.headless.client = client
|
||||
|
||||
|
||||
def app_view(
|
||||
function=None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@login_not_required
|
||||
@wraps(view_func)
|
||||
def _wrapper_view(request, *args, **kwargs):
|
||||
mark_request_as_headless(request, Client.APP)
|
||||
with authkit.authentication_context(request):
|
||||
return view_func(request, *args, **kwargs)
|
||||
|
||||
return _wrapper_view
|
||||
|
||||
ret = decorator
|
||||
if function:
|
||||
ret = decorator(function)
|
||||
return csrf_exempt(ret)
|
||||
|
||||
|
||||
def browser_view(
|
||||
function=None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@login_not_required
|
||||
@wraps(view_func)
|
||||
def _wrapper_view(request, *args, **kwargs):
|
||||
mark_request_as_headless(request, Client.BROWSER)
|
||||
# Needed -- so that the CSRF token is set in the response for the
|
||||
# frontend to pick up.
|
||||
get_token(request)
|
||||
return view_func(request, *args, **kwargs)
|
||||
|
||||
return _wrapper_view
|
||||
|
||||
ret = decorator
|
||||
if function:
|
||||
ret = decorator(function)
|
||||
return ret
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,25 @@
|
||||
from django.forms import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
ChoiceField,
|
||||
EmailField,
|
||||
Field,
|
||||
Form,
|
||||
ModelChoiceField,
|
||||
ModelMultipleChoiceField,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Field",
|
||||
"CharField",
|
||||
"ChoiceField",
|
||||
"EmailField",
|
||||
"BooleanField",
|
||||
"ModelMultipleChoiceField",
|
||||
"ModelChoiceField",
|
||||
]
|
||||
|
||||
|
||||
class Input(Form):
|
||||
pass
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from django.forms.utils import ErrorList
|
||||
from django.http import JsonResponse
|
||||
from django.utils.cache import add_never_cache_headers
|
||||
|
||||
from allauth.headless.internal import authkit, sessionkit
|
||||
|
||||
|
||||
class APIResponse(JsonResponse):
|
||||
def __init__(
|
||||
self,
|
||||
request,
|
||||
errors=None,
|
||||
data=None,
|
||||
meta: Optional[Dict] = None,
|
||||
status: int = 200,
|
||||
):
|
||||
d: Dict[str, Any] = {"status": status}
|
||||
if data is not None:
|
||||
d["data"] = data
|
||||
meta = self._add_session_meta(request, meta)
|
||||
if meta is not None:
|
||||
d["meta"] = meta
|
||||
if errors:
|
||||
d["errors"] = errors
|
||||
super().__init__(d, status=status)
|
||||
add_never_cache_headers(self)
|
||||
|
||||
def _add_session_meta(self, request, meta: Optional[Dict]) -> Optional[Dict]:
|
||||
session_token = sessionkit.expose_session_token(request)
|
||||
access_token_payload = authkit.expose_access_token(request)
|
||||
if session_token:
|
||||
meta = meta or {}
|
||||
meta["session_token"] = session_token
|
||||
if access_token_payload:
|
||||
meta = meta or {}
|
||||
meta.update(access_token_payload)
|
||||
return meta
|
||||
|
||||
|
||||
class ErrorResponse(APIResponse):
|
||||
def __init__(self, request, exception=None, input=None, status=400):
|
||||
errors = []
|
||||
if exception is not None:
|
||||
error_datas = ErrorList(exception.error_list).get_json_data()
|
||||
errors.extend(error_datas)
|
||||
if input is not None:
|
||||
for field, error_list in input.errors.items():
|
||||
error_datas = error_list.get_json_data()
|
||||
for error_data in error_datas:
|
||||
if field != "__all__":
|
||||
error_data["param"] = field
|
||||
errors.extend(error_datas)
|
||||
super().__init__(request, status=status, errors=errors)
|
||||
@@ -0,0 +1,53 @@
|
||||
import json
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from django.http import HttpResponseBadRequest
|
||||
from django.views.generic import View
|
||||
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.headless.internal.restkit.inputs import Input
|
||||
from allauth.headless.internal.restkit.response import ErrorResponse
|
||||
|
||||
|
||||
class RESTView(View):
|
||||
input_class: Union[Optional[Dict[str, Type[Input]]], Type[Input]] = None
|
||||
handle_json_input = True
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return self.handle(request, *args, **kwargs)
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
if self.handle_json_input and request.method != "GET":
|
||||
self.data = self._parse_json(request)
|
||||
response = self.handle_input(self.data)
|
||||
if response:
|
||||
return response
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {}
|
||||
|
||||
def handle_input(self, data):
|
||||
input_class = self.input_class
|
||||
if isinstance(input_class, dict):
|
||||
input_class = input_class.get(self.request.method)
|
||||
if not input_class:
|
||||
return
|
||||
input_kwargs = self.get_input_kwargs()
|
||||
if data is None:
|
||||
# Make form bound on empty POST
|
||||
data = {}
|
||||
self.input = input_class(data=data, **input_kwargs)
|
||||
if not self.input.is_valid():
|
||||
return self.handle_invalid_input(self.input)
|
||||
|
||||
def handle_invalid_input(self, input):
|
||||
return ErrorResponse(self.request, input=input)
|
||||
|
||||
def _parse_json(self, request):
|
||||
if request.method == "GET" or not request.body:
|
||||
return
|
||||
try:
|
||||
return json.loads(request.body.decode("utf8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
raise ImmediateHttpResponse(response=HttpResponseBadRequest())
|
||||
@@ -0,0 +1,28 @@
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from allauth.headless import app_settings
|
||||
from allauth.headless.constants import Client
|
||||
|
||||
|
||||
def session_store(session_key=None):
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
return engine.SessionStore(session_key=session_key)
|
||||
|
||||
|
||||
def new_session():
|
||||
return session_store()
|
||||
|
||||
|
||||
def expose_session_token(request):
|
||||
if request.allauth.headless.client != Client.APP:
|
||||
return
|
||||
strategy = app_settings.TOKEN_STRATEGY
|
||||
hdr_token = strategy.get_session_token(request)
|
||||
modified = request.session.modified
|
||||
empty = request.session.is_empty()
|
||||
if modified and not empty:
|
||||
new_token = strategy.create_session_token(request)
|
||||
if not hdr_token or hdr_token != new_token:
|
||||
return new_token
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,28 @@
|
||||
from django.contrib.auth import (
|
||||
BACKEND_SESSION_KEY,
|
||||
HASH_SESSION_KEY,
|
||||
SESSION_KEY,
|
||||
)
|
||||
from django.contrib.auth.middleware import AuthenticationMiddleware
|
||||
from django.contrib.sessions.middleware import SessionMiddleware
|
||||
from django.http import HttpResponse
|
||||
|
||||
from allauth.headless.internal.authkit import purge_request_user_cache
|
||||
|
||||
|
||||
def test_purge_request_user_cache(rf, user):
|
||||
request = rf.get("/")
|
||||
smw = SessionMiddleware(lambda request: HttpResponse())
|
||||
smw(request)
|
||||
amw = AuthenticationMiddleware(lambda request: HttpResponse())
|
||||
amw(request)
|
||||
assert request.user.is_anonymous
|
||||
assert not request.user.pk
|
||||
purge_request_user_cache(request)
|
||||
request.session[SESSION_KEY] = user.pk
|
||||
request.session[BACKEND_SESSION_KEY] = (
|
||||
"allauth.account.auth_backends.AuthenticationBackend"
|
||||
)
|
||||
request.session[HASH_SESSION_KEY] = user.get_session_auth_hash()
|
||||
assert request.user.is_authenticated
|
||||
assert request.user.pk == user.pk
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,74 @@
|
||||
from allauth.account.forms import BaseSignupForm
|
||||
from allauth.headless.internal.restkit import inputs
|
||||
from allauth.mfa.base.forms import AuthenticateForm
|
||||
from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.recovery_codes.forms import GenerateRecoveryCodesForm
|
||||
from allauth.mfa.totp.forms import ActivateTOTPForm
|
||||
from allauth.mfa.webauthn.forms import (
|
||||
AddWebAuthnForm,
|
||||
AuthenticateWebAuthnForm,
|
||||
LoginWebAuthnForm,
|
||||
ReauthenticateWebAuthnForm,
|
||||
SignupWebAuthnForm,
|
||||
)
|
||||
|
||||
|
||||
class AuthenticateInput(AuthenticateForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class ActivateTOTPInput(ActivateTOTPForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class GenerateRecoveryCodesInput(GenerateRecoveryCodesForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class AddWebAuthnInput(AddWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class CreateWebAuthnInput(SignupWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class UpdateWebAuthnInput(inputs.Input):
|
||||
id = inputs.ModelChoiceField(queryset=Authenticator.objects.none())
|
||||
name = inputs.CharField(required=True, max_length=100)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["id"].queryset = Authenticator.objects.filter(
|
||||
user=self.user, type=Authenticator.Type.WEBAUTHN
|
||||
)
|
||||
|
||||
|
||||
class DeleteWebAuthnInput(inputs.Input):
|
||||
authenticators = inputs.ModelMultipleChoiceField(
|
||||
queryset=Authenticator.objects.none()
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["authenticators"].queryset = Authenticator.objects.filter(
|
||||
user=self.user, type=Authenticator.Type.WEBAUTHN
|
||||
)
|
||||
|
||||
|
||||
class ReauthenticateWebAuthnInput(ReauthenticateWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticateWebAuthnInput(AuthenticateWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class LoginWebAuthnInput(LoginWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class SignupWebAuthnInput(BaseSignupForm, inputs.Input):
|
||||
pass
|
||||
@@ -0,0 +1,104 @@
|
||||
from allauth.headless.base.response import APIResponse
|
||||
from allauth.mfa import app_settings as mfa_settings
|
||||
|
||||
|
||||
def get_config_data(request):
|
||||
data = {
|
||||
"mfa": {
|
||||
"supported_types": mfa_settings.SUPPORTED_TYPES,
|
||||
"passkey_login_enabled": mfa_settings.PASSKEY_LOGIN_ENABLED,
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def _authenticator_data(authenticator, sensitive=False):
|
||||
data = {
|
||||
"type": authenticator.type,
|
||||
"created_at": authenticator.created_at.timestamp(),
|
||||
"last_used_at": (
|
||||
authenticator.last_used_at.timestamp()
|
||||
if authenticator.last_used_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
if authenticator.type == authenticator.Type.TOTP:
|
||||
pass
|
||||
elif authenticator.type == authenticator.Type.RECOVERY_CODES:
|
||||
wrapped = authenticator.wrap()
|
||||
unused_codes = wrapped.get_unused_codes()
|
||||
data.update(
|
||||
{
|
||||
"total_code_count": len(wrapped.generate_codes()),
|
||||
"unused_code_count": len(unused_codes),
|
||||
}
|
||||
)
|
||||
if sensitive:
|
||||
data["unused_codes"] = unused_codes
|
||||
elif authenticator.type == authenticator.Type.WEBAUTHN:
|
||||
wrapped = authenticator.wrap()
|
||||
data["id"] = authenticator.pk
|
||||
data["name"] = wrapped.name
|
||||
passwordless = wrapped.is_passwordless
|
||||
if passwordless is not None:
|
||||
data["is_passwordless"] = passwordless
|
||||
return data
|
||||
|
||||
|
||||
class AuthenticatorDeletedResponse(APIResponse):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticatorsDeletedResponse(APIResponse):
|
||||
pass
|
||||
|
||||
|
||||
class TOTPNotFoundResponse(APIResponse):
|
||||
def __init__(self, request, secret, totp_url):
|
||||
super().__init__(
|
||||
request,
|
||||
meta={
|
||||
"secret": secret,
|
||||
"totp_url": totp_url,
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
|
||||
|
||||
class TOTPResponse(APIResponse):
|
||||
def __init__(self, request, authenticator):
|
||||
data = _authenticator_data(authenticator)
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class AuthenticatorsResponse(APIResponse):
|
||||
def __init__(self, request, authenticators):
|
||||
data = [_authenticator_data(authenticator) for authenticator in authenticators]
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class AuthenticatorResponse(APIResponse):
|
||||
def __init__(self, request, authenticator, meta=None):
|
||||
data = _authenticator_data(authenticator)
|
||||
super().__init__(request, data=data, meta=meta)
|
||||
|
||||
|
||||
class RecoveryCodesNotFoundResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=404)
|
||||
|
||||
|
||||
class RecoveryCodesResponse(APIResponse):
|
||||
def __init__(self, request, authenticator):
|
||||
data = _authenticator_data(authenticator, sensitive=True)
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class AddWebAuthnResponse(APIResponse):
|
||||
def __init__(self, request, registration_data):
|
||||
super().__init__(request, data={"creation_options": registration_data})
|
||||
|
||||
|
||||
class WebAuthnRequestOptionsResponse(APIResponse):
|
||||
def __init__(self, request, request_options):
|
||||
super().__init__(request, data={"request_options": request_options})
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,60 @@
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
|
||||
def test_get_recovery_codes_requires_reauth(
|
||||
auth_client, user_with_recovery_codes, headless_reverse
|
||||
):
|
||||
rc = Authenticator.objects.get(
|
||||
type=Authenticator.Type.RECOVERY_CODES, user=user_with_recovery_codes
|
||||
)
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:reauthenticate"),
|
||||
data={"code": rc.wrap().get_unused_codes()[0]},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_get_recovery_codes(
|
||||
auth_client,
|
||||
user_with_recovery_codes,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "recovery_codes"
|
||||
assert len(data["data"]["unused_codes"]) == 10
|
||||
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:authenticators"))
|
||||
data = resp.json()
|
||||
assert len(data["data"]) == 2
|
||||
rc = [autor for autor in data["data"] if autor["type"] == "recovery_codes"][0]
|
||||
assert "unused_codes" not in rc
|
||||
|
||||
|
||||
def test_generate_recovery_codes(
|
||||
auth_client,
|
||||
user_with_totp,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 404
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:manage_recovery_codes"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "recovery_codes"
|
||||
assert len(data["data"]["unused_codes"]) == 10
|
||||
@@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("email_verified", [False, True])
|
||||
def test_get_totp_not_active(auth_client, user, headless_reverse, email_verified):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_totp"))
|
||||
if email_verified:
|
||||
assert resp.status_code == 404
|
||||
data = resp.json()
|
||||
assert len(data["meta"]["secret"]) == 32
|
||||
assert len(data["meta"]["totp_url"]) == 145
|
||||
else:
|
||||
assert resp.status_code == 409
|
||||
assert resp.json() == {
|
||||
"status": 409,
|
||||
"errors": [
|
||||
{
|
||||
"message": "You cannot activate two-factor authentication until you have verified your email address.",
|
||||
"code": "unverified_email",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_get_totp(
|
||||
auth_client,
|
||||
user_with_totp,
|
||||
headless_reverse,
|
||||
):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_totp"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "totp"
|
||||
assert isinstance(data["data"]["created_at"], float)
|
||||
|
||||
|
||||
def test_deactivate_totp(
|
||||
auth_client,
|
||||
user_with_totp,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.delete(headless_reverse("headless:mfa:manage_totp"))
|
||||
assert resp.status_code == 200
|
||||
assert not Authenticator.objects.filter(user=user_with_totp).exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("email_verified", [False, True])
|
||||
def test_activate_totp(
|
||||
auth_client,
|
||||
user,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
settings,
|
||||
totp_validation_bypass,
|
||||
email_verified,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
with totp_validation_bypass():
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:manage_totp"),
|
||||
data={"code": "42"},
|
||||
content_type="application/json",
|
||||
)
|
||||
if email_verified:
|
||||
assert resp.status_code == 200
|
||||
assert Authenticator.objects.filter(
|
||||
user=user, type=Authenticator.Type.TOTP
|
||||
).exists()
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "totp"
|
||||
assert isinstance(data["data"]["created_at"], float)
|
||||
assert data["data"]["last_used_at"] is None
|
||||
else:
|
||||
assert resp.status_code == 400
|
||||
@@ -0,0 +1,115 @@
|
||||
from allauth.account.models import EmailAddress, get_emailconfirmation_model
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_auth_unverified_email_and_mfa(
|
||||
client,
|
||||
user_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
totp_validation_bypass,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
password = password_factory()
|
||||
user = user_factory(email_verified=False, password=password, with_totp=True)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
|
||||
"is_pending"
|
||||
]
|
||||
emailaddress = EmailAddress.objects.filter(user=user, verified=False).get()
|
||||
key = get_emailconfirmation_model().create(emailaddress).key
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
flows = [
|
||||
{"id": "login"},
|
||||
{"id": "login_by_code"},
|
||||
{"id": "signup"},
|
||||
]
|
||||
if headless_client == "browser":
|
||||
flows.append(
|
||||
{
|
||||
"id": "provider_redirect",
|
||||
"providers": ["dummy", "openid_connect", "openid_connect"],
|
||||
}
|
||||
)
|
||||
flows.append({"id": "provider_token", "providers": ["dummy"]})
|
||||
flows.append({"id": "mfa_login_webauthn"})
|
||||
flows.append(
|
||||
{
|
||||
"id": "mfa_authenticate",
|
||||
"is_pending": True,
|
||||
"types": ["totp"],
|
||||
}
|
||||
)
|
||||
|
||||
assert resp.json() == {
|
||||
"data": {"flows": flows},
|
||||
"meta": {"is_authenticated": False},
|
||||
"status": 401,
|
||||
}
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:authenticate"),
|
||||
data={"code": "bad"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{"message": "Incorrect code.", "code": "incorrect_code", "param": "code"}
|
||||
],
|
||||
}
|
||||
|
||||
with totp_validation_bypass():
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:authenticate"),
|
||||
data={"code": "bad"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_dangling_mfa_is_logged_out(
|
||||
client,
|
||||
user_with_totp,
|
||||
password_factory,
|
||||
settings,
|
||||
totp_validation_bypass,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
user_password,
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user_with_totp.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
flow = [f for f in data["data"]["flows"] if f["id"] == Flow.MFA_AUTHENTICATE][0]
|
||||
assert flow["is_pending"]
|
||||
assert flow["types"] == ["totp"]
|
||||
resp = client.delete(headless_reverse("headless:account:current_session"))
|
||||
data = resp.json()
|
||||
assert resp.status_code == 401
|
||||
assert all(not f.get("is_pending") for f in data["data"]["flows"])
|
||||
@@ -0,0 +1,223 @@
|
||||
from unittest.mock import ANY
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
import pytest
|
||||
|
||||
from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
|
||||
from allauth.headless.constants import Flow
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
|
||||
def test_passkey_login(
|
||||
client, passkey, webauthn_authentication_bypass, headless_reverse
|
||||
):
|
||||
with webauthn_authentication_bypass(passkey) as credential:
|
||||
resp = client.get(headless_reverse("headless:mfa:login_webauthn"))
|
||||
assert "request_options" in resp.json()["data"]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:login_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["data"]["user"]["id"] == passkey.user_id
|
||||
|
||||
|
||||
def test_passkey_login_get_options(client, headless_client, headless_reverse, db):
|
||||
resp = client.get(headless_reverse("headless:mfa:login_webauthn"))
|
||||
data = resp.json()
|
||||
meta = {}
|
||||
if headless_client == "app":
|
||||
meta = {
|
||||
"meta": {"session_token": ANY},
|
||||
}
|
||||
assert data == {
|
||||
"status": 200,
|
||||
"data": {"request_options": {"publicKey": ANY}},
|
||||
**meta,
|
||||
}
|
||||
|
||||
|
||||
def test_reauthenticate(
|
||||
auth_client,
|
||||
passkey,
|
||||
user_with_recovery_codes,
|
||||
webauthn_authentication_bypass,
|
||||
headless_reverse,
|
||||
):
|
||||
# View recovery codes, confirm webauthn reauthentication is an option
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 401
|
||||
assert Flow.MFA_REAUTHENTICATE in [
|
||||
flow["id"] for flow in resp.json()["data"]["flows"]
|
||||
]
|
||||
|
||||
# Get request options
|
||||
with webauthn_authentication_bypass(passkey):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:reauthenticate_webauthn"))
|
||||
data = resp.json()
|
||||
assert data["status"] == 200
|
||||
assert data["data"]["request_options"] == ANY
|
||||
|
||||
# Reauthenticate
|
||||
with webauthn_authentication_bypass(passkey) as credential:
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:reauthenticate_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_update_authenticator(
|
||||
auth_client, headless_reverse, passkey, reauthentication_bypass
|
||||
):
|
||||
data = {"id": passkey.pk, "name": "Renamed!"}
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
# Reauthentication required
|
||||
assert resp.status_code == 401
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
passkey.refresh_from_db()
|
||||
assert passkey.wrap().name == "Renamed!"
|
||||
|
||||
|
||||
def test_delete_authenticator(
|
||||
auth_client, headless_reverse, passkey, reauthentication_bypass
|
||||
):
|
||||
data = {"authenticators": [passkey.pk]}
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
# Reauthentication required
|
||||
assert resp.status_code == 401
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert not Authenticator.objects.filter(pk=passkey.pk).exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("email_verified", [False, True])
|
||||
def test_add_authenticator(
|
||||
user,
|
||||
auth_client,
|
||||
headless_reverse,
|
||||
webauthn_registration_bypass,
|
||||
reauthentication_bypass,
|
||||
email_verified,
|
||||
):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_webauthn"))
|
||||
# Reauthentication required
|
||||
assert resp.status_code == 401 if email_verified else 409
|
||||
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_webauthn"))
|
||||
if email_verified:
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["creation_options"] == ANY
|
||||
else:
|
||||
assert resp.status_code == 409
|
||||
|
||||
with webauthn_registration_bypass(user, False) as credential:
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
webauthn_count = Authenticator.objects.filter(
|
||||
type=Authenticator.Type.WEBAUTHN, user=user
|
||||
).count()
|
||||
if email_verified:
|
||||
assert resp.status_code == 200
|
||||
assert webauthn_count == 1
|
||||
else:
|
||||
assert resp.status_code == 409
|
||||
assert webauthn_count == 0
|
||||
|
||||
|
||||
def test_2fa_login(
|
||||
client,
|
||||
user,
|
||||
user_password,
|
||||
passkey,
|
||||
webauthn_authentication_bypass,
|
||||
headless_reverse,
|
||||
):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"username": user.username,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
pending_flows = [f for f in data["data"]["flows"] if f.get("is_pending")]
|
||||
assert len(pending_flows) == 1
|
||||
pending_flow = pending_flows[0]
|
||||
assert pending_flow == {
|
||||
"id": "mfa_authenticate",
|
||||
"is_pending": True,
|
||||
"types": ["webauthn"],
|
||||
}
|
||||
with webauthn_authentication_bypass(passkey) as credential:
|
||||
resp = client.get(headless_reverse("headless:mfa:authenticate_webauthn"))
|
||||
assert "request_options" in resp.json()["data"]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:authenticate_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.json()
|
||||
assert resp.status_code == 200
|
||||
assert data["data"]["user"]["id"] == passkey.user_id
|
||||
assert client.headless_session()[AUTHENTICATION_METHODS_SESSION_KEY] == [
|
||||
{"method": "password", "at": ANY, "username": passkey.user.username},
|
||||
{"method": "mfa", "at": ANY, "id": ANY, "type": Authenticator.Type.WEBAUTHN},
|
||||
]
|
||||
|
||||
|
||||
def test_passkey_signup(client, db, webauthn_registration_bypass, headless_reverse):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:signup_webauthn"),
|
||||
data={"email": "pass@key.org", "username": "passkey"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
flow = [flow for flow in resp.json()["data"]["flows"] if flow.get("is_pending")][0]
|
||||
assert flow["id"] == Flow.MFA_SIGNUP_WEBAUTHN.value
|
||||
resp = client.get(headless_reverse("headless:mfa:signup_webauthn"))
|
||||
data = resp.json()
|
||||
assert "creation_options" in data["data"]
|
||||
|
||||
user = get_user_model().objects.get(email="pass@key.org")
|
||||
with webauthn_registration_bypass(user, True) as credential:
|
||||
resp = client.put(
|
||||
headless_reverse("headless:mfa:signup_webauthn"),
|
||||
data={"name": "Some key", "credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
authenticator = Authenticator.objects.get(user=user)
|
||||
assert authenticator.wrap().name == "Some key"
|
||||
100
venv/lib/python3.11/site-packages/allauth/headless/mfa/urls.py
Normal file
100
venv/lib/python3.11/site-packages/allauth/headless/mfa/urls.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from allauth.headless.mfa import views
|
||||
from allauth.mfa import app_settings as mfa_settings
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
auth_patterns = [
|
||||
path(
|
||||
"2fa/authenticate",
|
||||
views.AuthenticateView.as_api_view(client=client),
|
||||
name="authenticate",
|
||||
),
|
||||
path(
|
||||
"2fa/reauthenticate",
|
||||
views.ReauthenticateView.as_api_view(client=client),
|
||||
name="reauthenticate",
|
||||
),
|
||||
]
|
||||
|
||||
authenticators = []
|
||||
if "totp" in mfa_settings.SUPPORTED_TYPES:
|
||||
authenticators.append(
|
||||
path(
|
||||
"totp",
|
||||
views.ManageTOTPView.as_api_view(client=client),
|
||||
name="manage_totp",
|
||||
)
|
||||
)
|
||||
if "recovery_codes" in mfa_settings.SUPPORTED_TYPES:
|
||||
authenticators.append(
|
||||
path(
|
||||
"recovery-codes",
|
||||
views.ManageRecoveryCodesView.as_api_view(client=client),
|
||||
name="manage_recovery_codes",
|
||||
)
|
||||
)
|
||||
if "webauthn" in mfa_settings.SUPPORTED_TYPES:
|
||||
authenticators.extend(
|
||||
[
|
||||
path(
|
||||
"webauthn",
|
||||
views.ManageWebAuthnView.as_api_view(client=client),
|
||||
name="manage_webauthn",
|
||||
),
|
||||
]
|
||||
)
|
||||
auth_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"webauthn/authenticate",
|
||||
views.AuthenticateWebAuthnView.as_api_view(client=client),
|
||||
name="authenticate_webauthn",
|
||||
),
|
||||
path(
|
||||
"webauthn/reauthenticate",
|
||||
views.ReauthenticateWebAuthnView.as_api_view(client=client),
|
||||
name="reauthenticate_webauthn",
|
||||
),
|
||||
]
|
||||
)
|
||||
if mfa_settings.PASSKEY_LOGIN_ENABLED:
|
||||
auth_patterns.append(
|
||||
path(
|
||||
"webauthn/login",
|
||||
views.LoginWebAuthnView.as_api_view(client=client),
|
||||
name="login_webauthn",
|
||||
)
|
||||
)
|
||||
if mfa_settings.PASSKEY_SIGNUP_ENABLED:
|
||||
auth_patterns.append(
|
||||
path(
|
||||
"webauthn/signup",
|
||||
views.SignupWebAuthnView.as_api_view(client=client),
|
||||
name="signup_webauthn",
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
path(
|
||||
"auth/",
|
||||
include(auth_patterns),
|
||||
),
|
||||
path(
|
||||
"account/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"authenticators",
|
||||
views.AuthenticatorsView.as_api_view(client=client),
|
||||
name="authenticators",
|
||||
),
|
||||
path(
|
||||
"authenticators/",
|
||||
include(authenticators),
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
280
venv/lib/python3.11/site-packages/allauth/headless/mfa/views.py
Normal file
280
venv/lib/python3.11/site-packages/allauth/headless/mfa/views.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from allauth.account.internal.stagekit import get_pending_stage
|
||||
from allauth.account.models import Login
|
||||
from allauth.headless.account.views import SignupView
|
||||
from allauth.headless.base.response import (
|
||||
APIResponse,
|
||||
AuthenticationResponse,
|
||||
ConflictResponse,
|
||||
)
|
||||
from allauth.headless.base.views import (
|
||||
APIView,
|
||||
AuthenticatedAPIView,
|
||||
AuthenticationStageAPIView,
|
||||
)
|
||||
from allauth.headless.internal.restkit.response import ErrorResponse
|
||||
from allauth.headless.mfa import response
|
||||
from allauth.headless.mfa.inputs import (
|
||||
ActivateTOTPInput,
|
||||
AddWebAuthnInput,
|
||||
AuthenticateInput,
|
||||
AuthenticateWebAuthnInput,
|
||||
CreateWebAuthnInput,
|
||||
DeleteWebAuthnInput,
|
||||
GenerateRecoveryCodesInput,
|
||||
LoginWebAuthnInput,
|
||||
ReauthenticateWebAuthnInput,
|
||||
SignupWebAuthnInput,
|
||||
UpdateWebAuthnInput,
|
||||
)
|
||||
from allauth.mfa.adapter import DefaultMFAAdapter, get_adapter
|
||||
from allauth.mfa.internal.flows import add
|
||||
from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.recovery_codes.internal import flows as recovery_codes_flows
|
||||
from allauth.mfa.stages import AuthenticateStage
|
||||
from allauth.mfa.totp.internal import auth as totp_auth, flows as totp_flows
|
||||
from allauth.mfa.webauthn.internal import (
|
||||
auth as webauthn_auth,
|
||||
flows as webauthn_flows,
|
||||
)
|
||||
from allauth.mfa.webauthn.stages import PasskeySignupStage
|
||||
|
||||
|
||||
def _validate_can_add_authenticator(request):
|
||||
try:
|
||||
add.validate_can_add_authenticator(request.user)
|
||||
except ValidationError as e:
|
||||
return ErrorResponse(request, status=409, exception=e)
|
||||
|
||||
|
||||
class AuthenticateView(AuthenticationStageAPIView):
|
||||
input_class = AuthenticateInput
|
||||
stage_class = AuthenticateStage
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
self.input.save()
|
||||
return self.respond_next_stage()
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.stage.login.user}
|
||||
|
||||
|
||||
class ReauthenticateView(AuthenticatedAPIView):
|
||||
input_class = AuthenticateInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
self.input.save()
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
class AuthenticatorsView(AuthenticatedAPIView):
|
||||
def get(self, request, *args, **kwargs):
|
||||
authenticators = Authenticator.objects.filter(user=request.user)
|
||||
return response.AuthenticatorsResponse(request, authenticators)
|
||||
|
||||
|
||||
class ManageTOTPView(AuthenticatedAPIView):
|
||||
input_class = {"POST": ActivateTOTPInput}
|
||||
|
||||
def get(self, request, *args, **kwargs) -> APIResponse:
|
||||
authenticator = self._get_authenticator()
|
||||
if not authenticator:
|
||||
err = _validate_can_add_authenticator(request)
|
||||
if err:
|
||||
return err
|
||||
adapter: DefaultMFAAdapter = get_adapter()
|
||||
secret = totp_auth.get_totp_secret(regenerate=True)
|
||||
totp_url: str = adapter.build_totp_url(request.user, secret)
|
||||
return response.TOTPNotFoundResponse(request, secret, totp_url)
|
||||
return response.TOTPResponse(request, authenticator)
|
||||
|
||||
def _get_authenticator(self):
|
||||
return Authenticator.objects.filter(
|
||||
type=Authenticator.Type.TOTP, user=self.request.user
|
||||
).first()
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = totp_flows.activate_totp(request, self.input)[0]
|
||||
return response.TOTPResponse(request, authenticator)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
authenticator = self._get_authenticator()
|
||||
if authenticator:
|
||||
authenticator = totp_flows.deactivate_totp(request, authenticator)
|
||||
return response.AuthenticatorDeletedResponse(request)
|
||||
|
||||
|
||||
class ManageRecoveryCodesView(AuthenticatedAPIView):
|
||||
input_class = GenerateRecoveryCodesInput
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
authenticator = recovery_codes_flows.view_recovery_codes(request)
|
||||
if not authenticator:
|
||||
return response.RecoveryCodesNotFoundResponse(request)
|
||||
return response.RecoveryCodesResponse(request, authenticator)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = recovery_codes_flows.generate_recovery_codes(request)
|
||||
return response.RecoveryCodesResponse(request, authenticator)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
class ManageWebAuthnView(AuthenticatedAPIView):
|
||||
input_class = {
|
||||
"POST": AddWebAuthnInput,
|
||||
"PUT": UpdateWebAuthnInput,
|
||||
"DELETE": DeleteWebAuthnInput,
|
||||
}
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
if request.method in ["GET", "POST"]:
|
||||
err = _validate_can_add_authenticator(request)
|
||||
if err:
|
||||
return err
|
||||
return super().handle(request, *args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
passwordless = "passwordless" in request.GET
|
||||
creation_options = webauthn_flows.begin_registration(
|
||||
request, request.user, passwordless
|
||||
)
|
||||
return response.AddWebAuthnResponse(request, creation_options)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
auth, rc_auth = webauthn_flows.add_authenticator(
|
||||
request,
|
||||
name=self.input.cleaned_data["name"],
|
||||
credential=self.input.cleaned_data["credential"],
|
||||
)
|
||||
did_generate_recovery_codes = bool(rc_auth)
|
||||
return response.AuthenticatorResponse(
|
||||
request,
|
||||
auth,
|
||||
meta={"recovery_codes_generated": did_generate_recovery_codes},
|
||||
)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
authenticator = self.input.cleaned_data["id"]
|
||||
webauthn_flows.rename_authenticator(
|
||||
request, authenticator, self.input.cleaned_data["name"]
|
||||
)
|
||||
return response.AuthenticatorResponse(request, authenticator)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
authenticators = self.input.cleaned_data["authenticators"]
|
||||
webauthn_flows.remove_authenticators(request, authenticators)
|
||||
return response.AuthenticatorsDeletedResponse(request)
|
||||
|
||||
|
||||
class ReauthenticateWebAuthnView(AuthenticatedAPIView):
|
||||
input_class = {
|
||||
"POST": ReauthenticateWebAuthnInput,
|
||||
}
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
request_options = webauthn_auth.begin_authentication(request.user)
|
||||
return response.WebAuthnRequestOptionsResponse(request, request_options)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = self.input.cleaned_data["credential"]
|
||||
webauthn_flows.reauthenticate(request, authenticator)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class AuthenticateWebAuthnView(AuthenticationStageAPIView):
|
||||
input_class = {
|
||||
"POST": AuthenticateWebAuthnInput,
|
||||
}
|
||||
stage_class = AuthenticateStage
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
request_options = webauthn_auth.begin_authentication(self.stage.login.user)
|
||||
return response.WebAuthnRequestOptionsResponse(request, request_options)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.stage.login.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
self.input.save()
|
||||
return self.respond_next_stage()
|
||||
|
||||
|
||||
class LoginWebAuthnView(APIView):
|
||||
input_class = {
|
||||
"POST": LoginWebAuthnInput,
|
||||
}
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
request_options = webauthn_auth.begin_authentication()
|
||||
return response.WebAuthnRequestOptionsResponse(request, request_options)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = self.input.cleaned_data["credential"]
|
||||
redirect_url = None
|
||||
login = Login(user=authenticator.user, redirect_url=redirect_url)
|
||||
webauthn_flows.perform_passwordless_login(request, authenticator, login)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
|
||||
class SignupWebAuthnView(SignupView):
|
||||
input_class = {
|
||||
"POST": SignupWebAuthnInput,
|
||||
"PUT": CreateWebAuthnInput,
|
||||
}
|
||||
by_passkey = True
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
resp = self._require_stage()
|
||||
if resp:
|
||||
return resp
|
||||
creation_options = webauthn_flows.begin_registration(
|
||||
request, self.stage.login.user, passwordless=True, signup=True
|
||||
)
|
||||
return response.AddWebAuthnResponse(request, creation_options)
|
||||
|
||||
def _prep_stage(self):
|
||||
if hasattr(self, "stage"):
|
||||
return self.stage
|
||||
self.stage = get_pending_stage(self.request)
|
||||
return self.stage
|
||||
|
||||
def _require_stage(self):
|
||||
self._prep_stage()
|
||||
if not self.stage or self.stage.key != PasskeySignupStage.key:
|
||||
return ConflictResponse(self.request)
|
||||
return None
|
||||
|
||||
def get_input_kwargs(self):
|
||||
ret = super().get_input_kwargs()
|
||||
self._prep_stage()
|
||||
if self.stage and self.request.method == "PUT":
|
||||
ret["user"] = self.stage.login.user
|
||||
return ret
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
resp = self._require_stage()
|
||||
if resp:
|
||||
return resp
|
||||
webauthn_flows.signup_authenticator(
|
||||
request,
|
||||
user=self.stage.login.user,
|
||||
name=self.input.cleaned_data["name"],
|
||||
credential=self.input.cleaned_data["credential"],
|
||||
)
|
||||
self.stage.exit()
|
||||
return AuthenticationResponse(request)
|
||||
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user