first commit

This commit is contained in:
pacnpal
2024-10-28 17:09:57 -04:00
commit 1339baec59
9993 changed files with 1182741 additions and 0 deletions

View File

@@ -0,0 +1,218 @@
from django.core.exceptions import ImproperlyConfigured
from django.core.validators import validate_email
from allauth.account import app_settings as account_app_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.forms import (
AddEmailForm,
BaseSignupForm,
ConfirmLoginCodeForm,
ReauthenticateForm,
RequestLoginCodeForm,
ResetPasswordForm,
UserTokenForm,
)
from allauth.account.internal import flows
from allauth.account.models import (
EmailAddress,
Login,
get_emailconfirmation_model,
)
from allauth.core import context
from allauth.headless.adapter import get_adapter
from allauth.headless.internal.restkit import inputs
class SignupInput(BaseSignupForm, inputs.Input):
password = inputs.CharField()
def clean_password(self):
password = self.cleaned_data["password"]
return get_account_adapter().clean_password(password)
class LoginInput(inputs.Input):
username = inputs.CharField(required=False)
email = inputs.EmailField(required=False)
password = inputs.CharField()
def clean(self):
cleaned_data = super().clean()
username = None
email = None
if (
account_app_settings.AUTHENTICATION_METHOD
== account_app_settings.AuthenticationMethod.USERNAME
):
username = cleaned_data.get("username")
missing_field = "username"
elif (
account_app_settings.AUTHENTICATION_METHOD
== account_app_settings.AuthenticationMethod.EMAIL
):
email = cleaned_data.get("email")
missing_field = "email"
elif (
account_app_settings.AUTHENTICATION_METHOD
== account_app_settings.AuthenticationMethod.USERNAME_EMAIL
):
username = cleaned_data.get("username")
email = cleaned_data.get("email")
missing_field = "email"
if email and username:
raise get_adapter().validation_error("email_or_username")
else:
raise ImproperlyConfigured(account_app_settings.AUTHENTICATION_METHOD)
if not email and not username:
if not self.errors.get(missing_field):
self.add_error(
missing_field, get_adapter().validation_error("required")
)
password = cleaned_data.get("password")
if password and (username or email):
credentials = {"password": password}
if email:
credentials["email"] = email
auth_method = account_app_settings.AuthenticationMethod.EMAIL
else:
credentials["username"] = username
auth_method = account_app_settings.AuthenticationMethod.USERNAME
user = get_account_adapter().authenticate(context.request, **credentials)
if user:
self.login = Login(user=user, email=credentials.get("email"))
if flows.login.is_login_rate_limited(context.request, self.login):
raise get_account_adapter().validation_error(
"too_many_login_attempts"
)
else:
error_code = "%s_password_mismatch" % auth_method.value
self.add_error(
"password", get_account_adapter().validation_error(error_code)
)
return cleaned_data
class VerifyEmailInput(inputs.Input):
key = inputs.CharField()
def clean_key(self):
key = self.cleaned_data["key"]
model = get_emailconfirmation_model()
confirmation = model.from_key(key)
valid = confirmation and not confirmation.key_expired()
if not valid:
raise get_account_adapter().validation_error(
"incorrect_code"
if account_app_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED
else "invalid_or_expired_key"
)
if valid and not confirmation.email_address.can_set_verified():
raise get_account_adapter().validation_error("email_taken")
return confirmation
class RequestPasswordResetInput(ResetPasswordForm, inputs.Input):
pass
class ResetPasswordKeyInput(inputs.Input):
key = inputs.CharField()
def __init__(self, *args, **kwargs):
self.user = None
super().__init__(*args, **kwargs)
def clean_key(self):
key = self.cleaned_data["key"]
uidb36, _, subkey = key.partition("-")
token_form = UserTokenForm(data={"uidb36": uidb36, "key": subkey})
if not token_form.is_valid():
raise get_account_adapter().validation_error("invalid_password_reset")
self.user = token_form.reset_user
return key
class ResetPasswordInput(ResetPasswordKeyInput):
password = inputs.CharField()
def clean(self):
cleaned_data = super().clean()
password = self.cleaned_data.get("password")
if self.user and password is not None:
get_account_adapter().clean_password(password, user=self.user)
return cleaned_data
class ChangePasswordInput(inputs.Input):
current_password = inputs.CharField(required=False)
new_password = inputs.CharField()
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user")
super().__init__(*args, **kwargs)
self.fields["current_password"].required = self.user.has_usable_password()
def clean_current_password(self):
current_password = self.cleaned_data["current_password"]
if current_password:
if not self.user.check_password(current_password):
raise get_account_adapter().validation_error("enter_current_password")
return current_password
def clean_new_password(self):
new_password = self.cleaned_data["new_password"]
adapter = get_account_adapter()
return adapter.clean_password(new_password, user=self.user)
class AddEmailInput(AddEmailForm, inputs.Input):
pass
class SelectEmailInput(inputs.Input):
email = inputs.CharField()
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user")
super().__init__(*args, **kwargs)
def clean_email(self):
email = self.cleaned_data["email"]
validate_email(email)
try:
return EmailAddress.objects.get_for_user(user=self.user, email=email)
except EmailAddress.DoesNotExist:
raise get_adapter().validation_error("unknown_email")
class DeleteEmailInput(SelectEmailInput):
def clean_email(self):
email = super().clean_email()
if not flows.manage_email.can_delete_email(email):
raise get_account_adapter().validation_error("cannot_remove_primary_email")
return email
class MarkAsPrimaryEmailInput(SelectEmailInput):
primary = inputs.BooleanField(required=True)
def clean_email(self):
email = super().clean_email()
if not flows.manage_email.can_mark_as_primary(email):
raise get_account_adapter().validation_error("unverified_primary_email")
return email
class ReauthenticateInput(ReauthenticateForm, inputs.Input):
pass
class RequestLoginCodeInput(RequestLoginCodeForm, inputs.Input):
pass
class ConfirmLoginCodeInput(ConfirmLoginCodeForm, inputs.Input):
pass

View File

@@ -0,0 +1,44 @@
from allauth.headless.adapter import get_adapter
from allauth.headless.base.response import APIResponse
class RequestEmailVerificationResponse(APIResponse):
def __init__(self, request, verification_sent):
super().__init__(request, status=200 if verification_sent else 403)
class VerifyEmailResponse(APIResponse):
def __init__(self, request, verification, stage):
adapter = get_adapter()
data = {
"email": verification.email_address.email,
"user": adapter.serialize_user(verification.email_address.user),
}
meta = {
"is_authenticating": stage is not None,
}
super().__init__(request, data=data, meta=meta)
class EmailAddressesResponse(APIResponse):
def __init__(self, request, email_addresses):
data = [
{
"email": addr.email,
"verified": addr.verified,
"primary": addr.primary,
}
for addr in email_addresses
]
super().__init__(request, data=data)
class RequestPasswordResponse(APIResponse):
pass
class PasswordResetKeyResponse(APIResponse):
def __init__(self, request, user):
adapter = get_adapter()
data = {"user": adapter.serialize_user(user)}
super().__init__(request, data=data)

View File

@@ -0,0 +1,99 @@
from allauth.account.models import EmailAddress
def test_list_email(auth_client, user, headless_reverse):
resp = auth_client.get(
headless_reverse("headless:account:manage_email"),
)
assert len(resp.json()["data"]) == 1
def test_remove_email(auth_client, user, email_factory, headless_reverse):
addr = EmailAddress.objects.create(email=email_factory(), user=user)
assert EmailAddress.objects.filter(user=user).count() == 2
resp = auth_client.delete(
headless_reverse("headless:account:manage_email"),
data={"email": addr.email},
content_type="application/json",
)
assert resp.status_code == 200
assert len(resp.json()["data"]) == 1
assert not EmailAddress.objects.filter(pk=addr.pk).exists()
def test_add_email(auth_client, user, email_factory, headless_reverse):
new_email = email_factory()
resp = auth_client.post(
headless_reverse("headless:account:manage_email"),
data={"email": new_email},
content_type="application/json",
)
assert resp.status_code == 200
assert len(resp.json()["data"]) == 2
assert EmailAddress.objects.filter(email=new_email, verified=False).exists()
def test_change_primary(auth_client, user, email_factory, headless_reverse):
addr = EmailAddress.objects.create(
email=email_factory(), user=user, verified=True, primary=False
)
resp = auth_client.patch(
headless_reverse("headless:account:manage_email"),
data={"email": addr.email, "primary": True},
content_type="application/json",
)
assert resp.status_code == 200
assert len(resp.json()["data"]) == 2
assert EmailAddress.objects.filter(pk=addr.pk, primary=True).exists()
def test_resend_verification(
auth_client, user, email_factory, headless_reverse, mailoutbox
):
addr = EmailAddress.objects.create(email=email_factory(), user=user, verified=False)
resp = auth_client.put(
headless_reverse("headless:account:manage_email"),
data={"email": addr.email},
content_type="application/json",
)
assert resp.status_code == 200
assert len(mailoutbox) == 1
def test_email_rate_limit(
auth_client, user, email_factory, headless_reverse, settings, enable_cache
):
settings.ACCOUNT_RATE_LIMITS = {"manage_email": "1/m/ip"}
for attempt in range(2):
new_email = email_factory()
resp = auth_client.post(
headless_reverse("headless:account:manage_email"),
data={"email": new_email},
content_type="application/json",
)
expected_status = 200 if attempt == 0 else 429
assert resp.status_code == expected_status
assert resp.json()["status"] == expected_status
def test_resend_verification_rate_limit(
auth_client,
user,
email_factory,
headless_reverse,
settings,
enable_cache,
mailoutbox,
):
settings.ACCOUNT_RATE_LIMITS = {"confirm_email": "1/m/ip"}
for attempt in range(2):
addr = EmailAddress.objects.create(
email=email_factory(), user=user, verified=False
)
resp = auth_client.put(
headless_reverse("headless:account:manage_email"),
data={"email": addr.email},
content_type="application/json",
)
assert resp.status_code == 403 if attempt else 200
assert len(mailoutbox) == 1

View File

