mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-22 15:11:09 -05:00
okay fine
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.validators import validate_email
|
||||
|
||||
from allauth.account import app_settings as account_app_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.forms import (
|
||||
AddEmailForm,
|
||||
BaseSignupForm,
|
||||
ConfirmLoginCodeForm,
|
||||
ReauthenticateForm,
|
||||
RequestLoginCodeForm,
|
||||
ResetPasswordForm,
|
||||
UserTokenForm,
|
||||
)
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.models import (
|
||||
EmailAddress,
|
||||
Login,
|
||||
get_emailconfirmation_model,
|
||||
)
|
||||
from allauth.core import context
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.headless.internal.restkit import inputs
|
||||
|
||||
|
||||
class SignupInput(BaseSignupForm, inputs.Input):
|
||||
password = inputs.CharField()
|
||||
|
||||
def clean_password(self):
|
||||
password = self.cleaned_data["password"]
|
||||
return get_account_adapter().clean_password(password)
|
||||
|
||||
|
||||
class LoginInput(inputs.Input):
|
||||
username = inputs.CharField(required=False)
|
||||
email = inputs.EmailField(required=False)
|
||||
password = inputs.CharField()
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
username = None
|
||||
email = None
|
||||
|
||||
if (
|
||||
account_app_settings.AUTHENTICATION_METHOD
|
||||
== account_app_settings.AuthenticationMethod.USERNAME
|
||||
):
|
||||
username = cleaned_data.get("username")
|
||||
missing_field = "username"
|
||||
elif (
|
||||
account_app_settings.AUTHENTICATION_METHOD
|
||||
== account_app_settings.AuthenticationMethod.EMAIL
|
||||
):
|
||||
email = cleaned_data.get("email")
|
||||
missing_field = "email"
|
||||
elif (
|
||||
account_app_settings.AUTHENTICATION_METHOD
|
||||
== account_app_settings.AuthenticationMethod.USERNAME_EMAIL
|
||||
):
|
||||
username = cleaned_data.get("username")
|
||||
email = cleaned_data.get("email")
|
||||
missing_field = "email"
|
||||
if email and username:
|
||||
raise get_adapter().validation_error("email_or_username")
|
||||
else:
|
||||
raise ImproperlyConfigured(account_app_settings.AUTHENTICATION_METHOD)
|
||||
|
||||
if not email and not username:
|
||||
if not self.errors.get(missing_field):
|
||||
self.add_error(
|
||||
missing_field, get_adapter().validation_error("required")
|
||||
)
|
||||
|
||||
password = cleaned_data.get("password")
|
||||
if password and (username or email):
|
||||
credentials = {"password": password}
|
||||
if email:
|
||||
credentials["email"] = email
|
||||
auth_method = account_app_settings.AuthenticationMethod.EMAIL
|
||||
else:
|
||||
credentials["username"] = username
|
||||
auth_method = account_app_settings.AuthenticationMethod.USERNAME
|
||||
user = get_account_adapter().authenticate(context.request, **credentials)
|
||||
if user:
|
||||
self.login = Login(user=user, email=credentials.get("email"))
|
||||
if flows.login.is_login_rate_limited(context.request, self.login):
|
||||
raise get_account_adapter().validation_error(
|
||||
"too_many_login_attempts"
|
||||
)
|
||||
else:
|
||||
error_code = "%s_password_mismatch" % auth_method.value
|
||||
self.add_error(
|
||||
"password", get_account_adapter().validation_error(error_code)
|
||||
)
|
||||
return cleaned_data
|
||||
|
||||
|
||||
class VerifyEmailInput(inputs.Input):
|
||||
key = inputs.CharField()
|
||||
|
||||
def clean_key(self):
|
||||
key = self.cleaned_data["key"]
|
||||
model = get_emailconfirmation_model()
|
||||
confirmation = model.from_key(key)
|
||||
valid = confirmation and not confirmation.key_expired()
|
||||
if not valid:
|
||||
raise get_account_adapter().validation_error(
|
||||
"incorrect_code"
|
||||
if account_app_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED
|
||||
else "invalid_or_expired_key"
|
||||
)
|
||||
if valid and not confirmation.email_address.can_set_verified():
|
||||
raise get_account_adapter().validation_error("email_taken")
|
||||
return confirmation
|
||||
|
||||
|
||||
class RequestPasswordResetInput(ResetPasswordForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class ResetPasswordKeyInput(inputs.Input):
|
||||
key = inputs.CharField()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def clean_key(self):
|
||||
key = self.cleaned_data["key"]
|
||||
uidb36, _, subkey = key.partition("-")
|
||||
token_form = UserTokenForm(data={"uidb36": uidb36, "key": subkey})
|
||||
if not token_form.is_valid():
|
||||
raise get_account_adapter().validation_error("invalid_password_reset")
|
||||
self.user = token_form.reset_user
|
||||
return key
|
||||
|
||||
|
||||
class ResetPasswordInput(ResetPasswordKeyInput):
|
||||
password = inputs.CharField()
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
password = self.cleaned_data.get("password")
|
||||
if self.user and password is not None:
|
||||
get_account_adapter().clean_password(password, user=self.user)
|
||||
return cleaned_data
|
||||
|
||||
|
||||
class ChangePasswordInput(inputs.Input):
|
||||
current_password = inputs.CharField(required=False)
|
||||
new_password = inputs.CharField()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["current_password"].required = self.user.has_usable_password()
|
||||
|
||||
def clean_current_password(self):
|
||||
current_password = self.cleaned_data["current_password"]
|
||||
if current_password:
|
||||
if not self.user.check_password(current_password):
|
||||
raise get_account_adapter().validation_error("enter_current_password")
|
||||
return current_password
|
||||
|
||||
def clean_new_password(self):
|
||||
new_password = self.cleaned_data["new_password"]
|
||||
adapter = get_account_adapter()
|
||||
return adapter.clean_password(new_password, user=self.user)
|
||||
|
||||
|
||||
class AddEmailInput(AddEmailForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class SelectEmailInput(inputs.Input):
|
||||
email = inputs.CharField()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def clean_email(self):
|
||||
email = self.cleaned_data["email"]
|
||||
validate_email(email)
|
||||
try:
|
||||
return EmailAddress.objects.get_for_user(user=self.user, email=email)
|
||||
except EmailAddress.DoesNotExist:
|
||||
raise get_adapter().validation_error("unknown_email")
|
||||
|
||||
|
||||
class DeleteEmailInput(SelectEmailInput):
|
||||
def clean_email(self):
|
||||
email = super().clean_email()
|
||||
if not flows.manage_email.can_delete_email(email):
|
||||
raise get_account_adapter().validation_error("cannot_remove_primary_email")
|
||||
return email
|
||||
|
||||
|
||||
class MarkAsPrimaryEmailInput(SelectEmailInput):
|
||||
primary = inputs.BooleanField(required=True)
|
||||
|
||||
def clean_email(self):
|
||||
email = super().clean_email()
|
||||
if not flows.manage_email.can_mark_as_primary(email):
|
||||
raise get_account_adapter().validation_error("unverified_primary_email")
|
||||
return email
|
||||
|
||||
|
||||
class ReauthenticateInput(ReauthenticateForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class RequestLoginCodeInput(RequestLoginCodeForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class ConfirmLoginCodeInput(ConfirmLoginCodeForm, inputs.Input):
|
||||
pass
|
||||
@@ -0,0 +1,44 @@
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.headless.base.response import APIResponse
|
||||
|
||||
|
||||
class RequestEmailVerificationResponse(APIResponse):
|
||||
def __init__(self, request, verification_sent):
|
||||
super().__init__(request, status=200 if verification_sent else 403)
|
||||
|
||||
|
||||
class VerifyEmailResponse(APIResponse):
|
||||
def __init__(self, request, verification, stage):
|
||||
adapter = get_adapter()
|
||||
data = {
|
||||
"email": verification.email_address.email,
|
||||
"user": adapter.serialize_user(verification.email_address.user),
|
||||
}
|
||||
meta = {
|
||||
"is_authenticating": stage is not None,
|
||||
}
|
||||
super().__init__(request, data=data, meta=meta)
|
||||
|
||||
|
||||
class EmailAddressesResponse(APIResponse):
|
||||
def __init__(self, request, email_addresses):
|
||||
data = [
|
||||
{
|
||||
"email": addr.email,
|
||||
"verified": addr.verified,
|
||||
"primary": addr.primary,
|
||||
}
|
||||
for addr in email_addresses
|
||||
]
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class RequestPasswordResponse(APIResponse):
|
||||
pass
|
||||
|
||||
|
||||
class PasswordResetKeyResponse(APIResponse):
|
||||
def __init__(self, request, user):
|
||||
adapter = get_adapter()
|
||||
data = {"user": adapter.serialize_user(user)}
|
||||
super().__init__(request, data=data)
|
||||
@@ -0,0 +1,99 @@
|
||||
from allauth.account.models import EmailAddress
|
||||
|
||||
|
||||
def test_list_email(auth_client, user, headless_reverse):
|
||||
resp = auth_client.get(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
)
|
||||
assert len(resp.json()["data"]) == 1
|
||||
|
||||
|
||||
def test_remove_email(auth_client, user, email_factory, headless_reverse):
|
||||
addr = EmailAddress.objects.create(email=email_factory(), user=user)
|
||||
assert EmailAddress.objects.filter(user=user).count() == 2
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 1
|
||||
assert not EmailAddress.objects.filter(pk=addr.pk).exists()
|
||||
|
||||
|
||||
def test_add_email(auth_client, user, email_factory, headless_reverse):
|
||||
new_email = email_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": new_email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 2
|
||||
assert EmailAddress.objects.filter(email=new_email, verified=False).exists()
|
||||
|
||||
|
||||
def test_change_primary(auth_client, user, email_factory, headless_reverse):
|
||||
addr = EmailAddress.objects.create(
|
||||
email=email_factory(), user=user, verified=True, primary=False
|
||||
)
|
||||
resp = auth_client.patch(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email, "primary": True},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 2
|
||||
assert EmailAddress.objects.filter(pk=addr.pk, primary=True).exists()
|
||||
|
||||
|
||||
def test_resend_verification(
|
||||
auth_client, user, email_factory, headless_reverse, mailoutbox
|
||||
):
|
||||
addr = EmailAddress.objects.create(email=email_factory(), user=user, verified=False)
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(mailoutbox) == 1
|
||||
|
||||
|
||||
def test_email_rate_limit(
|
||||
auth_client, user, email_factory, headless_reverse, settings, enable_cache
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"manage_email": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
new_email = email_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": new_email},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 200 if attempt == 0 else 429
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
|
||||
|
||||
def test_resend_verification_rate_limit(
|
||||
auth_client,
|
||||
user,
|
||||
email_factory,
|
||||
headless_reverse,
|
||||
settings,
|
||||
enable_cache,
|
||||
mailoutbox,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"confirm_email": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
addr = EmailAddress.objects.create(
|
||||
email=email_factory(), user=user, verified=False
|
||||
)
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:account:manage_email"),
|
||||
data={"email": addr.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 403 if attempt else 200
|
||||
assert len(mailoutbox) == 1
|
||||
@@ -0,0 +1,220 @@
|
||||
import copy
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"has_password,request_data,response_data,status_code",
|
||||
[
|
||||
# Wrong current password
|
||||
(
|
||||
True,
|
||||
{"current_password": "wrong", "new_password": "{password_factory}"},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "current_password",
|
||||
"message": "Please type your current password.",
|
||||
"code": "enter_current_password",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Happy flow, regular password change
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "{user_password}",
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 200,
|
||||
"meta": {"is_authenticated": True},
|
||||
"data": {
|
||||
"user": ANY,
|
||||
"methods": [],
|
||||
},
|
||||
},
|
||||
200,
|
||||
),
|
||||
# New password does not match constraints
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "{user_password}",
|
||||
"new_password": "a",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "new_password",
|
||||
"code": "password_too_short",
|
||||
"message": "This password is too short. It must contain at least 6 characters.",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# New password not empty
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "{user_password}",
|
||||
"new_password": "",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "new_password",
|
||||
"code": "required",
|
||||
"message": "This field is required.",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Current password not blank
|
||||
(
|
||||
True,
|
||||
{
|
||||
"current_password": "",
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "current_password",
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Current password missing
|
||||
(
|
||||
True,
|
||||
{
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "current_password",
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
}
|
||||
],
|
||||
},
|
||||
400,
|
||||
),
|
||||
# Current password not set, happy flow
|
||||
(
|
||||
False,
|
||||
{
|
||||
"current_password": "",
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 200,
|
||||
"meta": {"is_authenticated": True},
|
||||
"data": {
|
||||
"user": ANY,
|
||||
"methods": [],
|
||||
},
|
||||
},
|
||||
200,
|
||||
),
|
||||
# Current password not set, current_password absent
|
||||
(
|
||||
False,
|
||||
{
|
||||
"new_password": "{password_factory}",
|
||||
},
|
||||
{
|
||||
"status": 200,
|
||||
"meta": {"is_authenticated": True},
|
||||
"data": {
|
||||
"user": ANY,
|
||||
"methods": [],
|
||||
},
|
||||
},
|
||||
200,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_change_password(
|
||||
auth_client,
|
||||
user,
|
||||
request_data,
|
||||
response_data,
|
||||
status_code,
|
||||
has_password,
|
||||
user_password,
|
||||
password_factory,
|
||||
settings,
|
||||
mailoutbox,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
request_data = copy.deepcopy(request_data)
|
||||
response_data = copy.deepcopy(response_data)
|
||||
settings.ACCOUNT_EMAIL_NOTIFICATIONS = True
|
||||
if not has_password:
|
||||
user.set_unusable_password()
|
||||
user.save(update_fields=["password"])
|
||||
auth_client.force_login(user)
|
||||
if request_data.get("current_password") == "{user_password}":
|
||||
request_data["current_password"] = user_password
|
||||
if request_data.get("new_password") == "{password_factory}":
|
||||
request_data["new_password"] = password_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:change_password"),
|
||||
data=request_data,
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == status_code
|
||||
resp_json = resp.json()
|
||||
if headless_client == "app" and resp.status_code == 200:
|
||||
response_data["meta"]["session_token"] = ANY
|
||||
assert resp_json == response_data
|
||||
user.refresh_from_db()
|
||||
if resp.status_code == 200:
|
||||
assert user.check_password(request_data["new_password"])
|
||||
assert len(mailoutbox) == 1
|
||||
else:
|
||||
assert user.check_password(user_password)
|
||||
assert len(mailoutbox) == 0
|
||||
|
||||
|
||||
def test_change_password_rate_limit(
|
||||
enable_cache,
|
||||
auth_client,
|
||||
user,
|
||||
user_password,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"change_password": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
new_password = password_factory()
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:change_password"),
|
||||
data={
|
||||
"current_password": user_password,
|
||||
"new_password": new_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
user_password = new_password
|
||||
expected_status = 200 if attempt == 0 else 429
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
@@ -0,0 +1,86 @@
|
||||
from allauth.account.models import (
|
||||
EmailAddress,
|
||||
EmailConfirmationHMAC,
|
||||
get_emailconfirmation_model,
|
||||
)
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_verify_email_other_user(auth_client, user, user_factory, headless_reverse):
|
||||
other_user = user_factory(email_verified=False)
|
||||
email_address = EmailAddress.objects.get(user=other_user, verified=False)
|
||||
assert not email_address.verified
|
||||
key = EmailConfirmationHMAC(email_address).key
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# We're still authenticated as the user originally logged in, not the
|
||||
# other_user.
|
||||
assert data["data"]["user"]["id"] == user.pk
|
||||
|
||||
|
||||
def test_auth_unverified_email(
|
||||
client, user_factory, password_factory, settings, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
password = password_factory()
|
||||
user = user_factory(email_verified=False, password=password)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
flows = data["data"]["flows"]
|
||||
assert [f for f in flows if f["id"] == Flow.VERIFY_EMAIL][0]["is_pending"]
|
||||
emailaddress = EmailAddress.objects.filter(user=user, verified=False).get()
|
||||
key = get_emailconfirmation_model().create(emailaddress).key
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_verify_email_bad_key(
|
||||
client, settings, password_factory, user_factory, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
password = password_factory()
|
||||
user = user_factory(email_verified=False, password=password)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": "bad"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"code": "invalid_or_expired_key",
|
||||
"param": "key",
|
||||
"message": "Invalid or expired key.",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_email_verification_rate_limits(
|
||||
client,
|
||||
db,
|
||||
user_password,
|
||||
settings,
|
||||
user_factory,
|
||||
password_factory,
|
||||
enable_cache,
|
||||
headless_reverse,
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION_BY_CODE_ENABLED = True
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_RATE_LIMITS = {"confirm_email": "1/m/key"}
|
||||
email = "user@email.org"
|
||||
user_factory(email=email, email_verified=False, password=user_password)
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
if attempt == 0:
|
||||
assert resp.status_code == 401
|
||||
flow = [
|
||||
flow for flow in resp.json()["data"]["flows"] if flow.get("is_pending")
|
||||
][0]
|
||||
assert flow["id"] == Flow.VERIFY_EMAIL
|
||||
else:
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"message": "Too many failed login attempts. Try again later.",
|
||||
"code": "too_many_login_attempts",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_auth_password_input_error(headless_reverse, client):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
"param": "password",
|
||||
},
|
||||
{
|
||||
"message": "This field is required.",
|
||||
"code": "required",
|
||||
"param": "username",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_auth_password_bad_password(headless_reverse, client, user, settings):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": "wrong",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "password",
|
||||
"message": "The email address and/or password you specified are not correct.",
|
||||
"code": "email_password_mismatch",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_auth_password_success(
|
||||
client, user, user_password, settings, headless_reverse, headless_client
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
login_resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
session_resp = client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert session_resp.status_code == 200
|
||||
for resp in [login_resp, session_resp]:
|
||||
extra_meta = {}
|
||||
if headless_client == "app" and resp == login_resp:
|
||||
# The session is created on first login, and hence the token is
|
||||
# exposed only at that moment.
|
||||
extra_meta["session_token"] = ANY
|
||||
assert resp.json() == {
|
||||
"status": 200,
|
||||
"data": {
|
||||
"user": {
|
||||
"id": user.pk,
|
||||
"display": str(user),
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"has_usable_password": True,
|
||||
},
|
||||
"methods": [
|
||||
{
|
||||
"at": ANY,
|
||||
"email": user.email,
|
||||
"method": "password",
|
||||
}
|
||||
],
|
||||
},
|
||||
"meta": {"is_authenticated": True, **extra_meta},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_active,status_code", [(False, 401), (True, 200)])
|
||||
def test_auth_password_user_inactive(
|
||||
client, user, user_password, settings, status_code, is_active, headless_reverse
|
||||
):
|
||||
user.is_active = is_active
|
||||
user.save(update_fields=["is_active"])
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"username": user.username,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == status_code
|
||||
|
||||
|
||||
def test_login_failed_rate_limit(
|
||||
client,
|
||||
user,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"login_failed": "1/m/ip"}
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": "wrong",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["errors"] == [
|
||||
(
|
||||
{
|
||||
"code": "email_password_mismatch",
|
||||
"message": "The email address and/or password you specified are not correct.",
|
||||
"param": "password",
|
||||
}
|
||||
if attempt == 0
|
||||
else {
|
||||
"message": "Too many failed login attempts. Try again later.",
|
||||
"code": "too_many_login_attempts",
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_login_rate_limit(
|
||||
client,
|
||||
user,
|
||||
user_password,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"login": "1/m/ip"}
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 200
|
||||
assert resp.status_code == expected_status
|
||||
|
||||
|
||||
def test_login_already_logged_in(
|
||||
auth_client, user, user_password, settings, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
@@ -0,0 +1,146 @@
|
||||
import time
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_login_by_code(headless_reverse, user, client, mailoutbox):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.LOGIN_BY_CODE][0][
|
||||
"is_pending"
|
||||
]
|
||||
assert len(mailoutbox) == 1
|
||||
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": code},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
|
||||
|
||||
def test_login_by_code_rate_limit(
|
||||
headless_reverse, user, client, mailoutbox, settings, enable_cache
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"request_login_code": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_code = 400 if attempt else 401
|
||||
assert resp.status_code == expected_code
|
||||
data = resp.json()
|
||||
assert data["status"] == expected_code
|
||||
if expected_code == 400:
|
||||
assert data["errors"] == [
|
||||
{
|
||||
"code": "too_many_login_attempts",
|
||||
"message": "Too many failed login attempts. Try again later.",
|
||||
"param": "email",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_login_by_code_max_attemps(headless_reverse, user, client, settings):
|
||||
settings.ACCOUNT_LOGIN_BY_CODE_MAX_ATTEMPTS = 2
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
for i in range(3):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": "wrong"},
|
||||
content_type="application/json",
|
||||
)
|
||||
session_resp = client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
data={"code": "wrong"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert session_resp.status_code == 401
|
||||
pending_flows = [
|
||||
f for f in session_resp.json()["data"]["flows"] if f.get("is_pending")
|
||||
]
|
||||
if i >= 1:
|
||||
assert resp.status_code == 409 if i >= 2 else 400
|
||||
assert len(pending_flows) == 0
|
||||
else:
|
||||
assert resp.status_code == 400
|
||||
assert len(pending_flows) == 1
|
||||
|
||||
|
||||
def test_login_by_code_required(
|
||||
client, settings, user_factory, password_factory, headless_reverse, mailoutbox
|
||||
):
|
||||
settings.ACCOUNT_LOGIN_BY_CODE_REQUIRED = True
|
||||
password = password_factory()
|
||||
user = user_factory(password=password, email_verified=False)
|
||||
email_address = EmailAddress.objects.get(email=user.email)
|
||||
assert not email_address.verified
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"username": user.username,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0][
|
||||
"id"
|
||||
]
|
||||
assert pending_flow == Flow.LOGIN_BY_CODE
|
||||
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": code},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
email_address.refresh_from_db()
|
||||
assert email_address.verified
|
||||
|
||||
|
||||
def test_login_by_code_expired(headless_reverse, user, client, mailoutbox):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_login_code"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.LOGIN_BY_CODE][0][
|
||||
"is_pending"
|
||||
]
|
||||
assert len(mailoutbox) == 1
|
||||
code = [line for line in mailoutbox[0].body.splitlines() if len(line) == 6][0]
|
||||
|
||||
# Expire code
|
||||
session = client.headless_session()
|
||||
login = session["account_login"]
|
||||
login["state"]["login_code"]["at"] = time.time() - 24 * 60 * 60
|
||||
session["account_login"] = login
|
||||
session.save()
|
||||
|
||||
# Post valid code
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:confirm_login_code"),
|
||||
data={"code": code},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
@@ -0,0 +1,52 @@
|
||||
def test_reauthenticate(
|
||||
auth_client, user, user_password, headless_reverse, headless_client
|
||||
):
|
||||
resp = auth_client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
method_count = len(data["data"]["methods"])
|
||||
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:reauthenticate"),
|
||||
data={
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = auth_client.get(
|
||||
headless_reverse("headless:account:current_session"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["data"]["methods"]) == method_count + 1
|
||||
last_method = data["data"]["methods"][-1]
|
||||
assert last_method["method"] == "password"
|
||||
|
||||
|
||||
def test_reauthenticate_rate_limit(
|
||||
auth_client,
|
||||
user,
|
||||
user_password,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
settings,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"reauthenticate": "1/m/ip"}
|
||||
for attempt in range(4):
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:reauthenticate"),
|
||||
data={
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 200
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
@@ -0,0 +1,151 @@
|
||||
from django.urls import reverse
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_password_reset_flow(
|
||||
client, user, mailoutbox, password_factory, settings, headless_reverse
|
||||
):
|
||||
settings.ACCOUNT_EMAIL_NOTIFICATIONS = True
|
||||
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_password_reset"),
|
||||
data={
|
||||
"email": user.email,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(mailoutbox) == 1
|
||||
body = mailoutbox[0].body
|
||||
# Extract URL for `password_reset_from_key` view
|
||||
url = body[body.find("/password/reset/") :].split()[0]
|
||||
key = url.split("/")[-2]
|
||||
password = password_factory()
|
||||
|
||||
# Too simple password
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": key,
|
||||
"password": "a",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"code": "password_too_short",
|
||||
"message": "This password is too short. It must contain at least 6 characters.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert len(mailoutbox) == 1
|
||||
|
||||
# Success
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": key,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
user.refresh_from_db()
|
||||
assert user.check_password(password)
|
||||
assert len(mailoutbox) == 2 # The security notification
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["get", "post"])
|
||||
def test_password_reset_flow_wrong_key(
|
||||
client, password_factory, headless_reverse, method
|
||||
):
|
||||
password = password_factory()
|
||||
|
||||
if method == "get":
|
||||
resp = client.get(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
HTTP_X_PASSWORD_RESET_KEY="wrong",
|
||||
)
|
||||
else:
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": "wrong",
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"param": "key",
|
||||
"code": "invalid_password_reset",
|
||||
"message": "The password reset token was invalid.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_password_reset_flow_unknown_user(
|
||||
client, db, mailoutbox, password_factory, settings, headless_reverse
|
||||
):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:request_password_reset"),
|
||||
data={
|
||||
"email": "not@registered.org",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(mailoutbox) == 1
|
||||
body = mailoutbox[0].body
|
||||
if getattr(settings, "HEADLESS_ONLY", False):
|
||||
assert settings.HEADLESS_FRONTEND_URLS["account_signup"] in body
|
||||
else:
|
||||
assert reverse("account_signup") in body
|
||||
|
||||
|
||||
def test_reset_password_rate_limit(
|
||||
auth_client, user, headless_reverse, settings, enable_cache
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"reset_password": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:request_password_reset"),
|
||||
data={"email": user.email},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 200 if attempt == 0 else 429
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
|
||||
|
||||
def test_password_reset_key_rate_limit(
|
||||
client,
|
||||
user,
|
||||
settings,
|
||||
headless_reverse,
|
||||
password_reset_key_generator,
|
||||
enable_cache,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"reset_password_from_key": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:reset_password"),
|
||||
data={
|
||||
"key": password_reset_key_generator(user),
|
||||
"password": "a", # too short
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 400
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
@@ -0,0 +1,33 @@
|
||||
from django.test.client import Client
|
||||
from django.urls import reverse
|
||||
|
||||
|
||||
def test_app_session_gone(db, user):
|
||||
# intentionally use a vanilla Django test client
|
||||
client = Client()
|
||||
# Force login, creates a Django session
|
||||
client.force_login(user)
|
||||
# That Django session should not play any role.
|
||||
resp = client.get(
|
||||
reverse("headless:app:account:current_session"), HTTP_X_SESSION_TOKEN="gone"
|
||||
)
|
||||
assert resp.status_code == 410
|
||||
|
||||
|
||||
def test_logout(auth_client, headless_reverse):
|
||||
# That Django session should not play any role.
|
||||
resp = auth_client.get(headless_reverse("headless:account:current_session"))
|
||||
assert resp.status_code == 200
|
||||
resp = auth_client.delete(headless_reverse("headless:account:current_session"))
|
||||
assert resp.status_code == 401
|
||||
resp = auth_client.get(headless_reverse("headless:account:current_session"))
|
||||
assert resp.status_code in [401, 410]
|
||||
|
||||
|
||||
def test_logout_no_token(app_client, user):
|
||||
app_client.force_login(user)
|
||||
resp = app_client.get(reverse("headless:app:account:current_session"))
|
||||
assert resp.status_code == 200
|
||||
resp = app_client.delete(reverse("headless:app:account:current_session"))
|
||||
assert resp.status_code == 401
|
||||
assert "session_token" not in resp.json()["meta"]
|
||||
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from allauth.account.models import EmailAddress, EmailConfirmationHMAC
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_signup(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": "wizard",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert User.objects.filter(username="wizard").exists()
|
||||
|
||||
|
||||
def test_signup_with_email_verification(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_USERNAME_REQUIRED = False
|
||||
email = email_factory()
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"email": email,
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert User.objects.filter(email=email).exists()
|
||||
data = resp.json()
|
||||
flow = next((f for f in data["data"]["flows"] if f.get("is_pending")))
|
||||
assert flow["id"] == "verify_email"
|
||||
addr = EmailAddress.objects.get(email=email)
|
||||
key = EmailConfirmationHMAC(addr).key
|
||||
resp = client.get(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
HTTP_X_EMAIL_VERIFICATION_KEY=key,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {
|
||||
"data": {
|
||||
"email": email,
|
||||
"user": ANY,
|
||||
},
|
||||
"meta": {"is_authenticating": True},
|
||||
"status": 200,
|
||||
}
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
addr.refresh_from_db()
|
||||
assert addr.verified
|
||||
|
||||
|
||||
def test_signup_prevent_enumeration(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
user,
|
||||
mailoutbox,
|
||||
):
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_USERNAME_REQUIRED = False
|
||||
settings.ACCOUNT_PREVENT_ENUMERATION = True
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert len(mailoutbox) == 1
|
||||
assert "an account using that email address already exists" in mailoutbox[0].body
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
|
||||
"is_pending"
|
||||
]
|
||||
resp = client.get(headless_reverse("headless:account:current_session"))
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
|
||||
"is_pending"
|
||||
]
|
||||
|
||||
|
||||
def test_signup_rate_limit(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
enable_cache,
|
||||
headless_client,
|
||||
):
|
||||
settings.ACCOUNT_RATE_LIMITS = {"signup": "1/m/ip"}
|
||||
for attempt in range(2):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": f"wizard{attempt}",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
expected_status = 429 if attempt else 200
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["status"] == expected_status
|
||||
|
||||
|
||||
def test_signup_closed(
|
||||
db,
|
||||
client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
with patch(
|
||||
"allauth.account.adapter.DefaultAccountAdapter.is_open_for_signup"
|
||||
) as iofs:
|
||||
iofs.return_value = False
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": "wizard",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert not User.objects.filter(username="wizard").exists()
|
||||
|
||||
|
||||
def test_signup_while_logged_in(
|
||||
db,
|
||||
auth_client,
|
||||
email_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:account:signup"),
|
||||
data={
|
||||
"username": "wizard",
|
||||
"email": email_factory(),
|
||||
"password": password_factory(),
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert resp.json() == {"status": 409}
|
||||
@@ -0,0 +1,96 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.headless.account import views
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
account_patterns = []
|
||||
auth_patterns = [
|
||||
path(
|
||||
"session",
|
||||
views.SessionView.as_api_view(client=client),
|
||||
name="current_session",
|
||||
),
|
||||
path(
|
||||
"reauthenticate",
|
||||
views.ReauthenticateView.as_api_view(client=client),
|
||||
name="reauthenticate",
|
||||
),
|
||||
path(
|
||||
"code/confirm",
|
||||
views.ConfirmLoginCodeView.as_api_view(client=client),
|
||||
name="confirm_login_code",
|
||||
),
|
||||
]
|
||||
if not allauth_settings.SOCIALACCOUNT_ONLY:
|
||||
account_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"password/change",
|
||||
views.ChangePasswordView.as_api_view(client=client),
|
||||
name="change_password",
|
||||
),
|
||||
path(
|
||||
"email",
|
||||
views.ManageEmailView.as_api_view(client=client),
|
||||
name="manage_email",
|
||||
),
|
||||
]
|
||||
)
|
||||
auth_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"password/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"request",
|
||||
views.RequestPasswordResetView.as_api_view(
|
||||
client=client
|
||||
),
|
||||
name="request_password_reset",
|
||||
),
|
||||
path(
|
||||
"reset",
|
||||
views.ResetPasswordView.as_api_view(client=client),
|
||||
name="reset_password",
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
path(
|
||||
"login",
|
||||
views.LoginView.as_api_view(client=client),
|
||||
name="login",
|
||||
),
|
||||
path(
|
||||
"signup",
|
||||
views.SignupView.as_api_view(client=client),
|
||||
name="signup",
|
||||
),
|
||||
path(
|
||||
"email/verify",
|
||||
views.VerifyEmailView.as_api_view(client=client),
|
||||
name="verify_email",
|
||||
),
|
||||
]
|
||||
)
|
||||
if account_settings.LOGIN_BY_CODE_ENABLED:
|
||||
auth_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"code/request",
|
||||
views.RequestLoginCodeView.as_api_view(client=client),
|
||||
name="request_login_code",
|
||||
),
|
||||
]
|
||||
)
|
||||
return [
|
||||
path("auth/", include(auth_patterns)),
|
||||
path(
|
||||
"account/",
|
||||
include(account_patterns),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,256 @@
|
||||
from django.utils.decorators import method_decorator
|
||||
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.internal.flows import password_change, password_reset
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.account.stages import EmailVerificationStage, LoginStageController
|
||||
from allauth.account.utils import send_email_confirmation
|
||||
from allauth.core import ratelimit
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.decorators import rate_limit
|
||||
from allauth.headless.account import response
|
||||
from allauth.headless.account.inputs import (
|
||||
AddEmailInput,
|
||||
ChangePasswordInput,
|
||||
ConfirmLoginCodeInput,
|
||||
DeleteEmailInput,
|
||||
LoginInput,
|
||||
MarkAsPrimaryEmailInput,
|
||||
ReauthenticateInput,
|
||||
RequestLoginCodeInput,
|
||||
RequestPasswordResetInput,
|
||||
ResetPasswordInput,
|
||||
ResetPasswordKeyInput,
|
||||
SelectEmailInput,
|
||||
SignupInput,
|
||||
VerifyEmailInput,
|
||||
)
|
||||
from allauth.headless.base.response import (
|
||||
APIResponse,
|
||||
AuthenticationResponse,
|
||||
ConflictResponse,
|
||||
ForbiddenResponse,
|
||||
)
|
||||
from allauth.headless.base.views import APIView, AuthenticatedAPIView
|
||||
from allauth.headless.internal import authkit
|
||||
from allauth.headless.internal.restkit.response import ErrorResponse
|
||||
|
||||
|
||||
class RequestLoginCodeView(APIView):
|
||||
input_class = RequestLoginCodeInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.login_by_code.request_login_code(
|
||||
self.request, self.input.cleaned_data["email"]
|
||||
)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class ConfirmLoginCodeView(APIView):
|
||||
input_class = ConfirmLoginCodeInput
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
auth_status = authkit.AuthenticationStatus(request)
|
||||
self.stage = auth_status.get_pending_stage()
|
||||
if not self.stage:
|
||||
return ConflictResponse(request)
|
||||
self.user, self.pending_login = flows.login_by_code.get_pending_login(
|
||||
request, self.stage.login, peek=True
|
||||
)
|
||||
if not self.pending_login:
|
||||
return ConflictResponse(request)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.login_by_code.perform_login_by_code(self.request, self.stage, None)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
kwargs = super().get_input_kwargs()
|
||||
kwargs["code"] = (
|
||||
self.pending_login.get("code", "") if self.pending_login else ""
|
||||
)
|
||||
return kwargs
|
||||
|
||||
def handle_invalid_input(self, input):
|
||||
flows.login_by_code.record_invalid_attempt(self.request, self.stage.login)
|
||||
return super().handle_invalid_input(input)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="login"), name="handle")
|
||||
class LoginView(APIView):
|
||||
input_class = LoginInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
if request.user.is_authenticated:
|
||||
return ConflictResponse(request)
|
||||
credentials = self.input.cleaned_data
|
||||
flows.login.perform_password_login(request, credentials, self.input.login)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="signup"), name="handle")
|
||||
class SignupView(APIView):
|
||||
input_class = {"POST": SignupInput}
|
||||
by_passkey = False
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
if request.user.is_authenticated:
|
||||
return ConflictResponse(request)
|
||||
if not get_account_adapter().is_open_for_signup(request):
|
||||
return ForbiddenResponse(request)
|
||||
user, resp = self.input.try_save(request)
|
||||
if not resp:
|
||||
try:
|
||||
flows.signup.complete_signup(
|
||||
request, user=user, by_passkey=self.by_passkey
|
||||
)
|
||||
except ImmediateHttpResponse:
|
||||
pass
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
|
||||
class SessionView(APIView):
|
||||
def get(self, request, *args, **kwargs):
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
adapter = get_account_adapter()
|
||||
adapter.logout(request)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
|
||||
class VerifyEmailView(APIView):
|
||||
input_class = VerifyEmailInput
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
self.stage = LoginStageController.enter(request, EmailVerificationStage.key)
|
||||
return super().handle(request, *args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
key = request.headers.get("x-email-verification-key", "")
|
||||
input = self.input_class({"key": key})
|
||||
if not input.is_valid():
|
||||
return ErrorResponse(request, input=input)
|
||||
verification = input.cleaned_data["key"]
|
||||
return response.VerifyEmailResponse(request, verification, stage=self.stage)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
confirmation = self.input.cleaned_data["key"]
|
||||
email_address = confirmation.confirm(request)
|
||||
if not email_address:
|
||||
# Should not happen, VerifyInputInput should have verified all
|
||||
# preconditions.
|
||||
return APIResponse(status=500)
|
||||
if self.stage:
|
||||
# Verifying email as part of login/signup flow, so emit a
|
||||
# authentication status response.
|
||||
self.stage.exit()
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class RequestPasswordResetView(APIView):
|
||||
input_class = RequestPasswordResetInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
r429 = ratelimit.consume_or_429(
|
||||
self.request,
|
||||
action="reset_password",
|
||||
key=self.input.cleaned_data["email"].lower(),
|
||||
)
|
||||
if r429:
|
||||
return r429
|
||||
self.input.save(request)
|
||||
return response.RequestPasswordResponse(request)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="reset_password_from_key"), name="handle")
|
||||
class ResetPasswordView(APIView):
|
||||
input_class = ResetPasswordInput
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
key = request.headers.get("X-Password-Reset-Key", "")
|
||||
input = ResetPasswordKeyInput({"key": key})
|
||||
if not input.is_valid():
|
||||
return ErrorResponse(request, input=input)
|
||||
return response.PasswordResetKeyResponse(request, input.user)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.password_reset.reset_password(
|
||||
self.input.user, self.input.cleaned_data["password"]
|
||||
)
|
||||
password_reset.finalize_password_reset(request, self.input.user)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="change_password"), name="handle")
|
||||
class ChangePasswordView(AuthenticatedAPIView):
|
||||
input_class = ChangePasswordInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
password_change.change_password(
|
||||
self.request.user, self.input.cleaned_data["new_password"]
|
||||
)
|
||||
is_set = not self.input.cleaned_data.get("current_password")
|
||||
if is_set:
|
||||
password_change.finalize_password_set(request, request.user)
|
||||
else:
|
||||
password_change.finalize_password_change(request, request.user)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="manage_email"), name="handle")
|
||||
class ManageEmailView(AuthenticatedAPIView):
|
||||
input_class = {
|
||||
"POST": AddEmailInput,
|
||||
"PUT": SelectEmailInput,
|
||||
"DELETE": DeleteEmailInput,
|
||||
"PATCH": MarkAsPrimaryEmailInput,
|
||||
}
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self._respond_email_list()
|
||||
|
||||
def _respond_email_list(self):
|
||||
addrs = EmailAddress.objects.filter(user=self.request.user)
|
||||
return response.EmailAddressesResponse(self.request, addrs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.manage_email.add_email(request, self.input)
|
||||
return self._respond_email_list()
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
addr = self.input.cleaned_data["email"]
|
||||
flows.manage_email.delete_email(request, addr)
|
||||
return self._respond_email_list()
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
addr = self.input.cleaned_data["email"]
|
||||
flows.manage_email.mark_as_primary(request, addr)
|
||||
return self._respond_email_list()
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
addr = self.input.cleaned_data["email"]
|
||||
sent = send_email_confirmation(request, request.user, email=addr.email)
|
||||
return response.RequestEmailVerificationResponse(
|
||||
request, verification_sent=sent
|
||||
)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
@method_decorator(rate_limit(action="reauthenticate"), name="handle")
|
||||
class ReauthenticateView(AuthenticatedAPIView):
|
||||
input_class = ReauthenticateInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.reauthentication.reauthenticate_by_password(self.request)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from django.forms.fields import Field
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.account.utils import user_display, user_username
|
||||
from allauth.core.internal.adapter import BaseAdapter
|
||||
from allauth.headless import app_settings
|
||||
from allauth.utils import import_attribute
|
||||
|
||||
|
||||
class DefaultHeadlessAdapter(BaseAdapter):
|
||||
"""The adapter class allows you to override various functionality of the
|
||||
``allauth.headless`` app. To do so, point ``settings.HEADLESS_ADAPTER`` to your own
|
||||
class that derives from ``DefaultHeadlessAdapter`` and override the behavior by
|
||||
altering the implementation of the methods according to your own need.
|
||||
"""
|
||||
|
||||
error_messages = {
|
||||
# For the following error messages i18n is not an issue as these should not be
|
||||
# showing up in a UI.
|
||||
"account_not_found": "Unknown account.",
|
||||
"client_id_required": "`client_id` required.",
|
||||
"email_or_username": "Pass only one of email or username, not both.",
|
||||
"invalid_token": "Invalid token.",
|
||||
"token_authentication_not_supported": "Provider does not support token authentication.",
|
||||
"token_required": "`id_token` and/or `access_token` required.",
|
||||
"required": Field.default_error_messages["required"],
|
||||
"unknown_email": "Unknown email address.",
|
||||
"invalid_url": "Invalid URL.",
|
||||
}
|
||||
|
||||
def serialize_user(self, user) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the basic user data. Note that this data is also exposed in
|
||||
partly authenticated scenario's (e.g. password reset, email
|
||||
verification).
|
||||
"""
|
||||
ret = {
|
||||
"id": user.pk,
|
||||
"display": user_display(user),
|
||||
"has_usable_password": user.has_usable_password(),
|
||||
}
|
||||
email = EmailAddress.objects.get_primary_email(user)
|
||||
if email:
|
||||
ret["email"] = email
|
||||
username = user_username(user)
|
||||
if username:
|
||||
ret["username"] = username
|
||||
return ret
|
||||
|
||||
|
||||
def get_adapter():
|
||||
return import_attribute(app_settings.ADAPTER)()
|
||||
@@ -0,0 +1,36 @@
|
||||
class AppSettings:
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
|
||||
def _setting(self, name, dflt):
|
||||
from allauth.utils import get_setting
|
||||
|
||||
return get_setting(self.prefix + name, dflt)
|
||||
|
||||
@property
|
||||
def ADAPTER(self):
|
||||
return self._setting(
|
||||
"ADAPTER", "allauth.headless.adapter.DefaultHeadlessAdapter"
|
||||
)
|
||||
|
||||
@property
|
||||
def TOKEN_STRATEGY(self):
|
||||
from allauth.utils import import_attribute
|
||||
|
||||
path = self._setting(
|
||||
"TOKEN_STRATEGY", "allauth.headless.tokens.sessions.SessionTokenStrategy"
|
||||
)
|
||||
cls = import_attribute(path)
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def FRONTEND_URLS(self):
|
||||
return self._setting("FRONTEND_URLS", {})
|
||||
|
||||
|
||||
_app_settings = AppSettings("HEADLESS_")
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
# See https://peps.python.org/pep-0562/
|
||||
return getattr(_app_settings, name)
|
||||
@@ -0,0 +1,7 @@
|
||||
from django.apps import AppConfig
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class HeadlessConfig(AppConfig):
|
||||
name = "allauth.headless"
|
||||
verbose_name = _("Headless")
|
||||
@@ -0,0 +1,155 @@
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.authentication import get_authentication_records
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.internal.stagekit import LOGIN_SESSION_KEY
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.headless.constants import Flow
|
||||
from allauth.headless.internal import authkit
|
||||
from allauth.headless.internal.restkit.response import APIResponse
|
||||
from allauth.mfa import app_settings as mfa_settings
|
||||
|
||||
|
||||
class BaseAuthenticationResponse(APIResponse):
|
||||
def __init__(self, request, user=None, status=None):
|
||||
data = {}
|
||||
if user and user.is_authenticated:
|
||||
adapter = get_adapter()
|
||||
data["user"] = adapter.serialize_user(user)
|
||||
data["methods"] = get_authentication_records(request)
|
||||
status = status or 200
|
||||
else:
|
||||
status = status or 401
|
||||
if status != 200:
|
||||
data["flows"] = self._get_flows(request, user)
|
||||
meta = {
|
||||
"is_authenticated": user and user.is_authenticated,
|
||||
}
|
||||
super().__init__(
|
||||
request,
|
||||
data=data,
|
||||
meta=meta,
|
||||
status=status,
|
||||
)
|
||||
|
||||
def _get_flows(self, request, user):
|
||||
auth_status = authkit.AuthenticationStatus(request)
|
||||
ret = []
|
||||
if user and user.is_authenticated:
|
||||
ret.extend(flows.reauthentication.get_reauthentication_flows(user))
|
||||
else:
|
||||
if not allauth_settings.SOCIALACCOUNT_ONLY:
|
||||
ret.append({"id": Flow.LOGIN})
|
||||
if account_settings.LOGIN_BY_CODE_ENABLED:
|
||||
ret.append({"id": Flow.LOGIN_BY_CODE})
|
||||
if (
|
||||
get_account_adapter().is_open_for_signup(request)
|
||||
and not allauth_settings.SOCIALACCOUNT_ONLY
|
||||
):
|
||||
ret.append({"id": Flow.SIGNUP})
|
||||
if allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
from allauth.headless.socialaccount.response import (
|
||||
provider_flows,
|
||||
)
|
||||
|
||||
ret.extend(provider_flows(request))
|
||||
if allauth_settings.MFA_ENABLED:
|
||||
if mfa_settings.PASSKEY_LOGIN_ENABLED:
|
||||
ret.append({"id": Flow.MFA_LOGIN_WEBAUTHN})
|
||||
stage_key = None
|
||||
stage = auth_status.get_pending_stage()
|
||||
if stage:
|
||||
stage_key = stage.key
|
||||
else:
|
||||
lsk = request.session.get(LOGIN_SESSION_KEY)
|
||||
if isinstance(lsk, str):
|
||||
stage_key = lsk
|
||||
if stage_key:
|
||||
pending_flow = {"id": stage_key, "is_pending": True}
|
||||
if stage and stage_key == Flow.MFA_AUTHENTICATE:
|
||||
self._enrich_mfa_flow(stage, pending_flow)
|
||||
self._upsert_pending_flow(ret, pending_flow)
|
||||
return ret
|
||||
|
||||
def _upsert_pending_flow(self, flows, pending_flow):
|
||||
flow = next((flow for flow in flows if flow["id"] == pending_flow["id"]), None)
|
||||
if flow:
|
||||
flow.update(pending_flow)
|
||||
else:
|
||||
flows.append(pending_flow)
|
||||
|
||||
def _enrich_mfa_flow(self, stage, flow: dict) -> None:
|
||||
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
adapter = get_mfa_adapter()
|
||||
types = []
|
||||
for typ in Authenticator.Type:
|
||||
if adapter.is_mfa_enabled(stage.login.user, types=[typ]):
|
||||
types.append(typ)
|
||||
flow["types"] = types
|
||||
|
||||
|
||||
class AuthenticationResponse(BaseAuthenticationResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, user=request.user)
|
||||
|
||||
|
||||
class ReauthenticationResponse(BaseAuthenticationResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, user=request.user, status=401)
|
||||
|
||||
|
||||
class UnauthorizedResponse(BaseAuthenticationResponse):
|
||||
def __init__(self, request, status=401):
|
||||
super().__init__(request, user=None, status=status)
|
||||
|
||||
|
||||
class ForbiddenResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=403)
|
||||
|
||||
|
||||
class ConflictResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=409)
|
||||
|
||||
|
||||
def get_config_data(request):
|
||||
data = {
|
||||
"authentication_method": account_settings.AUTHENTICATION_METHOD,
|
||||
"is_open_for_signup": get_account_adapter().is_open_for_signup(request),
|
||||
"email_verification_by_code_enabled": account_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED,
|
||||
"login_by_code_enabled": account_settings.LOGIN_BY_CODE_ENABLED,
|
||||
}
|
||||
return {"account": data}
|
||||
|
||||
|
||||
class ConfigResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
data = get_config_data(request)
|
||||
if allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
from allauth.headless.socialaccount.response import (
|
||||
get_config_data as get_socialaccount_config_data,
|
||||
)
|
||||
|
||||
data.update(get_socialaccount_config_data(request))
|
||||
if allauth_settings.MFA_ENABLED:
|
||||
from allauth.headless.mfa.response import (
|
||||
get_config_data as get_mfa_config_data,
|
||||
)
|
||||
|
||||
data.update(get_mfa_config_data(request))
|
||||
if allauth_settings.USERSESSIONS_ENABLED:
|
||||
from allauth.headless.usersessions.response import (
|
||||
get_config_data as get_usersessions_config_data,
|
||||
)
|
||||
|
||||
data.update(get_usersessions_config_data(request))
|
||||
return super().__init__(request, data=data)
|
||||
|
||||
|
||||
class RateLimitResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=429)
|
||||
@@ -0,0 +1,10 @@
|
||||
def test_config(db, client, headless_reverse):
|
||||
resp = client.get(headless_reverse("headless:config"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert set(data["data"].keys()) == {
|
||||
"account",
|
||||
"mfa",
|
||||
"socialaccount",
|
||||
"usersessions",
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
from django.urls import path
|
||||
|
||||
from allauth.headless.base import views
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
return [
|
||||
path(
|
||||
"config",
|
||||
views.ConfigView.as_api_view(client=client),
|
||||
name="config",
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from django.utils.decorators import classonlymethod
|
||||
|
||||
from allauth.account.stages import LoginStage, LoginStageController
|
||||
from allauth.core.exceptions import ReauthenticationRequired
|
||||
from allauth.headless.base import response
|
||||
from allauth.headless.constants import Client
|
||||
from allauth.headless.internal import decorators
|
||||
from allauth.headless.internal.restkit.views import RESTView
|
||||
|
||||
|
||||
class APIView(RESTView):
|
||||
client = None
|
||||
|
||||
@classonlymethod
|
||||
def as_api_view(cls, **initkwargs):
|
||||
view_func = cls.as_view(**initkwargs)
|
||||
if initkwargs["client"] == Client.APP:
|
||||
view_func = decorators.app_view(view_func)
|
||||
else:
|
||||
view_func = decorators.browser_view(view_func)
|
||||
return view_func
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
except ReauthenticationRequired:
|
||||
return response.ReauthenticationResponse(self.request)
|
||||
|
||||
|
||||
class AuthenticationStageAPIView(APIView):
|
||||
stage_class: Optional[Type[LoginStage]] = None
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
self.stage = LoginStageController.enter(request, self.stage_class.key)
|
||||
if not self.stage:
|
||||
return response.UnauthorizedResponse(request)
|
||||
return super().handle(request, *args, **kwargs)
|
||||
|
||||
def respond_stage_error(self):
|
||||
return response.UnauthorizedResponse(self.request)
|
||||
|
||||
def respond_next_stage(self):
|
||||
self.stage.exit()
|
||||
return response.AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class AuthenticatedAPIView(APIView):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if not request.user.is_authenticated:
|
||||
return response.AuthenticationResponse(request)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class ConfigView(APIView):
|
||||
def get(self, request, *args, **kwargs):
|
||||
"""
|
||||
The frontend queries (GET) this endpoint, expecting to receive
|
||||
either a 401 if no user is authenticated, or user information.
|
||||
"""
|
||||
return response.ConfigResponse(request)
|
||||
@@ -0,0 +1,63 @@
|
||||
from django.test.client import Client
|
||||
from django.urls import reverse
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(params=["app", "browser"])
|
||||
def headless_client(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headless_reverse(headless_client):
|
||||
def rev(viewname, **kwargs):
|
||||
viewname = viewname.replace("headless:", f"headless:{headless_client}:")
|
||||
return reverse(viewname, **kwargs)
|
||||
|
||||
return rev
|
||||
|
||||
|
||||
class AppClient(Client):
|
||||
session_token = None
|
||||
|
||||
def generic(self, *args, **kwargs):
|
||||
if self.session_token:
|
||||
kwargs["HTTP_X_SESSION_TOKEN"] = self.session_token
|
||||
resp = super().generic(*args, **kwargs)
|
||||
if resp["content-type"] == "application/json":
|
||||
data = resp.json()
|
||||
session_token = data.get("meta", {}).get("session_token")
|
||||
if session_token:
|
||||
self.session_token = session_token
|
||||
return resp
|
||||
|
||||
def force_login(self, user):
|
||||
ret = super().force_login(user)
|
||||
self.session_token = self.session.session_key
|
||||
return ret
|
||||
|
||||
def headless_session(self):
|
||||
from allauth.headless.internal import sessionkit
|
||||
|
||||
return sessionkit.session_store(self.session_token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_client():
|
||||
return AppClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(headless_client):
|
||||
if headless_client == "browser":
|
||||
client = Client()
|
||||
client.headless_session = lambda: client.session
|
||||
return client
|
||||
return AppClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(client, user):
|
||||
client.force_login(user)
|
||||
return client
|
||||
@@ -0,0 +1,23 @@
|
||||
from enum import Enum
|
||||
|
||||
from allauth.account.stages import EmailVerificationStage, LoginByCodeStage
|
||||
|
||||
|
||||
class Client(str, Enum):
|
||||
APP = "app"
|
||||
BROWSER = "browser"
|
||||
|
||||
|
||||
class Flow(str, Enum):
|
||||
VERIFY_EMAIL = EmailVerificationStage.key
|
||||
LOGIN = "login"
|
||||
LOGIN_BY_CODE = LoginByCodeStage.key
|
||||
SIGNUP = "signup"
|
||||
PROVIDER_REDIRECT = "provider_redirect"
|
||||
PROVIDER_SIGNUP = "provider_signup"
|
||||
PROVIDER_TOKEN = "provider_token"
|
||||
REAUTHENTICATE = "reauthenticate"
|
||||
MFA_REAUTHENTICATE = "mfa_reauthenticate"
|
||||
MFA_AUTHENTICATE = "mfa_authenticate" # NOTE: Equal to `allauth.mfa.stages.AuthenticationStage.key`
|
||||
MFA_LOGIN_WEBAUTHN = "mfa_login_webauthn"
|
||||
MFA_SIGNUP_WEBAUTHN = "mfa_signup_webauthn"
|
||||
@@ -0,0 +1,85 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from django.utils.functional import SimpleLazyObject, empty
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account.internal.stagekit import get_pending_stage
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.headless import app_settings
|
||||
from allauth.headless.constants import Client
|
||||
from allauth.headless.internal import sessionkit
|
||||
|
||||
|
||||
class AuthenticationStatus:
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
||||
@property
|
||||
def is_authenticated(self):
|
||||
return self.request.user.is_authenticated
|
||||
|
||||
def get_pending_stage(self):
|
||||
return get_pending_stage(self.request)
|
||||
|
||||
@property
|
||||
def has_pending_signup(self):
|
||||
if not allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
return False
|
||||
from allauth.socialaccount.internal import flows
|
||||
|
||||
return bool(flows.signup.get_pending_signup(self.request))
|
||||
|
||||
|
||||
def purge_request_user_cache(request):
|
||||
for attr in ["_cached_user", "_acached_user"]:
|
||||
if hasattr(request, attr):
|
||||
delattr(request, attr)
|
||||
if isinstance(request.user, SimpleLazyObject):
|
||||
request.user._wrapped = empty
|
||||
|
||||
|
||||
@contextmanager
|
||||
def authentication_context(request):
|
||||
from allauth.headless.base.response import UnauthorizedResponse
|
||||
|
||||
old_user = request.user
|
||||
old_session = request.session
|
||||
try:
|
||||
request.session = sessionkit.new_session()
|
||||
purge_request_user_cache(request)
|
||||
|
||||
strategy = app_settings.TOKEN_STRATEGY
|
||||
session_token = strategy.get_session_token(request)
|
||||
if session_token:
|
||||
session = strategy.lookup_session(session_token)
|
||||
if not session:
|
||||
raise ImmediateHttpResponse(UnauthorizedResponse(request, status=410))
|
||||
request.session = session
|
||||
purge_request_user_cache(request)
|
||||
request.allauth.headless._pre_user = request.user
|
||||
# request.user is lazy -- force evaluation
|
||||
request.allauth.headless._pre_user.pk
|
||||
yield
|
||||
finally:
|
||||
if request.session.modified and not request.session.is_empty():
|
||||
request.session.save()
|
||||
request.user = old_user
|
||||
request.session = old_session
|
||||
# e.g. logging in calls csrf `rotate_token()` -- this prevents setting a new CSRF cookie.
|
||||
request.META["CSRF_COOKIE_NEEDS_UPDATE"] = False
|
||||
|
||||
|
||||
def expose_access_token(request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Determines if a new access token needs to be exposed.
|
||||
"""
|
||||
if request.allauth.headless.client != Client.APP:
|
||||
return None
|
||||
if not request.user.is_authenticated:
|
||||
return None
|
||||
pre_user = request.allauth.headless._pre_user
|
||||
if pre_user.is_authenticated and pre_user.pk == request.user.pk:
|
||||
return None
|
||||
strategy = app_settings.TOKEN_STRATEGY
|
||||
return strategy.create_access_token_payload(request)
|
||||
@@ -0,0 +1,54 @@
|
||||
from functools import wraps
|
||||
from types import SimpleNamespace
|
||||
|
||||
from django.middleware.csrf import get_token
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from allauth.account.internal.decorators import login_not_required
|
||||
from allauth.headless.constants import Client
|
||||
from allauth.headless.internal import authkit
|
||||
|
||||
|
||||
def mark_request_as_headless(request, client):
|
||||
request.allauth.headless = SimpleNamespace()
|
||||
request.allauth.headless.client = client
|
||||
|
||||
|
||||
def app_view(
|
||||
function=None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@login_not_required
|
||||
@wraps(view_func)
|
||||
def _wrapper_view(request, *args, **kwargs):
|
||||
mark_request_as_headless(request, Client.APP)
|
||||
with authkit.authentication_context(request):
|
||||
return view_func(request, *args, **kwargs)
|
||||
|
||||
return _wrapper_view
|
||||
|
||||
ret = decorator
|
||||
if function:
|
||||
ret = decorator(function)
|
||||
return csrf_exempt(ret)
|
||||
|
||||
|
||||
def browser_view(
|
||||
function=None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@login_not_required
|
||||
@wraps(view_func)
|
||||
def _wrapper_view(request, *args, **kwargs):
|
||||
mark_request_as_headless(request, Client.BROWSER)
|
||||
# Needed -- so that the CSRF token is set in the response for the
|
||||
# frontend to pick up.
|
||||
get_token(request)
|
||||
return view_func(request, *args, **kwargs)
|
||||
|
||||
return _wrapper_view
|
||||
|
||||
ret = decorator
|
||||
if function:
|
||||
ret = decorator(function)
|
||||
return ret
|
||||
@@ -0,0 +1,25 @@
|
||||
from django.forms import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
ChoiceField,
|
||||
EmailField,
|
||||
Field,
|
||||
Form,
|
||||
ModelChoiceField,
|
||||
ModelMultipleChoiceField,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Field",
|
||||
"CharField",
|
||||
"ChoiceField",
|
||||
"EmailField",
|
||||
"BooleanField",
|
||||
"ModelMultipleChoiceField",
|
||||
"ModelChoiceField",
|
||||
]
|
||||
|
||||
|
||||
class Input(Form):
|
||||
pass
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from django.forms.utils import ErrorList
|
||||
from django.http import JsonResponse
|
||||
from django.utils.cache import add_never_cache_headers
|
||||
|
||||
from allauth.headless.internal import authkit, sessionkit
|
||||
|
||||
|
||||
class APIResponse(JsonResponse):
|
||||
def __init__(
|
||||
self,
|
||||
request,
|
||||
errors=None,
|
||||
data=None,
|
||||
meta: Optional[Dict] = None,
|
||||
status: int = 200,
|
||||
):
|
||||
d: Dict[str, Any] = {"status": status}
|
||||
if data is not None:
|
||||
d["data"] = data
|
||||
meta = self._add_session_meta(request, meta)
|
||||
if meta is not None:
|
||||
d["meta"] = meta
|
||||
if errors:
|
||||
d["errors"] = errors
|
||||
super().__init__(d, status=status)
|
||||
add_never_cache_headers(self)
|
||||
|
||||
def _add_session_meta(self, request, meta: Optional[Dict]) -> Optional[Dict]:
|
||||
session_token = sessionkit.expose_session_token(request)
|
||||
access_token_payload = authkit.expose_access_token(request)
|
||||
if session_token:
|
||||
meta = meta or {}
|
||||
meta["session_token"] = session_token
|
||||
if access_token_payload:
|
||||
meta = meta or {}
|
||||
meta.update(access_token_payload)
|
||||
return meta
|
||||
|
||||
|
||||
class ErrorResponse(APIResponse):
|
||||
def __init__(self, request, exception=None, input=None, status=400):
|
||||
errors = []
|
||||
if exception is not None:
|
||||
error_datas = ErrorList(exception.error_list).get_json_data()
|
||||
errors.extend(error_datas)
|
||||
if input is not None:
|
||||
for field, error_list in input.errors.items():
|
||||
error_datas = error_list.get_json_data()
|
||||
for error_data in error_datas:
|
||||
if field != "__all__":
|
||||
error_data["param"] = field
|
||||
errors.extend(error_datas)
|
||||
super().__init__(request, status=status, errors=errors)
|
||||
@@ -0,0 +1,53 @@
|
||||
import json
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from django.http import HttpResponseBadRequest
|
||||
from django.views.generic import View
|
||||
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.headless.internal.restkit.inputs import Input
|
||||
from allauth.headless.internal.restkit.response import ErrorResponse
|
||||
|
||||
|
||||
class RESTView(View):
|
||||
input_class: Union[Optional[Dict[str, Type[Input]]], Type[Input]] = None
|
||||
handle_json_input = True
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return self.handle(request, *args, **kwargs)
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
if self.handle_json_input and request.method != "GET":
|
||||
self.data = self._parse_json(request)
|
||||
response = self.handle_input(self.data)
|
||||
if response:
|
||||
return response
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {}
|
||||
|
||||
def handle_input(self, data):
|
||||
input_class = self.input_class
|
||||
if isinstance(input_class, dict):
|
||||
input_class = input_class.get(self.request.method)
|
||||
if not input_class:
|
||||
return
|
||||
input_kwargs = self.get_input_kwargs()
|
||||
if data is None:
|
||||
# Make form bound on empty POST
|
||||
data = {}
|
||||
self.input = input_class(data=data, **input_kwargs)
|
||||
if not self.input.is_valid():
|
||||
return self.handle_invalid_input(self.input)
|
||||
|
||||
def handle_invalid_input(self, input):
|
||||
return ErrorResponse(self.request, input=input)
|
||||
|
||||
def _parse_json(self, request):
|
||||
if request.method == "GET" or not request.body:
|
||||
return
|
||||
try:
|
||||
return json.loads(request.body.decode("utf8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
raise ImmediateHttpResponse(response=HttpResponseBadRequest())
|
||||
@@ -0,0 +1,28 @@
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from allauth.headless import app_settings
|
||||
from allauth.headless.constants import Client
|
||||
|
||||
|
||||
def session_store(session_key=None):
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
return engine.SessionStore(session_key=session_key)
|
||||
|
||||
|
||||
def new_session():
|
||||
return session_store()
|
||||
|
||||
|
||||
def expose_session_token(request):
|
||||
if request.allauth.headless.client != Client.APP:
|
||||
return
|
||||
strategy = app_settings.TOKEN_STRATEGY
|
||||
hdr_token = strategy.get_session_token(request)
|
||||
modified = request.session.modified
|
||||
empty = request.session.is_empty()
|
||||
if modified and not empty:
|
||||
new_token = strategy.create_session_token(request)
|
||||
if not hdr_token or hdr_token != new_token:
|
||||
return new_token
|
||||
@@ -0,0 +1,28 @@
|
||||
from django.contrib.auth import (
|
||||
BACKEND_SESSION_KEY,
|
||||
HASH_SESSION_KEY,
|
||||
SESSION_KEY,
|
||||
)
|
||||
from django.contrib.auth.middleware import AuthenticationMiddleware
|
||||
from django.contrib.sessions.middleware import SessionMiddleware
|
||||
from django.http import HttpResponse
|
||||
|
||||
from allauth.headless.internal.authkit import purge_request_user_cache
|
||||
|
||||
|
||||
def test_purge_request_user_cache(rf, user):
|
||||
request = rf.get("/")
|
||||
smw = SessionMiddleware(lambda request: HttpResponse())
|
||||
smw(request)
|
||||
amw = AuthenticationMiddleware(lambda request: HttpResponse())
|
||||
amw(request)
|
||||
assert request.user.is_anonymous
|
||||
assert not request.user.pk
|
||||
purge_request_user_cache(request)
|
||||
request.session[SESSION_KEY] = user.pk
|
||||
request.session[BACKEND_SESSION_KEY] = (
|
||||
"allauth.account.auth_backends.AuthenticationBackend"
|
||||
)
|
||||
request.session[HASH_SESSION_KEY] = user.get_session_auth_hash()
|
||||
assert request.user.is_authenticated
|
||||
assert request.user.pk == user.pk
|
||||
@@ -0,0 +1,74 @@
|
||||
from allauth.account.forms import BaseSignupForm
|
||||
from allauth.headless.internal.restkit import inputs
|
||||
from allauth.mfa.base.forms import AuthenticateForm
|
||||
from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.recovery_codes.forms import GenerateRecoveryCodesForm
|
||||
from allauth.mfa.totp.forms import ActivateTOTPForm
|
||||
from allauth.mfa.webauthn.forms import (
|
||||
AddWebAuthnForm,
|
||||
AuthenticateWebAuthnForm,
|
||||
LoginWebAuthnForm,
|
||||
ReauthenticateWebAuthnForm,
|
||||
SignupWebAuthnForm,
|
||||
)
|
||||
|
||||
|
||||
class AuthenticateInput(AuthenticateForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class ActivateTOTPInput(ActivateTOTPForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class GenerateRecoveryCodesInput(GenerateRecoveryCodesForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class AddWebAuthnInput(AddWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class CreateWebAuthnInput(SignupWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class UpdateWebAuthnInput(inputs.Input):
|
||||
id = inputs.ModelChoiceField(queryset=Authenticator.objects.none())
|
||||
name = inputs.CharField(required=True, max_length=100)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["id"].queryset = Authenticator.objects.filter(
|
||||
user=self.user, type=Authenticator.Type.WEBAUTHN
|
||||
)
|
||||
|
||||
|
||||
class DeleteWebAuthnInput(inputs.Input):
|
||||
authenticators = inputs.ModelMultipleChoiceField(
|
||||
queryset=Authenticator.objects.none()
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["authenticators"].queryset = Authenticator.objects.filter(
|
||||
user=self.user, type=Authenticator.Type.WEBAUTHN
|
||||
)
|
||||
|
||||
|
||||
class ReauthenticateWebAuthnInput(ReauthenticateWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticateWebAuthnInput(AuthenticateWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class LoginWebAuthnInput(LoginWebAuthnForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class SignupWebAuthnInput(BaseSignupForm, inputs.Input):
|
||||
pass
|
||||
@@ -0,0 +1,104 @@
|
||||
from allauth.headless.base.response import APIResponse
|
||||
from allauth.mfa import app_settings as mfa_settings
|
||||
|
||||
|
||||
def get_config_data(request):
|
||||
data = {
|
||||
"mfa": {
|
||||
"supported_types": mfa_settings.SUPPORTED_TYPES,
|
||||
"passkey_login_enabled": mfa_settings.PASSKEY_LOGIN_ENABLED,
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def _authenticator_data(authenticator, sensitive=False):
|
||||
data = {
|
||||
"type": authenticator.type,
|
||||
"created_at": authenticator.created_at.timestamp(),
|
||||
"last_used_at": (
|
||||
authenticator.last_used_at.timestamp()
|
||||
if authenticator.last_used_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
if authenticator.type == authenticator.Type.TOTP:
|
||||
pass
|
||||
elif authenticator.type == authenticator.Type.RECOVERY_CODES:
|
||||
wrapped = authenticator.wrap()
|
||||
unused_codes = wrapped.get_unused_codes()
|
||||
data.update(
|
||||
{
|
||||
"total_code_count": len(wrapped.generate_codes()),
|
||||
"unused_code_count": len(unused_codes),
|
||||
}
|
||||
)
|
||||
if sensitive:
|
||||
data["unused_codes"] = unused_codes
|
||||
elif authenticator.type == authenticator.Type.WEBAUTHN:
|
||||
wrapped = authenticator.wrap()
|
||||
data["id"] = authenticator.pk
|
||||
data["name"] = wrapped.name
|
||||
passwordless = wrapped.is_passwordless
|
||||
if passwordless is not None:
|
||||
data["is_passwordless"] = passwordless
|
||||
return data
|
||||
|
||||
|
||||
class AuthenticatorDeletedResponse(APIResponse):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticatorsDeletedResponse(APIResponse):
|
||||
pass
|
||||
|
||||
|
||||
class TOTPNotFoundResponse(APIResponse):
|
||||
def __init__(self, request, secret, totp_url):
|
||||
super().__init__(
|
||||
request,
|
||||
meta={
|
||||
"secret": secret,
|
||||
"totp_url": totp_url,
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
|
||||
|
||||
class TOTPResponse(APIResponse):
|
||||
def __init__(self, request, authenticator):
|
||||
data = _authenticator_data(authenticator)
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class AuthenticatorsResponse(APIResponse):
|
||||
def __init__(self, request, authenticators):
|
||||
data = [_authenticator_data(authenticator) for authenticator in authenticators]
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class AuthenticatorResponse(APIResponse):
|
||||
def __init__(self, request, authenticator, meta=None):
|
||||
data = _authenticator_data(authenticator)
|
||||
super().__init__(request, data=data, meta=meta)
|
||||
|
||||
|
||||
class RecoveryCodesNotFoundResponse(APIResponse):
|
||||
def __init__(self, request):
|
||||
super().__init__(request, status=404)
|
||||
|
||||
|
||||
class RecoveryCodesResponse(APIResponse):
|
||||
def __init__(self, request, authenticator):
|
||||
data = _authenticator_data(authenticator, sensitive=True)
|
||||
super().__init__(request, data=data)
|
||||
|
||||
|
||||
class AddWebAuthnResponse(APIResponse):
|
||||
def __init__(self, request, registration_data):
|
||||
super().__init__(request, data={"creation_options": registration_data})
|
||||
|
||||
|
||||
class WebAuthnRequestOptionsResponse(APIResponse):
|
||||
def __init__(self, request, request_options):
|
||||
super().__init__(request, data={"request_options": request_options})
|
||||
@@ -0,0 +1,60 @@
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
|
||||
def test_get_recovery_codes_requires_reauth(
|
||||
auth_client, user_with_recovery_codes, headless_reverse
|
||||
):
|
||||
rc = Authenticator.objects.get(
|
||||
type=Authenticator.Type.RECOVERY_CODES, user=user_with_recovery_codes
|
||||
)
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:reauthenticate"),
|
||||
data={"code": rc.wrap().get_unused_codes()[0]},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_get_recovery_codes(
|
||||
auth_client,
|
||||
user_with_recovery_codes,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "recovery_codes"
|
||||
assert len(data["data"]["unused_codes"]) == 10
|
||||
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:authenticators"))
|
||||
data = resp.json()
|
||||
assert len(data["data"]) == 2
|
||||
rc = [autor for autor in data["data"] if autor["type"] == "recovery_codes"][0]
|
||||
assert "unused_codes" not in rc
|
||||
|
||||
|
||||
def test_generate_recovery_codes(
|
||||
auth_client,
|
||||
user_with_totp,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 404
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:manage_recovery_codes"),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "recovery_codes"
|
||||
assert len(data["data"]["unused_codes"]) == 10
|
||||
@@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("email_verified", [False, True])
|
||||
def test_get_totp_not_active(auth_client, user, headless_reverse, email_verified):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_totp"))
|
||||
if email_verified:
|
||||
assert resp.status_code == 404
|
||||
data = resp.json()
|
||||
assert len(data["meta"]["secret"]) == 32
|
||||
assert len(data["meta"]["totp_url"]) == 145
|
||||
else:
|
||||
assert resp.status_code == 409
|
||||
assert resp.json() == {
|
||||
"status": 409,
|
||||
"errors": [
|
||||
{
|
||||
"message": "You cannot activate two-factor authentication until you have verified your email address.",
|
||||
"code": "unverified_email",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_get_totp(
|
||||
auth_client,
|
||||
user_with_totp,
|
||||
headless_reverse,
|
||||
):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_totp"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "totp"
|
||||
assert isinstance(data["data"]["created_at"], float)
|
||||
|
||||
|
||||
def test_deactivate_totp(
|
||||
auth_client,
|
||||
user_with_totp,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.delete(headless_reverse("headless:mfa:manage_totp"))
|
||||
assert resp.status_code == 200
|
||||
assert not Authenticator.objects.filter(user=user_with_totp).exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("email_verified", [False, True])
|
||||
def test_activate_totp(
|
||||
auth_client,
|
||||
user,
|
||||
headless_reverse,
|
||||
reauthentication_bypass,
|
||||
settings,
|
||||
totp_validation_bypass,
|
||||
email_verified,
|
||||
):
|
||||
with reauthentication_bypass():
|
||||
with totp_validation_bypass():
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:manage_totp"),
|
||||
data={"code": "42"},
|
||||
content_type="application/json",
|
||||
)
|
||||
if email_verified:
|
||||
assert resp.status_code == 200
|
||||
assert Authenticator.objects.filter(
|
||||
user=user, type=Authenticator.Type.TOTP
|
||||
).exists()
|
||||
data = resp.json()
|
||||
assert data["data"]["type"] == "totp"
|
||||
assert isinstance(data["data"]["created_at"], float)
|
||||
assert data["data"]["last_used_at"] is None
|
||||
else:
|
||||
assert resp.status_code == 400
|
||||
@@ -0,0 +1,115 @@
|
||||
from allauth.account.models import EmailAddress, get_emailconfirmation_model
|
||||
from allauth.headless.constants import Flow
|
||||
|
||||
|
||||
def test_auth_unverified_email_and_mfa(
|
||||
client,
|
||||
user_factory,
|
||||
password_factory,
|
||||
settings,
|
||||
totp_validation_bypass,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
password = password_factory()
|
||||
user = user_factory(email_verified=False, password=password, with_totp=True)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
assert [f for f in data["data"]["flows"] if f["id"] == Flow.VERIFY_EMAIL][0][
|
||||
"is_pending"
|
||||
]
|
||||
emailaddress = EmailAddress.objects.filter(user=user, verified=False).get()
|
||||
key = get_emailconfirmation_model().create(emailaddress).key
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:verify_email"),
|
||||
data={"key": key},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
flows = [
|
||||
{"id": "login"},
|
||||
{"id": "login_by_code"},
|
||||
{"id": "signup"},
|
||||
]
|
||||
if headless_client == "browser":
|
||||
flows.append(
|
||||
{
|
||||
"id": "provider_redirect",
|
||||
"providers": ["dummy", "openid_connect", "openid_connect"],
|
||||
}
|
||||
)
|
||||
flows.append({"id": "provider_token", "providers": ["dummy"]})
|
||||
flows.append({"id": "mfa_login_webauthn"})
|
||||
flows.append(
|
||||
{
|
||||
"id": "mfa_authenticate",
|
||||
"is_pending": True,
|
||||
"types": ["totp"],
|
||||
}
|
||||
)
|
||||
|
||||
assert resp.json() == {
|
||||
"data": {"flows": flows},
|
||||
"meta": {"is_authenticated": False},
|
||||
"status": 401,
|
||||
}
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:authenticate"),
|
||||
data={"code": "bad"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{"message": "Incorrect code.", "code": "incorrect_code", "param": "code"}
|
||||
],
|
||||
}
|
||||
|
||||
with totp_validation_bypass():
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:authenticate"),
|
||||
data={"code": "bad"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_dangling_mfa_is_logged_out(
|
||||
client,
|
||||
user_with_totp,
|
||||
password_factory,
|
||||
settings,
|
||||
totp_validation_bypass,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
user_password,
|
||||
):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user_with_totp.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
flow = [f for f in data["data"]["flows"] if f["id"] == Flow.MFA_AUTHENTICATE][0]
|
||||
assert flow["is_pending"]
|
||||
assert flow["types"] == ["totp"]
|
||||
resp = client.delete(headless_reverse("headless:account:current_session"))
|
||||
data = resp.json()
|
||||
assert resp.status_code == 401
|
||||
assert all(not f.get("is_pending") for f in data["data"]["flows"])
|
||||
@@ -0,0 +1,223 @@
|
||||
from unittest.mock import ANY
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
import pytest
|
||||
|
||||
from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
|
||||
from allauth.headless.constants import Flow
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
|
||||
def test_passkey_login(
|
||||
client, passkey, webauthn_authentication_bypass, headless_reverse
|
||||
):
|
||||
with webauthn_authentication_bypass(passkey) as credential:
|
||||
resp = client.get(headless_reverse("headless:mfa:login_webauthn"))
|
||||
assert "request_options" in resp.json()["data"]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:login_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["data"]["user"]["id"] == passkey.user_id
|
||||
|
||||
|
||||
def test_passkey_login_get_options(client, headless_client, headless_reverse, db):
|
||||
resp = client.get(headless_reverse("headless:mfa:login_webauthn"))
|
||||
data = resp.json()
|
||||
meta = {}
|
||||
if headless_client == "app":
|
||||
meta = {
|
||||
"meta": {"session_token": ANY},
|
||||
}
|
||||
assert data == {
|
||||
"status": 200,
|
||||
"data": {"request_options": {"publicKey": ANY}},
|
||||
**meta,
|
||||
}
|
||||
|
||||
|
||||
def test_reauthenticate(
|
||||
auth_client,
|
||||
passkey,
|
||||
user_with_recovery_codes,
|
||||
webauthn_authentication_bypass,
|
||||
headless_reverse,
|
||||
):
|
||||
# View recovery codes, confirm webauthn reauthentication is an option
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 401
|
||||
assert Flow.MFA_REAUTHENTICATE in [
|
||||
flow["id"] for flow in resp.json()["data"]["flows"]
|
||||
]
|
||||
|
||||
# Get request options
|
||||
with webauthn_authentication_bypass(passkey):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:reauthenticate_webauthn"))
|
||||
data = resp.json()
|
||||
assert data["status"] == 200
|
||||
assert data["data"]["request_options"] == ANY
|
||||
|
||||
# Reauthenticate
|
||||
with webauthn_authentication_bypass(passkey) as credential:
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:reauthenticate_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_recovery_codes"))
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_update_authenticator(
|
||||
auth_client, headless_reverse, passkey, reauthentication_bypass
|
||||
):
|
||||
data = {"id": passkey.pk, "name": "Renamed!"}
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
# Reauthentication required
|
||||
assert resp.status_code == 401
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.put(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
passkey.refresh_from_db()
|
||||
assert passkey.wrap().name == "Renamed!"
|
||||
|
||||
|
||||
def test_delete_authenticator(
|
||||
auth_client, headless_reverse, passkey, reauthentication_bypass
|
||||
):
|
||||
data = {"authenticators": [passkey.pk]}
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
# Reauthentication required
|
||||
assert resp.status_code == 401
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data=data,
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert not Authenticator.objects.filter(pk=passkey.pk).exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("email_verified", [False, True])
|
||||
def test_add_authenticator(
|
||||
user,
|
||||
auth_client,
|
||||
headless_reverse,
|
||||
webauthn_registration_bypass,
|
||||
reauthentication_bypass,
|
||||
email_verified,
|
||||
):
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_webauthn"))
|
||||
# Reauthentication required
|
||||
assert resp.status_code == 401 if email_verified else 409
|
||||
|
||||
with reauthentication_bypass():
|
||||
resp = auth_client.get(headless_reverse("headless:mfa:manage_webauthn"))
|
||||
if email_verified:
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["creation_options"] == ANY
|
||||
else:
|
||||
assert resp.status_code == 409
|
||||
|
||||
with webauthn_registration_bypass(user, False) as credential:
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:mfa:manage_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
webauthn_count = Authenticator.objects.filter(
|
||||
type=Authenticator.Type.WEBAUTHN, user=user
|
||||
).count()
|
||||
if email_verified:
|
||||
assert resp.status_code == 200
|
||||
assert webauthn_count == 1
|
||||
else:
|
||||
assert resp.status_code == 409
|
||||
assert webauthn_count == 0
|
||||
|
||||
|
||||
def test_2fa_login(
|
||||
client,
|
||||
user,
|
||||
user_password,
|
||||
passkey,
|
||||
webauthn_authentication_bypass,
|
||||
headless_reverse,
|
||||
):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"username": user.username,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = resp.json()
|
||||
pending_flows = [f for f in data["data"]["flows"] if f.get("is_pending")]
|
||||
assert len(pending_flows) == 1
|
||||
pending_flow = pending_flows[0]
|
||||
assert pending_flow == {
|
||||
"id": "mfa_authenticate",
|
||||
"is_pending": True,
|
||||
"types": ["webauthn"],
|
||||
}
|
||||
with webauthn_authentication_bypass(passkey) as credential:
|
||||
resp = client.get(headless_reverse("headless:mfa:authenticate_webauthn"))
|
||||
assert "request_options" in resp.json()["data"]
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:authenticate_webauthn"),
|
||||
data={"credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.json()
|
||||
assert resp.status_code == 200
|
||||
assert data["data"]["user"]["id"] == passkey.user_id
|
||||
assert client.headless_session()[AUTHENTICATION_METHODS_SESSION_KEY] == [
|
||||
{"method": "password", "at": ANY, "username": passkey.user.username},
|
||||
{"method": "mfa", "at": ANY, "id": ANY, "type": Authenticator.Type.WEBAUTHN},
|
||||
]
|
||||
|
||||
|
||||
def test_passkey_signup(client, db, webauthn_registration_bypass, headless_reverse):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:mfa:signup_webauthn"),
|
||||
data={"email": "pass@key.org", "username": "passkey"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
flow = [flow for flow in resp.json()["data"]["flows"] if flow.get("is_pending")][0]
|
||||
assert flow["id"] == Flow.MFA_SIGNUP_WEBAUTHN.value
|
||||
resp = client.get(headless_reverse("headless:mfa:signup_webauthn"))
|
||||
data = resp.json()
|
||||
assert "creation_options" in data["data"]
|
||||
|
||||
user = get_user_model().objects.get(email="pass@key.org")
|
||||
with webauthn_registration_bypass(user, True) as credential:
|
||||
resp = client.put(
|
||||
headless_reverse("headless:mfa:signup_webauthn"),
|
||||
data={"name": "Some key", "credential": credential},
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["meta"]["is_authenticated"]
|
||||
authenticator = Authenticator.objects.get(user=user)
|
||||
assert authenticator.wrap().name == "Some key"
|
||||
100
.venv/lib/python3.12/site-packages/allauth/headless/mfa/urls.py
Normal file
100
.venv/lib/python3.12/site-packages/allauth/headless/mfa/urls.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from allauth.headless.mfa import views
|
||||
from allauth.mfa import app_settings as mfa_settings
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
auth_patterns = [
|
||||
path(
|
||||
"2fa/authenticate",
|
||||
views.AuthenticateView.as_api_view(client=client),
|
||||
name="authenticate",
|
||||
),
|
||||
path(
|
||||
"2fa/reauthenticate",
|
||||
views.ReauthenticateView.as_api_view(client=client),
|
||||
name="reauthenticate",
|
||||
),
|
||||
]
|
||||
|
||||
authenticators = []
|
||||
if "totp" in mfa_settings.SUPPORTED_TYPES:
|
||||
authenticators.append(
|
||||
path(
|
||||
"totp",
|
||||
views.ManageTOTPView.as_api_view(client=client),
|
||||
name="manage_totp",
|
||||
)
|
||||
)
|
||||
if "recovery_codes" in mfa_settings.SUPPORTED_TYPES:
|
||||
authenticators.append(
|
||||
path(
|
||||
"recovery-codes",
|
||||
views.ManageRecoveryCodesView.as_api_view(client=client),
|
||||
name="manage_recovery_codes",
|
||||
)
|
||||
)
|
||||
if "webauthn" in mfa_settings.SUPPORTED_TYPES:
|
||||
authenticators.extend(
|
||||
[
|
||||
path(
|
||||
"webauthn",
|
||||
views.ManageWebAuthnView.as_api_view(client=client),
|
||||
name="manage_webauthn",
|
||||
),
|
||||
]
|
||||
)
|
||||
auth_patterns.extend(
|
||||
[
|
||||
path(
|
||||
"webauthn/authenticate",
|
||||
views.AuthenticateWebAuthnView.as_api_view(client=client),
|
||||
name="authenticate_webauthn",
|
||||
),
|
||||
path(
|
||||
"webauthn/reauthenticate",
|
||||
views.ReauthenticateWebAuthnView.as_api_view(client=client),
|
||||
name="reauthenticate_webauthn",
|
||||
),
|
||||
]
|
||||
)
|
||||
if mfa_settings.PASSKEY_LOGIN_ENABLED:
|
||||
auth_patterns.append(
|
||||
path(
|
||||
"webauthn/login",
|
||||
views.LoginWebAuthnView.as_api_view(client=client),
|
||||
name="login_webauthn",
|
||||
)
|
||||
)
|
||||
if mfa_settings.PASSKEY_SIGNUP_ENABLED:
|
||||
auth_patterns.append(
|
||||
path(
|
||||
"webauthn/signup",
|
||||
views.SignupWebAuthnView.as_api_view(client=client),
|
||||
name="signup_webauthn",
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
path(
|
||||
"auth/",
|
||||
include(auth_patterns),
|
||||
),
|
||||
path(
|
||||
"account/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"authenticators",
|
||||
views.AuthenticatorsView.as_api_view(client=client),
|
||||
name="authenticators",
|
||||
),
|
||||
path(
|
||||
"authenticators/",
|
||||
include(authenticators),
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
280
.venv/lib/python3.12/site-packages/allauth/headless/mfa/views.py
Normal file
280
.venv/lib/python3.12/site-packages/allauth/headless/mfa/views.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from allauth.account.internal.stagekit import get_pending_stage
|
||||
from allauth.account.models import Login
|
||||
from allauth.headless.account.views import SignupView
|
||||
from allauth.headless.base.response import (
|
||||
APIResponse,
|
||||
AuthenticationResponse,
|
||||
ConflictResponse,
|
||||
)
|
||||
from allauth.headless.base.views import (
|
||||
APIView,
|
||||
AuthenticatedAPIView,
|
||||
AuthenticationStageAPIView,
|
||||
)
|
||||
from allauth.headless.internal.restkit.response import ErrorResponse
|
||||
from allauth.headless.mfa import response
|
||||
from allauth.headless.mfa.inputs import (
|
||||
ActivateTOTPInput,
|
||||
AddWebAuthnInput,
|
||||
AuthenticateInput,
|
||||
AuthenticateWebAuthnInput,
|
||||
CreateWebAuthnInput,
|
||||
DeleteWebAuthnInput,
|
||||
GenerateRecoveryCodesInput,
|
||||
LoginWebAuthnInput,
|
||||
ReauthenticateWebAuthnInput,
|
||||
SignupWebAuthnInput,
|
||||
UpdateWebAuthnInput,
|
||||
)
|
||||
from allauth.mfa.adapter import DefaultMFAAdapter, get_adapter
|
||||
from allauth.mfa.internal.flows import add
|
||||
from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.recovery_codes.internal import flows as recovery_codes_flows
|
||||
from allauth.mfa.stages import AuthenticateStage
|
||||
from allauth.mfa.totp.internal import auth as totp_auth, flows as totp_flows
|
||||
from allauth.mfa.webauthn.internal import (
|
||||
auth as webauthn_auth,
|
||||
flows as webauthn_flows,
|
||||
)
|
||||
from allauth.mfa.webauthn.stages import PasskeySignupStage
|
||||
|
||||
|
||||
def _validate_can_add_authenticator(request):
|
||||
try:
|
||||
add.validate_can_add_authenticator(request.user)
|
||||
except ValidationError as e:
|
||||
return ErrorResponse(request, status=409, exception=e)
|
||||
|
||||
|
||||
class AuthenticateView(AuthenticationStageAPIView):
|
||||
input_class = AuthenticateInput
|
||||
stage_class = AuthenticateStage
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
self.input.save()
|
||||
return self.respond_next_stage()
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.stage.login.user}
|
||||
|
||||
|
||||
class ReauthenticateView(AuthenticatedAPIView):
|
||||
input_class = AuthenticateInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
self.input.save()
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
class AuthenticatorsView(AuthenticatedAPIView):
|
||||
def get(self, request, *args, **kwargs):
|
||||
authenticators = Authenticator.objects.filter(user=request.user)
|
||||
return response.AuthenticatorsResponse(request, authenticators)
|
||||
|
||||
|
||||
class ManageTOTPView(AuthenticatedAPIView):
|
||||
input_class = {"POST": ActivateTOTPInput}
|
||||
|
||||
def get(self, request, *args, **kwargs) -> APIResponse:
|
||||
authenticator = self._get_authenticator()
|
||||
if not authenticator:
|
||||
err = _validate_can_add_authenticator(request)
|
||||
if err:
|
||||
return err
|
||||
adapter: DefaultMFAAdapter = get_adapter()
|
||||
secret = totp_auth.get_totp_secret(regenerate=True)
|
||||
totp_url: str = adapter.build_totp_url(request.user, secret)
|
||||
return response.TOTPNotFoundResponse(request, secret, totp_url)
|
||||
return response.TOTPResponse(request, authenticator)
|
||||
|
||||
def _get_authenticator(self):
|
||||
return Authenticator.objects.filter(
|
||||
type=Authenticator.Type.TOTP, user=self.request.user
|
||||
).first()
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = totp_flows.activate_totp(request, self.input)[0]
|
||||
return response.TOTPResponse(request, authenticator)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
authenticator = self._get_authenticator()
|
||||
if authenticator:
|
||||
authenticator = totp_flows.deactivate_totp(request, authenticator)
|
||||
return response.AuthenticatorDeletedResponse(request)
|
||||
|
||||
|
||||
class ManageRecoveryCodesView(AuthenticatedAPIView):
|
||||
input_class = GenerateRecoveryCodesInput
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
authenticator = recovery_codes_flows.view_recovery_codes(request)
|
||||
if not authenticator:
|
||||
return response.RecoveryCodesNotFoundResponse(request)
|
||||
return response.RecoveryCodesResponse(request, authenticator)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = recovery_codes_flows.generate_recovery_codes(request)
|
||||
return response.RecoveryCodesResponse(request, authenticator)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
class ManageWebAuthnView(AuthenticatedAPIView):
|
||||
input_class = {
|
||||
"POST": AddWebAuthnInput,
|
||||
"PUT": UpdateWebAuthnInput,
|
||||
"DELETE": DeleteWebAuthnInput,
|
||||
}
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
if request.method in ["GET", "POST"]:
|
||||
err = _validate_can_add_authenticator(request)
|
||||
if err:
|
||||
return err
|
||||
return super().handle(request, *args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
passwordless = "passwordless" in request.GET
|
||||
creation_options = webauthn_flows.begin_registration(
|
||||
request, request.user, passwordless
|
||||
)
|
||||
return response.AddWebAuthnResponse(request, creation_options)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
auth, rc_auth = webauthn_flows.add_authenticator(
|
||||
request,
|
||||
name=self.input.cleaned_data["name"],
|
||||
credential=self.input.cleaned_data["credential"],
|
||||
)
|
||||
did_generate_recovery_codes = bool(rc_auth)
|
||||
return response.AuthenticatorResponse(
|
||||
request,
|
||||
auth,
|
||||
meta={"recovery_codes_generated": did_generate_recovery_codes},
|
||||
)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
authenticator = self.input.cleaned_data["id"]
|
||||
webauthn_flows.rename_authenticator(
|
||||
request, authenticator, self.input.cleaned_data["name"]
|
||||
)
|
||||
return response.AuthenticatorResponse(request, authenticator)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
authenticators = self.input.cleaned_data["authenticators"]
|
||||
webauthn_flows.remove_authenticators(request, authenticators)
|
||||
return response.AuthenticatorsDeletedResponse(request)
|
||||
|
||||
|
||||
class ReauthenticateWebAuthnView(AuthenticatedAPIView):
|
||||
input_class = {
|
||||
"POST": ReauthenticateWebAuthnInput,
|
||||
}
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
request_options = webauthn_auth.begin_authentication(request.user)
|
||||
return response.WebAuthnRequestOptionsResponse(request, request_options)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = self.input.cleaned_data["credential"]
|
||||
webauthn_flows.reauthenticate(request, authenticator)
|
||||
return AuthenticationResponse(self.request)
|
||||
|
||||
|
||||
class AuthenticateWebAuthnView(AuthenticationStageAPIView):
|
||||
input_class = {
|
||||
"POST": AuthenticateWebAuthnInput,
|
||||
}
|
||||
stage_class = AuthenticateStage
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
request_options = webauthn_auth.begin_authentication(self.stage.login.user)
|
||||
return response.WebAuthnRequestOptionsResponse(request, request_options)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.stage.login.user}
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
self.input.save()
|
||||
return self.respond_next_stage()
|
||||
|
||||
|
||||
class LoginWebAuthnView(APIView):
|
||||
input_class = {
|
||||
"POST": LoginWebAuthnInput,
|
||||
}
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
request_options = webauthn_auth.begin_authentication()
|
||||
return response.WebAuthnRequestOptionsResponse(request, request_options)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
authenticator = self.input.cleaned_data["credential"]
|
||||
redirect_url = None
|
||||
login = Login(user=authenticator.user, redirect_url=redirect_url)
|
||||
webauthn_flows.perform_passwordless_login(request, authenticator, login)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
|
||||
class SignupWebAuthnView(SignupView):
|
||||
input_class = {
|
||||
"POST": SignupWebAuthnInput,
|
||||
"PUT": CreateWebAuthnInput,
|
||||
}
|
||||
by_passkey = True
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
resp = self._require_stage()
|
||||
if resp:
|
||||
return resp
|
||||
creation_options = webauthn_flows.begin_registration(
|
||||
request, self.stage.login.user, passwordless=True, signup=True
|
||||
)
|
||||
return response.AddWebAuthnResponse(request, creation_options)
|
||||
|
||||
def _prep_stage(self):
|
||||
if hasattr(self, "stage"):
|
||||
return self.stage
|
||||
self.stage = get_pending_stage(self.request)
|
||||
return self.stage
|
||||
|
||||
def _require_stage(self):
|
||||
self._prep_stage()
|
||||
if not self.stage or self.stage.key != PasskeySignupStage.key:
|
||||
return ConflictResponse(self.request)
|
||||
return None
|
||||
|
||||
def get_input_kwargs(self):
|
||||
ret = super().get_input_kwargs()
|
||||
self._prep_stage()
|
||||
if self.stage and self.request.method == "PUT":
|
||||
ret["user"] = self.stage.login.user
|
||||
return ret
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
resp = self._require_stage()
|
||||
if resp:
|
||||
return resp
|
||||
webauthn_flows.signup_authenticator(
|
||||
request,
|
||||
user=self.stage.login.user,
|
||||
name=self.input.cleaned_data["name"],
|
||||
credential=self.input.cleaned_data["credential"],
|
||||
)
|
||||
self.stage.exit()
|
||||
return AuthenticationResponse(request)
|
||||
@@ -0,0 +1,33 @@
|
||||
from django import forms
|
||||
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.core import context
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.socialaccount.adapter import (
|
||||
get_adapter as get_socialaccount_adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.base.constants import AuthProcess
|
||||
|
||||
|
||||
class RedirectToProviderForm(forms.Form):
|
||||
provider = forms.CharField()
|
||||
callback_url = forms.CharField()
|
||||
process = forms.ChoiceField(
|
||||
choices=[
|
||||
(AuthProcess.LOGIN, AuthProcess.LOGIN),
|
||||
(AuthProcess.CONNECT, AuthProcess.CONNECT),
|
||||
]
|
||||
)
|
||||
|
||||
def clean_callback_url(self):
|
||||
url = self.cleaned_data["callback_url"]
|
||||
if not get_account_adapter().is_safe_url(url):
|
||||
raise get_adapter().validation_error("invalid_url")
|
||||
return url
|
||||
|
||||
def clean_provider(self):
|
||||
provider_id = self.cleaned_data["provider"]
|
||||
provider = get_socialaccount_adapter().get_provider(
|
||||
context.request, provider_id
|
||||
)
|
||||
return provider
|
||||
@@ -0,0 +1,116 @@
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from allauth.core import context
|
||||
from allauth.headless.adapter import get_adapter
|
||||
from allauth.headless.internal.restkit import inputs
|
||||
from allauth.socialaccount.adapter import (
|
||||
get_adapter as get_socialaccount_adapter,
|
||||
)
|
||||
from allauth.socialaccount.forms import SignupForm
|
||||
from allauth.socialaccount.models import SocialAccount, SocialApp
|
||||
from allauth.socialaccount.providers import registry
|
||||
from allauth.socialaccount.providers.base.constants import AuthProcess
|
||||
|
||||
|
||||
class SignupInput(SignupForm, inputs.Input):
|
||||
pass
|
||||
|
||||
|
||||
class DeleteProviderAccountInput(inputs.Input):
|
||||
provider = inputs.CharField()
|
||||
account = inputs.CharField()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
uid = cleaned_data.get("account")
|
||||
provider_id = cleaned_data.get("provider")
|
||||
if uid and provider_id:
|
||||
accounts = SocialAccount.objects.filter(user=self.user)
|
||||
account = accounts.filter(
|
||||
uid=uid,
|
||||
provider=provider_id,
|
||||
).first()
|
||||
if not account:
|
||||
raise get_adapter().validation_error("account_not_found")
|
||||
get_socialaccount_adapter().validate_disconnect(account, accounts)
|
||||
self.cleaned_data["account"] = account
|
||||
return cleaned_data
|
||||
|
||||
|
||||
class ProviderTokenInput(inputs.Input):
|
||||
provider = inputs.CharField()
|
||||
process = inputs.ChoiceField(
|
||||
choices=[
|
||||
(AuthProcess.LOGIN, AuthProcess.LOGIN),
|
||||
(AuthProcess.CONNECT, AuthProcess.CONNECT),
|
||||
]
|
||||
)
|
||||
token = inputs.Field()
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
token = self.data.get("token")
|
||||
adapter = get_adapter()
|
||||
if not isinstance(token, dict):
|
||||
self.add_error("token", adapter.validation_error("invalid_token"))
|
||||
token = None
|
||||
|
||||
provider_id = cleaned_data.get("provider")
|
||||
provider = None
|
||||
if provider_id and token:
|
||||
provider_class = registry.get_class(provider_id)
|
||||
# If `provider_id` is a sub provider ID we won't find it by class.
|
||||
client_id_required = provider_class is None or provider_class.uses_apps
|
||||
client_id = token.get("client_id")
|
||||
if client_id_required and not isinstance(client_id, str):
|
||||
self.add_error("token", adapter.validation_error("client_id_required"))
|
||||
else:
|
||||
try:
|
||||
provider = get_socialaccount_adapter().get_provider(
|
||||
context.request, provider_id, client_id=client_id
|
||||
)
|
||||
except SocialApp.DoesNotExist:
|
||||
self.add_error("token", adapter.validation_error("invalid_token"))
|
||||
else:
|
||||
if not provider.supports_token_authentication:
|
||||
self.add_error(
|
||||
"provider",
|
||||
adapter.validation_error(
|
||||
"token_authentication_not_supported"
|
||||
),
|
||||
)
|
||||
elif (
|
||||
provider.uses_apps
|
||||
and client_id
|
||||
and provider.app.client_id != client_id
|
||||
):
|
||||
self.add_error(
|
||||
"token", adapter.validation_error("client_id_mismatch")
|
||||
)
|
||||
else:
|
||||
id_token = token.get("id_token")
|
||||
access_token = token.get("access_token")
|
||||
if (
|
||||
(id_token is not None and not isinstance(id_token, str))
|
||||
or (
|
||||
access_token is not None
|
||||
and not isinstance(access_token, str)
|
||||
)
|
||||
or (not id_token and not access_token)
|
||||
):
|
||||
self.add_error(
|
||||
"token", adapter.validation_error("token_required")
|
||||
)
|
||||
if not self.errors:
|
||||
cleaned_data["provider"] = provider
|
||||
try:
|
||||
login = provider.verify_token(context.request, token)
|
||||
login.state["process"] = cleaned_data["process"]
|
||||
cleaned_data["sociallogin"] = login
|
||||
except ValidationError as e:
|
||||
self.add_error("token", e)
|
||||
return cleaned_data
|
||||
@@ -0,0 +1,97 @@
|
||||
from django.core.exceptions import PermissionDenied, ValidationError
|
||||
from django.http import HttpResponseRedirect
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.core.exceptions import (
|
||||
ImmediateHttpResponse,
|
||||
ReauthenticationRequired,
|
||||
SignupClosedException,
|
||||
)
|
||||
from allauth.core.internal import httpkit
|
||||
from allauth.headless.internal.authkit import AuthenticationStatus
|
||||
from allauth.socialaccount.internal import flows, statekit
|
||||
from allauth.socialaccount.providers.base.constants import (
|
||||
AuthError,
|
||||
AuthProcess,
|
||||
)
|
||||
|
||||
|
||||
def on_authentication_error(
|
||||
request,
|
||||
provider,
|
||||
error=None,
|
||||
exception=None,
|
||||
extra_context=None,
|
||||
state_id=None,
|
||||
):
|
||||
"""
|
||||
Called at a time when it is not clear whether or not this is a headless flow.
|
||||
"""
|
||||
state = None
|
||||
if extra_context:
|
||||
state = extra_context.get("state")
|
||||
if state is None:
|
||||
state_id = extra_context.get("state_id")
|
||||
if state_id:
|
||||
state = statekit.unstash_state(request, state_id)
|
||||
params = {"error": error}
|
||||
if state is not None:
|
||||
headless = state.get("headless")
|
||||
next_url = state.get("next")
|
||||
params["error_process"] = state["process"]
|
||||
else:
|
||||
headless = allauth_settings.HEADLESS_ONLY
|
||||
next_url = None
|
||||
params["error_process"] = AuthProcess.LOGIN
|
||||
if not headless:
|
||||
return
|
||||
if not next_url:
|
||||
next_url = httpkit.get_frontend_url(request, "socialaccount_login_error") or "/"
|
||||
next_url = httpkit.add_query_params(next_url, params)
|
||||
raise ImmediateHttpResponse(HttpResponseRedirect(next_url))
|
||||
|
||||
|
||||
def complete_token_login(request, sociallogin):
|
||||
flows.login.complete_login(request, sociallogin, raises=True)
|
||||
|
||||
|
||||
def complete_login(request, sociallogin):
|
||||
"""
|
||||
Called when `sociallogin.is_headless`.
|
||||
"""
|
||||
error = None
|
||||
try:
|
||||
flows.login.complete_login(request, sociallogin, raises=True)
|
||||
except ReauthenticationRequired:
|
||||
error = "reauthentication_required"
|
||||
except SignupClosedException:
|
||||
error = "signup_closed"
|
||||
except PermissionDenied:
|
||||
error = "permission_denied"
|
||||
except ValidationError as e:
|
||||
error = e.code
|
||||
else:
|
||||
# At this stage, we're either:
|
||||
# 1) logged in (or in of the login pipeline stages, such as email verification)
|
||||
# 2) auto signed up -- a pipeline stage, so see 1)
|
||||
# 3) performing a social signup
|
||||
|
||||
# 4) Stopped, due to not being open-for-signup
|
||||
# It would be good to refactor the above into a more generic social login
|
||||
# pipeline with clear stages, but for now the /auth endpoint properly responds
|
||||
status = AuthenticationStatus(request)
|
||||
if all(
|
||||
[
|
||||
not status.is_authenticated,
|
||||
not status.has_pending_signup,
|
||||
not status.get_pending_stage(),
|
||||
]
|
||||
):
|
||||
error = AuthError.UNKNOWN
|
||||
next_url = sociallogin.state["next"]
|
||||
if error:
|
||||
next_url = httpkit.add_query_params(
|
||||
next_url,
|
||||
{"error": error, "error_process": sociallogin.state["process"]},
|
||||
)
|
||||
return HttpResponseRedirect(next_url)
|
||||
@@ -0,0 +1,97 @@
|
||||
from allauth.headless.base.response import APIResponse
|
||||
from allauth.headless.constants import Client, Flow
|
||||
from allauth.socialaccount.adapter import (
|
||||
get_adapter as get_socialaccount_adapter,
|
||||
)
|
||||
from allauth.socialaccount.internal.flows import signup
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
def _provider_data(request, provider):
|
||||
ret = {"id": provider.sub_id, "name": provider.name, "flows": []}
|
||||
if provider.supports_redirect:
|
||||
ret["flows"].append(Flow.PROVIDER_REDIRECT)
|
||||
if provider.supports_token_authentication:
|
||||
ret["flows"].append(Flow.PROVIDER_TOKEN)
|
||||
if isinstance(provider, OAuth2Provider):
|
||||
ret["client_id"] = provider.app.client_id
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def provider_flows(request):
|
||||
flows = []
|
||||
providers = _list_supported_providers(request)
|
||||
if providers:
|
||||
redirect_providers = [p.id for p in providers if p.supports_redirect]
|
||||
token_providers = [p.id for p in providers if p.supports_token_authentication]
|
||||
if redirect_providers and request.allauth.headless.client == Client.BROWSER:
|
||||
flows.append(
|
||||
{
|
||||
"id": Flow.PROVIDER_REDIRECT,
|
||||
"providers": redirect_providers,
|
||||
}
|
||||
)
|
||||
if token_providers:
|
||||
flows.append(
|
||||
{
|
||||
"id": Flow.PROVIDER_TOKEN,
|
||||
"providers": token_providers,
|
||||
}
|
||||
)
|
||||
sociallogin = signup.get_pending_signup(request)
|
||||
if sociallogin:
|
||||
flows.append(_signup_flow(request, sociallogin))
|
||||
return flows
|
||||
|
||||
|
||||
def _signup_flow(request, sociallogin):
|
||||
provider = sociallogin.account.get_provider()
|
||||
flow = {
|
||||
"id": Flow.PROVIDER_SIGNUP,
|
||||
"provider": _provider_data(request, provider),
|
||||
"is_pending": True,
|
||||
}
|
||||
return flow
|
||||
|
||||
|
||||
def _is_provider_supported(provider, client):
|
||||
if client == Client.APP:
|
||||
return provider.supports_token_authentication
|
||||
elif client == Client.BROWSER:
|
||||
return provider.supports_redirect
|
||||
return False
|
||||
|
||||
|
||||
def _list_supported_providers(request):
|
||||
adapter = get_socialaccount_adapter()
|
||||
providers = adapter.list_providers(request)
|
||||
providers = [
|
||||
p
|
||||
for p in providers
|
||||
if _is_provider_supported(p, request.allauth.headless.client)
|
||||
]
|
||||
return providers
|
||||
|
||||
|
||||
def get_config_data(request):
|
||||
entries = []
|
||||
data = {"socialaccount": {"providers": entries}}
|
||||
providers = _list_supported_providers(request)
|
||||
providers = sorted(providers, key=lambda p: p.name)
|
||||
for provider in providers:
|
||||
entries.append(_provider_data(request, provider))
|
||||
return data
|
||||
|
||||
|
||||
class SocialAccountsResponse(APIResponse):
|
||||
def __init__(self, request, accounts):
|
||||
data = [
|
||||
{
|
||||
"uid": account.uid,
|
||||
"provider": _provider_data(request, account.get_provider()),
|
||||
"display": account.get_provider_account().to_str(),
|
||||
}
|
||||
for account in accounts
|
||||
]
|
||||
super().__init__(request, data=data)
|
||||
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
|
||||
from allauth.headless.socialaccount.inputs import ProviderTokenInput
|
||||
|
||||
|
||||
@pytest.mark.parametrize("client_id", ["client1", "client2"])
|
||||
def test_provider_token_multiple_apps(settings, db, client_id):
|
||||
gsettings = {
|
||||
"APPS": [
|
||||
{"client_id": "client1", "secret": "secret"},
|
||||
{"client_id": "client2", "secret": "secret"},
|
||||
]
|
||||
}
|
||||
settings.SOCIALACCOUNT_PROVIDERS = {"google": gsettings}
|
||||
|
||||
inp = ProviderTokenInput(
|
||||
{
|
||||
"provider": "google",
|
||||
"process": "login",
|
||||
"token": {"client_id": client_id, "id_token": "it", "access_token": "at"},
|
||||
}
|
||||
)
|
||||
assert not inp.is_valid()
|
||||
assert inp.cleaned_data["provider"].app.client_id == client_id
|
||||
assert inp.errors == {"token": ["Invalid token."]}
|
||||
|
||||
|
||||
def test_provider_token_client_id_required(settings, db):
|
||||
inp = ProviderTokenInput(
|
||||
{
|
||||
"provider": "google",
|
||||
"process": "login",
|
||||
"token": {"id_token": "it", "access_token": "at"},
|
||||
}
|
||||
)
|
||||
assert not inp.is_valid()
|
||||
assert inp.errors == {"token": ["`client_id` required."]}
|
||||
@@ -0,0 +1,410 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from pytest_django.asserts import assertTemplateUsed
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from allauth.socialaccount.providers.base.constants import AuthProcess
|
||||
|
||||
|
||||
def test_bad_redirect(client, headless_reverse, db, settings):
|
||||
settings.HEADLESS_ONLY = False
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:redirect_to_provider"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"callback_url": "https://unsafe.org/hack",
|
||||
"process": AuthProcess.LOGIN,
|
||||
},
|
||||
)
|
||||
assertTemplateUsed(resp, "socialaccount/authentication_error.html")
|
||||
|
||||
|
||||
def test_valid_redirect(client, headless_reverse, db):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:redirect_to_provider"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"callback_url": "/",
|
||||
"process": AuthProcess.LOGIN,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
|
||||
|
||||
def test_manage_providers(auth_client, user, headless_reverse, provider_id):
|
||||
account_to_del = SocialAccount.objects.create(
|
||||
user=user, provider=provider_id, uid="p123"
|
||||
)
|
||||
account_to_keep = SocialAccount.objects.create(
|
||||
user=user, provider=provider_id, uid="p456"
|
||||
)
|
||||
resp = auth_client.get(
|
||||
headless_reverse("headless:socialaccount:manage_providers"),
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["status"] == 200
|
||||
assert len(data["data"]) == 2
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:socialaccount:manage_providers"),
|
||||
data={"provider": account_to_del.provider, "account": account_to_del.uid},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {
|
||||
"status": 200,
|
||||
"data": [
|
||||
{
|
||||
"display": "Unittest Server",
|
||||
"provider": {
|
||||
"client_id": "Unittest client_id",
|
||||
"flows": ["provider_redirect"],
|
||||
"id": provider_id,
|
||||
"name": "Unittest Server",
|
||||
},
|
||||
"uid": "p456",
|
||||
}
|
||||
],
|
||||
}
|
||||
assert not SocialAccount.objects.filter(pk=account_to_del.pk).exists()
|
||||
assert SocialAccount.objects.filter(pk=account_to_keep.pk).exists()
|
||||
|
||||
|
||||
def test_disconnect_bad_request(auth_client, user, headless_reverse, provider_id):
|
||||
resp = auth_client.delete(
|
||||
headless_reverse("headless:socialaccount:manage_providers"),
|
||||
data={"provider": provider_id, "account": "unknown"},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [{"code": "account_not_found", "message": "Unknown account."}],
|
||||
}
|
||||
|
||||
|
||||
def test_valid_token(client, headless_reverse, db):
|
||||
id_token = json.dumps(
|
||||
{
|
||||
"id": 123,
|
||||
"email": "a@b.com",
|
||||
"email_verified": True,
|
||||
}
|
||||
)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_token"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"token": {
|
||||
"id_token": id_token,
|
||||
},
|
||||
"process": AuthProcess.LOGIN,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert EmailAddress.objects.filter(email="a@b.com", verified=True).exists()
|
||||
|
||||
|
||||
def test_invalid_token(client, headless_reverse, db, google_provider_settings):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_token"),
|
||||
data={
|
||||
"provider": "google",
|
||||
"token": {
|
||||
"id_token": "dummy",
|
||||
"client_id": google_provider_settings["APPS"][0]["client_id"],
|
||||
},
|
||||
"process": AuthProcess.LOGIN,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
data = resp.json()
|
||||
assert data == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{"message": "Invalid token.", "code": "invalid_token", "param": "token"}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_auth_error_no_headless_request(client, db, google_provider_settings, settings):
|
||||
"""Authentication errors use the regular "Third-Party Login Failure"
|
||||
template if headless is not used.
|
||||
"""
|
||||
settings.HEADLESS_ONLY = False
|
||||
resp = client.get(reverse("google_callback"))
|
||||
assertTemplateUsed(resp, "socialaccount/authentication_error.html")
|
||||
|
||||
|
||||
def test_auth_error_headless_request(
|
||||
client, db, google_provider_settings, sociallogin_setup_state
|
||||
):
|
||||
"""Authentication errors redirect to the next URL with ?error params for
|
||||
headless requests.
|
||||
"""
|
||||
state = sociallogin_setup_state(client, headless=True, next="/foo")
|
||||
resp = client.get(reverse("google_callback") + f"?state={state}")
|
||||
assert resp["location"] == "/foo?error=unknown&error_process=login"
|
||||
|
||||
|
||||
def test_auth_error_no_headless_state_request_headless_only(
|
||||
settings, client, db, google_provider_settings
|
||||
):
|
||||
"""Authentication errors redirect to a fallback error URL for headless-only,
|
||||
in case no next can be recovered from the state.
|
||||
"""
|
||||
settings.HEADLESS_ONLY = True
|
||||
settings.HEADLESS_FRONTEND_URLS = {"socialaccount_login_error": "/3rdparty/failure"}
|
||||
resp = client.get(reverse("google_callback"))
|
||||
assert (
|
||||
resp["location"]
|
||||
== "http://testserver/3rdparty/failure?error=unknown&error_process=login"
|
||||
)
|
||||
|
||||
|
||||
def test_auth_error_headless_state_request_headless_only(
|
||||
settings, client, db, google_provider_settings, sociallogin_setup_state
|
||||
):
|
||||
"""Authentication errors redirect to a fallback error URL for headless-only,
|
||||
in case no next can be recovered from the state.
|
||||
"""
|
||||
state = sociallogin_setup_state(client, headless=True, next="/foo")
|
||||
settings.HEADLESS_ONLY = True
|
||||
settings.HEADLESS_FRONTEND_URLS = {"socialaccount_login_error": "/3rdparty/failure"}
|
||||
resp = client.get(reverse("google_callback") + f"?state={state}")
|
||||
assert resp["location"] == "/foo?error=unknown&error_process=login"
|
||||
|
||||
|
||||
def test_token_signup_closed(client, headless_reverse, db):
|
||||
id_token = json.dumps(
|
||||
{
|
||||
"id": 123,
|
||||
"email": "a@b.com",
|
||||
"email_verified": True,
|
||||
}
|
||||
)
|
||||
with patch(
|
||||
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.is_open_for_signup"
|
||||
) as iofs:
|
||||
iofs.return_value = False
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_token"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"token": {
|
||||
"id_token": id_token,
|
||||
},
|
||||
"process": AuthProcess.LOGIN,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert not EmailAddress.objects.filter(email="a@b.com", verified=True).exists()
|
||||
|
||||
|
||||
def test_provider_signup(client, headless_reverse, db, settings):
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_EMAIL_REQUIRED = True
|
||||
settings.ACCOUNT_USERNAME_REQUIRED = False
|
||||
id_token = json.dumps(
|
||||
{
|
||||
"id": 123,
|
||||
}
|
||||
)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_token"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"token": {
|
||||
"id_token": id_token,
|
||||
},
|
||||
"process": AuthProcess.LOGIN,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0]
|
||||
assert pending_flow["id"] == "provider_signup"
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_signup"),
|
||||
data={
|
||||
"email": "a@b.com",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0]
|
||||
assert pending_flow["id"] == "verify_email"
|
||||
assert EmailAddress.objects.filter(email="a@b.com").exists()
|
||||
|
||||
|
||||
def test_signup_closed(client, headless_reverse, db, settings):
|
||||
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
|
||||
settings.ACCOUNT_EMAIL_REQUIRED = True
|
||||
settings.ACCOUNT_USERNAME_REQUIRED = False
|
||||
id_token = json.dumps(
|
||||
{
|
||||
"id": 123,
|
||||
}
|
||||
)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_token"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"token": {
|
||||
"id_token": id_token,
|
||||
},
|
||||
"process": AuthProcess.LOGIN,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0]
|
||||
assert pending_flow["id"] == "provider_signup"
|
||||
with patch(
|
||||
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.is_open_for_signup"
|
||||
) as iofs:
|
||||
iofs.return_value = False
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_signup"),
|
||||
data={
|
||||
"email": "a@b.com",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_connect(user, auth_client, sociallogin_setup_state, headless_reverse, db):
|
||||
state = sociallogin_setup_state(
|
||||
auth_client, process="connect", next="/foo", headless=True
|
||||
)
|
||||
resp = auth_client.post(
|
||||
reverse("dummy_authenticate") + f"?state={state}",
|
||||
data={
|
||||
"id": 123,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert resp["location"] == "/foo"
|
||||
assert SocialAccount.objects.filter(user=user, provider="dummy", uid="123").exists()
|
||||
|
||||
|
||||
def test_connect_reauthentication_required(
|
||||
user, auth_client, sociallogin_setup_state, headless_reverse, db, settings
|
||||
):
|
||||
settings.ACCOUNT_REAUTHENTICATION_REQUIRED = True
|
||||
|
||||
state = sociallogin_setup_state(
|
||||
auth_client, process="connect", next="/foo", headless=True
|
||||
)
|
||||
resp = auth_client.post(
|
||||
reverse("dummy_authenticate") + f"?state={state}",
|
||||
data={
|
||||
"id": 123,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert (
|
||||
resp["location"] == "/foo?error=reauthentication_required&error_process=connect"
|
||||
)
|
||||
|
||||
|
||||
def test_connect_already_connected(
|
||||
user, user_factory, auth_client, sociallogin_setup_state, headless_reverse, db
|
||||
):
|
||||
# The other user already connected the account.
|
||||
other_user = user_factory()
|
||||
SocialAccount.objects.create(user=other_user, uid="123", provider="dummy")
|
||||
# Then, this user tries to connect...
|
||||
state = sociallogin_setup_state(
|
||||
auth_client, process=AuthProcess.CONNECT, next="/foo", headless=True
|
||||
)
|
||||
resp = auth_client.post(
|
||||
reverse("dummy_authenticate") + f"?state={state}",
|
||||
data={
|
||||
"id": 123,
|
||||
},
|
||||
)
|
||||
# We're redirected, and an error code is shown.
|
||||
assert resp.status_code == 302
|
||||
assert resp["location"] == "/foo?error=connected_other&error_process=connect"
|
||||
assert not SocialAccount.objects.filter(
|
||||
user=user, provider="dummy", uid="123"
|
||||
).exists()
|
||||
|
||||
|
||||
def test_token_connect(user, auth_client, headless_reverse, db):
|
||||
id_token = json.dumps(
|
||||
{
|
||||
"id": 123,
|
||||
"email": "a@b.com",
|
||||
"email_verified": True,
|
||||
}
|
||||
)
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:socialaccount:provider_token"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"token": {
|
||||
"id_token": id_token,
|
||||
},
|
||||
"process": AuthProcess.CONNECT,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert SocialAccount.objects.filter(uid="123", user=user).exists()
|
||||
|
||||
|
||||
def test_token_connect_already_connected(
|
||||
user, auth_client, headless_reverse, db, user_factory
|
||||
):
|
||||
# The other user already connected the account.
|
||||
other_user = user_factory()
|
||||
SocialAccount.objects.create(user=other_user, uid="123", provider="dummy")
|
||||
id_token = json.dumps(
|
||||
{
|
||||
"id": 123,
|
||||
"email": "a@b.com",
|
||||
"email_verified": True,
|
||||
}
|
||||
)
|
||||
resp = auth_client.post(
|
||||
headless_reverse("headless:socialaccount:provider_token"),
|
||||
data={
|
||||
"provider": "dummy",
|
||||
"token": {
|
||||
"id_token": id_token,
|
||||
},
|
||||
"process": AuthProcess.CONNECT,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert not SocialAccount.objects.filter(uid="123", user=user).exists()
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {
|
||||
"status": 400,
|
||||
"errors": [
|
||||
{
|
||||
"code": "connected_other",
|
||||
"message": "The third-party account is already connected to a different account.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_provider_signup_not_pending(client, headless_reverse, db, settings):
|
||||
resp = client.post(
|
||||
headless_reverse("headless:socialaccount:provider_signup"),
|
||||
data={
|
||||
"email": "a@b.com",
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
@@ -0,0 +1,51 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from allauth.headless.socialaccount import views
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
return [
|
||||
path(
|
||||
"account/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"providers",
|
||||
views.ManageProvidersView.as_api_view(client=client),
|
||||
name="manage_providers",
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
path(
|
||||
"auth/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"provider/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"signup",
|
||||
views.ProviderSignupView.as_api_view(client=client),
|
||||
name="provider_signup",
|
||||
),
|
||||
path(
|
||||
"redirect",
|
||||
views.RedirectToProviderView.as_api_view(
|
||||
client=client
|
||||
),
|
||||
name="redirect_to_provider",
|
||||
),
|
||||
path(
|
||||
"token",
|
||||
views.ProviderTokenView.as_api_view(client=client),
|
||||
name="provider_token",
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,102 @@
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from allauth.core.exceptions import SignupClosedException
|
||||
from allauth.headless.base.response import (
|
||||
AuthenticationResponse,
|
||||
ConflictResponse,
|
||||
ForbiddenResponse,
|
||||
)
|
||||
from allauth.headless.base.views import APIView, AuthenticatedAPIView
|
||||
from allauth.headless.internal.restkit.response import ErrorResponse
|
||||
from allauth.headless.socialaccount.forms import RedirectToProviderForm
|
||||
from allauth.headless.socialaccount.inputs import (
|
||||
DeleteProviderAccountInput,
|
||||
ProviderTokenInput,
|
||||
SignupInput,
|
||||
)
|
||||
from allauth.headless.socialaccount.internal import complete_token_login
|
||||
from allauth.headless.socialaccount.response import SocialAccountsResponse
|
||||
from allauth.socialaccount.adapter import (
|
||||
get_adapter as get_socialaccount_adapter,
|
||||
)
|
||||
from allauth.socialaccount.helpers import render_authentication_error
|
||||
from allauth.socialaccount.internal import flows
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
|
||||
|
||||
class ProviderSignupView(APIView):
|
||||
input_class = SignupInput
|
||||
|
||||
def handle(self, request, *args, **kwargs):
|
||||
self.sociallogin = flows.signup.get_pending_signup(self.request)
|
||||
if not self.sociallogin:
|
||||
return ConflictResponse(request)
|
||||
if not get_socialaccount_adapter().is_open_for_signup(
|
||||
request, self.sociallogin
|
||||
):
|
||||
return ForbiddenResponse(request)
|
||||
return super().handle(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
flows.signup.signup_by_form(self.request, self.sociallogin, self.input)
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"sociallogin": self.sociallogin}
|
||||
|
||||
|
||||
class RedirectToProviderView(APIView):
|
||||
handle_json_input = False
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
form = RedirectToProviderForm(request.POST)
|
||||
if not form.is_valid():
|
||||
return render_authentication_error(
|
||||
request,
|
||||
provider=request.POST.get("provider"),
|
||||
exception=ValidationError(form.errors),
|
||||
)
|
||||
provider = form.cleaned_data["provider"]
|
||||
next_url = form.cleaned_data["callback_url"]
|
||||
process = form.cleaned_data["process"]
|
||||
return provider.redirect(
|
||||
request,
|
||||
process,
|
||||
next_url=next_url,
|
||||
headless=True,
|
||||
)
|
||||
|
||||
|
||||
class ManageProvidersView(AuthenticatedAPIView):
|
||||
input_class = {
|
||||
"DELETE": DeleteProviderAccountInput,
|
||||
}
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.respond_provider_accounts(request)
|
||||
|
||||
@classmethod
|
||||
def respond_provider_accounts(self, request):
|
||||
accounts = SocialAccount.objects.filter(user=request.user)
|
||||
return SocialAccountsResponse(request, accounts)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
flows.connect.disconnect(request, self.input.cleaned_data["account"])
|
||||
return self.respond_provider_accounts(request)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
|
||||
|
||||
class ProviderTokenView(APIView):
|
||||
input_class = ProviderTokenInput
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
sociallogin = self.input.cleaned_data["sociallogin"]
|
||||
try:
|
||||
complete_token_login(request, sociallogin)
|
||||
except ValidationError as e:
|
||||
return ErrorResponse(self.request, exception=e)
|
||||
except SignupClosedException:
|
||||
return ForbiddenResponse(self.request)
|
||||
return AuthenticationResponse(self.request)
|
||||
@@ -0,0 +1,33 @@
|
||||
from allauth.headless.tokens.sessions import SessionTokenStrategy
|
||||
|
||||
|
||||
class DummyAccessTokenStrategy(SessionTokenStrategy):
|
||||
def create_access_token(self, request):
|
||||
return f"at-user-{request.user.pk}"
|
||||
|
||||
|
||||
def test_access_token(
|
||||
client,
|
||||
user,
|
||||
user_password,
|
||||
settings,
|
||||
headless_reverse,
|
||||
headless_client,
|
||||
):
|
||||
settings.HEADLESS_TOKEN_STRATEGY = (
|
||||
"allauth.headless.tests.test_tokens.DummyAccessTokenStrategy"
|
||||
)
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"username": user.username,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["status"] == 200
|
||||
if headless_client == "app":
|
||||
assert data["meta"]["access_token"] == f"at-user-{user.pk}"
|
||||
else:
|
||||
assert "access_token" not in data["meta"]
|
||||
@@ -0,0 +1,60 @@
|
||||
import abc
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from django.contrib.sessions.backends.base import SessionBase
|
||||
from django.http import HttpRequest
|
||||
|
||||
|
||||
class AbstractTokenStrategy(abc.ABC):
|
||||
def get_session_token(self, request: HttpRequest) -> Optional[str]:
|
||||
"""
|
||||
Returns the session token, if any.
|
||||
"""
|
||||
token = request.headers.get("x-session-token")
|
||||
return token
|
||||
|
||||
def create_access_token_payload(
|
||||
self, request: HttpRequest
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
After authenticating, this method is called to create the access
|
||||
token response payload, exposing the access token and possibly other
|
||||
information such as a ``refresh_token`` and ``expires_in``.
|
||||
"""
|
||||
at = self.create_access_token(request)
|
||||
if not at:
|
||||
return None
|
||||
return {"access_token": at}
|
||||
|
||||
def create_access_token(self, request: HttpRequest) -> Optional[str]:
|
||||
"""Create an access token.
|
||||
|
||||
While session tokens are required to handle the authentication process,
|
||||
depending on your requirements, a different type of token may be needed
|
||||
once authenticated.
|
||||
|
||||
For example, your app likely needs access to other APIs as well. These
|
||||
APIs may even be implemented using different technologies, in which case
|
||||
having a stateless token, possibly a JWT encoding the user ID, might be
|
||||
a good fit.
|
||||
|
||||
We make no assumptions in this regard. If you need access tokens, you
|
||||
will have to implement a token strategy that returns an access token
|
||||
here.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_session_token(self, request: HttpRequest) -> str:
|
||||
"""
|
||||
Create a session token for the `request.session`.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def lookup_session(self, session_token: str) -> Optional[SessionBase]:
|
||||
"""
|
||||
Looks up the Django session given the session token. Returns `None`
|
||||
if the session does not / no longer exist.
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,22 @@
|
||||
import typing
|
||||
|
||||
from django.contrib.sessions.backends.base import SessionBase
|
||||
from django.http import HttpRequest
|
||||
|
||||
from allauth.headless.internal import sessionkit
|
||||
from allauth.headless.tokens.base import AbstractTokenStrategy
|
||||
|
||||
|
||||
class SessionTokenStrategy(AbstractTokenStrategy):
|
||||
def create_session_token(self, request: HttpRequest) -> str:
|
||||
if not request.session.session_key:
|
||||
request.session.save()
|
||||
key = request.session.session_key
|
||||
assert isinstance(key, str) # We did save.
|
||||
return key
|
||||
|
||||
def lookup_session(self, session_token: str) -> typing.Optional[SessionBase]:
|
||||
session_key = session_token
|
||||
if sessionkit.session_store().exists(session_key):
|
||||
return sessionkit.session_store(session_key)
|
||||
return None
|
||||
76
.venv/lib/python3.12/site-packages/allauth/headless/urls.py
Normal file
76
.venv/lib/python3.12/site-packages/allauth/headless/urls.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.headless.account import urls as account_urls
|
||||
from allauth.headless.base import urls as base_urls
|
||||
from allauth.headless.constants import Client
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
patterns = []
|
||||
patterns.extend(base_urls.build_urlpatterns(client))
|
||||
patterns.append(
|
||||
path(
|
||||
"",
|
||||
include(
|
||||
(account_urls.build_urlpatterns(client), "headless"),
|
||||
namespace="account",
|
||||
),
|
||||
)
|
||||
)
|
||||
if allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
from allauth.headless.socialaccount import urls as socialaccount_urls
|
||||
|
||||
patterns.append(
|
||||
path(
|
||||
"",
|
||||
include(
|
||||
(socialaccount_urls.build_urlpatterns(client), "headless"),
|
||||
namespace="socialaccount",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if allauth_settings.MFA_ENABLED:
|
||||
from allauth.headless.mfa import urls as mfa_urls
|
||||
|
||||
patterns.append(
|
||||
path(
|
||||
"",
|
||||
include(
|
||||
(mfa_urls.build_urlpatterns(client), "headless"),
|
||||
namespace="mfa",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if allauth_settings.USERSESSIONS_ENABLED:
|
||||
from allauth.headless.usersessions import urls as usersessions_urls
|
||||
|
||||
patterns.append(
|
||||
path(
|
||||
"",
|
||||
include(
|
||||
(usersessions_urls.build_urlpatterns(client), "headless"),
|
||||
namespace="usersessions",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return [path("v1/", include(patterns))]
|
||||
|
||||
|
||||
app_name = "headless"
|
||||
urlpatterns = [
|
||||
path(
|
||||
"browser/",
|
||||
include(
|
||||
(build_urlpatterns(Client.BROWSER), "headless"),
|
||||
namespace="browser",
|
||||
),
|
||||
),
|
||||
path(
|
||||
"app/",
|
||||
include((build_urlpatterns(Client.APP), "headless"), namespace="app"),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,11 @@
|
||||
from allauth.headless.internal.restkit import inputs
|
||||
from allauth.usersessions.models import UserSession
|
||||
|
||||
|
||||
class SelectSessionsInput(inputs.Input):
|
||||
sessions = inputs.ModelMultipleChoiceField(queryset=UserSession.objects.none())
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["sessions"].queryset = UserSession.objects.filter(user=self.user)
|
||||
@@ -0,0 +1,24 @@
|
||||
from allauth.headless.base.response import APIResponse
|
||||
from allauth.usersessions import app_settings
|
||||
|
||||
|
||||
class SessionsResponse(APIResponse):
|
||||
def __init__(self, request, sessions):
|
||||
super().__init__(request, data=[self._session_data(s) for s in sessions])
|
||||
|
||||
def _session_data(self, session):
|
||||
data = {
|
||||
"user_agent": session.user_agent,
|
||||
"ip": session.ip,
|
||||
"created_at": session.created_at.timestamp(),
|
||||
"is_current": session.is_current(),
|
||||
"id": session.pk,
|
||||
}
|
||||
if app_settings.TRACK_ACTIVITY:
|
||||
data["last_seen_at"] = session.last_seen_at.timestamp()
|
||||
return data
|
||||
|
||||
|
||||
def get_config_data(request):
|
||||
data = {"usersessions": {"track_activity": app_settings.TRACK_ACTIVITY}}
|
||||
return data
|
||||
@@ -0,0 +1,27 @@
|
||||
from allauth.usersessions.models import UserSession
|
||||
|
||||
|
||||
def test_flow(client, user, user_password, headless_reverse, settings):
|
||||
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
|
||||
resp = client.post(
|
||||
headless_reverse("headless:account:login"),
|
||||
data={
|
||||
"email": user.email,
|
||||
"password": user_password,
|
||||
},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
resp = client.get(headless_reverse("headless:usersessions:sessions"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["data"]) == 1
|
||||
session_pk = data["data"][0]["id"]
|
||||
assert UserSession.objects.filter(pk=session_pk).exists()
|
||||
resp = client.delete(
|
||||
headless_reverse("headless:usersessions:sessions"),
|
||||
data={"sessions": [session_pk]},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert not UserSession.objects.filter(pk=session_pk).exists()
|
||||
@@ -0,0 +1,20 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from allauth.headless.usersessions import views
|
||||
|
||||
|
||||
def build_urlpatterns(client):
|
||||
return [
|
||||
path(
|
||||
"auth/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"sessions",
|
||||
views.SessionsView.as_api_view(client=client),
|
||||
name="sessions",
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
from allauth.headless.base.response import AuthenticationResponse
|
||||
from allauth.headless.base.views import AuthenticatedAPIView
|
||||
from allauth.headless.usersessions.inputs import SelectSessionsInput
|
||||
from allauth.headless.usersessions.response import SessionsResponse
|
||||
from allauth.usersessions.internal import flows
|
||||
from allauth.usersessions.models import UserSession
|
||||
|
||||
|
||||
class SessionsView(AuthenticatedAPIView):
|
||||
input_class = {"DELETE": SelectSessionsInput}
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
sessions = self.input.cleaned_data["sessions"]
|
||||
flows.sessions.end_sessions(request, sessions)
|
||||
if self.request.user.is_authenticated:
|
||||
return self._respond_session_list()
|
||||
return AuthenticationResponse(request)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self._respond_session_list()
|
||||
|
||||
def _respond_session_list(self):
|
||||
sessions = UserSession.objects.purge_and_list(self.request.user)
|
||||
return SessionsResponse(self.request, sessions)
|
||||
|
||||
def get_input_kwargs(self):
|
||||
return {"user": self.request.user}
|
||||
Reference in New Issue
Block a user