@@ -0,0 +1,220 @@
import copy
from unittest.mock import ANY
import pytest
@pytest.mark.parametrize(
"has_password,request_data,response_data,status_code",
[
# Wrong current password
(
True,
{"current_password": "wrong", "new_password": "{password_factory}"},
{
"status": 400,
"errors": [
{
"param": "current_password",
"message": "Please type your current password.",
"code": "enter_current_password",
}
],
},
400,
),
# Happy flow, regular password change
(
True,
{
"current_password": "{user_password}",
"new_password": "{password_factory}",
},
{
"status": 200,
"meta": {"is_authenticated": True},
"data": {
"user": ANY,
"methods": [],
},
},
200,
),
# New password does not match constraints
(
True,
{
"current_password": "{user_password}",
"new_password": "a",
},
{
"status": 400,
"errors": [
{
"param": "new_password",
"code": "password_too_short",
"message": "This password is too short. It must contain at least 6 characters.",
}
],
},
400,
),
# New password not empty
(
True,
{
"current_password": "{user_password}",
"new_password": "",
},
{
"status": 400,
"errors": [
{
"param": "new_password",
"code": "required",
"message": "This field is required.",
}
],
},
400,
),
# Current password not blank
(
True,
{
"current_password": "",
"new_password": "{password_factory}",
},
{
"status": 400,
"errors": [
{
"param": "current_password",
"message": "This field is required.",
"code": "required",
}
],
},
400,
),
# Current password missing
(
True,
{
"new_password": "{password_factory}",
},
{
"status": 400,
"errors": [
{
"param": "current_password",
"message": "This field is required.",
"code": "required",
}
],
},
400,
),
# Current password not set, happy flow
(
False,
{
"current_password": "",
"new_password": "{password_factory}",
},
{
"status": 200,
"meta": {"is_authenticated": True},
"data": {
"user": ANY,
"methods": [],
},
},
200,
),
# Current password not set, current_password absent
(
False,
{
"new_password": "{password_factory}",
},
{
"status": 200,
"meta": {"is_authenticated": True},
"data": {
"user": ANY,
"methods": [],
},
},
200,
),
],
)
def test_change_password(
auth_client,
user,
request_data,
response_data,
status_code,
has_password,
user_password,
password_factory,
settings,
mailoutbox,
headless_reverse,
headless_client,
):
request_data = copy.deepcopy(request_data)
response_data = copy.deepcopy(response_data)
settings.ACCOUNT_EMAIL_NOTIFICATIONS = True
if not has_password:
user.set_unusable_password()
user.save(update_fields=["password"])
auth_client.force_login(user)
if request_data.get("current_password") == "{user_password}":
request_data["current_password"] = user_password
if request_data.get("new_password") == "{password_factory}":
request_data["new_password"] = password_factory()
resp = auth_client.post(
headless_reverse("headless:account:change_password"),
data=request_data,
content_type="application/json",
)
assert resp.status_code == status_code
resp_json = resp.json()
if headless_client == "app" and resp.status_code == 200:
response_data["meta"]["session_token"] = ANY
assert resp_json == response_data
user.refresh_from_db()
if resp.status_code == 200:
assert user.check_password(request_data["new_password"])
assert len(mailoutbox) == 1
else:
assert user.check_password(user_password)
assert len(mailoutbox) == 0
def test_change_password_rate_limit(
enable_cache,
auth_client,
user,
user_password,
password_factory,
settings,
headless_reverse,
):
settings.ACCOUNT_RATE_LIMITS = {"change_password": "1/m/ip"}
for attempt in range(2):
new_password = password_factory()
resp = auth_client.post(
headless_reverse("headless:account:change_password"),
data={
"current_password": user_password,
"new_password": new_password,
},
content_type="application/json",
)
user_password = new_password
expected_status = 200 if attempt == 0 else 429
assert resp.status_code == expected_status
assert resp.json()["status"] == expected_status

View File

@@ -0,0 +1,86 @@
from allauth.account.models import (
EmailAddress,
EmailConfirmationHMAC,
get_emailconfirmation_model,
)
from allauth.headless.constants import Flow
def test_verify_email_other_user(auth_client, user, user_factory, headless_reverse):
other_user = user_factory(email_verified=False)
email_address = EmailAddress.objects.get(user=other_user, verified=False)
assert not email_address.verified
key = EmailConfirmationHMAC(email_address).key
resp = auth_client.post(
headless_reverse("headless:account:verify_email"),
data={"key": key},
content_type="application/json",
)
assert resp.status_code == 200
data = resp.json()
# We're still authenticated as the user originally logged in, not the
# other_user.
assert data["data"]["user"]["id"] == user.pk
def test_auth_unverified_email(
client, user_factory, password_factory, settings, headless_reverse
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
password = password_factory()
user = user_factory(email_verified=False, password=password)
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": password,
},
content_type="application/json",
)
assert resp.status_code == 401
data = resp.json()
flows = data["data"]["flows"]
assert [f for f in flows if f["id"] == Flow.VERIFY_EMAIL][0]["is_pending"]
emailaddress = EmailAddress.objects.filter(user=user, verified=False).get()
key = get_emailconfirmation_model().create(emailaddress).key
resp = client.post(
headless_reverse("headless:account:verify_email"),
data={"key": key},
content_type="application/json",
)
assert resp.status_code == 200
def test_verify_email_bad_key(
client, settings, password_factory, user_factory, headless_reverse
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
password = password_factory()
user = user_factory(email_verified=False, password=password)
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": password,
},
content_type="application/json",
)
assert resp.status_code == 401
resp = client.post(
headless_reverse("headless:account:verify_email"),
data={"key": "bad"},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{
"code": "invalid_or_expired_key",
"param": "key",
"message": "Invalid or expired key.",
}
],
}

View File

@@ -0,0 +1,45 @@
from allauth.headless.constants import Flow
def test_email_verification_rate_limits(
client,
db,
user_password,
settings,
user_factory,
password_factory,
enable_cache,
headless_reverse,
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
settings.ACCOUNT_EMAIL_VERIFICATION_BY_CODE_ENABLED = True
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
settings.ACCOUNT_RATE_LIMITS = {"confirm_email": "1/m/key"}
email = "user@email.org"
user_factory(email=email, email_verified=False, password=user_password)
for attempt in range(2):
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": email,
"password": user_password,
},
content_type="application/json",
)
if attempt == 0:
assert resp.status_code == 401
flow = [
flow for flow in resp.json()["data"]["flows"] if flow.get("is_pending")
][0]
assert flow["id"] == Flow.VERIFY_EMAIL
else:
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{
"message": "Too many failed login attempts. Try again later.",
"code": "too_many_login_attempts",
}
],
}

View File

@@ -0,0 +1,188 @@
from unittest.mock import ANY
import pytest
def test_auth_password_input_error(headless_reverse, client):
resp = client.post(
headless_reverse("headless:account:login"),
data={},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{
"message": "This field is required.",
"code": "required",
"param": "password",
},
{
"message": "This field is required.",
"code": "required",
"param": "username",
},
],
}
def test_auth_password_bad_password(headless_reverse, client, user, settings):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": "wrong",
},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{
"param": "password",
"message": "The email address and/or password you specified are not correct.",
"code": "email_password_mismatch",
}
],
}
def test_auth_password_success(
client, user, user_password, settings, headless_reverse, headless_client
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
login_resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": user_password,
},
content_type="application/json",
)
assert login_resp.status_code == 200
session_resp = client.get(
headless_reverse("headless:account:current_session"),
content_type="application/json",
)
assert session_resp.status_code == 200
for resp in [login_resp, session_resp]:
extra_meta = {}
if headless_client == "app" and resp == login_resp:
# The session is created on first login, and hence the token is
# exposed only at that moment.
extra_meta["session_token"] = ANY
assert resp.json() == {
"status": 200,
"data": {
"user": {
"id": user.pk,
"display": str(user),
"email": user.email,
"username": user.username,
"has_usable_password": True,
},
"methods": [
{
"at": ANY,
"email": user.email,
"method": "password",
}
],
},
"meta": {"is_authenticated": True, **extra_meta},
}
@pytest.mark.parametrize("is_active,status_code", [(False, 401), (True, 200)])
def test_auth_password_user_inactive(
client, user, user_password, settings, status_code, is_active, headless_reverse
):
user.is_active = is_active
user.save(update_fields=["is_active"])
resp = client.post(
headless_reverse("headless:account:login"),
data={
"username": user.username,
"password": user_password,
},
content_type="application/json",
)
assert resp.status_code == status_code
def test_login_failed_rate_limit(
client,
user,
settings,
headless_reverse,
headless_client,
enable_cache,
):
settings.ACCOUNT_RATE_LIMITS = {"login_failed": "1/m/ip"}
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
for attempt in range(2):
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": "wrong",
},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json()["errors"] == [
(
{
"code": "email_password_mismatch",
"message": "The email address and/or password you specified are not correct.",
"param": "password",
}
if attempt == 0
else {
"message": "Too many failed login attempts. Try again later.",
"code": "too_many_login_attempts",
}
)
]
def test_login_rate_limit(
client,
user,
user_password,
settings,
headless_reverse,
headless_client,
enable_cache,
):
settings.ACCOUNT_RATE_LIMITS = {"login": "1/m/ip"}
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
for attempt in range(2):
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": user_password,
},
content_type="application/json",
)
expected_status = 429 if attempt else 200
assert resp.status_code == expected_status
def test_login_already_logged_in(
auth_client, user, user_password, settings, headless_reverse
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
resp = auth_client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": user_password,
},
content_type="application/json",
)
assert resp.status_code == 409

View File

@@ -0,0 +1,146 @@
import time
from allauth.account.models import EmailAddress
from allauth.headless.constants import Flow
def test_login_by_code(headless_reverse, user, client, mailoutbox):
resp = client.post(
headless_reverse("headless:account:request_login_code"),
data={"email": user.email},
content_type="application/json",
)
assert resp.status_code == 401
data = resp.json()
assert [f for f in data["data"]["flows"] if f["id"] == Flow.LOGIN_BY_CODE][0][
"is_pending"
]
assert len(mailoutbox) == 1
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
resp = client.post(
headless_reverse("headless:account:confirm_login_code"),
data={"code": code},
content_type="application/json",
)
assert resp.status_code == 200
data = resp.json()
assert data["meta"]["is_authenticated"]
def test_login_by_code_rate_limit(
headless_reverse, user, client, mailoutbox, settings, enable_cache
):
settings.ACCOUNT_RATE_LIMITS = {"request_login_code": "1/m/ip"}
for attempt in range(2):
resp = client.post(
headless_reverse("headless:account:request_login_code"),
data={"email": user.email},
content_type="application/json",
)
expected_code = 400 if attempt else 401
assert resp.status_code == expected_code
data = resp.json()
assert data["status"] == expected_code
if expected_code == 400:
assert data["errors"] == [
{
"code": "too_many_login_attempts",
"message": "Too many failed login attempts. Try again later.",
"param": "email",
},
]
def test_login_by_code_max_attemps(headless_reverse, user, client, settings):
settings.ACCOUNT_LOGIN_BY_CODE_MAX_ATTEMPTS = 2
resp = client.post(
headless_reverse("headless:account:request_login_code"),
data={"email": user.email},
content_type="application/json",
)
assert resp.status_code == 401
for i in range(3):
resp = client.post(
headless_reverse("headless:account:confirm_login_code"),
data={"code": "wrong"},
content_type="application/json",
)
session_resp = client.get(
headless_reverse("headless:account:current_session"),
data={"code": "wrong"},
content_type="application/json",
)
assert session_resp.status_code == 401
pending_flows = [
f for f in session_resp.json()["data"]["flows"] if f.get("is_pending")
]
if i >= 1:
assert resp.status_code == 409 if i >= 2 else 400
assert len(pending_flows) == 0
else:
assert resp.status_code == 400
assert len(pending_flows) == 1
def test_login_by_code_required(
client, settings, user_factory, password_factory, headless_reverse, mailoutbox
):
settings.ACCOUNT_LOGIN_BY_CODE_REQUIRED = True
password = password_factory()
user = user_factory(password=password, email_verified=False)
email_address = EmailAddress.objects.get(email=user.email)
assert not email_address.verified
resp = client.post(
headless_reverse("headless:account:login"),
data={
"username": user.username,
"password": password,
},
content_type="application/json",
)
assert resp.status_code == 401
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0][
"id"
]
assert pending_flow == Flow.LOGIN_BY_CODE
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
resp = client.post(
headless_reverse("headless:account:confirm_login_code"),
data={"code": code},
content_type="application/json",
)
assert resp.status_code == 200
data = resp.json()
assert data["meta"]["is_authenticated"]
email_address.refresh_from_db()
assert email_address.verified
def test_login_by_code_expired(headless_reverse, user, client, mailoutbox):
resp = client.post(
headless_reverse("headless:account:request_login_code"),
data={"email": user.email},
content_type="application/json",
)
assert resp.status_code == 401
data = resp.json()
assert [f for f in data["data"]["flows"] if f["id"] == Flow.LOGIN_BY_CODE][0][
"is_pending"
]
assert len(mailoutbox) == 1
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
# Expire code
session = client.headless_session()
login = session["account_login"]
login["state"]["login_code"]["at"] = time.time() - 24 * 60 * 60
session["account_login"] = login
session.save()
# Post valid code
resp = client.post(
headless_reverse("headless:account:confirm_login_code"),
data={"code": code},
content_type="application/json",
)
assert resp.status_code == 409

View File

@@ -0,0 +1,52 @@
def test_reauthenticate(
auth_client, user, user_password, headless_reverse, headless_client
):
resp = auth_client.get(
headless_reverse("headless:account:current_session"),
content_type="application/json",
)
assert resp.status_code == 200
data = resp.json()
method_count = len(data["data"]["methods"])
resp = auth_client.post(
headless_reverse("headless:account:reauthenticate"),
data={
"password": user_password,
},
content_type="application/json",
)
assert resp.status_code == 200
resp = auth_client.get(
headless_reverse("headless:account:current_session"),
content_type="application/json",
)
assert resp.status_code == 200
data = resp.json()
assert len(data["data"]["methods"]) == method_count + 1
last_method = data["data"]["methods"][-1]
assert last_method["method"] == "password"
def test_reauthenticate_rate_limit(
auth_client,
user,
user_password,
headless_reverse,
headless_client,
settings,
enable_cache,
):
settings.ACCOUNT_RATE_LIMITS = {"reauthenticate": "1/m/ip"}
for attempt in range(4):
resp = auth_client.post(
headless_reverse("headless:account:reauthenticate"),
data={
"password": user_password,
},
content_type="application/json",
)
expected_status = 429 if attempt else 200
assert resp.status_code == expected_status
assert resp.json()["status"] == expected_status

View File

@@ -0,0 +1,151 @@
from django.urls import reverse
import pytest
def test_password_reset_flow(
client, user, mailoutbox, password_factory, settings, headless_reverse
):
settings.ACCOUNT_EMAIL_NOTIFICATIONS = True
resp = client.post(
headless_reverse("headless:account:request_password_reset"),
data={
"email": user.email,
},
content_type="application/json",
)
assert resp.status_code == 200
assert len(mailoutbox) == 1
body = mailoutbox[0].body
# Extract URL for `password_reset_from_key` view
url = body[body.find("/password/reset/") :].split()[0]
key = url.split("/")[-2]
password = password_factory()
# Too simple password
resp = client.post(
headless_reverse("headless:account:reset_password"),
data={
"key": key,
"password": "a",
},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{
"code": "password_too_short",
"message": "This password is too short. It must contain at least 6 characters.",
}
],
}
assert len(mailoutbox) == 1
# Success
resp = client.post(
headless_reverse("headless:account:reset_password"),
data={
"key": key,
"password": password,
},
content_type="application/json",
)
assert resp.status_code == 401
user.refresh_from_db()
assert user.check_password(password)
assert len(mailoutbox) == 2 # The security notification
@pytest.mark.parametrize("method", ["get", "post"])
def test_password_reset_flow_wrong_key(
client, password_factory, headless_reverse, method
):
password = password_factory()
if method == "get":
resp = client.get(
headless_reverse("headless:account:reset_password"),
HTTP_X_PASSWORD_RESET_KEY="wrong",
)
else:
resp = client.post(
headless_reverse("headless:account:reset_password"),
data={
"key": "wrong",
"password": password,
},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{
"param": "key",
"code": "invalid_password_reset",
"message": "The password reset token was invalid.",
}
],
}
def test_password_reset_flow_unknown_user(
client, db, mailoutbox, password_factory, settings, headless_reverse
):
resp = client.post(
headless_reverse("headless:account:request_password_reset"),
data={
"email": "not@registered.org",
},
content_type="application/json",
)
assert resp.status_code == 200
assert len(mailoutbox) == 1
body = mailoutbox[0].body
if getattr(settings, "HEADLESS_ONLY", False):
assert settings.HEADLESS_FRONTEND_URLS["account_signup"] in body
else:
assert reverse("account_signup") in body
def test_reset_password_rate_limit(
auth_client, user, headless_reverse, settings, enable_cache
):
settings.ACCOUNT_RATE_LIMITS = {"reset_password": "1/m/ip"}
for attempt in range(2):
resp = auth_client.post(
headless_reverse("headless:account:request_password_reset"),
data={"email": user.email},
content_type="application/json",
)
expected_status = 200 if attempt == 0 else 429
assert resp.status_code == expected_status
assert resp.json()["status"] == expected_status
def test_password_reset_key_rate_limit(
client,
user,
settings,
headless_reverse,
password_reset_key_generator,
enable_cache,
):
settings.ACCOUNT_RATE_LIMITS = {"reset_password_from_key": "1/m/ip"}
for attempt in range(2):
resp = client.post(
headless_reverse("headless:account:reset_password"),
data={
"key": password_reset_key_generator(user),
"password": "a", # too short
},
content_type="application/json",
)
expected_status = 429 if attempt else 400
assert resp.status_code == expected_status
assert resp.json()["status"] == expected_status

View File

@@ -0,0 +1,33 @@
from django.test.client import Client
from django.urls import reverse
def test_app_session_gone(db, user):
# intentionally use a vanilla Django test client
client = Client()
# Force login, creates a Django session
client.force_login(user)
# That Django session should not play any role.
resp = client.get(
reverse("headless:app:account:current_session"), HTTP_X_SESSION_TOKEN="gone"
)
assert resp.status_code == 410
def test_logout(auth_client, headless_reverse):
# That Django session should not play any role.
resp = auth_client.get(headless_reverse("headless:account:current_session"))
assert resp.status_code == 200
resp = auth_client.delete(headless_reverse("headless:account:current_session"))
assert resp.status_code == 401
resp = auth_client.get(headless_reverse("headless:account:current_session"))
assert resp.status_code in [401, 410]
def test_logout_no_token(app_client, user):
app_client.force_login(user)
resp = app_client.get(reverse("headless:app:account:current_session"))
assert resp.status_code == 200
resp = app_client.delete(reverse("headless:app:account:current_session"))
assert resp.status_code == 401
assert "session_token" not in resp.json()["meta"]

View File

@@ -0,0 +1,190 @@
from unittest.mock import ANY, patch
from django.contrib.auth.models import User
from allauth.account.models import EmailAddress, EmailConfirmationHMAC
from allauth.headless.constants import Flow
def test_signup(
db,
client,
email_factory,
password_factory,
settings,
headless_reverse,
headless_client,
):
resp = client.post(
headless_reverse("headless:account:signup"),
data={
"username": "wizard",
"email": email_factory(),
"password": password_factory(),
},
content_type="application/json",
)
assert resp.status_code == 200
assert User.objects.filter(username="wizard").exists()
def test_signup_with_email_verification(
db,
client,
email_factory,
password_factory,
settings,
headless_reverse,
headless_client,
):
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
settings.ACCOUNT_USERNAME_REQUIRED = False
email = email_factory()
resp = client.post(
headless_reverse("headless:account:signup"),
data={
"email": email,
"password": password_factory(),
},
content_type="application/json",
)
assert resp.status_code == 401
assert User.objects.filter(email=email).exists()
data = resp.json()
flow = next((f for f in data["data"]["flows"] if f.get("is_pending")))
assert flow["id"] == "verify_email"
addr = EmailAddress.objects.get(email=email)
key = EmailConfirmationHMAC(addr).key
resp = client.get(
headless_reverse("headless:account:verify_email"),
HTTP_X_EMAIL_VERIFICATION_KEY=key,
)
assert resp.status_code == 200
assert resp.json() == {
"data": {
"email": email,
"user": ANY,
},
"meta": {"is_authenticating": True},
"status": 200,
}
resp = client.post(
headless_reverse("headless:account:verify_email"),
data={"key": key},
content_type="application/json",
)
assert resp.status_code == 200
data = resp.json()
assert data["meta"]["is_authenticated"]
addr.refresh_from_db()
assert addr.verified
def test_signup_prevent_enumeration(
db,
client,
email_factory,
password_factory,
settings,
headless_reverse,
headless_client,
user,
mailoutbox,
):
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
settings.ACCOUNT_USERNAME_REQUIRED = False
settings.ACCOUNT_PREVENT_ENUMERATION = True
resp = client.post(
headless_reverse("headless:account:signup"),
data={
"email": user.email,
"password": password_factory(),
},
content_type="application/json",
)
assert len(mailoutbox) == 1
assert "an account using that email address already exists" in mailoutbox[0].body
assert resp.status_code == 401
data = resp.json()
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
"is_pending"
]
resp = client.get(headless_reverse("headless:account:current_session"))
data = resp.json()
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
"is_pending"
]
def test_signup_rate_limit(
db,
client,
email_factory,
password_factory,
settings,
headless_reverse,
enable_cache,
headless_client,
):
settings.ACCOUNT_RATE_LIMITS = {"signup": "1/m/ip"}
for attempt in range(2):
resp = client.post(
headless_reverse("headless:account:signup"),
data={
"username": f"wizard{attempt}",
"email": email_factory(),
"password": password_factory(),
},
content_type="application/json",
)
expected_status = 429 if attempt else 200
assert resp.status_code == expected_status
assert resp.json()["status"] == expected_status
def test_signup_closed(
db,
client,
email_factory,
password_factory,
settings,
headless_reverse,
headless_client,
):
with patch(
"allauth.account.adapter.DefaultAccountAdapter.is_open_for_signup"
) as iofs:
iofs.return_value = False
resp = client.post(
headless_reverse("headless:account:signup"),
data={
"username": "wizard",
"email": email_factory(),
"password": password_factory(),
},
content_type="application/json",
)
assert resp.status_code == 403
assert not User.objects.filter(username="wizard").exists()
def test_signup_while_logged_in(
db,
auth_client,
email_factory,
password_factory,
settings,
headless_reverse,
headless_client,
):
resp = auth_client.post(
headless_reverse("headless:account:signup"),
data={
"username": "wizard",
"email": email_factory(),
"password": password_factory(),
},
content_type="application/json",
)
assert resp.status_code == 409
assert resp.json() == {"status": 409}

View File

@@ -0,0 +1,96 @@
from django.urls import include, path
from allauth import app_settings as allauth_settings
from allauth.account import app_settings as account_settings
from allauth.headless.account import views
def build_urlpatterns(client):
account_patterns = []
auth_patterns = [
path(
"session",
views.SessionView.as_api_view(client=client),
name="current_session",
),
path(
"reauthenticate",
views.ReauthenticateView.as_api_view(client=client),
name="reauthenticate",
),
path(
"code/confirm",
views.ConfirmLoginCodeView.as_api_view(client=client),
name="confirm_login_code",
),
]
if not allauth_settings.SOCIALACCOUNT_ONLY:
account_patterns.extend(
[
path(
"password/change",
views.ChangePasswordView.as_api_view(client=client),
name="change_password",
),
path(
"email",
views.ManageEmailView.as_api_view(client=client),
name="manage_email",
),
]
)
auth_patterns.extend(
[
path(
"password/",
include(
[
path(
"request",
views.RequestPasswordResetView.as_api_view(
client=client
),
name="request_password_reset",
),
path(
"reset",
views.ResetPasswordView.as_api_view(client=client),
name="reset_password",
),
]
),
),
path(
"login",
views.LoginView.as_api_view(client=client),
name="login",
),
path(
"signup",
views.SignupView.as_api_view(client=client),
name="signup",
),
path(
"email/verify",
views.VerifyEmailView.as_api_view(client=client),
name="verify_email",
),
]
)
if account_settings.LOGIN_BY_CODE_ENABLED:
auth_patterns.extend(
[
path(
"code/request",
views.RequestLoginCodeView.as_api_view(client=client),
name="request_login_code",
),
]
)
return [
path("auth/", include(auth_patterns)),
path(
"account/",
include(account_patterns),
),
]

View File

@@ -0,0 +1,256 @@
from django.utils.decorators import method_decorator
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.internal import flows
from allauth.account.internal.flows import password_change, password_reset
from allauth.account.models import EmailAddress
from allauth.account.stages import EmailVerificationStage, LoginStageController
from allauth.account.utils import send_email_confirmation
from allauth.core import ratelimit
from allauth.core.exceptions import ImmediateHttpResponse
from allauth.decorators import rate_limit
from allauth.headless.account import response
from allauth.headless.account.inputs import (
AddEmailInput,
ChangePasswordInput,
ConfirmLoginCodeInput,
DeleteEmailInput,
LoginInput,
MarkAsPrimaryEmailInput,
ReauthenticateInput,
RequestLoginCodeInput,
RequestPasswordResetInput,
ResetPasswordInput,
ResetPasswordKeyInput,
SelectEmailInput,
SignupInput,
VerifyEmailInput,
)
from allauth.headless.base.response import (
APIResponse,
AuthenticationResponse,
ConflictResponse,
ForbiddenResponse,
)
from allauth.headless.base.views import APIView, AuthenticatedAPIView
from allauth.headless.internal import authkit
from allauth.headless.internal.restkit.response import ErrorResponse
class RequestLoginCodeView(APIView):
input_class = RequestLoginCodeInput
def post(self, request, *args, **kwargs):
flows.login_by_code.request_login_code(
self.request, self.input.cleaned_data["email"]
)
return AuthenticationResponse(self.request)
class ConfirmLoginCodeView(APIView):
input_class = ConfirmLoginCodeInput
def dispatch(self, request, *args, **kwargs):
auth_status = authkit.AuthenticationStatus(request)
self.stage = auth_status.get_pending_stage()
if not self.stage:
return ConflictResponse(request)
self.user, self.pending_login = flows.login_by_code.get_pending_login(
request, self.stage.login, peek=True
)
if not self.pending_login:
return ConflictResponse(request)
return super().dispatch(request, *args, **kwargs)
def post(self, request, *args, **kwargs):
flows.login_by_code.perform_login_by_code(self.request, self.stage, None)
return AuthenticationResponse(request)
def get_input_kwargs(self):
kwargs = super().get_input_kwargs()
kwargs["code"] = (
self.pending_login.get("code", "") if self.pending_login else ""
)
return kwargs
def handle_invalid_input(self, input):
flows.login_by_code.record_invalid_attempt(self.request, self.stage.login)
return super().handle_invalid_input(input)
@method_decorator(rate_limit(action="login"), name="handle")
class LoginView(APIView):
input_class = LoginInput
def post(self, request, *args, **kwargs):
if request.user.is_authenticated:
return ConflictResponse(request)
credentials = self.input.cleaned_data
flows.login.perform_password_login(request, credentials, self.input.login)
return AuthenticationResponse(self.request)
@method_decorator(rate_limit(action="signup"), name="handle")
class SignupView(APIView):
input_class = {"POST": SignupInput}
by_passkey = False
def post(self, request, *args, **kwargs):
if request.user.is_authenticated:
return ConflictResponse(request)
if not get_account_adapter().is_open_for_signup(request):
return ForbiddenResponse(request)
user, resp = self.input.try_save(request)
if not resp:
try:
flows.signup.complete_signup(
request, user=user, by_passkey=self.by_passkey
)
except ImmediateHttpResponse:
pass
return AuthenticationResponse(request)
class SessionView(APIView):
def get(self, request, *args, **kwargs):
return AuthenticationResponse(request)
def delete(self, request, *args, **kwargs):
adapter = get_account_adapter()
adapter.logout(request)
return AuthenticationResponse(request)
class VerifyEmailView(APIView):
input_class = VerifyEmailInput
def handle(self, request, *args, **kwargs):
self.stage = LoginStageController.enter(request, EmailVerificationStage.key)
return super().handle(request, *args, **kwargs)
def get(self, request, *args, **kwargs):
key = request.headers.get("x-email-verification-key", "")
input = self.input_class({"key": key})
if not input.is_valid():
return ErrorResponse(request, input=input)
verification = input.cleaned_data["key"]
return response.VerifyEmailResponse(request, verification, stage=self.stage)
def post(self, request, *args, **kwargs):
confirmation = self.input.cleaned_data["key"]
email_address = confirmation.confirm(request)
if not email_address:
# Should not happen, VerifyInputInput should have verified all
# preconditions.
return APIResponse(status=500)
if self.stage:
# Verifying email as part of login/signup flow, so emit a
# authentication status response.
self.stage.exit()
return AuthenticationResponse(self.request)
class RequestPasswordResetView(APIView):
input_class = RequestPasswordResetInput
def post(self, request, *args, **kwargs):
r429 = ratelimit.consume_or_429(
self.request,
action="reset_password",
key=self.input.cleaned_data["email"].lower(),
)
if r429:
return r429
self.input.save(request)
return response.RequestPasswordResponse(request)
@method_decorator(rate_limit(action="reset_password_from_key"), name="handle")
class ResetPasswordView(APIView):
input_class = ResetPasswordInput
def get(self, request, *args, **kwargs):
key = request.headers.get("X-Password-Reset-Key", "")
input = ResetPasswordKeyInput({"key": key})
if not input.is_valid():
return ErrorResponse(request, input=input)
return response.PasswordResetKeyResponse(request, input.user)
def post(self, request, *args, **kwargs):
flows.password_reset.reset_password(
self.input.user, self.input.cleaned_data["password"]
)
password_reset.finalize_password_reset(request, self.input.user)
return AuthenticationResponse(self.request)
@method_decorator(rate_limit(action="change_password"), name="handle")
class ChangePasswordView(AuthenticatedAPIView):
input_class = ChangePasswordInput
def post(self, request, *args, **kwargs):
password_change.change_password(
self.request.user, self.input.cleaned_data["new_password"]
)
is_set = not self.input.cleaned_data.get("current_password")
if is_set:
password_change.finalize_password_set(request, request.user)
else:
password_change.finalize_password_change(request, request.user)
return AuthenticationResponse(request)
def get_input_kwargs(self):
return {"user": self.request.user}
@method_decorator(rate_limit(action="manage_email"), name="handle")
class ManageEmailView(AuthenticatedAPIView):
input_class = {
"POST": AddEmailInput,
"PUT": SelectEmailInput,
"DELETE": DeleteEmailInput,
"PATCH": MarkAsPrimaryEmailInput,
}
def get(self, request, *args, **kwargs):
return self._respond_email_list()
def _respond_email_list(self):
addrs = EmailAddress.objects.filter(user=self.request.user)
return response.EmailAddressesResponse(self.request, addrs)
def post(self, request, *args, **kwargs):
flows.manage_email.add_email(request, self.input)
return self._respond_email_list()
def delete(self, request, *args, **kwargs):
addr = self.input.cleaned_data["email"]
flows.manage_email.delete_email(request, addr)
return self._respond_email_list()
def patch(self, request, *args, **kwargs):
addr = self.input.cleaned_data["email"]
flows.manage_email.mark_as_primary(request, addr)
return self._respond_email_list()
def put(self, request, *args, **kwargs):
addr = self.input.cleaned_data["email"]
sent = send_email_confirmation(request, request.user, email=addr.email)
return response.RequestEmailVerificationResponse(
request, verification_sent=sent
)
def get_input_kwargs(self):
return {"user": self.request.user}
@method_decorator(rate_limit(action="reauthenticate"), name="handle")
class ReauthenticateView(AuthenticatedAPIView):
input_class = ReauthenticateInput
def post(self, request, *args, **kwargs):
flows.reauthentication.reauthenticate_by_password(self.request)
return AuthenticationResponse(self.request)
def get_input_kwargs(self):
return {"user": self.request.user}

View File

@@ -0,0 +1,54 @@
from typing import Any, Dict
from django.forms.fields import Field
from allauth.account.models import EmailAddress
from allauth.account.utils import user_display, user_username
from allauth.core.internal.adapter import BaseAdapter
from allauth.headless import app_settings
from allauth.utils import import_attribute
class DefaultHeadlessAdapter(BaseAdapter):
"""The adapter class allows you to override various functionality of the
``allauth.headless`` app. To do so, point ``settings.HEADLESS_ADAPTER`` to your own
class that derives from ``DefaultHeadlessAdapter`` and override the behavior by
altering the implementation of the methods according to your own need.
"""
error_messages = {
# For the following error messages i18n is not an issue as these should not be
# showing up in a UI.
"account_not_found": "Unknown account.",
"client_id_required": "`client_id` required.",
"email_or_username": "Pass only one of email or username, not both.",
"invalid_token": "Invalid token.",
"token_authentication_not_supported": "Provider does not support token authentication.",
"token_required": "`id_token` and/or `access_token` required.",
"required": Field.default_error_messages["required"],
"unknown_email": "Unknown email address.",
"invalid_url": "Invalid URL.",
}
def serialize_user(self, user) -> Dict[str, Any]:
"""
Returns the basic user data. Note that this data is also exposed in
partly authenticated scenario's (e.g. password reset, email
verification).
"""
ret = {
"id": user.pk,
"display": user_display(user),
"has_usable_password": user.has_usable_password(),
}
email = EmailAddress.objects.get_primary_email(user)
if email:
ret["email"] = email
username = user_username(user)
if username:
ret["username"] = username
return ret
def get_adapter():
return import_attribute(app_settings.ADAPTER)()

View File

@@ -0,0 +1,36 @@
class AppSettings:
def __init__(self, prefix):
self.prefix = prefix
def _setting(self, name, dflt):
from allauth.utils import get_setting
return get_setting(self.prefix + name, dflt)
@property
def ADAPTER(self):
return self._setting(
"ADAPTER", "allauth.headless.adapter.DefaultHeadlessAdapter"
)
@property
def TOKEN_STRATEGY(self):
from allauth.utils import import_attribute
path = self._setting(
"TOKEN_STRATEGY", "allauth.headless.tokens.sessions.SessionTokenStrategy"
)
cls = import_attribute(path)
return cls()
@property
def FRONTEND_URLS(self):
return self._setting("FRONTEND_URLS", {})
_app_settings = AppSettings("HEADLESS_")
def __getattr__(name):
# See https://peps.python.org/pep-0562/
return getattr(_app_settings, name)

View File

@@ -0,0 +1,7 @@
from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
class HeadlessConfig(AppConfig):
name = "allauth.headless"
verbose_name = _("Headless")

View File

@@ -0,0 +1,155 @@
from allauth import app_settings as allauth_settings
from allauth.account import app_settings as account_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.authentication import get_authentication_records
from allauth.account.internal import flows
from allauth.account.internal.stagekit import LOGIN_SESSION_KEY
from allauth.headless.adapter import get_adapter
from allauth.headless.constants import Flow
from allauth.headless.internal import authkit
from allauth.headless.internal.restkit.response import APIResponse
from allauth.mfa import app_settings as mfa_settings
class BaseAuthenticationResponse(APIResponse):
def __init__(self, request, user=None, status=None):
data = {}
if user and user.is_authenticated:
adapter = get_adapter()
data["user"] = adapter.serialize_user(user)
data["methods"] = get_authentication_records(request)
status = status or 200
else:
status = status or 401
if status != 200:
data["flows"] = self._get_flows(request, user)
meta = {
"is_authenticated": user and user.is_authenticated,
}
super().__init__(
request,
data=data,
meta=meta,
status=status,
)
def _get_flows(self, request, user):
auth_status = authkit.AuthenticationStatus(request)
ret = []
if user and user.is_authenticated:
ret.extend(flows.reauthentication.get_reauthentication_flows(user))
else:
if not allauth_settings.SOCIALACCOUNT_ONLY:
ret.append({"id": Flow.LOGIN})
if account_settings.LOGIN_BY_CODE_ENABLED:
ret.append({"id": Flow.LOGIN_BY_CODE})
if (
get_account_adapter().is_open_for_signup(request)
and not allauth_settings.SOCIALACCOUNT_ONLY
):
ret.append({"id": Flow.SIGNUP})
if allauth_settings.SOCIALACCOUNT_ENABLED:
from allauth.headless.socialaccount.response import (
provider_flows,
)
ret.extend(provider_flows(request))
if allauth_settings.MFA_ENABLED:
if mfa_settings.PASSKEY_LOGIN_ENABLED:
ret.append({"id": Flow.MFA_LOGIN_WEBAUTHN})
stage_key = None
stage = auth_status.get_pending_stage()
if stage:
stage_key = stage.key
else:
lsk = request.session.get(LOGIN_SESSION_KEY)
if isinstance(lsk, str):
stage_key = lsk
if stage_key:
pending_flow = {"id": stage_key, "is_pending": True}
if stage and stage_key == Flow.MFA_AUTHENTICATE:
self._enrich_mfa_flow(stage, pending_flow)
self._upsert_pending_flow(ret, pending_flow)
return ret
def _upsert_pending_flow(self, flows, pending_flow):
flow = next((flow for flow in flows if flow["id"] == pending_flow["id"]), None)
if flow:
flow.update(pending_flow)
else:
flows.append(pending_flow)
def _enrich_mfa_flow(self, stage, flow: dict) -> None:
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
from allauth.mfa.models import Authenticator
adapter = get_mfa_adapter()
types = []
for typ in Authenticator.Type:
if adapter.is_mfa_enabled(stage.login.user, types=[typ]):
types.append(typ)
flow["types"] = types
class AuthenticationResponse(BaseAuthenticationResponse):
def __init__(self, request):
super().__init__(request, user=request.user)
class ReauthenticationResponse(BaseAuthenticationResponse):
def __init__(self, request):
super().__init__(request, user=request.user, status=401)
class UnauthorizedResponse(BaseAuthenticationResponse):
def __init__(self, request, status=401):
super().__init__(request, user=None, status=status)
class ForbiddenResponse(APIResponse):
def __init__(self, request):
super().__init__(request, status=403)
class ConflictResponse(APIResponse):
def __init__(self, request):
super().__init__(request, status=409)
def get_config_data(request):
data = {
"authentication_method": account_settings.AUTHENTICATION_METHOD,
"is_open_for_signup": get_account_adapter().is_open_for_signup(request),
"email_verification_by_code_enabled": account_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED,
"login_by_code_enabled": account_settings.LOGIN_BY_CODE_ENABLED,
}
return {"account": data}
class ConfigResponse(APIResponse):
def __init__(self, request):
data = get_config_data(request)
if allauth_settings.SOCIALACCOUNT_ENABLED:
from allauth.headless.socialaccount.response import (
get_config_data as get_socialaccount_config_data,
)
data.update(get_socialaccount_config_data(request))
if allauth_settings.MFA_ENABLED:
from allauth.headless.mfa.response import (
get_config_data as get_mfa_config_data,
)
data.update(get_mfa_config_data(request))
if allauth_settings.USERSESSIONS_ENABLED:
from allauth.headless.usersessions.response import (
get_config_data as get_usersessions_config_data,
)
data.update(get_usersessions_config_data(request))
return super().__init__(request, data=data)
class RateLimitResponse(APIResponse):
def __init__(self, request):
super().__init__(request, status=429)

View File

@@ -0,0 +1,10 @@
def test_config(db, client, headless_reverse):
resp = client.get(headless_reverse("headless:config"))
assert resp.status_code == 200
data = resp.json()
assert set(data["data"].keys()) == {
"account",
"mfa",
"socialaccount",
"usersessions",
}

View File

@@ -0,0 +1,13 @@
from django.urls import path
from allauth.headless.base import views
def build_urlpatterns(client):
return [
path(
"config",
views.ConfigView.as_api_view(client=client),
name="config",
),
]

View File

@@ -0,0 +1,62 @@
from typing import Optional, Type
from django.utils.decorators import classonlymethod
from allauth.account.stages import LoginStage, LoginStageController
from allauth.core.exceptions import ReauthenticationRequired
from allauth.headless.base import response
from allauth.headless.constants import Client
from allauth.headless.internal import decorators
from allauth.headless.internal.restkit.views import RESTView
class APIView(RESTView):
client = None
@classonlymethod
def as_api_view(cls, **initkwargs):
view_func = cls.as_view(**initkwargs)
if initkwargs["client"] == Client.APP:
view_func = decorators.app_view(view_func)
else:
view_func = decorators.browser_view(view_func)
return view_func
def dispatch(self, request, *args, **kwargs):
try:
return super().dispatch(request, *args, **kwargs)
except ReauthenticationRequired:
return response.ReauthenticationResponse(self.request)
class AuthenticationStageAPIView(APIView):
stage_class: Optional[Type[LoginStage]] = None
def handle(self, request, *args, **kwargs):
self.stage = LoginStageController.enter(request, self.stage_class.key)
if not self.stage:
return response.UnauthorizedResponse(request)
return super().handle(request, *args, **kwargs)
def respond_stage_error(self):
return response.UnauthorizedResponse(self.request)
def respond_next_stage(self):
self.stage.exit()
return response.AuthenticationResponse(self.request)
class AuthenticatedAPIView(APIView):
def dispatch(self, request, *args, **kwargs):
if not request.user.is_authenticated:
return response.AuthenticationResponse(request)
return super().dispatch(request, *args, **kwargs)
class ConfigView(APIView):
def get(self, request, *args, **kwargs):
"""
The frontend queries (GET) this endpoint, expecting to receive
either a 401 if no user is authenticated, or user information.
"""
return response.ConfigResponse(request)

View File

@@ -0,0 +1,63 @@
from django.test.client import Client
from django.urls import reverse
import pytest
@pytest.fixture(params=["app", "browser"])
def headless_client(request):
return request.param
@pytest.fixture
def headless_reverse(headless_client):
def rev(viewname, **kwargs):
viewname = viewname.replace("headless:", f"headless:{headless_client}:")
return reverse(viewname, **kwargs)
return rev
class AppClient(Client):
session_token = None
def generic(self, *args, **kwargs):
if self.session_token:
kwargs["HTTP_X_SESSION_TOKEN"] = self.session_token
resp = super().generic(*args, **kwargs)
if resp["content-type"] == "application/json":
data = resp.json()
session_token = data.get("meta", {}).get("session_token")
if session_token:
self.session_token = session_token
return resp
def force_login(self, user):
ret = super().force_login(user)
self.session_token = self.session.session_key
return ret
def headless_session(self):
from allauth.headless.internal import sessionkit
return sessionkit.session_store(self.session_token)
@pytest.fixture
def app_client():
return AppClient()
@pytest.fixture
def client(headless_client):
if headless_client == "browser":
client = Client()
client.headless_session = lambda: client.session
return client
return AppClient()
@pytest.fixture
def auth_client(client, user):
client.force_login(user)
return client

View File

@@ -0,0 +1,23 @@
from enum import Enum
from allauth.account.stages import EmailVerificationStage, LoginByCodeStage
class Client(str, Enum):
APP = "app"
BROWSER = "browser"
class Flow(str, Enum):
VERIFY_EMAIL = EmailVerificationStage.key
LOGIN = "login"
LOGIN_BY_CODE = LoginByCodeStage.key
SIGNUP = "signup"
PROVIDER_REDIRECT = "provider_redirect"
PROVIDER_SIGNUP = "provider_signup"
PROVIDER_TOKEN = "provider_token"
REAUTHENTICATE = "reauthenticate"
MFA_REAUTHENTICATE = "mfa_reauthenticate"
MFA_AUTHENTICATE = "mfa_authenticate" # NOTE: Equal to `allauth.mfa.stages.AuthenticationStage.key`
MFA_LOGIN_WEBAUTHN = "mfa_login_webauthn"
MFA_SIGNUP_WEBAUTHN = "mfa_signup_webauthn"

View File

@@ -0,0 +1,85 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional
from django.utils.functional import SimpleLazyObject, empty
from allauth import app_settings as allauth_settings
from allauth.account.internal.stagekit import get_pending_stage
from allauth.core.exceptions import ImmediateHttpResponse
from allauth.headless import app_settings
from allauth.headless.constants import Client
from allauth.headless.internal import sessionkit
class AuthenticationStatus:
def __init__(self, request):
self.request = request
@property
def is_authenticated(self):
return self.request.user.is_authenticated
def get_pending_stage(self):
return get_pending_stage(self.request)
@property
def has_pending_signup(self):
if not allauth_settings.SOCIALACCOUNT_ENABLED:
return False
from allauth.socialaccount.internal import flows
return bool(flows.signup.get_pending_signup(self.request))
def purge_request_user_cache(request):
for attr in ["_cached_user", "_acached_user"]:
if hasattr(request, attr):
delattr(request, attr)
if isinstance(request.user, SimpleLazyObject):
request.user._wrapped = empty
@contextmanager
def authentication_context(request):
from allauth.headless.base.response import UnauthorizedResponse
old_user = request.user
old_session = request.session
try:
request.session = sessionkit.new_session()
purge_request_user_cache(request)
strategy = app_settings.TOKEN_STRATEGY
session_token = strategy.get_session_token(request)
if session_token:
session = strategy.lookup_session(session_token)
if not session:
raise ImmediateHttpResponse(UnauthorizedResponse(request, status=410))
request.session = session
purge_request_user_cache(request)
request.allauth.headless._pre_user = request.user
# request.user is lazy -- force evaluation
request.allauth.headless._pre_user.pk
yield
finally:
if request.session.modified and not request.session.is_empty():
request.session.save()
request.user = old_user
request.session = old_session
# e.g. logging in calls csrf `rotate_token()` -- this prevents setting a new CSRF cookie.
request.META["CSRF_COOKIE_NEEDS_UPDATE"] = False
def expose_access_token(request) -> Optional[Dict[str, Any]]:
"""
Determines if a new access token needs to be exposed.
"""
if request.allauth.headless.client != Client.APP:
return None
if not request.user.is_authenticated:
return None
pre_user = request.allauth.headless._pre_user
if pre_user.is_authenticated and pre_user.pk == request.user.pk:
return None
strategy = app_settings.TOKEN_STRATEGY
return strategy.create_access_token_payload(request)

View File

@@ -0,0 +1,54 @@
from functools import wraps
from types import SimpleNamespace
from django.middleware.csrf import get_token
from django.views.decorators.csrf import csrf_exempt
from allauth.account.internal.decorators import login_not_required
from allauth.headless.constants import Client
from allauth.headless.internal import authkit
def mark_request_as_headless(request, client):
request.allauth.headless = SimpleNamespace()
request.allauth.headless.client = client
def app_view(
function=None,
):
def decorator(view_func):
@login_not_required
@wraps(view_func)
def _wrapper_view(request, *args, **kwargs):
mark_request_as_headless(request, Client.APP)
with authkit.authentication_context(request):
return view_func(request, *args, **kwargs)
return _wrapper_view
ret = decorator
if function:
ret = decorator(function)
return csrf_exempt(ret)
def browser_view(
function=None,
):
def decorator(view_func):
@login_not_required
@wraps(view_func)
def _wrapper_view(request, *args, **kwargs):
mark_request_as_headless(request, Client.BROWSER)
# Needed -- so that the CSRF token is set in the response for the
# frontend to pick up.
get_token(request)
return view_func(request, *args, **kwargs)
return _wrapper_view
ret = decorator
if function:
ret = decorator(function)
return ret

View File

@@ -0,0 +1,25 @@
from django.forms import (
BooleanField,
CharField,
ChoiceField,
EmailField,
Field,
Form,
ModelChoiceField,
ModelMultipleChoiceField,
)
__all__ = [
"Field",
"CharField",
"ChoiceField",
"EmailField",
"BooleanField",
"ModelMultipleChoiceField",
"ModelChoiceField",
]
class Input(Form):
pass

View File

@@ -0,0 +1,55 @@
from typing import Any, Dict, Optional
from django.forms.utils import ErrorList
from django.http import JsonResponse
from django.utils.cache import add_never_cache_headers
from allauth.headless.internal import authkit, sessionkit
class APIResponse(JsonResponse):
def __init__(
self,
request,
errors=None,
data=None,
meta: Optional[Dict] = None,
status: int = 200,
):
d: Dict[str, Any] = {"status": status}
if data is not None:
d["data"] = data
meta = self._add_session_meta(request, meta)
if meta is not None:
d["meta"] = meta
if errors:
d["errors"] = errors
super().__init__(d, status=status)
add_never_cache_headers(self)
def _add_session_meta(self, request, meta: Optional[Dict]) -> Optional[Dict]:
session_token = sessionkit.expose_session_token(request)
access_token_payload = authkit.expose_access_token(request)
if session_token:
meta = meta or {}
meta["session_token"] = session_token
if access_token_payload:
meta = meta or {}
meta.update(access_token_payload)
return meta
class ErrorResponse(APIResponse):
def __init__(self, request, exception=None, input=None, status=400):
errors = []
if exception is not None:
error_datas = ErrorList(exception.error_list).get_json_data()
errors.extend(error_datas)
if input is not None:
for field, error_list in input.errors.items():
error_datas = error_list.get_json_data()
for error_data in error_datas:
if field != "__all__":
error_data["param"] = field
errors.extend(error_datas)
super().__init__(request, status=status, errors=errors)

View File

@@ -0,0 +1,53 @@
import json
from typing import Dict, Optional, Type, Union
from django.http import HttpResponseBadRequest
from django.views.generic import View
from allauth.core.exceptions import ImmediateHttpResponse
from allauth.headless.internal.restkit.inputs import Input
from allauth.headless.internal.restkit.response import ErrorResponse
class RESTView(View):
input_class: Union[Optional[Dict[str, Type[Input]]], Type[Input]] = None
handle_json_input = True
def dispatch(self, request, *args, **kwargs):
return self.handle(request, *args, **kwargs)
def handle(self, request, *args, **kwargs):
if self.handle_json_input and request.method != "GET":
self.data = self._parse_json(request)
response = self.handle_input(self.data)
if response:
return response
return super().dispatch(request, *args, **kwargs)
def get_input_kwargs(self):
return {}
def handle_input(self, data):
input_class = self.input_class
if isinstance(input_class, dict):
input_class = input_class.get(self.request.method)
if not input_class:
return
input_kwargs = self.get_input_kwargs()
if data is None:
# Make form bound on empty POST
data = {}
self.input = input_class(data=data, **input_kwargs)
if not self.input.is_valid():
return self.handle_invalid_input(self.input)
def handle_invalid_input(self, input):
return ErrorResponse(self.request, input=input)
def _parse_json(self, request):
if request.method == "GET" or not request.body:
return
try:
return json.loads(request.body.decode("utf8"))
except (UnicodeDecodeError, json.JSONDecodeError):
raise ImmediateHttpResponse(response=HttpResponseBadRequest())

View File

@@ -0,0 +1,28 @@
from importlib import import_module
from django.conf import settings
from allauth.headless import app_settings
from allauth.headless.constants import Client
def session_store(session_key=None):
engine = import_module(settings.SESSION_ENGINE)
return engine.SessionStore(session_key=session_key)
def new_session():
return session_store()
def expose_session_token(request):
if request.allauth.headless.client != Client.APP:
return
strategy = app_settings.TOKEN_STRATEGY
hdr_token = strategy.get_session_token(request)
modified = request.session.modified
empty = request.session.is_empty()
if modified and not empty:
new_token = strategy.create_session_token(request)
if not hdr_token or hdr_token != new_token:
return new_token

View File

@@ -0,0 +1,28 @@
from django.contrib.auth import (
BACKEND_SESSION_KEY,
HASH_SESSION_KEY,
SESSION_KEY,
)
from django.contrib.auth.middleware import AuthenticationMiddleware
from django.contrib.sessions.middleware import SessionMiddleware
from django.http import HttpResponse
from allauth.headless.internal.authkit import purge_request_user_cache
def test_purge_request_user_cache(rf, user):
request = rf.get("/")
smw = SessionMiddleware(lambda request: HttpResponse())
smw(request)
amw = AuthenticationMiddleware(lambda request: HttpResponse())
amw(request)
assert request.user.is_anonymous
assert not request.user.pk
purge_request_user_cache(request)
request.session[SESSION_KEY] = user.pk
request.session[BACKEND_SESSION_KEY] = (
"allauth.account.auth_backends.AuthenticationBackend"
)
request.session[HASH_SESSION_KEY] = user.get_session_auth_hash()
assert request.user.is_authenticated
assert request.user.pk == user.pk

View File

@@ -0,0 +1,74 @@
from allauth.account.forms import BaseSignupForm
from allauth.headless.internal.restkit import inputs
from allauth.mfa.base.forms import AuthenticateForm
from allauth.mfa.models import Authenticator
from allauth.mfa.recovery_codes.forms import GenerateRecoveryCodesForm
from allauth.mfa.totp.forms import ActivateTOTPForm
from allauth.mfa.webauthn.forms import (
AddWebAuthnForm,
AuthenticateWebAuthnForm,
LoginWebAuthnForm,
ReauthenticateWebAuthnForm,
SignupWebAuthnForm,
)
class AuthenticateInput(AuthenticateForm, inputs.Input):
pass
class ActivateTOTPInput(ActivateTOTPForm, inputs.Input):
pass
class GenerateRecoveryCodesInput(GenerateRecoveryCodesForm, inputs.Input):
pass
class AddWebAuthnInput(AddWebAuthnForm, inputs.Input):
pass
class CreateWebAuthnInput(SignupWebAuthnForm, inputs.Input):
pass
class UpdateWebAuthnInput(inputs.Input):
id = inputs.ModelChoiceField(queryset=Authenticator.objects.none())
name = inputs.CharField(required=True, max_length=100)
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user")
super().__init__(*args, **kwargs)
self.fields["id"].queryset = Authenticator.objects.filter(
user=self.user, type=Authenticator.Type.WEBAUTHN
)
class DeleteWebAuthnInput(inputs.Input):
authenticators = inputs.ModelMultipleChoiceField(
queryset=Authenticator.objects.none()
)
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user")
super().__init__(*args, **kwargs)
self.fields["authenticators"].queryset = Authenticator.objects.filter(
user=self.user, type=Authenticator.Type.WEBAUTHN
)
class ReauthenticateWebAuthnInput(ReauthenticateWebAuthnForm, inputs.Input):
pass
class AuthenticateWebAuthnInput(AuthenticateWebAuthnForm, inputs.Input):
pass
class LoginWebAuthnInput(LoginWebAuthnForm, inputs.Input):
pass
class SignupWebAuthnInput(BaseSignupForm, inputs.Input):
pass

View File

@@ -0,0 +1,104 @@
from allauth.headless.base.response import APIResponse
from allauth.mfa import app_settings as mfa_settings
def get_config_data(request):
data = {
"mfa": {
"supported_types": mfa_settings.SUPPORTED_TYPES,
"passkey_login_enabled": mfa_settings.PASSKEY_LOGIN_ENABLED,
}
}
return data
def _authenticator_data(authenticator, sensitive=False):
data = {
"type": authenticator.type,
"created_at": authenticator.created_at.timestamp(),
"last_used_at": (
authenticator.last_used_at.timestamp()
if authenticator.last_used_at
else None
),
}
if authenticator.type == authenticator.Type.TOTP:
pass
elif authenticator.type == authenticator.Type.RECOVERY_CODES:
wrapped = authenticator.wrap()
unused_codes = wrapped.get_unused_codes()
data.update(
{
"total_code_count": len(wrapped.generate_codes()),
"unused_code_count": len(unused_codes),
}
)
if sensitive:
data["unused_codes"] = unused_codes
elif authenticator.type == authenticator.Type.WEBAUTHN:
wrapped = authenticator.wrap()
data["id"] = authenticator.pk
data["name"] = wrapped.name
passwordless = wrapped.is_passwordless
if passwordless is not None:
data["is_passwordless"] = passwordless
return data
class AuthenticatorDeletedResponse(APIResponse):
pass
class AuthenticatorsDeletedResponse(APIResponse):
pass
class TOTPNotFoundResponse(APIResponse):
def __init__(self, request, secret, totp_url):
super().__init__(
request,
meta={
"secret": secret,
"totp_url": totp_url,
},
status=404,
)
class TOTPResponse(APIResponse):
def __init__(self, request, authenticator):
data = _authenticator_data(authenticator)
super().__init__(request, data=data)
class AuthenticatorsResponse(APIResponse):
def __init__(self, request, authenticators):
data = [_authenticator_data(authenticator) for authenticator in authenticators]
super().__init__(request, data=data)
class AuthenticatorResponse(APIResponse):
def __init__(self, request, authenticator, meta=None):
data = _authenticator_data(authenticator)
super().__init__(request, data=data, meta=meta)
class RecoveryCodesNotFoundResponse(APIResponse):
def __init__(self, request):
super().__init__(request, status=404)
class RecoveryCodesResponse(APIResponse):
def __init__(self, request, authenticator):
data = _authenticator_data(authenticator, sensitive=True)
super().__init__(request, data=data)
class AddWebAuthnResponse(APIResponse):
def __init__(self, request, registration_data):
super().__init__(request, data={"creation_options": registration_data})
class WebAuthnRequestOptionsResponse(APIResponse):
def __init__(self, request, request_options):
super().__init__(request, data={"request_options": request_options})

View File

@@ -0,0 +1,60 @@
from allauth.mfa.models import Authenticator
def test_get_recovery_codes_requires_reauth(
auth_client, user_with_recovery_codes, headless_reverse
):
rc = Authenticator.objects.get(
type=Authenticator.Type.RECOVERY_CODES, user=user_with_recovery_codes
)
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
assert resp.status_code == 401
data = resp.json()
assert data["meta"]["is_authenticated"]
resp = auth_client.post(
headless_reverse("headless:mfa:reauthenticate"),
data={"code": rc.wrap().get_unused_codes()[0]},
content_type="application/json",
)
assert resp.status_code == 200
def test_get_recovery_codes(
auth_client,
user_with_recovery_codes,
headless_reverse,
reauthentication_bypass,
):
with reauthentication_bypass():
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
assert resp.status_code == 200
data = resp.json()
assert data["data"]["type"] == "recovery_codes"
assert len(data["data"]["unused_codes"]) == 10
with reauthentication_bypass():
resp = auth_client.get(headless_reverse("headless:mfa:authenticators"))
data = resp.json()
assert len(data["data"]) == 2
rc = [autor for autor in data["data"] if autor["type"] == "recovery_codes"][0]
assert "unused_codes" not in rc
def test_generate_recovery_codes(
auth_client,
user_with_totp,
headless_reverse,
reauthentication_bypass,
):
with reauthentication_bypass():
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
assert resp.status_code == 404
with reauthentication_bypass():
resp = auth_client.post(
headless_reverse("headless:mfa:manage_recovery_codes"),
content_type="application/json",
)
assert resp.status_code == 200
data = resp.json()
assert data["data"]["type"] == "recovery_codes"
assert len(data["data"]["unused_codes"]) == 10

View File

@@ -0,0 +1,78 @@
import pytest
from allauth.mfa.models import Authenticator
@pytest.mark.parametrize("email_verified", [False, True])
def test_get_totp_not_active(auth_client, user, headless_reverse, email_verified):
resp = auth_client.get(headless_reverse("headless:mfa:manage_totp"))
if email_verified:
assert resp.status_code == 404
data = resp.json()
assert len(data["meta"]["secret"]) == 32
assert len(data["meta"]["totp_url"]) == 145
else:
assert resp.status_code == 409
assert resp.json() == {
"status": 409,
"errors": [
{
"message": "You cannot activate two-factor authentication until you have verified your email address.",
"code": "unverified_email",
}
],
}
def test_get_totp(
auth_client,
user_with_totp,
headless_reverse,
):
resp = auth_client.get(headless_reverse("headless:mfa:manage_totp"))
assert resp.status_code == 200
data = resp.json()
assert data["data"]["type"] == "totp"
assert isinstance(data["data"]["created_at"], float)
def test_deactivate_totp(
auth_client,
user_with_totp,
headless_reverse,
reauthentication_bypass,
):
with reauthentication_bypass():
resp = auth_client.delete(headless_reverse("headless:mfa:manage_totp"))
assert resp.status_code == 200
assert not Authenticator.objects.filter(user=user_with_totp).exists()
@pytest.mark.parametrize("email_verified", [False, True])
def test_activate_totp(
auth_client,
user,
headless_reverse,
reauthentication_bypass,
settings,
totp_validation_bypass,
email_verified,
):
with reauthentication_bypass():
with totp_validation_bypass():
resp = auth_client.post(
headless_reverse("headless:mfa:manage_totp"),
data={"code": "42"},
content_type="application/json",
)
if email_verified:
assert resp.status_code == 200
assert Authenticator.objects.filter(
user=user, type=Authenticator.Type.TOTP
).exists()
data = resp.json()
assert data["data"]["type"] == "totp"
assert isinstance(data["data"]["created_at"], float)
assert data["data"]["last_used_at"] is None
else:
assert resp.status_code == 400

View File

@@ -0,0 +1,115 @@
from allauth.account.models import EmailAddress, get_emailconfirmation_model
from allauth.headless.constants import Flow
def test_auth_unverified_email_and_mfa(
client,
user_factory,
password_factory,
settings,
totp_validation_bypass,
headless_reverse,
headless_client,
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
password = password_factory()
user = user_factory(email_verified=False, password=password, with_totp=True)
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user.email,
"password": password,
},
content_type="application/json",
)
assert resp.status_code == 401
data = resp.json()
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
"is_pending"
]
emailaddress = EmailAddress.objects.filter(user=user, verified=False).get()
key = get_emailconfirmation_model().create(emailaddress).key
resp = client.post(
headless_reverse("headless:account:verify_email"),
data={"key": key},
content_type="application/json",
)
assert resp.status_code == 401
flows = [
{"id": "login"},
{"id": "login_by_code"},
{"id": "signup"},
]
if headless_client == "browser":
flows.append(
{
"id": "provider_redirect",
"providers": ["dummy", "openid_connect", "openid_connect"],
}
)
flows.append({"id": "provider_token", "providers": ["dummy"]})
flows.append({"id": "mfa_login_webauthn"})
flows.append(
{
"id": "mfa_authenticate",
"is_pending": True,
"types": ["totp"],
}
)
assert resp.json() == {
"data": {"flows": flows},
"meta": {"is_authenticated": False},
"status": 401,
}
resp = client.post(
headless_reverse("headless:mfa:authenticate"),
data={"code": "bad"},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{"message": "Incorrect code.", "code": "incorrect_code", "param": "code"}
],
}
with totp_validation_bypass():
resp = client.post(
headless_reverse("headless:mfa:authenticate"),
data={"code": "bad"},
content_type="application/json",
)
assert resp.status_code == 200
def test_dangling_mfa_is_logged_out(
client,
user_with_totp,
password_factory,
settings,
totp_validation_bypass,
headless_reverse,
headless_client,
user_password,
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
resp = client.post(
headless_reverse("headless:account:login"),
data={
"email": user_with_totp.email,
"password": user_password,
},
content_type="application/json",
)
assert resp.status_code == 401
data = resp.json()
flow = [f for f in data["data"]["flows"] if f["id"] == Flow.MFA_AUTHENTICATE][0]
assert flow["is_pending"]
assert flow["types"] == ["totp"]
resp = client.delete(headless_reverse("headless:account:current_session"))
data = resp.json()
assert resp.status_code == 401
assert all(not f.get("is_pending") for f in data["data"]["flows"])

View File

@@ -0,0 +1,223 @@
from unittest.mock import ANY
from django.contrib.auth import get_user_model
import pytest
from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
from allauth.headless.constants import Flow
from allauth.mfa.models import Authenticator
def test_passkey_login(
client, passkey, webauthn_authentication_bypass, headless_reverse
):
with webauthn_authentication_bypass(passkey) as credential:
resp = client.get(headless_reverse("headless:mfa:login_webauthn"))
assert "request_options" in resp.json()["data"]
resp = client.post(
headless_reverse("headless:mfa:login_webauthn"),
data={"credential": credential},
content_type="application/json",
)
data = resp.json()
assert data["data"]["user"]["id"] == passkey.user_id
def test_passkey_login_get_options(client, headless_client, headless_reverse, db):
resp = client.get(headless_reverse("headless:mfa:login_webauthn"))
data = resp.json()
meta = {}
if headless_client == "app":
meta = {
"meta": {"session_token": ANY},
}
assert data == {
"status": 200,
"data": {"request_options": {"publicKey": ANY}},
**meta,
}
def test_reauthenticate(
auth_client,
passkey,
user_with_recovery_codes,
webauthn_authentication_bypass,
headless_reverse,
):
# View recovery codes, confirm webauthn reauthentication is an option
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
assert resp.status_code == 401
assert Flow.MFA_REAUTHENTICATE in [
flow["id"] for flow in resp.json()["data"]["flows"]
]
# Get request options
with webauthn_authentication_bypass(passkey):
resp = auth_client.get(headless_reverse("headless:mfa:reauthenticate_webauthn"))
data = resp.json()
assert data["status"] == 200
assert data["data"]["request_options"] == ANY
# Reauthenticate
with webauthn_authentication_bypass(passkey) as credential:
resp = auth_client.post(
headless_reverse("headless:mfa:reauthenticate_webauthn"),
data={"credential": credential},
content_type="application/json",
)
assert resp.status_code == 200
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
assert resp.status_code == 200
def test_update_authenticator(
auth_client, headless_reverse, passkey, reauthentication_bypass
):
data = {"id": passkey.pk, "name": "Renamed!"}
resp = auth_client.put(
headless_reverse("headless:mfa:manage_webauthn"),
data=data,
content_type="application/json",
)
# Reauthentication required
assert resp.status_code == 401
with reauthentication_bypass():
resp = auth_client.put(
headless_reverse("headless:mfa:manage_webauthn"),
data=data,
content_type="application/json",
)
assert resp.status_code == 200
passkey.refresh_from_db()
assert passkey.wrap().name == "Renamed!"
def test_delete_authenticator(
auth_client, headless_reverse, passkey, reauthentication_bypass
):
data = {"authenticators": [passkey.pk]}
resp = auth_client.delete(
headless_reverse("headless:mfa:manage_webauthn"),
data=data,
content_type="application/json",
)
# Reauthentication required
assert resp.status_code == 401
with reauthentication_bypass():
resp = auth_client.delete(
headless_reverse("headless:mfa:manage_webauthn"),
data=data,
content_type="application/json",
)
assert resp.status_code == 200
assert not Authenticator.objects.filter(pk=passkey.pk).exists()
@pytest.mark.parametrize("email_verified", [False, True])
def test_add_authenticator(
user,
auth_client,
headless_reverse,
webauthn_registration_bypass,
reauthentication_bypass,
email_verified,
):
resp = auth_client.get(headless_reverse("headless:mfa:manage_webauthn"))
# Reauthentication required
assert resp.status_code == 401 if email_verified else 409
with reauthentication_bypass():
resp = auth_client.get(headless_reverse("headless:mfa:manage_webauthn"))
if email_verified:
assert resp.status_code == 200
data = resp.json()
assert data["data"]["creation_options"] == ANY
else:
assert resp.status_code == 409
with webauthn_registration_bypass(user, False) as credential:
resp = auth_client.post(
headless_reverse("headless:mfa:manage_webauthn"),
data={"credential": credential},
content_type="application/json",
)
webauthn_count = Authenticator.objects.filter(
type=Authenticator.Type.WEBAUTHN, user=user
).count()
if email_verified:
assert resp.status_code == 200
assert webauthn_count == 1
else:
assert resp.status_code == 409
assert webauthn_count == 0
def test_2fa_login(
client,
user,
user_password,
passkey,
webauthn_authentication_bypass,
headless_reverse,
):
resp = client.post(
headless_reverse("headless:account:login"),
data={
"username": user.username,
"password": user_password,
},
content_type="application/json",
)
assert resp.status_code == 401
data = resp.json()
pending_flows = [f for f in data["data"]["flows"] if f.get("is_pending")]
assert len(pending_flows) == 1
pending_flow = pending_flows[0]
assert pending_flow == {
"id": "mfa_authenticate",
"is_pending": True,
"types": ["webauthn"],
}
with webauthn_authentication_bypass(passkey) as credential:
resp = client.get(headless_reverse("headless:mfa:authenticate_webauthn"))
assert "request_options" in resp.json()["data"]
resp = client.post(
headless_reverse("headless:mfa:authenticate_webauthn"),
data={"credential": credential},
content_type="application/json",
)
data = resp.json()
assert resp.status_code == 200
assert data["data"]["user"]["id"] == passkey.user_id
assert client.headless_session()[AUTHENTICATION_METHODS_SESSION_KEY] == [
{"method": "password", "at": ANY, "username": passkey.user.username},
{"method": "mfa", "at": ANY, "id": ANY, "type": Authenticator.Type.WEBAUTHN},
]
def test_passkey_signup(client, db, webauthn_registration_bypass, headless_reverse):
resp = client.post(
headless_reverse("headless:mfa:signup_webauthn"),
data={"email": "pass@key.org", "username": "passkey"},
content_type="application/json",
)
assert resp.status_code == 401
flow = [flow for flow in resp.json()["data"]["flows"] if flow.get("is_pending")][0]
assert flow["id"] == Flow.MFA_SIGNUP_WEBAUTHN.value
resp = client.get(headless_reverse("headless:mfa:signup_webauthn"))
data = resp.json()
assert "creation_options" in data["data"]
user = get_user_model().objects.get(email="pass@key.org")
with webauthn_registration_bypass(user, True) as credential:
resp = client.put(
headless_reverse("headless:mfa:signup_webauthn"),
data={"name": "Some key", "credential": credential},
content_type="application/json",
)
data = resp.json()
assert data["meta"]["is_authenticated"]
authenticator = Authenticator.objects.get(user=user)
assert authenticator.wrap().name == "Some key"

View File

@@ -0,0 +1,100 @@
from django.urls import include, path
from allauth.headless.mfa import views
from allauth.mfa import app_settings as mfa_settings
def build_urlpatterns(client):
auth_patterns = [
path(
"2fa/authenticate",
views.AuthenticateView.as_api_view(client=client),
name="authenticate",
),
path(
"2fa/reauthenticate",
views.ReauthenticateView.as_api_view(client=client),
name="reauthenticate",
),
]
authenticators = []
if "totp" in mfa_settings.SUPPORTED_TYPES:
authenticators.append(
path(
"totp",
views.ManageTOTPView.as_api_view(client=client),
name="manage_totp",
)
)
if "recovery_codes" in mfa_settings.SUPPORTED_TYPES:
authenticators.append(
path(
"recovery-codes",
views.ManageRecoveryCodesView.as_api_view(client=client),
name="manage_recovery_codes",
)
)
if "webauthn" in mfa_settings.SUPPORTED_TYPES:
authenticators.extend(
[
path(
"webauthn",
views.ManageWebAuthnView.as_api_view(client=client),
name="manage_webauthn",
),
]
)
auth_patterns.extend(
[
path(
"webauthn/authenticate",
views.AuthenticateWebAuthnView.as_api_view(client=client),
name="authenticate_webauthn",
),
path(
"webauthn/reauthenticate",
views.ReauthenticateWebAuthnView.as_api_view(client=client),
name="reauthenticate_webauthn",
),
]
)
if mfa_settings.PASSKEY_LOGIN_ENABLED:
auth_patterns.append(
path(
"webauthn/login",
views.LoginWebAuthnView.as_api_view(client=client),
name="login_webauthn",
)
)
if mfa_settings.PASSKEY_SIGNUP_ENABLED:
auth_patterns.append(
path(
"webauthn/signup",
views.SignupWebAuthnView.as_api_view(client=client),
name="signup_webauthn",
)
)
return [
path(
"auth/",
include(auth_patterns),
),
path(
"account/",
include(
[
path(
"authenticators",
views.AuthenticatorsView.as_api_view(client=client),
name="authenticators",
),
path(
"authenticators/",
include(authenticators),
),
]
),
),
]

View File

@@ -0,0 +1,280 @@
from django.core.exceptions import ValidationError
from allauth.account.internal.stagekit import get_pending_stage
from allauth.account.models import Login
from allauth.headless.account.views import SignupView
from allauth.headless.base.response import (
APIResponse,
AuthenticationResponse,
ConflictResponse,
)
from allauth.headless.base.views import (
APIView,
AuthenticatedAPIView,
AuthenticationStageAPIView,
)
from allauth.headless.internal.restkit.response import ErrorResponse
from allauth.headless.mfa import response
from allauth.headless.mfa.inputs import (
ActivateTOTPInput,
AddWebAuthnInput,
AuthenticateInput,
AuthenticateWebAuthnInput,
CreateWebAuthnInput,
DeleteWebAuthnInput,
GenerateRecoveryCodesInput,
LoginWebAuthnInput,
ReauthenticateWebAuthnInput,
SignupWebAuthnInput,
UpdateWebAuthnInput,
)
from allauth.mfa.adapter import DefaultMFAAdapter, get_adapter
from allauth.mfa.internal.flows import add
from allauth.mfa.models import Authenticator
from allauth.mfa.recovery_codes.internal import flows as recovery_codes_flows
from allauth.mfa.stages import AuthenticateStage
from allauth.mfa.totp.internal import auth as totp_auth, flows as totp_flows
from allauth.mfa.webauthn.internal import (
auth as webauthn_auth,
flows as webauthn_flows,
)
from allauth.mfa.webauthn.stages import PasskeySignupStage
def _validate_can_add_authenticator(request):
try:
add.validate_can_add_authenticator(request.user)
except ValidationError as e:
return ErrorResponse(request, status=409, exception=e)
class AuthenticateView(AuthenticationStageAPIView):
input_class = AuthenticateInput
stage_class = AuthenticateStage
def post(self, request, *args, **kwargs):
self.input.save()
return self.respond_next_stage()
def get_input_kwargs(self):
return {"user": self.stage.login.user}
class ReauthenticateView(AuthenticatedAPIView):
input_class = AuthenticateInput
def post(self, request, *args, **kwargs):
self.input.save()
return AuthenticationResponse(self.request)
def get_input_kwargs(self):
return {"user": self.request.user}
class AuthenticatorsView(AuthenticatedAPIView):
def get(self, request, *args, **kwargs):
authenticators = Authenticator.objects.filter(user=request.user)
return response.AuthenticatorsResponse(request, authenticators)
class ManageTOTPView(AuthenticatedAPIView):
input_class = {"POST": ActivateTOTPInput}
def get(self, request, *args, **kwargs) -> APIResponse:
authenticator = self._get_authenticator()
if not authenticator:
err = _validate_can_add_authenticator(request)
if err:
return err
adapter: DefaultMFAAdapter = get_adapter()
secret = totp_auth.get_totp_secret(regenerate=True)
totp_url: str = adapter.build_totp_url(request.user, secret)
return response.TOTPNotFoundResponse(request, secret, totp_url)
return response.TOTPResponse(request, authenticator)
def _get_authenticator(self):
return Authenticator.objects.filter(
type=Authenticator.Type.TOTP, user=self.request.user
).first()
def get_input_kwargs(self):
return {"user": self.request.user}
def post(self, request, *args, **kwargs):
authenticator = totp_flows.activate_totp(request, self.input)[0]
return response.TOTPResponse(request, authenticator)
def delete(self, request, *args, **kwargs):
authenticator = self._get_authenticator()
if authenticator:
authenticator = totp_flows.deactivate_totp(request, authenticator)
return response.AuthenticatorDeletedResponse(request)
class ManageRecoveryCodesView(AuthenticatedAPIView):
input_class = GenerateRecoveryCodesInput
def get(self, request, *args, **kwargs):
authenticator = recovery_codes_flows.view_recovery_codes(request)
if not authenticator:
return response.RecoveryCodesNotFoundResponse(request)
return response.RecoveryCodesResponse(request, authenticator)
def post(self, request, *args, **kwargs):
authenticator = recovery_codes_flows.generate_recovery_codes(request)
return response.RecoveryCodesResponse(request, authenticator)
def get_input_kwargs(self):
return {"user": self.request.user}
class ManageWebAuthnView(AuthenticatedAPIView):
input_class = {
"POST": AddWebAuthnInput,
"PUT": UpdateWebAuthnInput,
"DELETE": DeleteWebAuthnInput,
}
def handle(self, request, *args, **kwargs):
if request.method in ["GET", "POST"]:
err = _validate_can_add_authenticator(request)
if err:
return err
return super().handle(request, *args, **kwargs)
def get(self, request, *args, **kwargs):
passwordless = "passwordless" in request.GET
creation_options = webauthn_flows.begin_registration(
request, request.user, passwordless
)
return response.AddWebAuthnResponse(request, creation_options)
def get_input_kwargs(self):
return {"user": self.request.user}
def post(self, request, *args, **kwargs):
auth, rc_auth = webauthn_flows.add_authenticator(
request,
name=self.input.cleaned_data["name"],
credential=self.input.cleaned_data["credential"],
)
did_generate_recovery_codes = bool(rc_auth)
return response.AuthenticatorResponse(
request,
auth,
meta={"recovery_codes_generated": did_generate_recovery_codes},
)
def put(self, request, *args, **kwargs):
authenticator = self.input.cleaned_data["id"]
webauthn_flows.rename_authenticator(
request, authenticator, self.input.cleaned_data["name"]
)
return response.AuthenticatorResponse(request, authenticator)
def delete(self, request, *args, **kwargs):
authenticators = self.input.cleaned_data["authenticators"]
webauthn_flows.remove_authenticators(request, authenticators)
return response.AuthenticatorsDeletedResponse(request)
class ReauthenticateWebAuthnView(AuthenticatedAPIView):
input_class = {
"POST": ReauthenticateWebAuthnInput,
}
def get(self, request, *args, **kwargs):
request_options = webauthn_auth.begin_authentication(request.user)
return response.WebAuthnRequestOptionsResponse(request, request_options)
def get_input_kwargs(self):
return {"user": self.request.user}
def post(self, request, *args, **kwargs):
authenticator = self.input.cleaned_data["credential"]
webauthn_flows.reauthenticate(request, authenticator)
return AuthenticationResponse(self.request)
class AuthenticateWebAuthnView(AuthenticationStageAPIView):
input_class = {
"POST": AuthenticateWebAuthnInput,
}
stage_class = AuthenticateStage
def get(self, request, *args, **kwargs):
request_options = webauthn_auth.begin_authentication(self.stage.login.user)
return response.WebAuthnRequestOptionsResponse(request, request_options)
def get_input_kwargs(self):
return {"user": self.stage.login.user}
def post(self, request, *args, **kwargs):
self.input.save()
return self.respond_next_stage()
class LoginWebAuthnView(APIView):
input_class = {
"POST": LoginWebAuthnInput,
}
def get(self, request, *args, **kwargs):
request_options = webauthn_auth.begin_authentication()
return response.WebAuthnRequestOptionsResponse(request, request_options)
def post(self, request, *args, **kwargs):
authenticator = self.input.cleaned_data["credential"]
redirect_url = None
login = Login(user=authenticator.user, redirect_url=redirect_url)
webauthn_flows.perform_passwordless_login(request, authenticator, login)
return AuthenticationResponse(request)
class SignupWebAuthnView(SignupView):
input_class = {
"POST": SignupWebAuthnInput,
"PUT": CreateWebAuthnInput,
}
by_passkey = True
def get(self, request, *args, **kwargs):
resp = self._require_stage()
if resp:
return resp
creation_options = webauthn_flows.begin_registration(
request, self.stage.login.user, passwordless=True, signup=True
)
return response.AddWebAuthnResponse(request, creation_options)
def _prep_stage(self):
if hasattr(self, "stage"):
return self.stage
self.stage = get_pending_stage(self.request)
return self.stage
def _require_stage(self):
self._prep_stage()
if not self.stage or self.stage.key != PasskeySignupStage.key:
return ConflictResponse(self.request)
return None
def get_input_kwargs(self):
ret = super().get_input_kwargs()
self._prep_stage()
if self.stage and self.request.method == "PUT":
ret["user"] = self.stage.login.user
return ret
def put(self, request, *args, **kwargs):
resp = self._require_stage()
if resp:
return resp
webauthn_flows.signup_authenticator(
request,
user=self.stage.login.user,
name=self.input.cleaned_data["name"],
credential=self.input.cleaned_data["credential"],
)
self.stage.exit()
return AuthenticationResponse(request)

Some files were not shown because too many files have changed in this diff Show More