mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-24 17:31:09 -05:00
okay fine
This commit is contained in:
@@ -0,0 +1,378 @@
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from django.core.exceptions import (
|
||||
ImproperlyConfigured,
|
||||
MultipleObjectsReturned,
|
||||
)
|
||||
from django.db.models import Q
|
||||
from django.urls import reverse
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.utils import user_email, user_field, user_username
|
||||
from allauth.core.internal.adapter import BaseAdapter
|
||||
from allauth.utils import (
|
||||
deserialize_instance,
|
||||
import_attribute,
|
||||
serialize_instance,
|
||||
valid_email_or_none,
|
||||
)
|
||||
|
||||
from . import app_settings
|
||||
|
||||
|
||||
class DefaultSocialAccountAdapter(BaseAdapter):
|
||||
"""The adapter class allows you to override various functionality of the
|
||||
``allauth.socialaccount`` app. To do so, point ``settings.SOCIALACCOUNT_ADAPTER`` to
|
||||
your own class that derives from ``DefaultSocialAccountAdapter`` and override the
|
||||
behavior by altering the implementation of the methods according to your own
|
||||
needs.
|
||||
"""
|
||||
|
||||
error_messages = {
|
||||
"email_taken": _(
|
||||
"An account already exists with this email address."
|
||||
" Please sign in to that account first, then connect"
|
||||
" your %s account."
|
||||
),
|
||||
"invalid_token": _("Invalid token."),
|
||||
"no_password": _("Your account has no password set up."),
|
||||
"no_verified_email": _("Your account has no verified email address."),
|
||||
"disconnect_last": _(
|
||||
"You cannot disconnect your last remaining third-party account."
|
||||
),
|
||||
"connected_other": _(
|
||||
"The third-party account is already connected to a different account."
|
||||
),
|
||||
}
|
||||
|
||||
def pre_social_login(self, request, sociallogin):
|
||||
"""
|
||||
Invoked just after a user successfully authenticates via a
|
||||
social provider, but before the login is actually processed
|
||||
(and before the pre_social_login signal is emitted).
|
||||
|
||||
You can use this hook to intervene, e.g. abort the login by
|
||||
raising an ImmediateHttpResponse
|
||||
|
||||
Why both an adapter hook and the signal? Intervening in
|
||||
e.g. the flow from within a signal handler is bad -- multiple
|
||||
handlers may be active and are executed in undetermined order.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_authentication_error(
|
||||
self,
|
||||
request,
|
||||
provider,
|
||||
error=None,
|
||||
exception=None,
|
||||
extra_context=None,
|
||||
):
|
||||
"""
|
||||
Invoked when there is an error in the authentication cycle. In this
|
||||
case, pre_social_login will not be reached.
|
||||
|
||||
You can use this hook to intervene, e.g. redirect to an
|
||||
educational flow by raising an ImmediateHttpResponse.
|
||||
"""
|
||||
if hasattr(self, "authentication_error"):
|
||||
warnings.warn(
|
||||
"adapter.authentication_error() is deprecated, use adapter.on_authentication_error()"
|
||||
)
|
||||
|
||||
self.authentication_error(
|
||||
request,
|
||||
provider.id,
|
||||
error=error,
|
||||
exception=exception,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
|
||||
def new_user(self, request, sociallogin):
|
||||
"""
|
||||
Instantiates a new User instance.
|
||||
"""
|
||||
return get_account_adapter().new_user(request)
|
||||
|
||||
def save_user(self, request, sociallogin, form=None):
|
||||
"""
|
||||
Saves a newly signed up social login. In case of auto-signup,
|
||||
the signup form is not available.
|
||||
"""
|
||||
u = sociallogin.user
|
||||
u.set_unusable_password()
|
||||
if form:
|
||||
get_account_adapter().save_user(request, u, form)
|
||||
else:
|
||||
get_account_adapter().populate_username(request, u)
|
||||
sociallogin.save(request)
|
||||
return u
|
||||
|
||||
def populate_user(self, request, sociallogin, data):
|
||||
"""
|
||||
Hook that can be used to further populate the user instance.
|
||||
|
||||
For convenience, we populate several common fields.
|
||||
|
||||
Note that the user instance being populated represents a
|
||||
suggested User instance that represents the social user that is
|
||||
in the process of being logged in.
|
||||
|
||||
The User instance need not be completely valid and conflict
|
||||
free. For example, verifying whether or not the username
|
||||
already exists, is not a responsibility.
|
||||
"""
|
||||
username = data.get("username")
|
||||
first_name = data.get("first_name")
|
||||
last_name = data.get("last_name")
|
||||
email = data.get("email")
|
||||
name = data.get("name")
|
||||
user = sociallogin.user
|
||||
user_username(user, username or "")
|
||||
user_email(user, valid_email_or_none(email) or "")
|
||||
name_parts = (name or "").partition(" ")
|
||||
user_field(user, "first_name", first_name or name_parts[0])
|
||||
user_field(user, "last_name", last_name or name_parts[2])
|
||||
return user
|
||||
|
||||
def get_connect_redirect_url(self, request, socialaccount):
|
||||
"""
|
||||
Returns the default URL to redirect to after successfully
|
||||
connecting a social account.
|
||||
"""
|
||||
url = reverse("socialaccount_connections")
|
||||
return url
|
||||
|
||||
def validate_disconnect(self, account, accounts) -> None:
|
||||
"""
|
||||
Validate whether or not the socialaccount account can be
|
||||
safely disconnected.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_auto_signup_allowed(self, request, sociallogin):
|
||||
# If email is specified, check for duplicate and if so, no auto signup.
|
||||
auto_signup = app_settings.AUTO_SIGNUP
|
||||
return auto_signup
|
||||
|
||||
def is_open_for_signup(self, request, sociallogin):
|
||||
"""
|
||||
Checks whether or not the site is open for signups.
|
||||
|
||||
Next to simply returning True/False you can also intervene the
|
||||
regular flow by raising an ImmediateHttpResponse
|
||||
"""
|
||||
return get_account_adapter(request).is_open_for_signup(request)
|
||||
|
||||
def get_signup_form_initial_data(self, sociallogin):
|
||||
user = sociallogin.user
|
||||
initial = {
|
||||
"email": user_email(user) or "",
|
||||
"username": user_username(user) or "",
|
||||
"first_name": user_field(user, "first_name") or "",
|
||||
"last_name": user_field(user, "last_name") or "",
|
||||
}
|
||||
return initial
|
||||
|
||||
def deserialize_instance(self, model, data):
|
||||
return deserialize_instance(model, data)
|
||||
|
||||
def serialize_instance(self, instance):
|
||||
return serialize_instance(instance)
|
||||
|
||||
def list_providers(self, request):
|
||||
from allauth.socialaccount.providers import registry
|
||||
|
||||
ret = []
|
||||
provider_classes = registry.get_class_list()
|
||||
apps = self.list_apps(request)
|
||||
apps_map = {}
|
||||
for app in apps:
|
||||
apps_map.setdefault(app.provider, []).append(app)
|
||||
for provider_class in provider_classes:
|
||||
provider_apps = apps_map.get(provider_class.id, [])
|
||||
if not provider_apps:
|
||||
if provider_class.uses_apps:
|
||||
continue
|
||||
provider_apps = [None]
|
||||
for app in provider_apps:
|
||||
provider = provider_class(request=request, app=app)
|
||||
ret.append(provider)
|
||||
return ret
|
||||
|
||||
def get_provider(self, request, provider, client_id=None):
|
||||
"""Looks up a `provider`, supporting subproviders by looking up by
|
||||
`provider_id`.
|
||||
"""
|
||||
from allauth.socialaccount.providers import registry
|
||||
|
||||
provider_class = registry.get_class(provider)
|
||||
if provider_class is None or provider_class.uses_apps:
|
||||
app = self.get_app(request, provider=provider, client_id=client_id)
|
||||
if not provider_class:
|
||||
# In this case, the `provider` argument passed was a
|
||||
# `provider_id`.
|
||||
provider_class = registry.get_class(app.provider)
|
||||
if not provider_class:
|
||||
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
|
||||
return provider_class(request, app=app)
|
||||
elif provider_class:
|
||||
assert not provider_class.uses_apps
|
||||
return provider_class(request, app=None)
|
||||
else:
|
||||
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
|
||||
|
||||
def list_apps(self, request, provider=None, client_id=None):
|
||||
"""SocialApp's can be setup in the database, or, via
|
||||
`settings.SOCIALACCOUNT_PROVIDERS`. This methods returns a uniform list
|
||||
of all known apps matching the specified criteria, and blends both
|
||||
(db/settings) sources of data.
|
||||
"""
|
||||
# NOTE: Avoid loading models at top due to registry boot...
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
|
||||
# Map provider to the list of apps.
|
||||
provider_to_apps = {}
|
||||
|
||||
# First, populate it with the DB backed apps.
|
||||
if request:
|
||||
db_apps = SocialApp.objects.on_site(request)
|
||||
else:
|
||||
db_apps = SocialApp.objects.all()
|
||||
if provider:
|
||||
db_apps = db_apps.filter(Q(provider=provider) | Q(provider_id=provider))
|
||||
if client_id:
|
||||
db_apps = db_apps.filter(client_id=client_id)
|
||||
for app in db_apps:
|
||||
apps = provider_to_apps.setdefault(app.provider, [])
|
||||
apps.append(app)
|
||||
|
||||
# Then, extend it with the settings backed apps.
|
||||
for p, pcfg in app_settings.PROVIDERS.items():
|
||||
app_configs = pcfg.get("APPS")
|
||||
if app_configs is None:
|
||||
app_config = pcfg.get("APP")
|
||||
if app_config is None:
|
||||
continue
|
||||
app_configs = [app_config]
|
||||
|
||||
apps = provider_to_apps.setdefault(p, [])
|
||||
for config in app_configs:
|
||||
app = SocialApp(provider=p)
|
||||
for field in [
|
||||
"name",
|
||||
"provider_id",
|
||||
"client_id",
|
||||
"secret",
|
||||
"key",
|
||||
"settings",
|
||||
]:
|
||||
if field in config:
|
||||
setattr(app, field, config[field])
|
||||
if "certificate_key" in config:
|
||||
warnings.warn("'certificate_key' should be moved into app.settings")
|
||||
app.settings["certificate_key"] = config["certificate_key"]
|
||||
if client_id and app.client_id != client_id:
|
||||
continue
|
||||
if (
|
||||
provider
|
||||
and app.provider_id != provider
|
||||
and app.provider != provider
|
||||
):
|
||||
continue
|
||||
apps.append(app)
|
||||
|
||||
# Flatten the list of apps.
|
||||
apps = []
|
||||
for provider_apps in provider_to_apps.values():
|
||||
apps.extend(provider_apps)
|
||||
return apps
|
||||
|
||||
def get_app(self, request, provider, client_id=None):
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
|
||||
apps = self.list_apps(request, provider=provider, client_id=client_id)
|
||||
if len(apps) > 1:
|
||||
visible_apps = [app for app in apps if not app.settings.get("hidden")]
|
||||
if len(visible_apps) != 1:
|
||||
raise MultipleObjectsReturned
|
||||
apps = visible_apps
|
||||
elif len(apps) == 0:
|
||||
raise SocialApp.DoesNotExist()
|
||||
return apps[0]
|
||||
|
||||
def send_notification_mail(self, *args, **kwargs):
|
||||
return get_account_adapter().send_notification_mail(*args, **kwargs)
|
||||
|
||||
def get_requests_session(self):
|
||||
import requests
|
||||
|
||||
session = requests.Session()
|
||||
session.request = functools.partial(
|
||||
session.request, timeout=app_settings.REQUESTS_TIMEOUT
|
||||
)
|
||||
return session
|
||||
|
||||
def is_email_verified(self, provider, email):
|
||||
"""
|
||||
Returns ``True`` iff the given email encountered during a social
|
||||
login for the given provider is to be assumed verified.
|
||||
|
||||
This can be configured with a ``"verified_email"`` key in the provider
|
||||
app settings, or a ``"VERIFIED_EMAIL"`` in the global provider settings
|
||||
(``SOCIALACCOUNT_PROVIDERS``). Both can be set to ``False`` or
|
||||
``True``, or, a list of domains to match email addresses against.
|
||||
"""
|
||||
verified_email = None
|
||||
if provider.app:
|
||||
verified_email = provider.app.settings.get("verified_email")
|
||||
if verified_email is None:
|
||||
settings = provider.get_settings()
|
||||
verified_email = settings.get("VERIFIED_EMAIL", False)
|
||||
if isinstance(verified_email, bool):
|
||||
pass
|
||||
elif isinstance(verified_email, list):
|
||||
email_domain = email.partition("@")[2].lower()
|
||||
verified_domains = [d.lower() for d in verified_email]
|
||||
verified_email = email_domain in verified_domains
|
||||
else:
|
||||
raise ImproperlyConfigured("verified_email wrongly configured")
|
||||
return verified_email
|
||||
|
||||
def can_authenticate_by_email(self, login, email):
|
||||
"""
|
||||
Returns ``True`` iff authentication by email is active for this login/email.
|
||||
|
||||
This can be configured with a ``"email_authentication"`` key in the provider
|
||||
app settings, or a ``"VERIFIED_EMAIL"`` in the global provider settings
|
||||
(``SOCIALACCOUNT_PROVIDERS``).
|
||||
"""
|
||||
ret = None
|
||||
provider = login.account.get_provider()
|
||||
if provider.app:
|
||||
ret = provider.app.settings.get("email_authentication")
|
||||
if ret is None:
|
||||
ret = app_settings.EMAIL_AUTHENTICATION or provider.get_settings().get(
|
||||
"EMAIL_AUTHENTICATION", False
|
||||
)
|
||||
return ret
|
||||
|
||||
def generate_state_param(self, state: dict) -> str:
|
||||
"""
|
||||
To preserve certain state before the handshake with the provider
|
||||
takes place, and be able to verify/use that state later on, a `state`
|
||||
parameter is typically passed to the provider. By default, a random
|
||||
string sufficies as the state parameter value is actually just a
|
||||
reference/pointer to the actual state. You can use this adapter method
|
||||
to alter the generation of the `state` parameter.
|
||||
"""
|
||||
from allauth.socialaccount.internal.statekit import STATE_ID_LENGTH
|
||||
|
||||
return get_random_string(STATE_ID_LENGTH)
|
||||
|
||||
|
||||
def get_adapter(request=None):
|
||||
return import_attribute(app_settings.ADAPTER)(request)
|
||||
@@ -0,0 +1,69 @@
|
||||
from typing import List
|
||||
|
||||
from django import forms
|
||||
from django.contrib import admin
|
||||
|
||||
from allauth import app_settings
|
||||
from allauth.account.adapter import get_adapter
|
||||
from allauth.socialaccount import providers
|
||||
from allauth.socialaccount.models import SocialAccount, SocialApp, SocialToken
|
||||
|
||||
|
||||
class SocialAppForm(forms.ModelForm):
|
||||
class Meta:
|
||||
model = SocialApp
|
||||
exclude: List[str] = []
|
||||
widgets = {
|
||||
"client_id": forms.TextInput(attrs={"size": "100"}),
|
||||
"key": forms.TextInput(attrs={"size": "100"}),
|
||||
"secret": forms.TextInput(attrs={"size": "100"}),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["provider"] = forms.ChoiceField(
|
||||
choices=providers.registry.as_choices()
|
||||
)
|
||||
|
||||
|
||||
class SocialAppAdmin(admin.ModelAdmin):
|
||||
form = SocialAppForm
|
||||
list_display = (
|
||||
"name",
|
||||
"provider",
|
||||
)
|
||||
filter_horizontal = ("sites",) if app_settings.SITES_ENABLED else ()
|
||||
|
||||
|
||||
class SocialAccountAdmin(admin.ModelAdmin):
|
||||
search_fields = []
|
||||
raw_id_fields = ("user",)
|
||||
list_display = ("user", "uid", "provider")
|
||||
list_filter = ("provider",)
|
||||
|
||||
def get_search_fields(self, request):
|
||||
base_fields = get_adapter().get_user_search_fields()
|
||||
return list(map(lambda a: "user__" + a, base_fields))
|
||||
|
||||
|
||||
class SocialTokenAdmin(admin.ModelAdmin):
|
||||
raw_id_fields = (
|
||||
"app",
|
||||
"account",
|
||||
)
|
||||
list_display = ("app", "account", "truncated_token", "expires_at")
|
||||
list_filter = ("app", "app__provider", "expires_at")
|
||||
|
||||
def truncated_token(self, token):
|
||||
max_chars = 40
|
||||
ret = token.token
|
||||
if len(ret) > max_chars:
|
||||
ret = ret[0:max_chars] + "...(truncated)"
|
||||
return ret
|
||||
|
||||
truncated_token.short_description = "Token" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
admin.site.register(SocialApp, SocialAppAdmin)
|
||||
admin.site.register(SocialToken, SocialTokenAdmin)
|
||||
admin.site.register(SocialAccount, SocialAccountAdmin)
|
||||
@@ -0,0 +1,155 @@
|
||||
class AppSettings:
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
|
||||
def _setting(self, name, dflt):
|
||||
from allauth.utils import get_setting
|
||||
|
||||
return get_setting(self.prefix + name, dflt)
|
||||
|
||||
@property
|
||||
def QUERY_EMAIL(self):
|
||||
"""
|
||||
Request email address from 3rd party account provider?
|
||||
E.g. using OpenID AX
|
||||
"""
|
||||
from allauth.account import app_settings as account_settings
|
||||
|
||||
return self._setting("QUERY_EMAIL", account_settings.EMAIL_REQUIRED)
|
||||
|
||||
@property
|
||||
def AUTO_SIGNUP(self):
|
||||
"""
|
||||
Attempt to bypass the signup form by using fields (e.g. username,
|
||||
email) retrieved from the social account provider. If a conflict
|
||||
arises due to a duplicate email signup form will still kick in.
|
||||
"""
|
||||
return self._setting("AUTO_SIGNUP", True)
|
||||
|
||||
@property
|
||||
def PROVIDERS(self):
|
||||
"""
|
||||
Provider specific settings
|
||||
"""
|
||||
ret = self._setting("PROVIDERS", {})
|
||||
oidc = ret.get("openid_connect")
|
||||
if oidc:
|
||||
ret["openid_connect"] = self._migrate_oidc(oidc)
|
||||
return ret
|
||||
|
||||
def _migrate_oidc(self, oidc):
|
||||
servers = oidc.get("SERVERS")
|
||||
if servers is None:
|
||||
return oidc
|
||||
ret = {}
|
||||
apps = []
|
||||
for server in servers:
|
||||
app = dict(**server["APP"])
|
||||
app_settings = {}
|
||||
if "token_auth_method" in server:
|
||||
app_settings["token_auth_method"] = server["token_auth_method"]
|
||||
app_settings["server_url"] = server["server_url"]
|
||||
app.update(
|
||||
{
|
||||
"name": server.get("name", ""),
|
||||
"provider_id": server["id"],
|
||||
"settings": app_settings,
|
||||
}
|
||||
)
|
||||
assert app["provider_id"]
|
||||
apps.append(app)
|
||||
ret["APPS"] = apps
|
||||
return ret
|
||||
|
||||
@property
|
||||
def EMAIL_REQUIRED(self):
|
||||
"""
|
||||
The user is required to hand over an email address when signing up
|
||||
"""
|
||||
from allauth.account import app_settings as account_settings
|
||||
|
||||
return self._setting("EMAIL_REQUIRED", account_settings.EMAIL_REQUIRED)
|
||||
|
||||
@property
|
||||
def EMAIL_VERIFICATION(self):
|
||||
"""
|
||||
See email verification method. When `None`, the default
|
||||
`allauth.account` logic kicks in.
|
||||
"""
|
||||
return self._setting("EMAIL_VERIFICATION", None)
|
||||
|
||||
@property
|
||||
def EMAIL_AUTHENTICATION(self):
|
||||
"""Consider a scenario where a social login occurs, and the social
|
||||
account comes with a verified email address (verified by the account
|
||||
provider), but that email address is already taken by a local user
|
||||
account. Additionally, assume that the local user account does not have
|
||||
any social account connected. Now, if the provider can be fully trusted,
|
||||
you can argue that we should treat this scenario as a login to the
|
||||
existing local user account even if the local account does not already
|
||||
have the social account connected, because -- according to the provider
|
||||
-- the user logging in has ownership of the email address. This is how
|
||||
this scenario is handled when `EMAIL_AUTHENTICATION` is set to
|
||||
`True`. As this implies that an untrustworthy provider can login to any
|
||||
local account by fabricating social account data, this setting defaults
|
||||
to `False`. Only set it to `True` if you are using providers that can be
|
||||
fully trusted.
|
||||
"""
|
||||
return self._setting("EMAIL_AUTHENTICATION", False)
|
||||
|
||||
@property
|
||||
def EMAIL_AUTHENTICATION_AUTO_CONNECT(self):
|
||||
"""In case email authentication is applied, this setting controls
|
||||
whether or not the social account is automatically connected to the
|
||||
local account. In case of ``False`` (the default) the local account
|
||||
remains unchanged during the login. In case of ``True``, the social
|
||||
account for which the email matched, is automatically added to the list
|
||||
of social accounts connected to the local account. As a result, even if
|
||||
the user were to change the email address afterwards, social login
|
||||
would still be possible when using ``True``, but not in case of
|
||||
``False``.
|
||||
"""
|
||||
return self._setting("EMAIL_AUTHENTICATION_AUTO_CONNECT", False)
|
||||
|
||||
@property
|
||||
def ADAPTER(self):
|
||||
return self._setting(
|
||||
"ADAPTER",
|
||||
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter",
|
||||
)
|
||||
|
||||
@property
|
||||
def FORMS(self):
|
||||
return self._setting("FORMS", {})
|
||||
|
||||
@property
|
||||
def LOGIN_ON_GET(self):
|
||||
return self._setting("LOGIN_ON_GET", False)
|
||||
|
||||
@property
|
||||
def STORE_TOKENS(self):
|
||||
return self._setting("STORE_TOKENS", False)
|
||||
|
||||
@property
|
||||
def UID_MAX_LENGTH(self):
|
||||
return 191
|
||||
|
||||
@property
|
||||
def SOCIALACCOUNT_STR(self):
|
||||
return self._setting("SOCIALACCOUNT_STR", None)
|
||||
|
||||
@property
|
||||
def REQUESTS_TIMEOUT(self):
|
||||
return self._setting("REQUESTS_TIMEOUT", 5)
|
||||
|
||||
@property
|
||||
def OPENID_CONNECT_URL_PREFIX(self):
|
||||
return self._setting("OPENID_CONNECT_URL_PREFIX", "oidc")
|
||||
|
||||
|
||||
_app_settings = AppSettings("SOCIALACCOUNT_")
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
# See https://peps.python.org/pep-0562/
|
||||
return getattr(_app_settings, name)
|
||||
@@ -0,0 +1,15 @@
|
||||
from django.apps import AppConfig
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from allauth import app_settings
|
||||
|
||||
|
||||
class SocialAccountConfig(AppConfig):
|
||||
name = "allauth.socialaccount"
|
||||
verbose_name = _("Social Accounts")
|
||||
default_auto_field = app_settings.DEFAULT_AUTO_FIELD or "django.db.models.AutoField"
|
||||
|
||||
def ready(self):
|
||||
from allauth.socialaccount.providers import registry
|
||||
|
||||
registry.load()
|
||||
@@ -0,0 +1,66 @@
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.models import (
|
||||
SocialAccount,
|
||||
SocialLogin,
|
||||
SocialToken,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sociallogin_factory(user_factory):
|
||||
def factory(
|
||||
email=None,
|
||||
username=None,
|
||||
with_email=True,
|
||||
provider="unittest-server",
|
||||
uid="123",
|
||||
email_verified=True,
|
||||
with_token=False,
|
||||
):
|
||||
user = user_factory(
|
||||
username=username, email=email, commit=False, with_email=with_email
|
||||
)
|
||||
account = SocialAccount(provider=provider, uid=uid)
|
||||
sociallogin = SocialLogin(user=user, account=account)
|
||||
if with_email:
|
||||
sociallogin.email_addresses = [
|
||||
EmailAddress(email=user.email, verified=email_verified, primary=True)
|
||||
]
|
||||
if with_token:
|
||||
sociallogin.token = SocialToken(token="123", token_secret="456")
|
||||
return sociallogin
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_decode_bypass():
|
||||
@contextmanager
|
||||
def f(jwt_data):
|
||||
with patch("allauth.socialaccount.internal.jwtkit.verify_and_decode") as m:
|
||||
data = {
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "client_id",
|
||||
"sub": "123sub",
|
||||
"hd": "example.com",
|
||||
"email": "raymond@example.com",
|
||||
"email_verified": True,
|
||||
"at_hash": "HK6E_P6Dh8Y93mRNtsDB1Q",
|
||||
"name": "Raymond Penners",
|
||||
"picture": "https://lh5.googleusercontent.com/photo.jpg",
|
||||
"given_name": "Raymond",
|
||||
"family_name": "Penners",
|
||||
"locale": "en",
|
||||
"iat": 123,
|
||||
"exp": 456,
|
||||
}
|
||||
data.update(jwt_data)
|
||||
m.return_value = data
|
||||
yield
|
||||
|
||||
return f
|
||||
@@ -0,0 +1,62 @@
|
||||
from django import forms
|
||||
|
||||
from allauth.account.forms import BaseSignupForm
|
||||
from allauth.socialaccount.internal import flows
|
||||
|
||||
from . import app_settings
|
||||
from .adapter import get_adapter
|
||||
from .models import SocialAccount
|
||||
|
||||
|
||||
class SignupForm(BaseSignupForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.sociallogin = kwargs.pop("sociallogin")
|
||||
initial = get_adapter().get_signup_form_initial_data(self.sociallogin)
|
||||
kwargs.update(
|
||||
{
|
||||
"initial": initial,
|
||||
"email_required": kwargs.get(
|
||||
"email_required", app_settings.EMAIL_REQUIRED
|
||||
),
|
||||
}
|
||||
)
|
||||
super(SignupForm, self).__init__(*args, **kwargs)
|
||||
|
||||
def save(self, request):
|
||||
adapter = get_adapter()
|
||||
user = adapter.save_user(request, self.sociallogin, form=self)
|
||||
self.custom_signup(request, user)
|
||||
return user
|
||||
|
||||
def validate_unique_email(self, value):
|
||||
try:
|
||||
return super(SignupForm, self).validate_unique_email(value)
|
||||
except forms.ValidationError:
|
||||
raise get_adapter().validation_error(
|
||||
"email_taken", self.sociallogin.account.get_provider().name
|
||||
)
|
||||
|
||||
|
||||
class DisconnectForm(forms.Form):
|
||||
account = forms.ModelChoiceField(
|
||||
queryset=SocialAccount.objects.none(),
|
||||
widget=forms.RadioSelect,
|
||||
required=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.request = kwargs.pop("request")
|
||||
self.accounts = SocialAccount.objects.filter(user=self.request.user)
|
||||
super(DisconnectForm, self).__init__(*args, **kwargs)
|
||||
self.fields["account"].queryset = self.accounts
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super(DisconnectForm, self).clean()
|
||||
account = cleaned_data.get("account")
|
||||
if account:
|
||||
flows.connect.validate_disconnect(self.request, account)
|
||||
return cleaned_data
|
||||
|
||||
def save(self):
|
||||
account = self.cleaned_data["account"]
|
||||
flows.connect.disconnect(self.request, account)
|
||||
@@ -0,0 +1,74 @@
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.shortcuts import render
|
||||
from django.urls import reverse
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.utils import user_display
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.internal import flows
|
||||
from allauth.socialaccount.providers.base import AuthError
|
||||
|
||||
|
||||
def render_authentication_error(
|
||||
request,
|
||||
provider,
|
||||
error=AuthError.UNKNOWN,
|
||||
exception=None,
|
||||
extra_context=None,
|
||||
):
|
||||
try:
|
||||
if allauth_settings.HEADLESS_ENABLED:
|
||||
from allauth.headless.socialaccount import internal
|
||||
|
||||
internal.on_authentication_error(
|
||||
request,
|
||||
provider=provider,
|
||||
error=error,
|
||||
exception=exception,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
|
||||
if extra_context is None:
|
||||
extra_context = {}
|
||||
get_adapter().on_authentication_error(
|
||||
request,
|
||||
provider,
|
||||
error=error,
|
||||
exception=exception,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
except ImmediateHttpResponse as e:
|
||||
return e.response
|
||||
if error == AuthError.CANCELLED:
|
||||
return HttpResponseRedirect(reverse("socialaccount_login_cancelled"))
|
||||
context = {
|
||||
"auth_error": {
|
||||
"provider": provider,
|
||||
"code": error,
|
||||
"exception": exception,
|
||||
}
|
||||
}
|
||||
context.update(extra_context)
|
||||
return render(
|
||||
request,
|
||||
"socialaccount/authentication_error." + account_settings.TEMPLATE_EXTENSION,
|
||||
context,
|
||||
)
|
||||
|
||||
|
||||
def complete_social_login(request, sociallogin):
|
||||
if sociallogin.is_headless:
|
||||
from allauth.headless.socialaccount import internal
|
||||
|
||||
return internal.complete_login(request, sociallogin)
|
||||
return flows.login.complete_login(request, sociallogin)
|
||||
|
||||
|
||||
def socialaccount_user_display(socialaccount):
|
||||
func = app_settings.SOCIALACCOUNT_STR
|
||||
if not func:
|
||||
return user_display(socialaccount.user)
|
||||
return func(socialaccount)
|
||||
@@ -0,0 +1,9 @@
|
||||
from allauth.socialaccount.internal.flows import (
|
||||
connect,
|
||||
email_authentication,
|
||||
login,
|
||||
signup,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["connect", "login", "signup", "email_authentication"]
|
||||
@@ -0,0 +1,130 @@
|
||||
from django.contrib import messages
|
||||
from django.core.exceptions import PermissionDenied, ValidationError
|
||||
from django.http import HttpResponseRedirect
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.core.exceptions import ReauthenticationRequired
|
||||
from allauth.socialaccount import signals
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialAccount, SocialLogin
|
||||
|
||||
|
||||
def validate_disconnect(request, account):
|
||||
"""
|
||||
Validate whether or not the socialaccount account can be
|
||||
safely disconnected.
|
||||
"""
|
||||
accounts = SocialAccount.objects.filter(user_id=account.user_id)
|
||||
is_last = not accounts.exclude(pk=account.pk).exists()
|
||||
adapter = get_adapter()
|
||||
if is_last:
|
||||
if allauth_settings.SOCIALACCOUNT_ONLY:
|
||||
raise adapter.validation_error("disconnect_last")
|
||||
# No usable password would render the local account unusable
|
||||
if not account.user.has_usable_password():
|
||||
raise adapter.validation_error("no_password")
|
||||
# No email address, no password reset
|
||||
if (
|
||||
account_settings.EMAIL_VERIFICATION
|
||||
== account_settings.EmailVerificationMethod.MANDATORY
|
||||
):
|
||||
if not EmailAddress.objects.filter(
|
||||
user=account.user, verified=True
|
||||
).exists():
|
||||
raise adapter.validation_error("no_verified_email")
|
||||
adapter.validate_disconnect(account, accounts)
|
||||
|
||||
|
||||
def disconnect(request, account):
|
||||
if account_settings.REAUTHENTICATION_REQUIRED:
|
||||
flows.reauthentication.raise_if_reauthentication_required(request)
|
||||
|
||||
get_account_adapter().add_message(
|
||||
request,
|
||||
messages.INFO,
|
||||
"socialaccount/messages/account_disconnected.txt",
|
||||
)
|
||||
provider = account.get_provider()
|
||||
account.delete()
|
||||
signals.social_account_removed.send(
|
||||
sender=SocialAccount, request=request, socialaccount=account
|
||||
)
|
||||
get_adapter().send_notification_mail(
|
||||
"socialaccount/email/account_disconnected",
|
||||
request.user,
|
||||
context={
|
||||
"account": account,
|
||||
"provider": provider,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resume_connect(request, serialized_state):
|
||||
sociallogin = SocialLogin.deserialize(serialized_state)
|
||||
return connect(request, sociallogin)
|
||||
|
||||
|
||||
def connect(request, sociallogin):
|
||||
try:
|
||||
ok, action, message = do_connect(request, sociallogin)
|
||||
except PermissionDenied:
|
||||
# This should not happen. Simply redirect to the connections
|
||||
# view (which has a login required)
|
||||
connect_redirect_url = get_adapter().get_connect_redirect_url(
|
||||
request, sociallogin.account
|
||||
)
|
||||
return HttpResponseRedirect(connect_redirect_url)
|
||||
except ReauthenticationRequired:
|
||||
return flows.reauthentication.stash_and_reauthenticate(
|
||||
request,
|
||||
sociallogin.serialize(),
|
||||
"allauth.socialaccount.internal.flows.connect.resume_connect",
|
||||
)
|
||||
except ValidationError:
|
||||
ok, action, message = (
|
||||
False,
|
||||
None,
|
||||
"socialaccount/messages/account_connected_other.txt",
|
||||
)
|
||||
level = messages.INFO if ok else messages.ERROR
|
||||
default_next = get_adapter().get_connect_redirect_url(request, sociallogin.account)
|
||||
next_url = sociallogin.get_redirect_url(request) or default_next
|
||||
get_account_adapter(request).add_message(
|
||||
request,
|
||||
level,
|
||||
message,
|
||||
message_context={"sociallogin": sociallogin, "action": action},
|
||||
)
|
||||
return HttpResponseRedirect(next_url)
|
||||
|
||||
|
||||
def do_connect(request, sociallogin):
|
||||
if request.user.is_anonymous:
|
||||
raise PermissionDenied()
|
||||
if account_settings.REAUTHENTICATION_REQUIRED:
|
||||
flows.reauthentication.raise_if_reauthentication_required(request)
|
||||
message = "socialaccount/messages/account_connected.txt"
|
||||
action = None
|
||||
ok = True
|
||||
if sociallogin.is_existing:
|
||||
if sociallogin.user != request.user:
|
||||
# Social account of other user. For now, this scenario
|
||||
# is not supported. Issue is that one cannot simply
|
||||
# remove the social account from the other user, as
|
||||
# that may render the account unusable.
|
||||
raise get_adapter().validation_error("connected_other")
|
||||
elif not sociallogin.account._state.adding:
|
||||
action = "updated"
|
||||
message = "socialaccount/messages/account_connected_updated.txt"
|
||||
else:
|
||||
action = "added"
|
||||
sociallogin.connect(request, request.user)
|
||||
else:
|
||||
# New account, let's connect
|
||||
action = "added"
|
||||
sociallogin.connect(request, request.user)
|
||||
return ok, action, message
|
||||
@@ -0,0 +1,34 @@
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account.models import EmailAddress
|
||||
|
||||
|
||||
def wipe_password(request, user, email: str):
|
||||
"""
|
||||
Consider a scenario where an attacker signs up for an account using the
|
||||
email address of a victim. Obviously, the email address cannot be
|
||||
verified, yet the attacker -- knowing the password -- can wait until the
|
||||
victim appears. When the victim signs in using email authentication, it
|
||||
is not obvious that the victim is signing into an account that was not
|
||||
created by the victim. As a result, both the attacker and the victim now
|
||||
have access to the account. To prevent this, we wipe the password of the
|
||||
account in case the email address was not verified, effectively locking
|
||||
out the attacker.
|
||||
"""
|
||||
try:
|
||||
address = EmailAddress.objects.get_for_user(user, email)
|
||||
except EmailAddress.DoesNotExist:
|
||||
address = None
|
||||
if address and address.verified:
|
||||
# Verified email address, no reason to worry.
|
||||
return
|
||||
if user.has_usable_password():
|
||||
user.set_unusable_password()
|
||||
user.save(update_fields=["password"])
|
||||
# Also wipe any other sessions (upstream integrators may hook up to the
|
||||
# ending of the sessions to trigger e.g. backchannel logout.
|
||||
if allauth_settings.USERSESSIONS_ENABLED:
|
||||
from allauth.usersessions.internal.flows.sessions import (
|
||||
end_other_sessions,
|
||||
)
|
||||
|
||||
end_other_sessions(request, user)
|
||||
@@ -0,0 +1,97 @@
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.shortcuts import render
|
||||
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.utils import perform_login
|
||||
from allauth.core.exceptions import (
|
||||
ImmediateHttpResponse,
|
||||
SignupClosedException,
|
||||
)
|
||||
from allauth.socialaccount import app_settings, signals
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.internal.flows.connect import connect, do_connect
|
||||
from allauth.socialaccount.internal.flows.signup import (
|
||||
clear_pending_signup,
|
||||
process_signup,
|
||||
)
|
||||
from allauth.socialaccount.models import SocialLogin
|
||||
from allauth.socialaccount.providers.base import AuthProcess
|
||||
|
||||
|
||||
def _login(request, sociallogin):
|
||||
sociallogin._accept_login(request)
|
||||
record_authentication(request, sociallogin)
|
||||
return perform_login(
|
||||
request,
|
||||
sociallogin.user,
|
||||
email_verification=app_settings.EMAIL_VERIFICATION,
|
||||
redirect_url=sociallogin.get_redirect_url(request),
|
||||
signal_kwargs={"sociallogin": sociallogin},
|
||||
)
|
||||
|
||||
|
||||
def pre_social_login(request, sociallogin):
|
||||
clear_pending_signup(request)
|
||||
assert not sociallogin.is_existing
|
||||
sociallogin.lookup()
|
||||
get_adapter().pre_social_login(request, sociallogin)
|
||||
signals.pre_social_login.send(
|
||||
sender=SocialLogin, request=request, sociallogin=sociallogin
|
||||
)
|
||||
|
||||
|
||||
def complete_login(request, sociallogin, raises=False):
|
||||
try:
|
||||
pre_social_login(request, sociallogin)
|
||||
process = sociallogin.state.get("process")
|
||||
if process == AuthProcess.REDIRECT:
|
||||
return _redirect(request, sociallogin)
|
||||
elif process == AuthProcess.CONNECT:
|
||||
if raises:
|
||||
do_connect(request, sociallogin)
|
||||
else:
|
||||
return connect(request, sociallogin)
|
||||
else:
|
||||
return _authenticate(request, sociallogin)
|
||||
except SignupClosedException:
|
||||
if raises:
|
||||
raise
|
||||
return render(
|
||||
request,
|
||||
"account/signup_closed." + account_settings.TEMPLATE_EXTENSION,
|
||||
)
|
||||
except ImmediateHttpResponse as e:
|
||||
if raises:
|
||||
raise
|
||||
return e.response
|
||||
|
||||
|
||||
def _redirect(request, sociallogin):
|
||||
next_url = sociallogin.get_redirect_url(request) or "/"
|
||||
return HttpResponseRedirect(next_url)
|
||||
|
||||
|
||||
def _authenticate(request, sociallogin):
|
||||
if request.user.is_authenticated:
|
||||
get_account_adapter(request).logout(request)
|
||||
if sociallogin.is_existing:
|
||||
# Login existing user
|
||||
ret = _login(request, sociallogin)
|
||||
else:
|
||||
# New social user
|
||||
ret = process_signup(request, sociallogin)
|
||||
return ret
|
||||
|
||||
|
||||
def record_authentication(request, sociallogin):
|
||||
from allauth.account.internal.flows.login import record_authentication
|
||||
|
||||
record_authentication(
|
||||
request,
|
||||
"socialaccount",
|
||||
**{
|
||||
"provider": sociallogin.account.provider,
|
||||
"uid": sociallogin.account.uid,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,114 @@
|
||||
from django.forms import ValidationError
|
||||
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.internal.flows.manage_email import assess_unique_email
|
||||
from allauth.account.internal.flows.signup import (
|
||||
complete_signup,
|
||||
prevent_enumeration,
|
||||
)
|
||||
from allauth.account.utils import user_username
|
||||
from allauth.core.exceptions import SignupClosedException
|
||||
from allauth.core.internal.httpkit import headed_redirect_response
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialLogin
|
||||
|
||||
|
||||
def get_pending_signup(request):
|
||||
data = request.session.get("socialaccount_sociallogin")
|
||||
if data:
|
||||
return SocialLogin.deserialize(data)
|
||||
|
||||
|
||||
def redirect_to_signup(request, sociallogin):
|
||||
request.session["socialaccount_sociallogin"] = sociallogin.serialize()
|
||||
return headed_redirect_response("socialaccount_signup")
|
||||
|
||||
|
||||
def clear_pending_signup(request):
|
||||
request.session.pop("socialaccount_sociallogin", None)
|
||||
|
||||
|
||||
def signup_by_form(request, sociallogin, form):
|
||||
clear_pending_signup(request)
|
||||
user, resp = form.try_save(request)
|
||||
if not resp:
|
||||
resp = complete_social_signup(request, sociallogin)
|
||||
return resp
|
||||
|
||||
|
||||
def process_auto_signup(request, sociallogin):
|
||||
auto_signup = get_adapter().is_auto_signup_allowed(request, sociallogin)
|
||||
if not auto_signup:
|
||||
return False, None
|
||||
email = None
|
||||
if sociallogin.email_addresses:
|
||||
email = sociallogin.email_addresses[0].email
|
||||
# Let's check if auto_signup is really possible...
|
||||
if email:
|
||||
assessment = assess_unique_email(email)
|
||||
if assessment is True:
|
||||
# Auto signup is fine.
|
||||
pass
|
||||
elif assessment is False:
|
||||
# Oops, another user already has this address. We cannot simply
|
||||
# connect this social account to the existing user. Reason is
|
||||
# that the email address may not be verified, meaning, the user
|
||||
# may be a hacker that has added your email address to their
|
||||
# account in the hope that you fall in their trap. We cannot
|
||||
# check on 'email_address.verified' either, because
|
||||
# 'email_address' is not guaranteed to be verified.
|
||||
auto_signup = False
|
||||
# TODO: We redirect to signup form -- user will see email
|
||||
# address conflict only after posting whereas we detected it
|
||||
# here already.
|
||||
else:
|
||||
assert assessment is None
|
||||
# Prevent enumeration is properly turned on, meaning, we cannot
|
||||
# show the signup form to allow the user to input another email
|
||||
# address. Instead, we're going to send the user an email that
|
||||
# the account already exists, and on the outside make it appear
|
||||
# as if an email verification mail was sent.
|
||||
resp = prevent_enumeration(request, email)
|
||||
return False, resp
|
||||
elif app_settings.EMAIL_REQUIRED:
|
||||
# Nope, email is required and we don't have it yet...
|
||||
auto_signup = False
|
||||
return auto_signup, None
|
||||
|
||||
|
||||
def process_signup(request, sociallogin):
|
||||
if not get_adapter().is_open_for_signup(request, sociallogin):
|
||||
raise SignupClosedException()
|
||||
auto_signup, resp = process_auto_signup(request, sociallogin)
|
||||
if resp:
|
||||
return resp
|
||||
if not auto_signup:
|
||||
resp = redirect_to_signup(request, sociallogin)
|
||||
else:
|
||||
# Ok, auto signup it is, at least the email address is ok.
|
||||
# We still need to check the username though...
|
||||
if account_settings.USER_MODEL_USERNAME_FIELD:
|
||||
username = user_username(sociallogin.user)
|
||||
try:
|
||||
get_account_adapter(request).clean_username(username)
|
||||
except ValidationError:
|
||||
# This username is no good ...
|
||||
user_username(sociallogin.user, "")
|
||||
# TODO: This part contains a lot of duplication of logic
|
||||
# ("closed" rendering, create user, send email, in active
|
||||
# etc..)
|
||||
get_adapter().save_user(request, sociallogin, form=None)
|
||||
resp = complete_social_signup(request, sociallogin)
|
||||
return resp
|
||||
|
||||
|
||||
def complete_social_signup(request, sociallogin):
|
||||
return complete_signup(
|
||||
request,
|
||||
user=sociallogin.user,
|
||||
email_verification=app_settings.EMAIL_VERIFICATION,
|
||||
redirect_url=sociallogin.get_redirect_url(request),
|
||||
signal_kwargs={"sociallogin": sociallogin},
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from django.core.cache import cache
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.x509 import load_pem_x509_certificate
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
|
||||
|
||||
|
||||
def lookup_kid_pem_x509_certificate(keys_data, kid):
|
||||
"""
|
||||
Looks up the key given keys data of the form:
|
||||
|
||||
{"<kid>": "-----BEGIN CERTIFICATE-----\nCERTIFICATE"}
|
||||
"""
|
||||
key = keys_data.get(kid)
|
||||
if key:
|
||||
public_key = load_pem_x509_certificate(
|
||||
key.encode("utf8"), default_backend()
|
||||
).public_key()
|
||||
return public_key
|
||||
|
||||
|
||||
def lookup_kid_jwk(keys_data, kid):
|
||||
"""
|
||||
Looks up the key given keys data of the form:
|
||||
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"kid": "W6WcOKB",
|
||||
"use": "sig",
|
||||
"alg": "RS256",
|
||||
"n": "2Zc5d0-zk....",
|
||||
"e": "AQAB"
|
||||
}]
|
||||
}
|
||||
"""
|
||||
for d in keys_data["keys"]:
|
||||
if d["kid"] == kid:
|
||||
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(d))
|
||||
return public_key
|
||||
|
||||
|
||||
def fetch_key(credential, keys_url, lookup):
|
||||
header = jwt.get_unverified_header(credential)
|
||||
# {'alg': 'RS256', 'kid': '[AWS-SECRET-REMOVED]', 'typ': 'JWT'}
|
||||
kid = header["kid"]
|
||||
alg = header["alg"]
|
||||
response = get_adapter().get_requests_session().get(keys_url)
|
||||
response.raise_for_status()
|
||||
keys_data = response.json()
|
||||
key = lookup(keys_data, kid)
|
||||
if not key:
|
||||
raise OAuth2Error(f"Invalid 'kid': '{kid}'")
|
||||
return alg, key
|
||||
|
||||
|
||||
def verify_jti(data: dict) -> None:
|
||||
"""
|
||||
Put the JWT token on a blacklist to prevent replay attacks.
|
||||
"""
|
||||
iss = data.get("iss")
|
||||
exp = data.get("exp")
|
||||
jti = data.get("jti")
|
||||
if iss is None or exp is None or jti is None:
|
||||
return
|
||||
timeout = exp - time.time()
|
||||
key = f"jwt:iss={iss},jti={jti}"
|
||||
if not cache.add(key=key, value=True, timeout=timeout):
|
||||
raise OAuth2Error("token already used")
|
||||
|
||||
|
||||
def verify_and_decode(
|
||||
*, credential, keys_url, issuer, audience, lookup_kid, verify_signature=True
|
||||
):
|
||||
try:
|
||||
if verify_signature:
|
||||
alg, key = fetch_key(credential, keys_url, lookup_kid)
|
||||
algorithms = [alg]
|
||||
else:
|
||||
key = ""
|
||||
algorithms = None
|
||||
data = jwt.decode(
|
||||
credential,
|
||||
key=key,
|
||||
options={
|
||||
"verify_signature": verify_signature,
|
||||
"verify_iss": True,
|
||||
"verify_aud": True,
|
||||
"verify_exp": True,
|
||||
},
|
||||
issuer=issuer,
|
||||
audience=audience,
|
||||
algorithms=algorithms,
|
||||
)
|
||||
verify_jti(data)
|
||||
return data
|
||||
except jwt.PyJWTError as e:
|
||||
raise OAuth2Error("Invalid id_token") from e
|
||||
@@ -0,0 +1,69 @@
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
|
||||
|
||||
STATE_ID_LENGTH = 16
|
||||
MAX_STATES = 10
|
||||
STATES_SESSION_KEY = "socialaccount_states"
|
||||
|
||||
|
||||
def get_oldest_state(
|
||||
states: Dict[str, Tuple[Dict[str, Any], float]], rev: bool = False
|
||||
) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
|
||||
oldest_ts = None
|
||||
oldest_id = None
|
||||
oldest = None
|
||||
for state_id, state_ts in states.items():
|
||||
ts = state_ts[1]
|
||||
if oldest_ts is None or (
|
||||
(rev and ts > oldest_ts) or ((not rev) and oldest_ts > ts)
|
||||
):
|
||||
oldest_ts = ts
|
||||
oldest_id = state_id
|
||||
oldest = state_ts[0]
|
||||
return oldest_id, oldest
|
||||
|
||||
|
||||
def gc_states(states: Dict[str, Tuple[Dict[str, Any], float]]):
|
||||
if len(states) > MAX_STATES:
|
||||
oldest_id, oldest = get_oldest_state(states)
|
||||
if oldest_id:
|
||||
del states[oldest_id]
|
||||
|
||||
|
||||
def get_states(request) -> Dict[str, Tuple[Dict[str, Any], float]]:
|
||||
states = request.session.get(STATES_SESSION_KEY)
|
||||
if not isinstance(states, dict):
|
||||
states = {}
|
||||
return states
|
||||
|
||||
|
||||
def stash_state(request, state: Dict[str, Any], state_id: Optional[str] = None) -> str:
|
||||
states = get_states(request)
|
||||
gc_states(states)
|
||||
if state_id is None:
|
||||
state_id = get_adapter().generate_state_param(state)
|
||||
states[state_id] = (state, time.time())
|
||||
request.session[STATES_SESSION_KEY] = states
|
||||
return state_id
|
||||
|
||||
|
||||
def unstash_state(request, state_id: str) -> Optional[Dict[str, Any]]:
|
||||
state: Optional[Dict[str, Any]] = None
|
||||
states = get_states(request)
|
||||
state_ts = states.get(state_id)
|
||||
if state_ts is not None:
|
||||
state = state_ts[0]
|
||||
del states[state_id]
|
||||
request.session[STATES_SESSION_KEY] = states
|
||||
return state
|
||||
|
||||
|
||||
def unstash_last_state(request) -> Optional[Dict[str, Any]]:
|
||||
states = get_states(request)
|
||||
state_id, state = get_oldest_state(states, rev=True)
|
||||
if state_id:
|
||||
unstash_state(request, state_id)
|
||||
return state
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from allauth.socialaccount.internal.jwtkit import verify_and_decode
|
||||
from allauth.socialaccount.providers.apple.client import jwt_encode
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
|
||||
|
||||
|
||||
def test_verify_and_decode(enable_cache):
|
||||
now = timezone.now()
|
||||
payload = {
|
||||
"iss": "https://accounts.google.com",
|
||||
"azp": "client_id",
|
||||
"aud": "client_id",
|
||||
"sub": "108204268033311374519",
|
||||
"hd": "example.com",
|
||||
"locale": "en",
|
||||
"iat": now,
|
||||
"jti": "[AWS-SECRET-REMOVED]",
|
||||
"exp": now + timedelta(hours=1),
|
||||
}
|
||||
id_token = jwt_encode(payload, "secret")
|
||||
for attempt in range(2):
|
||||
try:
|
||||
verify_and_decode(
|
||||
credential=id_token,
|
||||
keys_url="/",
|
||||
issuer=payload["iss"],
|
||||
audience=payload["aud"],
|
||||
lookup_kid=False,
|
||||
verify_signature=False,
|
||||
)
|
||||
assert attempt == 0
|
||||
except OAuth2Error:
|
||||
assert attempt == 1
|
||||
@@ -0,0 +1,48 @@
|
||||
from allauth.socialaccount.internal import statekit
|
||||
|
||||
|
||||
def test_get_oldest_state():
|
||||
states = {
|
||||
"new": [{"id": "new"}, 300],
|
||||
"mid": [{"id": "mid"}, 200],
|
||||
"old": [{"id": "old"}, 100],
|
||||
}
|
||||
state_id, state = statekit.get_oldest_state(states)
|
||||
assert state_id == "old"
|
||||
assert state["id"] == "old"
|
||||
|
||||
|
||||
def test_get_oldest_state_empty():
|
||||
state_id, state = statekit.get_oldest_state({})
|
||||
assert state_id is None
|
||||
assert state is None
|
||||
|
||||
|
||||
def test_gc_states():
|
||||
states = {}
|
||||
for i in range(statekit.MAX_STATES + 1):
|
||||
states[f"state-{i}"] = [{"i": i}, 1000 + i]
|
||||
assert len(states) == statekit.MAX_STATES + 1
|
||||
statekit.gc_states(states)
|
||||
assert len(states) == statekit.MAX_STATES
|
||||
assert "state-0" not in states
|
||||
|
||||
|
||||
def test_stashing(rf):
|
||||
request = rf.get("/")
|
||||
request.session = {}
|
||||
state_id = statekit.stash_state(request, {"foo": "bar"})
|
||||
state2_id = statekit.stash_state(request, {"foo2": "bar2"})
|
||||
state3_id = statekit.stash_state(request, {"foo3": "bar3"})
|
||||
state = statekit.unstash_last_state(request)
|
||||
assert state == {"foo3": "bar3"}
|
||||
state = statekit.unstash_state(request, state3_id)
|
||||
assert state is None
|
||||
state = statekit.unstash_state(request, state2_id)
|
||||
assert state == {"foo2": "bar2"}
|
||||
state = statekit.unstash_state(request, state2_id)
|
||||
assert state is None
|
||||
state = statekit.unstash_state(request, state_id)
|
||||
assert state == {"foo": "bar"}
|
||||
state = statekit.unstash_state(request, state_id)
|
||||
assert state is None
|
||||
@@ -0,0 +1,192 @@
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
from allauth import app_settings
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = (
|
||||
[
|
||||
("sites", "0001_initial"),
|
||||
]
|
||||
if app_settings.SITES_ENABLED
|
||||
else []
|
||||
) + [
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="SocialAccount",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
verbose_name="ID",
|
||||
serialize=False,
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.CharField(
|
||||
max_length=30,
|
||||
verbose_name="provider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"uid",
|
||||
models.CharField(
|
||||
max_length=getattr(
|
||||
settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191
|
||||
),
|
||||
verbose_name="uid",
|
||||
),
|
||||
),
|
||||
(
|
||||
"last_login",
|
||||
models.DateTimeField(auto_now=True, verbose_name="last login"),
|
||||
),
|
||||
(
|
||||
"date_joined",
|
||||
models.DateTimeField(auto_now_add=True, verbose_name="date joined"),
|
||||
),
|
||||
(
|
||||
"extra_data",
|
||||
models.TextField(default="{}", verbose_name="extra data"),
|
||||
),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "social account",
|
||||
"verbose_name_plural": "social accounts",
|
||||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="SocialApp",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
verbose_name="ID",
|
||||
serialize=False,
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.CharField(
|
||||
max_length=30,
|
||||
verbose_name="provider",
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=40, verbose_name="name")),
|
||||
(
|
||||
"client_id",
|
||||
models.CharField(
|
||||
help_text="App ID, or consumer key",
|
||||
max_length=100,
|
||||
verbose_name="client id",
|
||||
),
|
||||
),
|
||||
(
|
||||
"secret",
|
||||
models.CharField(
|
||||
help_text="API secret, client secret, or consumer secret",
|
||||
max_length=100,
|
||||
verbose_name="secret key",
|
||||
),
|
||||
),
|
||||
(
|
||||
"key",
|
||||
models.CharField(
|
||||
help_text="Key",
|
||||
max_length=100,
|
||||
verbose_name="key",
|
||||
blank=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
+ (
|
||||
[
|
||||
("sites", models.ManyToManyField(to="sites.Site", blank=True)),
|
||||
]
|
||||
if app_settings.SITES_ENABLED
|
||||
else []
|
||||
),
|
||||
options={
|
||||
"verbose_name": "social application",
|
||||
"verbose_name_plural": "social applications",
|
||||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="SocialToken",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
verbose_name="ID",
|
||||
serialize=False,
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"token",
|
||||
models.TextField(
|
||||
help_text='"oauth_token" (OAuth1) or access token (OAuth2)',
|
||||
verbose_name="token",
|
||||
),
|
||||
),
|
||||
(
|
||||
"token_secret",
|
||||
models.TextField(
|
||||
help_text='"oauth_token_secret" (OAuth1) or refresh token (OAuth2)',
|
||||
verbose_name="token secret",
|
||||
blank=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"expires_at",
|
||||
models.DateTimeField(
|
||||
null=True, verbose_name="expires at", blank=True
|
||||
),
|
||||
),
|
||||
(
|
||||
"account",
|
||||
models.ForeignKey(
|
||||
to="socialaccount.SocialAccount",
|
||||
on_delete=models.CASCADE,
|
||||
),
|
||||
),
|
||||
(
|
||||
"app",
|
||||
models.ForeignKey(
|
||||
to="socialaccount.SocialApp", on_delete=models.CASCADE
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "social application token",
|
||||
"verbose_name_plural": "social application tokens",
|
||||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="socialtoken",
|
||||
unique_together=set([("app", "account")]),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="socialaccount",
|
||||
unique_together=set([("provider", "uid")]),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,45 @@
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="uid",
|
||||
field=models.CharField(
|
||||
max_length=getattr(settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191),
|
||||
verbose_name="uid",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialapp",
|
||||
name="client_id",
|
||||
field=models.CharField(
|
||||
help_text="App ID, or consumer key",
|
||||
max_length=191,
|
||||
verbose_name="client id",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialapp",
|
||||
name="key",
|
||||
field=models.CharField(
|
||||
help_text="Key", max_length=191, verbose_name="key", blank=True
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialapp",
|
||||
name="secret",
|
||||
field=models.CharField(
|
||||
help_text="API secret, client secret, or consumer secret",
|
||||
max_length=191,
|
||||
verbose_name="secret key",
|
||||
blank=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,16 @@
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0002_token_max_lengths"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="extra_data",
|
||||
field=models.TextField(default="{}", verbose_name="extra data"),
|
||||
preserve_default=True,
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,29 @@
|
||||
# Generated by Django 3.2.19 on 2023-06-30 13:16
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0003_extra_data_default_dict"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="socialapp",
|
||||
name="provider_id",
|
||||
field=models.CharField(
|
||||
blank=True, max_length=200, verbose_name="provider ID"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="socialapp",
|
||||
name="settings",
|
||||
field=models.JSONField(blank=True, default=dict),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="provider",
|
||||
field=models.CharField(max_length=200, verbose_name="provider"),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 3.2.20 on 2023-09-03 19:46
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0004_app_provider_id_settings"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialtoken",
|
||||
name="app",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="socialaccount.socialapp",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
# Generated by Django 3.2.20 on 2023-10-11 09:23
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0005_socialtoken_nullable_app"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="extra_data",
|
||||
field=models.JSONField(default=dict, verbose_name="extra data"),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,411 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import authenticate, get_user_model
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
import allauth.app_settings
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.account.utils import (
|
||||
filter_users_by_email,
|
||||
get_next_redirect_url,
|
||||
setup_user_email,
|
||||
)
|
||||
from allauth.core import context
|
||||
from allauth.socialaccount import app_settings, providers, signals
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.internal import statekit
|
||||
from allauth.utils import get_request_param
|
||||
|
||||
|
||||
if not allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
raise ImproperlyConfigured(
|
||||
"allauth.socialaccount not installed, yet its models are imported."
|
||||
)
|
||||
|
||||
|
||||
class SocialAppManager(models.Manager):
|
||||
def on_site(self, request):
|
||||
if allauth.app_settings.SITES_ENABLED:
|
||||
site = get_current_site(request)
|
||||
return self.filter(sites__id=site.id)
|
||||
return self.all()
|
||||
|
||||
|
||||
class SocialApp(models.Model):
|
||||
objects = SocialAppManager()
|
||||
|
||||
# The provider type, e.g. "google", "telegram", "saml".
|
||||
provider = models.CharField(
|
||||
verbose_name=_("provider"),
|
||||
max_length=30,
|
||||
)
|
||||
# For providers that support subproviders, such as OpenID Connect and SAML,
|
||||
# this ID identifies that instance. SocialAccount's originating from app
|
||||
# will have their `provider` field set to the `provider_id` if available,
|
||||
# else `provider`.
|
||||
provider_id = models.CharField(
|
||||
verbose_name=_("provider ID"),
|
||||
max_length=200,
|
||||
blank=True,
|
||||
)
|
||||
name = models.CharField(verbose_name=_("name"), max_length=40)
|
||||
client_id = models.CharField(
|
||||
verbose_name=_("client id"),
|
||||
max_length=191,
|
||||
help_text=_("App ID, or consumer key"),
|
||||
)
|
||||
secret = models.CharField(
|
||||
verbose_name=_("secret key"),
|
||||
max_length=191,
|
||||
blank=True,
|
||||
help_text=_("API secret, client secret, or consumer secret"),
|
||||
)
|
||||
key = models.CharField(
|
||||
verbose_name=_("key"), max_length=191, blank=True, help_text=_("Key")
|
||||
)
|
||||
settings = models.JSONField(default=dict, blank=True)
|
||||
|
||||
if allauth.app_settings.SITES_ENABLED:
|
||||
# Most apps can be used across multiple domains, therefore we use
|
||||
# a ManyToManyField. Note that Facebook requires an app per domain
|
||||
# (unless the domains share a common base name).
|
||||
# blank=True allows for disabling apps without removing them
|
||||
sites = models.ManyToManyField("sites.Site", blank=True) # type: ignore[var-annotated]
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("social application")
|
||||
verbose_name_plural = _("social applications")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def get_provider(self, request):
|
||||
provider_class = providers.registry.get_class(self.provider)
|
||||
return provider_class(request=request, app=self)
|
||||
|
||||
|
||||
class SocialAccount(models.Model):
|
||||
user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
|
||||
# Given a `SocialApp` from which this account originates, this field equals
|
||||
# the app's `app.provider_id` if available, `app.provider` otherwise.
|
||||
provider = models.CharField(
|
||||
verbose_name=_("provider"),
|
||||
max_length=200,
|
||||
)
|
||||
# Just in case you're wondering if an OpenID identity URL is going
|
||||
# to fit in a 'uid':
|
||||
#
|
||||
# Ideally, URLField(max_length=1024, unique=True) would be used
|
||||
# for identity. However, MySQL has a max_length limitation of 191
|
||||
# for URLField (in case of utf8mb4). How about
|
||||
# models.TextField(unique=True) then? Well, that won't work
|
||||
# either for MySQL due to another bug[1]. So the only way out
|
||||
# would be to drop the unique constraint, or switch to shorter
|
||||
# identity URLs. Opted for the latter, as [2] suggests that
|
||||
# identity URLs are supposed to be short anyway, at least for the
|
||||
# old spec.
|
||||
#
|
||||
# [1] http://code.djangoproject.com/ticket/2495.
|
||||
# [2] http://openid.net/specs/openid-authentication-1_1.html#limits
|
||||
|
||||
uid = models.CharField(
|
||||
verbose_name=_("uid"), max_length=app_settings.UID_MAX_LENGTH
|
||||
)
|
||||
last_login = models.DateTimeField(verbose_name=_("last login"), auto_now=True)
|
||||
date_joined = models.DateTimeField(verbose_name=_("date joined"), auto_now_add=True)
|
||||
extra_data = models.JSONField(verbose_name=_("extra data"), default=dict)
|
||||
|
||||
class Meta:
|
||||
unique_together = ("provider", "uid")
|
||||
verbose_name = _("social account")
|
||||
verbose_name_plural = _("social accounts")
|
||||
|
||||
def authenticate(self):
|
||||
return authenticate(account=self)
|
||||
|
||||
def __str__(self):
|
||||
from .helpers import socialaccount_user_display
|
||||
|
||||
return socialaccount_user_display(self)
|
||||
|
||||
def get_profile_url(self):
|
||||
return self.get_provider_account().get_profile_url()
|
||||
|
||||
def get_avatar_url(self):
|
||||
return self.get_provider_account().get_avatar_url()
|
||||
|
||||
def get_provider(self, request=None):
|
||||
provider = getattr(self, "_provider", None)
|
||||
if provider:
|
||||
return provider
|
||||
adapter = get_adapter()
|
||||
provider = self._provider = adapter.get_provider(
|
||||
request or context.request, provider=self.provider
|
||||
)
|
||||
return provider
|
||||
|
||||
def get_provider_account(self):
|
||||
return self.get_provider().wrap_account(self)
|
||||
|
||||
|
||||
class SocialToken(models.Model):
|
||||
app = models.ForeignKey(SocialApp, on_delete=models.SET_NULL, blank=True, null=True)
|
||||
account = models.ForeignKey(SocialAccount, on_delete=models.CASCADE)
|
||||
token = models.TextField(
|
||||
verbose_name=_("token"),
|
||||
help_text=_('"oauth_token" (OAuth1) or access token (OAuth2)'),
|
||||
)
|
||||
token_secret = models.TextField(
|
||||
blank=True,
|
||||
verbose_name=_("token secret"),
|
||||
help_text=_('"oauth_token_secret" (OAuth1) or refresh token (OAuth2)'),
|
||||
)
|
||||
expires_at = models.DateTimeField(
|
||||
blank=True, null=True, verbose_name=_("expires at")
|
||||
)
|
||||
|
||||
class Meta:
|
||||
unique_together = ("app", "account")
|
||||
verbose_name = _("social application token")
|
||||
verbose_name_plural = _("social application tokens")
|
||||
|
||||
def __str__(self):
|
||||
return "%s (%s)" % (self._meta.verbose_name, self.pk)
|
||||
|
||||
|
||||
class SocialLogin:
|
||||
"""
|
||||
Represents a social user that is in the process of being logged
|
||||
in. This consists of the following information:
|
||||
|
||||
`account` (`SocialAccount` instance): The social account being
|
||||
logged in. Providers are not responsible for checking whether or
|
||||
not an account already exists or not. Therefore, a provider
|
||||
typically creates a new (unsaved) `SocialAccount` instance. The
|
||||
`User` instance pointed to by the account (`account.user`) may be
|
||||
prefilled by the provider for use as a starting point later on
|
||||
during the signup process.
|
||||
|
||||
`token` (`SocialToken` instance): An optional access token token
|
||||
that results from performing a successful authentication
|
||||
handshake.
|
||||
|
||||
`state` (`dict`): The state to be preserved during the
|
||||
authentication handshake. Note that this state may end up in the
|
||||
url -- do not put any secrets in here. It currently only contains
|
||||
the url to redirect to after login.
|
||||
|
||||
`email_addresses` (list of `EmailAddress`): Optional list of
|
||||
email addresses retrieved from the provider.
|
||||
"""
|
||||
|
||||
account: SocialAccount
|
||||
token: Optional[SocialToken]
|
||||
email_addresses: List[EmailAddress]
|
||||
state: Dict
|
||||
_did_authenticate_by_email: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user=None,
|
||||
account: Optional[SocialAccount] = None,
|
||||
token: Optional[SocialToken] = None,
|
||||
email_addresses: Optional[List[EmailAddress]] = None,
|
||||
):
|
||||
if token:
|
||||
assert token.account is None or token.account == account
|
||||
self.token = token
|
||||
self.user = user
|
||||
if account:
|
||||
self.account = account
|
||||
self.email_addresses = email_addresses if email_addresses else []
|
||||
self.state = {}
|
||||
|
||||
def connect(self, request, user) -> None:
|
||||
self.user = user
|
||||
self.save(request, connect=True)
|
||||
signals.social_account_added.send(
|
||||
sender=SocialLogin, request=request, sociallogin=self
|
||||
)
|
||||
|
||||
get_adapter().send_notification_mail(
|
||||
"socialaccount/email/account_connected",
|
||||
self.user,
|
||||
context={
|
||||
"account": self.account,
|
||||
"provider": self.account.get_provider(),
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def is_headless(self) -> bool:
|
||||
return bool(self.state.get("headless"))
|
||||
|
||||
def serialize(self) -> Dict[str, Any]:
|
||||
serialize_instance = get_adapter().serialize_instance
|
||||
ret = dict(
|
||||
account=serialize_instance(self.account),
|
||||
user=serialize_instance(self.user),
|
||||
state=self.state,
|
||||
email_addresses=[serialize_instance(ea) for ea in self.email_addresses],
|
||||
)
|
||||
if self.token:
|
||||
ret["token"] = serialize_instance(self.token)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: Dict[str, Any]) -> "SocialLogin":
|
||||
deserialize_instance = get_adapter().deserialize_instance
|
||||
account = deserialize_instance(SocialAccount, data["account"])
|
||||
user = deserialize_instance(get_user_model(), data["user"])
|
||||
if "token" in data:
|
||||
token = deserialize_instance(SocialToken, data["token"])
|
||||
else:
|
||||
token = None
|
||||
email_addresses = []
|
||||
for ea in data["email_addresses"]:
|
||||
email_address = deserialize_instance(EmailAddress, ea)
|
||||
email_addresses.append(email_address)
|
||||
ret = cls()
|
||||
ret.token = token
|
||||
ret.account = account
|
||||
ret.user = user
|
||||
ret.email_addresses = email_addresses
|
||||
ret.state = data["state"]
|
||||
return ret
|
||||
|
||||
def save(self, request, connect: bool = False) -> None:
|
||||
"""
|
||||
Saves a new account. Note that while the account is new,
|
||||
the user may be an existing one (when connecting accounts)
|
||||
"""
|
||||
user = self.user
|
||||
user.save()
|
||||
self.account.user = user
|
||||
self.account.save()
|
||||
if app_settings.STORE_TOKENS and self.token:
|
||||
self.token.account = self.account
|
||||
self.token.save()
|
||||
if connect:
|
||||
# TODO: Add any new email addresses automatically?
|
||||
pass
|
||||
else:
|
||||
setup_user_email(request, user, self.email_addresses)
|
||||
|
||||
@property
|
||||
def is_existing(self) -> bool:
|
||||
"""When `False`, this social login represents a temporary account, not
|
||||
yet backed by a database record.
|
||||
"""
|
||||
if self.user.pk is None:
|
||||
return False
|
||||
return get_user_model().objects.filter(pk=self.user.pk).exists()
|
||||
|
||||
def lookup(self) -> None:
|
||||
"""Look up the existing local user account to which this social login
|
||||
points, if any.
|
||||
"""
|
||||
self._did_authenticate_by_email = None
|
||||
if not self._lookup_by_socialaccount():
|
||||
self._lookup_by_email()
|
||||
|
||||
def _lookup_by_socialaccount(self) -> bool:
|
||||
assert not self.is_existing
|
||||
try:
|
||||
a = SocialAccount.objects.get(
|
||||
provider=self.account.provider, uid=self.account.uid
|
||||
)
|
||||
# Update account
|
||||
a.extra_data = self.account.extra_data
|
||||
self.account = a
|
||||
self.user = self.account.user
|
||||
a.save()
|
||||
signals.social_account_updated.send(
|
||||
sender=SocialLogin, request=context.request, sociallogin=self
|
||||
)
|
||||
self._store_token()
|
||||
return True
|
||||
except SocialAccount.DoesNotExist:
|
||||
return False
|
||||
|
||||
def _store_token(self) -> None:
|
||||
# Update token
|
||||
if not app_settings.STORE_TOKENS or not self.token:
|
||||
return
|
||||
assert not self.token.pk
|
||||
app = self.token.app
|
||||
if app and not app.pk:
|
||||
# If the app is not stored in the db, leave the FK empty.
|
||||
app = None
|
||||
try:
|
||||
t = SocialToken.objects.get(account=self.account, app=app)
|
||||
t.token = self.token.token
|
||||
if self.token.token_secret:
|
||||
# only update the refresh token if we got one
|
||||
# many oauth2 providers do not resend the refresh token
|
||||
t.token_secret = self.token.token_secret
|
||||
t.expires_at = self.token.expires_at
|
||||
t.save()
|
||||
self.token = t
|
||||
except SocialToken.DoesNotExist:
|
||||
self.token.account = self.account
|
||||
self.token.app = app
|
||||
self.token.save()
|
||||
|
||||
def _lookup_by_email(self) -> None:
|
||||
emails = [e.email for e in self.email_addresses if e.verified]
|
||||
for email in emails:
|
||||
if not get_adapter().can_authenticate_by_email(self, email):
|
||||
continue
|
||||
users = filter_users_by_email(email, prefer_verified=True)
|
||||
if users:
|
||||
self.user = users[0]
|
||||
self._did_authenticate_by_email = email
|
||||
return
|
||||
|
||||
def _accept_login(self, request) -> None:
|
||||
from allauth.socialaccount.internal.flows.email_authentication import (
|
||||
wipe_password,
|
||||
)
|
||||
|
||||
if self._did_authenticate_by_email:
|
||||
wipe_password(request, self.user, self._did_authenticate_by_email)
|
||||
if app_settings.EMAIL_AUTHENTICATION_AUTO_CONNECT:
|
||||
self.connect(context.request, self.user)
|
||||
|
||||
def get_redirect_url(self, request) -> Optional[str]:
|
||||
url = self.state.get("next")
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def state_from_request(cls, request) -> Dict[str, Any]:
|
||||
"""
|
||||
TODO: Deprecated! To be integrated with provider.redirect()
|
||||
"""
|
||||
state = {}
|
||||
next_url = get_next_redirect_url(request)
|
||||
if next_url:
|
||||
state["next"] = next_url
|
||||
state["process"] = get_request_param(request, "process", "login")
|
||||
state["scope"] = get_request_param(request, "scope", "")
|
||||
state["auth_params"] = get_request_param(request, "auth_params", "")
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def stash_state(cls, request, state: Optional[Dict[str, Any]] = None) -> str:
|
||||
if state is None:
|
||||
# Only for providers that don't support redirect() yet.
|
||||
state = cls.state_from_request(request)
|
||||
return statekit.stash_state(request, state)
|
||||
|
||||
@classmethod
|
||||
def unstash_state(cls, request) -> Optional[Dict[str, Any]]:
|
||||
state = statekit.unstash_last_state(request)
|
||||
if state is None:
|
||||
raise PermissionDenied()
|
||||
return state
|
||||
@@ -0,0 +1,57 @@
|
||||
import importlib
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
|
||||
from allauth.utils import import_attribute
|
||||
|
||||
|
||||
class ProviderRegistry:
|
||||
def __init__(self):
|
||||
self.provider_map = OrderedDict()
|
||||
self.loaded = False
|
||||
|
||||
def get_class_list(self):
|
||||
self.load()
|
||||
return list(self.provider_map.values())
|
||||
|
||||
def register(self, cls):
|
||||
self.provider_map[cls.id] = cls
|
||||
|
||||
def get_class(self, id):
|
||||
return self.provider_map.get(id)
|
||||
|
||||
def as_choices(self):
|
||||
self.load()
|
||||
for provider_cls in self.provider_map.values():
|
||||
yield (provider_cls.id, provider_cls.name)
|
||||
|
||||
def load(self):
|
||||
# TODO: Providers register with the provider registry when
|
||||
# loaded. Here, we build the URLs for all registered providers. So, we
|
||||
# really need to be sure all providers did register, which is why we're
|
||||
# forcefully importing the `provider` modules here. The overall
|
||||
# mechanism is way to magical and depends on the import order et al, so
|
||||
# all of this really needs to be revisited.
|
||||
if not self.loaded:
|
||||
for app_config in apps.get_app_configs():
|
||||
try:
|
||||
module_name = app_config.name + ".provider"
|
||||
provider_module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
if e.name != module_name:
|
||||
raise
|
||||
else:
|
||||
provider_settings = getattr(settings, "SOCIALACCOUNT_PROVIDERS", {})
|
||||
for cls in getattr(provider_module, "provider_classes", []):
|
||||
provider_class = provider_settings.get(cls.id, {}).get(
|
||||
"provider_class"
|
||||
)
|
||||
if provider_class:
|
||||
cls = import_attribute(provider_class)
|
||||
self.register(cls)
|
||||
self.loaded = True
|
||||
|
||||
|
||||
registry = ProviderRegistry()
|
||||
@@ -0,0 +1,37 @@
|
||||
from allauth.socialaccount.providers.agave.views import AgaveAdapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AgaveAccount(ProviderAccount):
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("web_url", "dflt")
|
||||
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("avatar_url", "dflt")
|
||||
|
||||
|
||||
class AgaveProvider(OAuth2Provider):
|
||||
id = "agave"
|
||||
name = "Agave"
|
||||
account_class = AgaveAccount
|
||||
oauth2_adapter_class = AgaveAdapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data.get("create_time"))
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(
|
||||
email=data.get("email"),
|
||||
username=data.get("username", ""),
|
||||
name=(
|
||||
(data.get("first_name", "") + " " + data.get("last_name", "")).strip()
|
||||
),
|
||||
)
|
||||
|
||||
def get_default_scope(self):
|
||||
scope = ["PRODUCTION"]
|
||||
return scope
|
||||
|
||||
|
||||
provider_classes = [AgaveProvider]
|
||||
@@ -0,0 +1,34 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AgaveProvider
|
||||
|
||||
|
||||
class AgaveTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AgaveProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{
|
||||
"status": "success",
|
||||
"message": "User details retrieved successfully.",
|
||||
"version": "2.0.0-SNAPSHOT-rc3fad",
|
||||
"result": {
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"full_name": "John Doe",
|
||||
"email": "jon@doe.edu",
|
||||
"phone": "",
|
||||
"mobile_phone": "",
|
||||
"status": "Active",
|
||||
"create_time": "20180322043812Z",
|
||||
"username": "jdoe"
|
||||
}
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "jdoe"
|
||||
@@ -0,0 +1,5 @@
|
||||
from allauth.socialaccount.providers.agave.provider import AgaveProvider
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AgaveProvider)
|
||||
@@ -0,0 +1,41 @@
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AgaveAdapter(OAuth2Adapter):
|
||||
provider_id = "agave"
|
||||
|
||||
settings = app_settings.PROVIDERS.get(provider_id, {})
|
||||
provider_base_url = settings.get("API_URL", "https://public.agaveapi.co")
|
||||
|
||||
access_token_url = "{0}/token".format(provider_base_url)
|
||||
authorize_url = "{0}/authorize".format(provider_base_url)
|
||||
profile_url = "{0}/profiles/v2/me".format(provider_base_url)
|
||||
|
||||
def complete_login(self, request, app, token, response):
|
||||
extra_data = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(
|
||||
self.profile_url,
|
||||
params={"access_token": token.token},
|
||||
headers={
|
||||
"Authorization": "Bearer " + token.token,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
user_profile = (
|
||||
extra_data.json()["result"] if "result" in extra_data.json() else {}
|
||||
)
|
||||
|
||||
return self.get_provider().sociallogin_from_response(request, user_profile)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AgaveAdapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AgaveAdapter)
|
||||
@@ -0,0 +1,34 @@
|
||||
from allauth.socialaccount.providers.amazon.views import AmazonOAuth2Adapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AmazonAccount(ProviderAccount):
|
||||
pass
|
||||
|
||||
|
||||
class AmazonProvider(OAuth2Provider):
|
||||
id = "amazon"
|
||||
name = "Amazon"
|
||||
account_class = AmazonAccount
|
||||
oauth2_adapter_class = AmazonOAuth2Adapter
|
||||
|
||||
def get_default_scope(self):
|
||||
return ["profile"]
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["user_id"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
# Hackish way of splitting the fullname.
|
||||
# Assumes no middlenames.
|
||||
name = data.get("name", "")
|
||||
first_name, last_name = name, ""
|
||||
if name and " " in name:
|
||||
first_name, last_name = name.split(" ", 1)
|
||||
return dict(
|
||||
email=data.get("email", ""), last_name=last_name, first_name=first_name
|
||||
)
|
||||
|
||||
|
||||
provider_classes = [AmazonProvider]
|
||||
@@ -0,0 +1,24 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AmazonProvider
|
||||
|
||||
|
||||
class AmazonTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AmazonProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{
|
||||
"Profile":{
|
||||
"CustomerId":"amzn1.account.K2LI23KL2LK2",
|
||||
"Name":"John Doe",
|
||||
"PrimaryEmail":"johndoe@example.com"
|
||||
}
|
||||
}""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "johndoe@example.com"
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import AmazonProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AmazonProvider)
|
||||
@@ -0,0 +1,33 @@
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AmazonOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "amazon"
|
||||
access_token_url = "https://api.amazon.com/auth/o2/token"
|
||||
authorize_url = "http://www.amazon.com/ap/oa"
|
||||
profile_url = "https://api.amazon.com/user/profile"
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
response = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(self.profile_url, params={"access_token": token.token})
|
||||
)
|
||||
response.raise_for_status()
|
||||
extra_data = response.json()
|
||||
if "Profile" in extra_data:
|
||||
extra_data = {
|
||||
"user_id": extra_data["Profile"]["CustomerId"],
|
||||
"name": extra_data["Profile"]["Name"],
|
||||
"email": extra_data["Profile"]["PrimaryEmail"],
|
||||
}
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AmazonOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonOAuth2Adapter)
|
||||
@@ -0,0 +1,69 @@
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.providers.amazon_cognito.utils import (
|
||||
convert_to_python_bool_if_value_is_json_string_bool,
|
||||
)
|
||||
from allauth.socialaccount.providers.amazon_cognito.views import (
|
||||
AmazonCognitoOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AmazonCognitoAccount(ProviderAccount):
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("picture")
|
||||
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("profile")
|
||||
|
||||
|
||||
class AmazonCognitoProvider(OAuth2Provider):
|
||||
id = "amazon_cognito"
|
||||
name = "Amazon Cognito"
|
||||
account_class = AmazonCognitoAccount
|
||||
oauth2_adapter_class = AmazonCognitoOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["sub"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return {
|
||||
"email": data.get("email"),
|
||||
"first_name": data.get("given_name"),
|
||||
"last_name": data.get("family_name"),
|
||||
}
|
||||
|
||||
def get_default_scope(self):
|
||||
return ["openid", "profile", "email"]
|
||||
|
||||
def extract_email_addresses(self, data):
|
||||
email = data.get("email")
|
||||
verified = convert_to_python_bool_if_value_is_json_string_bool(
|
||||
data.get("email_verified", False)
|
||||
)
|
||||
|
||||
return (
|
||||
[EmailAddress(email=email, verified=verified, primary=True)]
|
||||
if email
|
||||
else []
|
||||
)
|
||||
|
||||
def extract_extra_data(self, data):
|
||||
ret = dict(data)
|
||||
phone_number_verified = data.get("phone_number_verified")
|
||||
if phone_number_verified is not None:
|
||||
ret["phone_number_verified"] = (
|
||||
convert_to_python_bool_if_value_is_json_string_bool(
|
||||
"phone_number_verified"
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def get_slug(cls):
|
||||
# IMPORTANT: Amazon Cognito does not support `_` characters
|
||||
# as part of their redirect URI.
|
||||
return super(AmazonCognitoProvider, cls).get_slug().replace("_", "-")
|
||||
|
||||
|
||||
provider_classes = [AmazonCognitoProvider]
|
||||
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
import pytest
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from allauth.socialaccount.providers.amazon_cognito.provider import (
|
||||
AmazonCognitoProvider,
|
||||
)
|
||||
from allauth.socialaccount.providers.amazon_cognito.utils import (
|
||||
convert_to_python_bool_if_value_is_json_string_bool,
|
||||
)
|
||||
from allauth.socialaccount.providers.amazon_cognito.views import (
|
||||
AmazonCognitoOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
|
||||
def _get_mocked_claims():
|
||||
return {
|
||||
"sub": "[HEROKU-API-KEY-REMOVED]",
|
||||
"given_name": "John",
|
||||
"family_name": "Doe",
|
||||
"email": "jdoe@example.com",
|
||||
"username": "johndoe",
|
||||
}
|
||||
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_PROVIDERS={
|
||||
"amazon_cognito": {"DOMAIN": "https://domain.auth.us-east-1.amazoncognito.com"}
|
||||
}
|
||||
)
|
||||
class AmazonCognitoTestCase(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AmazonCognitoProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
mocked_payload = json.dumps(_get_mocked_claims())
|
||||
return MockedResponse(status_code=200, content=mocked_payload)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "johndoe"
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={"amazon_cognito": {}})
|
||||
def test_oauth2_adapter_raises_if_domain_settings_is_missing(
|
||||
self,
|
||||
):
|
||||
mocked_response = self.get_mocked_response()
|
||||
|
||||
with self.assertRaises(
|
||||
ValueError,
|
||||
msg=AmazonCognitoOAuth2Adapter.DOMAIN_KEY_MISSING_ERROR,
|
||||
):
|
||||
self.login(mocked_response)
|
||||
|
||||
def test_saves_email_as_verified_if_email_is_verified_in_cognito(
|
||||
self,
|
||||
):
|
||||
mocked_claims = _get_mocked_claims()
|
||||
mocked_claims["email_verified"] = True
|
||||
mocked_payload = json.dumps(mocked_claims)
|
||||
mocked_response = MockedResponse(status_code=200, content=mocked_payload)
|
||||
|
||||
self.login(mocked_response)
|
||||
|
||||
user_id = SocialAccount.objects.get(uid=mocked_claims["sub"]).user_id
|
||||
email_address = EmailAddress.objects.get(user_id=user_id)
|
||||
|
||||
self.assertEqual(email_address.email, mocked_claims["email"])
|
||||
self.assertTrue(email_address.verified)
|
||||
|
||||
def test_provider_slug_replaces_underscores_with_hyphens(self):
|
||||
self.assertTrue("_" not in self.provider.get_slug())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,output",
|
||||
[
|
||||
(True, True),
|
||||
("true", True),
|
||||
("false", False),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_convert_bool(input, output):
|
||||
assert convert_to_python_bool_if_value_is_json_string_bool(input) == output
|
||||
@@ -0,0 +1,7 @@
|
||||
from allauth.socialaccount.providers.amazon_cognito.provider import (
|
||||
AmazonCognitoProvider,
|
||||
)
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AmazonCognitoProvider)
|
||||
@@ -0,0 +1,7 @@
|
||||
def convert_to_python_bool_if_value_is_json_string_bool(s):
|
||||
if s == "true":
|
||||
return True
|
||||
elif s == "false":
|
||||
return False
|
||||
|
||||
return s
|
||||
@@ -0,0 +1,56 @@
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialToken
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AmazonCognitoOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "amazon_cognito"
|
||||
|
||||
DOMAIN_KEY_MISSING_ERROR = (
|
||||
'"DOMAIN" key is missing in Amazon Cognito configuration.'
|
||||
)
|
||||
|
||||
@property
|
||||
def settings(self):
|
||||
return app_settings.PROVIDERS.get(self.provider_id, {})
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
domain = self.settings.get("DOMAIN")
|
||||
|
||||
if domain is None:
|
||||
raise ValueError(self.DOMAIN_KEY_MISSING_ERROR)
|
||||
|
||||
return domain
|
||||
|
||||
@property
|
||||
def access_token_url(self):
|
||||
return "{}/oauth2/token".format(self.domain)
|
||||
|
||||
@property
|
||||
def authorize_url(self):
|
||||
return "{}/oauth2/authorize".format(self.domain)
|
||||
|
||||
@property
|
||||
def profile_url(self):
|
||||
return "{}/oauth2/userInfo".format(self.domain)
|
||||
|
||||
def complete_login(self, request, app, token: SocialToken, **kwargs):
|
||||
headers = {
|
||||
"Authorization": "Bearer {}".format(token.token),
|
||||
}
|
||||
extra_data = (
|
||||
get_adapter().get_requests_session().get(self.profile_url, headers=headers)
|
||||
)
|
||||
extra_data.raise_for_status()
|
||||
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data.json())
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AmazonCognitoOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonCognitoOAuth2Adapter)
|
||||
@@ -0,0 +1,33 @@
|
||||
from allauth.socialaccount.providers.angellist.views import (
|
||||
AngelListOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AngelListAccount(ProviderAccount):
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("angellist_url")
|
||||
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("image")
|
||||
|
||||
|
||||
class AngelListProvider(OAuth2Provider):
|
||||
id = "angellist"
|
||||
name = "AngelList"
|
||||
account_class = AngelListAccount
|
||||
oauth2_adapter_class = AngelListOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["id"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(
|
||||
email=data.get("email"),
|
||||
username=data.get("angellist_url").split("/")[-1],
|
||||
name=data.get("name"),
|
||||
)
|
||||
|
||||
|
||||
provider_classes = [AngelListProvider]
|
||||
@@ -0,0 +1,28 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AngelListProvider
|
||||
|
||||
|
||||
class AngelListTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AngelListProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{"name":"pennersr","id":424732,"bio":"","follower_count":0,
|
||||
"angellist_url":"https://angel.co/dsxtst",
|
||||
"image":"https://angel.co/images/shared/nopic.png",
|
||||
"email":"raymond.penners@example.com","blog_url":null,
|
||||
"online_bio_url":null,"twitter_url":"https://twitter.com/dsxtst",
|
||||
"facebook_url":null,"linkedin_url":null,"aboutme_url":null,
|
||||
"github_url":null,"dribbble_url":null,"behance_url":null,
|
||||
"what_ive_built":null,"locations":[],"roles":[],"skills":[],
|
||||
"investor":false,"scopes":["message","talent","dealflow","comment",
|
||||
"email"]}
|
||||
""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "raymond.penners@example.com"
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import AngelListProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AngelListProvider)
|
||||
@@ -0,0 +1,27 @@
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AngelListOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "angellist"
|
||||
access_token_url = "https://angel.co/api/oauth/token/"
|
||||
authorize_url = "https://angel.co/api/oauth/authorize/"
|
||||
profile_url = "https://api.angel.co/1/me/"
|
||||
supports_state = False
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
resp = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(self.profile_url, params={"access_token": token.token})
|
||||
)
|
||||
extra_data = resp.json()
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AngelListOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AngelListOAuth2Adapter)
|
||||
@@ -0,0 +1,8 @@
|
||||
from allauth.socialaccount.sessions import LoginSession
|
||||
|
||||
|
||||
APPLE_SESSION_COOKIE_NAME = "apple-login-session"
|
||||
|
||||
|
||||
def get_apple_session(request):
|
||||
return LoginSession(request, "apple_login_session", APPLE_SESSION_COOKIE_NAME)
|
||||
@@ -0,0 +1,101 @@
|
||||
import time
|
||||
from urllib.parse import parse_qsl, quote, urlencode
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
import jwt
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.client import (
|
||||
OAuth2Client,
|
||||
OAuth2Error,
|
||||
)
|
||||
|
||||
|
||||
def jwt_encode(*args, **kwargs):
|
||||
resp = jwt.encode(*args, **kwargs)
|
||||
if isinstance(resp, bytes):
|
||||
# For PyJWT <2
|
||||
resp = resp.decode("utf-8")
|
||||
return resp
|
||||
|
||||
|
||||
class Scope:
|
||||
EMAIL = "email"
|
||||
NAME = "name"
|
||||
|
||||
|
||||
class AppleOAuth2Client(OAuth2Client):
|
||||
"""
|
||||
Custom client because `Sign In With Apple`:
|
||||
* requires `response_mode` field in redirect_url
|
||||
* requires special `client_secret` as JWT
|
||||
"""
|
||||
|
||||
def generate_client_secret(self):
|
||||
"""Create a JWT signed with an apple provided private key"""
|
||||
now = int(time.time())
|
||||
app = get_adapter(self.request).get_app(self.request, "apple")
|
||||
if not app.key:
|
||||
raise ImproperlyConfigured("Apple 'key' missing")
|
||||
certificate_key = app.settings.get("certificate_key")
|
||||
if not certificate_key:
|
||||
raise ImproperlyConfigured("Apple 'certificate_key' missing")
|
||||
claims = {
|
||||
"iss": app.key,
|
||||
"aud": "https://appleid.apple.com",
|
||||
"sub": self.get_client_id(),
|
||||
"iat": now,
|
||||
"exp": now + 60 * 60,
|
||||
}
|
||||
headers = {"kid": self.consumer_secret, "alg": "ES256"}
|
||||
client_secret = jwt_encode(
|
||||
payload=claims, key=certificate_key, algorithm="ES256", headers=headers
|
||||
)
|
||||
return client_secret
|
||||
|
||||
def get_client_id(self):
|
||||
"""We support multiple client_ids, but use the first one for api calls"""
|
||||
return self.consumer_key.split(",")[0]
|
||||
|
||||
def get_access_token(self, code, pkce_code_verifier=None):
|
||||
url = self.access_token_url
|
||||
client_secret = self.generate_client_secret()
|
||||
data = {
|
||||
"client_id": self.get_client_id(),
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": self.callback_url,
|
||||
"client_secret": client_secret,
|
||||
}
|
||||
if pkce_code_verifier:
|
||||
data["code_verifier"] = pkce_code_verifier
|
||||
self._strip_empty_keys(data)
|
||||
resp = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.request(self.access_token_method, url, data=data, headers=self.headers)
|
||||
)
|
||||
access_token = None
|
||||
if resp.status_code in [200, 201]:
|
||||
try:
|
||||
access_token = resp.json()
|
||||
except ValueError:
|
||||
access_token = dict(parse_qsl(resp.text))
|
||||
if not access_token or "access_token" not in access_token:
|
||||
raise OAuth2Error("Error retrieving access token: %s" % resp.content)
|
||||
return access_token
|
||||
|
||||
def get_redirect_url(self, authorization_url, scope, extra_params):
|
||||
scope = self.scope_delimiter.join(set(scope))
|
||||
params = {
|
||||
"client_id": self.get_client_id(),
|
||||
"redirect_uri": self.callback_url,
|
||||
"response_mode": "form_post",
|
||||
"scope": scope,
|
||||
"response_type": "code id_token",
|
||||
}
|
||||
if self.state:
|
||||
params["state"] = self.state
|
||||
params.update(extra_params)
|
||||
return "%s?%s" % (authorization_url, urlencode(params, quote_via=quote))
|
||||
@@ -0,0 +1,92 @@
|
||||
import requests
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.app_settings import QUERY_EMAIL
|
||||
from allauth.socialaccount.providers.apple.views import AppleOAuth2Adapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AppleAccount(ProviderAccount):
|
||||
def to_str(self):
|
||||
email = self.account.extra_data.get("email")
|
||||
if email and not email.lower().endswith("@privaterelay.appleid.com"):
|
||||
return email
|
||||
|
||||
name = self.account.extra_data.get("name") or {}
|
||||
if name.get("firstName") or name.get("lastName"):
|
||||
full_name = f"{name['firstName'] or ''} {name['lastName'] or ''}"
|
||||
full_name = full_name.strip()
|
||||
if full_name:
|
||||
return full_name
|
||||
|
||||
return super().to_str()
|
||||
|
||||
|
||||
class AppleProvider(OAuth2Provider):
|
||||
id = "apple"
|
||||
name = "Apple"
|
||||
account_class = AppleAccount
|
||||
oauth2_adapter_class = AppleOAuth2Adapter
|
||||
supports_token_authentication = True
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["sub"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
fields = {"email": data.get("email")}
|
||||
|
||||
# If the name was provided
|
||||
name = data.get("name")
|
||||
if name:
|
||||
fields["first_name"] = name.get("firstName", "")
|
||||
fields["last_name"] = name.get("lastName", "")
|
||||
|
||||
return fields
|
||||
|
||||
def extract_email_addresses(self, data):
|
||||
ret = []
|
||||
email = data.get("email")
|
||||
verified = data.get("email_verified")
|
||||
if isinstance(verified, str):
|
||||
verified = verified.lower() == "true"
|
||||
if email:
|
||||
ret.append(
|
||||
EmailAddress(
|
||||
email=email,
|
||||
verified=verified,
|
||||
primary=True,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
def get_default_scope(self):
|
||||
scopes = ["name"]
|
||||
if QUERY_EMAIL:
|
||||
scopes.append("email")
|
||||
return scopes
|
||||
|
||||
def verify_token(self, request, token):
|
||||
from allauth.socialaccount.providers.apple.views import (
|
||||
AppleOAuth2Adapter,
|
||||
)
|
||||
|
||||
id_token = token.get("id_token")
|
||||
if not id_token:
|
||||
raise get_adapter().validation_error("invalid_token")
|
||||
try:
|
||||
identity_data = AppleOAuth2Adapter.get_verified_identity_data(
|
||||
self, id_token
|
||||
)
|
||||
except (OAuth2Error, requests.RequestException) as e:
|
||||
raise get_adapter().validation_error("invalid_token") from e
|
||||
login = self.sociallogin_from_response(request, identity_data)
|
||||
return login
|
||||
|
||||
def get_auds(self):
|
||||
return [aud.strip() for aud in self.app.client_id.split(",")]
|
||||
|
||||
|
||||
provider_classes = [AppleProvider]
|
||||
@@ -0,0 +1,264 @@
|
||||
import json
|
||||
import time
|
||||
from importlib import import_module
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.test.utils import override_settings
|
||||
from django.urls import reverse
|
||||
from django.utils.http import urlencode
|
||||
|
||||
import jwt
|
||||
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase, mocked_response
|
||||
|
||||
from .apple_session import APPLE_SESSION_COOKIE_NAME
|
||||
from .client import jwt_encode
|
||||
from .provider import AppleProvider
|
||||
|
||||
|
||||
# Generated on https://mkjwk.org/, used to sign and verify the apple id_token
|
||||
TESTING_JWT_KEYSET = {
|
||||
"p": (
|
||||
"4ADzS5jKx_kdQihyOocVS0Qwwo7m0f7Ow56EadySJ-cmnwoHHF3AxgRaq-h-KwybSphv"
|
||||
"dc-X7NbS79-b9dumHKyt1MeVLAsDZD1a-uQCEneY1g9LsQkscNr7OggcpvMg5UUFwv6A"
|
||||
"kavu8cB0iyhNdha5_AWX27K5lNebvpaXEJ8"
|
||||
),
|
||||
"kty": "RSA",
|
||||
"q": (
|
||||
"yy5UvMjrvZyO1Os_nxXIugCa3NyWOkC8oMppPvr1Bl5AnF_xwXN2n9ozPd9Nb3Q3n-om"
|
||||
"NgLayyUxhwIjWDlI67Vbx-ESuff8ZEBKuTK0Gdmr4C_QU_j0gvvNMNJweSPxDdRmIUgO"
|
||||
"njTVNWmdqFTZs43jXAT4J519rgveNLAkGNE"
|
||||
),
|
||||
"d": (
|
||||
"riPuGIDde88WS03CVbo_mZ9toFWPyTxvuz8VInJ9S1ZxULo-hQWDBohWGYwvg8cgfXck"
|
||||
"cqWt5OBqNvPYdLgwb84uVi2JeEHmhcQSc_x0zfRTau5HVE2KdR-gWxQjPWoaBHeDVqwo"
|
||||
"PKaU2XYxa-[AWS-SECRET-REMOVED]EHwyWXJbTpoar7AARW"
|
||||
"kz76qtngDkk4t9gk_Q0L1y1qf1GeWiAL7xWb-bdptma4-1ui-R2219-1ONEZ41v_jsIS"
|
||||
"_[AWS-SECRET-REMOVED]sXjaIwkdItbDmL1jFUgefwfO91Y"
|
||||
"YQ"
|
||||
),
|
||||
"e": "AQAB",
|
||||
"use": "sig",
|
||||
"kid": "testkey",
|
||||
"qi": (
|
||||
"R0Hu4YmpHzw3SKWGYuAcAo6B97-[AWS-SECRET-REMOVED]r"
|
||||
"[AWS-SECRET-REMOVED]t1c4tTotFDdw8WFptDOw4ow31Tml"
|
||||
"BPExLqzzGjJeQSNULB1bExuuhYMWx6wBXo8"
|
||||
),
|
||||
"dp": (
|
||||
"WBaHlnbjZ3hDVTzqjrGIYizSr-_[AWS-SECRET-REMOVED]G"
|
||||
"wuF78RsZoFLi1fAmhqgxQ7eopcU-[AWS-SECRET-REMOVED]"
|
||||
"szhVoqP4MLEMpR-Sy9S3PyItcKbJDE3T4ik"
|
||||
),
|
||||
"alg": "RS256",
|
||||
"dq": (
|
||||
"[AWS-SECRET-REMOVED]zRk2vCXbiOY7Qttad8ptLEUgfytV"
|
||||
"SsNtGvMsoQsZWRak8nHnhGJ4s0QzB1OK7sdNgU_cL1HV-VxSSPaHhdJBrJEcrzggDPEB"
|
||||
"KYfDHU6Iz34d1nvjBxoWE8rfqJsGbCW4xxE"
|
||||
),
|
||||
"n": (
|
||||
"sclLPioUv4VOcOZWAKoRhcvwIH2jOhoHhSI_Cj5c5zSp7qaK8jCU6T7-GObsgrhpty-k"
|
||||
"26ZuqRdgu9d-62WO8OBGt1e0wxbTh128-nTTrOESHUlV_K1wpJmXOxNpJiybcgzZNbAm"
|
||||
"ACmsHfxZvN9bt7gKPXxf3-_zFAf12PbYMrOionAJ1N_4HxL7fz3xkr5C87Av06QNilIC"
|
||||
"-mA-4n9Eqw_R2DYNpE3RYMdWtwKqBwJC8qs3677RpG9vcc-yZ_97pEiytd2FBJ8uoTwH"
|
||||
"[AWS-SECRET-REMOVED]79LrsfOzrXF4enkfKJYI40-uwT95"
|
||||
"zw"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Mocked version of the test data from https://appleid.apple.com/auth/keys
|
||||
KEY_SERVER_RESP_JSON = json.dumps(
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kty": TESTING_JWT_KEYSET["kty"],
|
||||
"kid": TESTING_JWT_KEYSET["kid"],
|
||||
"use": TESTING_JWT_KEYSET["use"],
|
||||
"alg": TESTING_JWT_KEYSET["alg"],
|
||||
"n": TESTING_JWT_KEYSET["n"],
|
||||
"e": TESTING_JWT_KEYSET["e"],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def sign_id_token(payload):
|
||||
"""
|
||||
Sign a payload as apple normally would for the id_token.
|
||||
"""
|
||||
signing_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(TESTING_JWT_KEYSET))
|
||||
return jwt_encode(
|
||||
payload,
|
||||
signing_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": TESTING_JWT_KEYSET["kid"]},
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_STORE_TOKENS=False,
|
||||
SOCIALACCOUNT_PROVIDERS={
|
||||
"apple": {
|
||||
"APP": {
|
||||
"client_id": "app123id",
|
||||
"key": "apple",
|
||||
"secret": "dummy",
|
||||
"settings": {
|
||||
"certificate_key": """[PRIVATE-KEY-REMOVED]
|
||||
[AWS-SECRET-REMOVED]awIBAQQg2+Eybl8ojH4wB30C
|
||||
[AWS-SECRET-REMOVED]Q+EpNgQQyQVs/F27dkq3gvAI
|
||||
[AWS-SECRET-REMOVED]Ru3XGyqy3mdb8gMy
|
||||
-----END PRIVATE KEY-----
|
||||
""",
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
class AppleTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AppleProvider.id
|
||||
|
||||
def get_apple_id_token_payload(self):
|
||||
now = int(time.time())
|
||||
return {
|
||||
"iss": "https://appleid.apple.com",
|
||||
"aud": "app123id", # Matches `setup_app`
|
||||
"exp": now + 60 * 60,
|
||||
"iat": now,
|
||||
"sub": "000313.c9720f41e9434e18987a.1218",
|
||||
"at_hash": "CkaUPjk4MJinaAq6Z0tGUA",
|
||||
"email": "test@privaterelay.appleid.com",
|
||||
"email_verified": "true",
|
||||
"is_private_email": "true",
|
||||
"auth_time": 1234345345, # not converted automatically by pyjwt
|
||||
}
|
||||
|
||||
def test_verify_token(self):
|
||||
id_token = sign_id_token(self.get_apple_id_token_payload())
|
||||
with mocked_response(self.get_mocked_response()):
|
||||
sociallogin = self.provider.verify_token(None, {"id_token": id_token})
|
||||
assert sociallogin.user.email == "test@privaterelay.appleid.com"
|
||||
|
||||
def get_login_response_json(self, with_refresh_token=True):
|
||||
"""
|
||||
`with_refresh_token` is not optional for apple, so it's ignored.
|
||||
"""
|
||||
id_token = sign_id_token(self.get_apple_id_token_payload())
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"access_token": "testac", # Matches OAuth2TestsMixin value
|
||||
"expires_in": 3600,
|
||||
"id_token": id_token,
|
||||
"refresh_token": "testrt", # Matches OAuth2TestsMixin value
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
def get_mocked_response(self):
|
||||
"""
|
||||
Apple is unusual in that the `id_token` contains all the user info
|
||||
so no profile info request is made. However, it does need the
|
||||
public key verification, so this mocked response is the public
|
||||
key request in order to verify the authenticity of the id_token.
|
||||
"""
|
||||
return MockedResponse(
|
||||
200, KEY_SERVER_RESP_JSON, {"content-type": "application/json"}
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "A B"
|
||||
|
||||
def get_complete_parameters(self, auth_request_params):
|
||||
"""
|
||||
Add apple specific response parameters which they include in the
|
||||
form_post response.
|
||||
|
||||
https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms
|
||||
"""
|
||||
params = super().get_complete_parameters(auth_request_params)
|
||||
params.update(
|
||||
{
|
||||
"id_token": sign_id_token(self.get_apple_id_token_payload()),
|
||||
"user": json.dumps(
|
||||
{
|
||||
"email": "private@appleid.apple.com",
|
||||
"name": {
|
||||
"firstName": "A",
|
||||
"lastName": "B",
|
||||
},
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
def login(self, resp_mock, process="login", with_refresh_token=True):
|
||||
resp = self.client.post(
|
||||
reverse(self.provider.id + "_login")
|
||||
+ "?"
|
||||
+ urlencode(dict(process=process))
|
||||
)
|
||||
p = urlparse(resp["location"])
|
||||
q = parse_qs(p.query)
|
||||
complete_url = reverse(self.provider.id + "_callback")
|
||||
self.assertGreater(q["redirect_uri"][0].find(complete_url), 0)
|
||||
response_json = self.get_login_response_json(
|
||||
with_refresh_token=with_refresh_token
|
||||
)
|
||||
with mocked_response(
|
||||
MockedResponse(200, response_json, {"content-type": "application/json"}),
|
||||
resp_mock,
|
||||
):
|
||||
resp = self.client.post(
|
||||
complete_url,
|
||||
data=self.get_complete_parameters(q),
|
||||
)
|
||||
assert reverse("apple_finish_callback") in resp.url
|
||||
|
||||
# Follow the redirect
|
||||
resp = self.client.get(resp.url)
|
||||
|
||||
return resp
|
||||
|
||||
def test_authentication_error(self):
|
||||
"""Override base test because apple posts errors"""
|
||||
resp = self.client.post(
|
||||
reverse(self.provider.id + "_callback"),
|
||||
data={"error": "misc", "state": "testingstate123"},
|
||||
)
|
||||
assert reverse("apple_finish_callback") in resp.url
|
||||
# Follow the redirect
|
||||
resp = self.client.get(resp.url)
|
||||
|
||||
self.assertTemplateUsed(
|
||||
resp,
|
||||
"socialaccount/authentication_error.%s"
|
||||
% getattr(settings, "ACCOUNT_TEMPLATE_EXTENSION", "html"),
|
||||
)
|
||||
|
||||
def test_apple_finish(self):
|
||||
resp = self.login(self.get_mocked_response())
|
||||
|
||||
# Check request generating the response
|
||||
finish_url = reverse("apple_finish_callback")
|
||||
self.assertEqual(resp.request["PATH_INFO"], finish_url)
|
||||
self.assertTrue("state" in resp.request["QUERY_STRING"])
|
||||
self.assertTrue("code" in resp.request["QUERY_STRING"])
|
||||
|
||||
# Check have cookie containing apple session
|
||||
self.assertTrue(APPLE_SESSION_COOKIE_NAME in self.client.cookies)
|
||||
|
||||
# Session should have been cleared
|
||||
apple_session_cookie = self.client.cookies.get(APPLE_SESSION_COOKIE_NAME)
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
SessionStore = engine.SessionStore
|
||||
apple_login_session = SessionStore(apple_session_cookie.value)
|
||||
self.assertEqual(len(apple_login_session.keys()), 0)
|
||||
|
||||
# Check cookie path was correctly set
|
||||
self.assertEqual(apple_session_cookie.get("path"), finish_url)
|
||||
@@ -0,0 +1,16 @@
|
||||
from django.urls import path
|
||||
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import AppleProvider
|
||||
from .views import oauth2_finish_login
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AppleProvider)
|
||||
urlpatterns += [
|
||||
path(
|
||||
AppleProvider.get_slug() + "/login/callback/finish/",
|
||||
oauth2_finish_login,
|
||||
name="apple_finish_callback",
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,148 @@
|
||||
import json
|
||||
from datetime import timedelta
|
||||
|
||||
from django.http import HttpResponseNotAllowed, HttpResponseRedirect
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone
|
||||
from django.utils.http import urlencode
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from allauth.account.internal.decorators import login_not_required
|
||||
from allauth.socialaccount.internal import jwtkit
|
||||
from allauth.socialaccount.models import SocialToken
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
from allauth.utils import build_absolute_uri, get_request_param
|
||||
|
||||
from .apple_session import get_apple_session
|
||||
from .client import AppleOAuth2Client
|
||||
|
||||
|
||||
class AppleOAuth2Adapter(OAuth2Adapter):
|
||||
client_class = AppleOAuth2Client
|
||||
provider_id = "apple"
|
||||
access_token_url = "https://appleid.apple.com/auth/token"
|
||||
authorize_url = "https://appleid.apple.com/auth/authorize"
|
||||
public_key_url = "https://appleid.apple.com/auth/keys"
|
||||
|
||||
@classmethod
|
||||
def get_verified_identity_data(cls, provider, id_token):
|
||||
data = jwtkit.verify_and_decode(
|
||||
credential=id_token,
|
||||
keys_url=cls.public_key_url,
|
||||
issuer="https://appleid.apple.com",
|
||||
audience=provider.get_auds(),
|
||||
lookup_kid=jwtkit.lookup_kid_jwk,
|
||||
)
|
||||
return data
|
||||
|
||||
def parse_token(self, data):
|
||||
token = SocialToken(
|
||||
token=data["access_token"],
|
||||
)
|
||||
token.token_secret = data.get("refresh_token", "")
|
||||
|
||||
expires_in = data.get(self.expires_in_key)
|
||||
if expires_in:
|
||||
token.expires_at = timezone.now() + timedelta(seconds=int(expires_in))
|
||||
|
||||
# `user_data` is a big flat dictionary with the parsed JWT claims
|
||||
# access_tokens, and user info from the apple post.
|
||||
identity_data = AppleOAuth2Adapter.get_verified_identity_data(
|
||||
self.get_provider(), data["id_token"]
|
||||
)
|
||||
token.user_data = {**data, **identity_data}
|
||||
|
||||
return token
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
extra_data = token.user_data
|
||||
login = self.get_provider().sociallogin_from_response(
|
||||
request=request, response=extra_data
|
||||
)
|
||||
login.state["id_token"] = token.user_data
|
||||
|
||||
# We can safely remove the apple login session now
|
||||
# Note: The cookie will remain, but it's set to delete on browser close
|
||||
get_apple_session(request).delete()
|
||||
return login
|
||||
|
||||
def get_user_scope_data(self, request):
|
||||
user_scope_data = request.apple_login_session.get("user", "")
|
||||
try:
|
||||
return json.loads(user_scope_data)
|
||||
except json.JSONDecodeError:
|
||||
# We do not care much about user scope data as it maybe blank
|
||||
# so return blank dictionary instead
|
||||
return {}
|
||||
|
||||
def get_access_token_data(self, request, app, client, pkce_code_verifier=None):
|
||||
"""We need to gather the info from the apple specific login"""
|
||||
apple_session = get_apple_session(request)
|
||||
|
||||
# Exchange `code`
|
||||
code = get_request_param(request, "code")
|
||||
access_token_data = client.get_access_token(
|
||||
code, pkce_code_verifier=pkce_code_verifier
|
||||
)
|
||||
|
||||
id_token = access_token_data.get("id_token", None)
|
||||
# In case of missing id_token in access_token_data
|
||||
if id_token is None:
|
||||
id_token = apple_session.store.get("id_token")
|
||||
|
||||
return {
|
||||
**access_token_data,
|
||||
**self.get_user_scope_data(request),
|
||||
"id_token": id_token,
|
||||
}
|
||||
|
||||
|
||||
@csrf_exempt
|
||||
@login_not_required
|
||||
def apple_post_callback(request, finish_endpoint_name="apple_finish_callback"):
|
||||
"""
|
||||
Apple uses a `form_post` response type, which due to
|
||||
CORS/Samesite-cookie rules means this request cannot access
|
||||
the request since the session cookie is unavailable.
|
||||
|
||||
We work around this by storing the apple response in a
|
||||
separate, temporary session and redirecting to a more normal
|
||||
oauth flow.
|
||||
|
||||
args:
|
||||
finish_endpoint_name (str): The name of a defined URL, which can be
|
||||
overridden in your url configuration if you have more than one
|
||||
callback endpoint.
|
||||
"""
|
||||
if request.method != "POST":
|
||||
return HttpResponseNotAllowed(["POST"])
|
||||
apple_session = get_apple_session(request)
|
||||
|
||||
# Add regular OAuth2 params to the URL - reduces the overrides required
|
||||
keys_to_put_in_url = ["code", "state", "error"]
|
||||
url_params = {}
|
||||
for key in keys_to_put_in_url:
|
||||
value = get_request_param(request, key, "")
|
||||
if value:
|
||||
url_params[key] = value
|
||||
|
||||
# Add other params to the apple_login_session
|
||||
keys_to_save_to_session = ["user", "id_token"]
|
||||
for key in keys_to_save_to_session:
|
||||
apple_session.store[key] = get_request_param(request, key, "")
|
||||
|
||||
url = build_absolute_uri(request, reverse(finish_endpoint_name))
|
||||
response = HttpResponseRedirect(
|
||||
"{url}?{query}".format(url=url, query=urlencode(url_params))
|
||||
)
|
||||
apple_session.save(response)
|
||||
return response
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AppleOAuth2Adapter)
|
||||
oauth2_callback = apple_post_callback
|
||||
oauth2_finish_login = OAuth2CallbackView.adapter_view(AppleOAuth2Adapter)
|
||||
@@ -0,0 +1,23 @@
|
||||
from allauth.socialaccount.providers.asana.views import AsanaOAuth2Adapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AsanaAccount(ProviderAccount):
|
||||
pass
|
||||
|
||||
|
||||
class AsanaProvider(OAuth2Provider):
|
||||
id = "asana"
|
||||
name = "Asana"
|
||||
account_class = AsanaAccount
|
||||
oauth2_adapter_class = AsanaOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["id"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(email=data.get("email"), name=data.get("name"))
|
||||
|
||||
|
||||
provider_classes = [AsanaProvider]
|
||||
@@ -0,0 +1,20 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AsanaProvider
|
||||
|
||||
|
||||
class AsanaTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AsanaProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{"data": {"photo": null, "workspaces": [{"id": 31337, "name": "example.com"},
|
||||
{"id": 3133777, "name": "Personal Projects"}], "email": "test@example.com",
|
||||
"name": "Test Name", "id": 43748387}}""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "test@example.com"
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import AsanaProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AsanaProvider)
|
||||
@@ -0,0 +1,26 @@
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AsanaOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "asana"
|
||||
access_token_url = "https://app.asana.com/-/oauth_token"
|
||||
authorize_url = "https://app.asana.com/-/oauth_authorize"
|
||||
profile_url = "https://app.asana.com/api/1.0/users/me"
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
resp = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(self.profile_url, params={"access_token": token.token})
|
||||
)
|
||||
extra_data = resp.json()["data"]
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AsanaOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AsanaOAuth2Adapter)
|
||||
@@ -0,0 +1,38 @@
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
from .views import AtlassianOAuth2Adapter
|
||||
|
||||
|
||||
class AtlassianAccount(ProviderAccount):
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("picture")
|
||||
|
||||
|
||||
class AtlassianProvider(OAuth2Provider):
|
||||
id = "atlassian"
|
||||
name = "Atlassian"
|
||||
account_class = AtlassianAccount
|
||||
oauth2_adapter_class = AtlassianOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return data["account_id"]
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return {
|
||||
"email": data.get("email"),
|
||||
"name": data.get("name"),
|
||||
"username": data.get("nickname"),
|
||||
"email_verified": data.get("email_verified"),
|
||||
}
|
||||
|
||||
def get_default_scope(self):
|
||||
return ["read:me"]
|
||||
|
||||
def get_auth_params(self):
|
||||
params = super().get_auth_params()
|
||||
params.update({"audience": "api.atlassian.com", "prompt": "consent"})
|
||||
return params
|
||||
|
||||
|
||||
provider_classes = [AtlassianProvider]
|
||||
@@ -0,0 +1,33 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AtlassianProvider
|
||||
|
||||
|
||||
class AtlassianTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AtlassianProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
response_data = """
|
||||
{
|
||||
"account_type": "atlassian",
|
||||
"account_id": "[HEROKU-API-KEY-REMOVED]",
|
||||
"email": "mia@example.com",
|
||||
"email_verified": true,
|
||||
"name": "Mia Krystof",
|
||||
"picture": "https://avatar-management--avatars.us-west-2.prod.public.atl-paas.net/[HEROKU-API-KEY-REMOVED]/1234abcd-9876-54aa-33aa-1234dfsade9487ds",
|
||||
"account_status": "active",
|
||||
"nickname": "mkrystof",
|
||||
"zoneinfo": "Australia/Sydney",
|
||||
"locale": "en-US",
|
||||
"extended_profile": {
|
||||
"job_title": "Designer",
|
||||
"organization": "mia@example.com",
|
||||
"department": "Design team",
|
||||
"location": "Sydney"
|
||||
}
|
||||
}"""
|
||||
return MockedResponse(200, response_data)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "mia@example.com"
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import AtlassianProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AtlassianProvider)
|
||||
@@ -0,0 +1,30 @@
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialToken
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AtlassianOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "atlassian"
|
||||
access_token_url = "https://api.atlassian.com/oauth/token"
|
||||
authorize_url = "https://auth.atlassian.com/authorize"
|
||||
profile_url = "https://api.atlassian.com/me"
|
||||
|
||||
def complete_login(self, request, app, token: SocialToken, **kwargs):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token.token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
response = (
|
||||
get_adapter().get_requests_session().get(self.profile_url, headers=headers)
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return self.get_provider().sociallogin_from_response(request, data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AtlassianOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AtlassianOAuth2Adapter)
|
||||
@@ -0,0 +1,31 @@
|
||||
from allauth.socialaccount.providers.auth0.views import Auth0OAuth2Adapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class Auth0Account(ProviderAccount):
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("picture")
|
||||
|
||||
|
||||
class Auth0Provider(OAuth2Provider):
|
||||
id = "auth0"
|
||||
name = "Auth0"
|
||||
account_class = Auth0Account
|
||||
oauth2_adapter_class = Auth0OAuth2Adapter
|
||||
|
||||
def get_default_scope(self):
|
||||
return ["openid", "profile", "email"]
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["sub"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(
|
||||
email=data.get("email"),
|
||||
username=data.get("username"),
|
||||
name=data.get("name"),
|
||||
)
|
||||
|
||||
|
||||
provider_classes = [Auth0Provider]
|
||||
@@ -0,0 +1,25 @@
|
||||
from allauth.socialaccount.providers.auth0.provider import Auth0Provider
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
|
||||
class Auth0Tests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = Auth0Provider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{
|
||||
"picture": "https://secure.gravatar.com/avatar/123",
|
||||
"email": "mr.bob@your.Auth0.server.example.com",
|
||||
"id": 2,
|
||||
"sub": 2,
|
||||
"identities": [],
|
||||
"name": "Mr Bob"
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "mr.bob@your.Auth0.server.example.com"
|
||||
@@ -0,0 +1,5 @@
|
||||
from allauth.socialaccount.providers.auth0.provider import Auth0Provider
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(Auth0Provider)
|
||||
@@ -0,0 +1,31 @@
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class Auth0OAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "auth0"
|
||||
|
||||
settings = app_settings.PROVIDERS.get(provider_id, {})
|
||||
provider_base_url = settings.get("AUTH0_URL")
|
||||
|
||||
access_token_url = "{0}/oauth/token".format(provider_base_url)
|
||||
authorize_url = "{0}/authorize".format(provider_base_url)
|
||||
profile_url = "{0}/userinfo".format(provider_base_url)
|
||||
|
||||
def complete_login(self, request, app, token, response):
|
||||
extra_data = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(self.profile_url, params={"access_token": token.token})
|
||||
.json()
|
||||
)
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(Auth0OAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(Auth0OAuth2Adapter)
|
||||
@@ -0,0 +1,110 @@
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.providers.authentiq.views import (
|
||||
AuthentiqOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.base import AuthAction, ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class Scope:
|
||||
NAME = "aq:name"
|
||||
EMAIL = "email"
|
||||
PHONE = "phone"
|
||||
ADDRESS = "address"
|
||||
LOCATION = "aq:location"
|
||||
PUSH = "aq:push"
|
||||
|
||||
|
||||
IDENTITY_CLAIMS = frozenset(
|
||||
[
|
||||
"sub",
|
||||
"name",
|
||||
"given_name",
|
||||
"family_name",
|
||||
"middle_name",
|
||||
"nickname",
|
||||
"preferred_username",
|
||||
"profile",
|
||||
"picture",
|
||||
"website",
|
||||
"email",
|
||||
"email_verified",
|
||||
"gender",
|
||||
"birthdate",
|
||||
"zoneinfo",
|
||||
"locale",
|
||||
"phone_number",
|
||||
"phone_number_verified",
|
||||
"address",
|
||||
"updated_at",
|
||||
"aq:location",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AuthentiqAccount(ProviderAccount):
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("profile")
|
||||
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("picture")
|
||||
|
||||
|
||||
class AuthentiqProvider(OAuth2Provider):
|
||||
id = "authentiq"
|
||||
name = "Authentiq"
|
||||
account_class = AuthentiqAccount
|
||||
oauth2_adapter_class = AuthentiqOAuth2Adapter
|
||||
|
||||
def get_scope_from_request(self, request):
|
||||
scope = set(super().get_scope_from_request(request))
|
||||
scope.add("openid")
|
||||
|
||||
if Scope.EMAIL in scope:
|
||||
modifiers = ""
|
||||
if app_settings.EMAIL_REQUIRED:
|
||||
modifiers += "r"
|
||||
if app_settings.EMAIL_VERIFICATION:
|
||||
modifiers += "s"
|
||||
if modifiers:
|
||||
scope.add(Scope.EMAIL + "~" + modifiers)
|
||||
scope.remove(Scope.EMAIL)
|
||||
return list(scope)
|
||||
|
||||
def get_default_scope(self):
|
||||
scope = [Scope.NAME, Scope.PUSH]
|
||||
if app_settings.QUERY_EMAIL:
|
||||
scope.append(Scope.EMAIL)
|
||||
return scope
|
||||
|
||||
def get_auth_params_from_request(self, request, action):
|
||||
ret = super().get_auth_params_from_request(request, action)
|
||||
if action == AuthAction.REAUTHENTICATE:
|
||||
ret["prompt"] = "select_account"
|
||||
return ret
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["sub"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(
|
||||
username=data.get("preferred_username", data.get("given_name")),
|
||||
email=data.get("email"),
|
||||
name=data.get("name"),
|
||||
first_name=data.get("given_name"),
|
||||
last_name=data.get("family_name"),
|
||||
)
|
||||
|
||||
def extract_extra_data(self, data):
|
||||
return {k: v for k, v in data.items() if k in IDENTITY_CLAIMS}
|
||||
|
||||
def extract_email_addresses(self, data):
|
||||
ret = []
|
||||
email = data.get("email")
|
||||
if email and data.get("email_verified"):
|
||||
ret.append(EmailAddress(email=email, verified=True, primary=True))
|
||||
return ret
|
||||
|
||||
|
||||
provider_classes = [AuthentiqProvider]
|
||||
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
|
||||
from django.test.client import RequestFactory
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AuthentiqProvider
|
||||
from .views import AuthentiqOAuth2Adapter
|
||||
|
||||
|
||||
class AuthentiqTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AuthentiqProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
json.dumps(
|
||||
{
|
||||
"sub": "ZLARGMFT1M",
|
||||
"email": "jane@email.invalid",
|
||||
"email_verified": True,
|
||||
"given_name": "Jane",
|
||||
"family_name": "Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "jane@email.invalid"
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_QUERY_EMAIL=False,
|
||||
)
|
||||
def test_default_scopes_no_email(self):
|
||||
scopes = self.provider.get_default_scope()
|
||||
self.assertIn("aq:name", scopes)
|
||||
self.assertNotIn("email", scopes)
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_QUERY_EMAIL=True,
|
||||
)
|
||||
def test_default_scopes_email(self):
|
||||
scopes = self.provider.get_default_scope()
|
||||
self.assertIn("aq:name", scopes)
|
||||
self.assertIn("email", scopes)
|
||||
|
||||
def test_scopes(self):
|
||||
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
|
||||
scopes = self.provider.get_scope_from_request(request)
|
||||
self.assertIn("openid", scopes)
|
||||
self.assertIn("aq:name", scopes)
|
||||
|
||||
def test_dynamic_scopes(self):
|
||||
request = RequestFactory().get(
|
||||
AuthentiqOAuth2Adapter.authorize_url, dict(scope="foo")
|
||||
)
|
||||
scopes = self.provider.get_scope_from_request(request)
|
||||
self.assertIn("openid", scopes)
|
||||
self.assertIn("aq:name", scopes)
|
||||
self.assertIn("foo", scopes)
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_QUERY_EMAIL=True,
|
||||
SOCIALACCOUNT_EMAIL_REQUIRED=True,
|
||||
SOCIALACCOUNT_EMAIL_VERIFICATION=True,
|
||||
)
|
||||
def test_scopes_required_verified_email(self):
|
||||
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
|
||||
scopes = self.provider.get_scope_from_request(request)
|
||||
self.assertIn("email~rs", scopes)
|
||||
self.assertNotIn("email", scopes)
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_QUERY_EMAIL=True,
|
||||
SOCIALACCOUNT_EMAIL_REQUIRED=False,
|
||||
SOCIALACCOUNT_EMAIL_VERIFICATION=True,
|
||||
)
|
||||
def test_scopes_optional_verified_email(self):
|
||||
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
|
||||
scopes = self.provider.get_scope_from_request(request)
|
||||
self.assertIn("email~s", scopes)
|
||||
self.assertNotIn("email", scopes)
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_QUERY_EMAIL=True,
|
||||
SOCIALACCOUNT_EMAIL_REQUIRED=True,
|
||||
SOCIALACCOUNT_EMAIL_VERIFICATION=False,
|
||||
)
|
||||
def test_scopes_required_email(self):
|
||||
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
|
||||
scopes = self.provider.get_scope_from_request(request)
|
||||
self.assertIn("email~r", scopes)
|
||||
self.assertNotIn("email", scopes)
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_QUERY_EMAIL=True,
|
||||
SOCIALACCOUNT_EMAIL_REQUIRED=False,
|
||||
SOCIALACCOUNT_EMAIL_VERIFICATION=False,
|
||||
)
|
||||
def test_scopes_optional_email(self):
|
||||
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
|
||||
scopes = self.provider.get_scope_from_request(request)
|
||||
self.assertIn("email", scopes)
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import AuthentiqProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AuthentiqProvider)
|
||||
@@ -0,0 +1,35 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AuthentiqOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "authentiq"
|
||||
|
||||
settings = app_settings.PROVIDERS.get(provider_id, {})
|
||||
|
||||
provider_url = settings.get("PROVIDER_URL", "https://connect.authentiq.io/")
|
||||
if not provider_url.endswith("/"):
|
||||
provider_url += "/"
|
||||
|
||||
access_token_url = urljoin(provider_url, "token")
|
||||
authorize_url = urljoin(provider_url, "authorize")
|
||||
profile_url = urljoin(provider_url, "userinfo")
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
auth = {"Authorization": "Bearer " + token.token}
|
||||
resp = get_adapter().get_requests_session().get(self.profile_url, headers=auth)
|
||||
resp.raise_for_status()
|
||||
extra_data = resp.json()
|
||||
login = self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
return login
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AuthentiqOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AuthentiqOAuth2Adapter)
|
||||
@@ -0,0 +1,34 @@
|
||||
from allauth.socialaccount.providers.baidu.views import BaiduOAuth2Adapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class BaiduAccount(ProviderAccount):
|
||||
def get_profile_url(self):
|
||||
return "http://www.baidu.com/p/" + self.account.extra_data.get("uname")
|
||||
|
||||
def get_avatar_url(self):
|
||||
return (
|
||||
"http://tb.himg.baidu.com/sys/portraitn/item/"
|
||||
+ self.account.extra_data.get("portrait")
|
||||
)
|
||||
|
||||
def to_str(self):
|
||||
dflt = super(BaiduAccount, self).to_str()
|
||||
return self.account.extra_data.get("uname", dflt)
|
||||
|
||||
|
||||
class BaiduProvider(OAuth2Provider):
|
||||
id = "baidu"
|
||||
name = "Baidu"
|
||||
account_class = BaiduAccount
|
||||
oauth2_adapter_class = BaiduOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return data["uid"]
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(username=data.get("uid"), name=data.get("uname"))
|
||||
|
||||
|
||||
provider_classes = [BaiduProvider]
|
||||
@@ -0,0 +1,19 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import BaiduProvider
|
||||
|
||||
|
||||
class BaiduTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = BaiduProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{"portrait": "78c0e9839de59bbde7859ccf43",
|
||||
"uname": "\u90dd\u56fd\u715c", "uid": "3225892368"}""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "\u90dd\u56fd\u715c"
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import BaiduProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(BaiduProvider)
|
||||
@@ -0,0 +1,28 @@
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class BaiduOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "baidu"
|
||||
access_token_url = "https://openapi.baidu.com/oauth/2.0/token"
|
||||
authorize_url = "https://openapi.baidu.com/oauth/2.0/authorize"
|
||||
profile_url = (
|
||||
"https://openapi.baidu.com/rest/2.0/passport/users/getLoggedInUser" # noqa
|
||||
)
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
resp = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(self.profile_url, params={"access_token": token.token})
|
||||
)
|
||||
extra_data = resp.json()
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(BaiduOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(BaiduOAuth2Adapter)
|
||||
@@ -0,0 +1,2 @@
|
||||
from .constants import AuthAction, AuthError, AuthProcess # noqa
|
||||
from .provider import Provider, ProviderAccount, ProviderException # noqa
|
||||
@@ -0,0 +1,16 @@
|
||||
class AuthProcess:
|
||||
LOGIN = "login"
|
||||
CONNECT = "connect"
|
||||
REDIRECT = "redirect"
|
||||
|
||||
|
||||
class AuthAction:
|
||||
AUTHENTICATE = "authenticate"
|
||||
REAUTHENTICATE = "reauthenticate"
|
||||
REREQUEST = "rerequest"
|
||||
|
||||
|
||||
class AuthError:
|
||||
UNKNOWN = "unknown"
|
||||
CANCELLED = "cancelled" # Cancelled on request of user
|
||||
DENIED = "denied" # Denied by server
|
||||
@@ -0,0 +1,363 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
|
||||
|
||||
from allauth.account.utils import get_next_redirect_url, get_request_param
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.internal import statekit
|
||||
from allauth.socialaccount.providers.base.constants import AuthProcess
|
||||
|
||||
|
||||
class ProviderException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Provider:
|
||||
name: str # Provided by subclasses
|
||||
id: str # Provided by subclasses
|
||||
slug: Optional[str] = None # Provided by subclasses
|
||||
|
||||
uses_apps = True
|
||||
supports_redirect = False
|
||||
# Indicates whether or not this provider supports logging in by posting an
|
||||
# access/id-token.
|
||||
supports_token_authentication = False
|
||||
|
||||
def __init__(self, request, app=None):
|
||||
self.request = request
|
||||
if self.uses_apps and app is None:
|
||||
raise ValueError("missing: app")
|
||||
self.app = app
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@classmethod
|
||||
def get_slug(cls):
|
||||
return cls.slug or cls.id
|
||||
|
||||
def get_login_url(self, request, next=None, **kwargs):
|
||||
"""
|
||||
Builds the URL to redirect to when initiating a login for this
|
||||
provider.
|
||||
"""
|
||||
raise NotImplementedError("get_login_url() for " + self.name)
|
||||
|
||||
def redirect_from_request(self, request):
|
||||
kwargs = self.get_redirect_from_request_kwargs(request)
|
||||
return self.redirect(request, **kwargs)
|
||||
|
||||
def get_redirect_from_request_kwargs(self, request):
|
||||
kwargs = {}
|
||||
next_url = get_next_redirect_url(request)
|
||||
if next_url:
|
||||
kwargs["next_url"] = next_url
|
||||
kwargs["process"] = get_request_param(request, "process", AuthProcess.LOGIN)
|
||||
return kwargs
|
||||
|
||||
def redirect(self, request, process, next_url=None, data=None, **kwargs):
|
||||
"""
|
||||
Initiate a redirect to the provider.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def verify_token(self, request, token):
|
||||
"""
|
||||
Verifies the token, returning a `SocialLogin` instance when valid.
|
||||
Raises a `ValidationError` otherwise.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def media_js(self, request):
|
||||
"""
|
||||
Some providers may require extra scripts (e.g. a Facebook connect)
|
||||
"""
|
||||
return ""
|
||||
|
||||
def wrap_account(self, social_account):
|
||||
return self.account_class(social_account)
|
||||
|
||||
def get_settings(self):
|
||||
return app_settings.PROVIDERS.get(self.id, {})
|
||||
|
||||
def sociallogin_from_response(self, request, response):
|
||||
"""
|
||||
Instantiates and populates a `SocialLogin` model based on the data
|
||||
retrieved in `response`. The method does NOT save the model to the
|
||||
DB.
|
||||
|
||||
Data for `SocialLogin` will be extracted from `response` with the
|
||||
help of the `.extract_uid()`, `.extract_extra_data()`,
|
||||
`.extract_common_fields()`, and `.extract_email_addresses()`
|
||||
methods.
|
||||
|
||||
:param request: a Django `HttpRequest` object.
|
||||
:param response: object retrieved via the callback response of the
|
||||
social auth provider.
|
||||
:return: A populated instance of the `SocialLogin` model (unsaved).
|
||||
"""
|
||||
# NOTE: Avoid loading models at top due to registry boot...
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialAccount, SocialLogin
|
||||
|
||||
adapter = get_adapter()
|
||||
uid = self.extract_uid(response)
|
||||
if not isinstance(uid, str):
|
||||
raise ValueError(f"uid must be a string: {repr(uid)}")
|
||||
if len(uid) > app_settings.UID_MAX_LENGTH:
|
||||
raise ImproperlyConfigured(
|
||||
f"SOCIALACCOUNT_UID_MAX_LENGTH too small (<{len(uid)})"
|
||||
)
|
||||
if not uid:
|
||||
raise ValueError("uid must be a non-empty string")
|
||||
|
||||
extra_data = self.extract_extra_data(response)
|
||||
common_fields = self.extract_common_fields(response)
|
||||
socialaccount = SocialAccount(
|
||||
extra_data=extra_data,
|
||||
uid=uid,
|
||||
provider=self.sub_id,
|
||||
)
|
||||
email_addresses = self.extract_email_addresses(response)
|
||||
email = self.cleanup_email_addresses(
|
||||
common_fields.get("email"),
|
||||
email_addresses,
|
||||
email_verified=common_fields.get("email_verified"),
|
||||
)
|
||||
if email:
|
||||
common_fields["email"] = email
|
||||
sociallogin = SocialLogin(
|
||||
account=socialaccount, email_addresses=email_addresses
|
||||
)
|
||||
user = sociallogin.user = adapter.new_user(request, sociallogin)
|
||||
user.set_unusable_password()
|
||||
adapter.populate_user(request, sociallogin, common_fields)
|
||||
return sociallogin
|
||||
|
||||
def extract_uid(self, data):
|
||||
"""
|
||||
Extracts the unique user ID from `data`
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The provider must implement the `extract_uid()` method"
|
||||
)
|
||||
|
||||
def extract_extra_data(self, data):
|
||||
"""
|
||||
Extracts fields from `data` that will be stored in
|
||||
`SocialAccount`'s `extra_data` JSONField, such as email address, first
|
||||
name, last name, and phone number.
|
||||
|
||||
:return: any JSON-serializable Python structure.
|
||||
"""
|
||||
return data
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
"""
|
||||
Extracts fields from `data` that will be used to populate the
|
||||
`User` model in the `SOCIALACCOUNT_ADAPTER`'s `populate_user()`
|
||||
method.
|
||||
|
||||
For example:
|
||||
|
||||
{'first_name': 'John'}
|
||||
|
||||
:return: dictionary of key-value pairs.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def cleanup_email_addresses(
|
||||
self, email: Optional[str], addresses: list, email_verified: bool = False
|
||||
) -> Optional[str]:
|
||||
# Avoid loading models before adapters have been registered.
|
||||
from allauth.account.models import EmailAddress
|
||||
|
||||
# Move user.email over to EmailAddress
|
||||
if email and email.lower() not in [a.email.lower() for a in addresses]:
|
||||
addresses.insert(
|
||||
0,
|
||||
EmailAddress(email=email, verified=bool(email_verified), primary=True),
|
||||
)
|
||||
# Force verified emails
|
||||
adapter = get_adapter()
|
||||
for address in addresses:
|
||||
if adapter.is_email_verified(self, address.email):
|
||||
address.verified = True
|
||||
|
||||
# Sort in order of importance (primary, verified...)
|
||||
addresses.sort(key=lambda a: (a.primary, a.verified, a.email), reverse=True)
|
||||
if not email and addresses:
|
||||
email = addresses[0].email
|
||||
return email
|
||||
|
||||
def extract_email_addresses(self, data):
|
||||
"""
|
||||
For example:
|
||||
|
||||
[EmailAddress(email='john@example.com',
|
||||
verified=True,
|
||||
primary=True)]
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def get_package(cls):
|
||||
pkg = getattr(cls, "package", None)
|
||||
if not pkg:
|
||||
pkg = cls.__module__.rpartition(".")[0]
|
||||
return pkg
|
||||
|
||||
def stash_redirect_state(
|
||||
self, request, process, next_url=None, data=None, state_id=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Stashes state, returning a (random) state ID using which the state
|
||||
can be looked up later. Application specific state is stored separately
|
||||
from (core) allauth state such as `process` and `**kwargs`.
|
||||
"""
|
||||
state = {"process": process, "data": data, **kwargs}
|
||||
if next_url:
|
||||
state["next"] = next_url
|
||||
return statekit.stash_state(request, state, state_id=state_id)
|
||||
|
||||
def unstash_redirect_state(self, request, state_id):
|
||||
state = statekit.unstash_state(request, state_id)
|
||||
if state is None:
|
||||
raise PermissionDenied()
|
||||
return state
|
||||
|
||||
@property
|
||||
def sub_id(self) -> str:
|
||||
return (
|
||||
(self.app.provider_id or self.app.provider) if self.uses_apps else self.id
|
||||
)
|
||||
|
||||
|
||||
class ProviderAccount:
|
||||
def __init__(self, social_account):
|
||||
self.account = social_account
|
||||
|
||||
def get_profile_url(self):
|
||||
return None
|
||||
|
||||
def get_avatar_url(self):
|
||||
return None
|
||||
|
||||
def get_brand(self):
|
||||
"""
|
||||
Returns a dict containing an id and name identifying the
|
||||
brand. Useful when displaying logos next to accounts in
|
||||
templates.
|
||||
|
||||
For most providers, these are identical to the provider. For
|
||||
OpenID however, the brand can derived from the OpenID identity
|
||||
url.
|
||||
"""
|
||||
provider = self.account.get_provider()
|
||||
return dict(id=provider.id, name=provider.name)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_str()
|
||||
|
||||
def get_user_data(self) -> Optional[Dict]:
|
||||
"""Typically, the ``extra_data`` directly contains user related keys.
|
||||
For some providers, however, they are nested below a different key. In
|
||||
that case, you can override this method so that the base ``__str__()``
|
||||
will still be able to find the data.
|
||||
"""
|
||||
ret = self.account.extra_data
|
||||
if not isinstance(ret, dict):
|
||||
ret = None
|
||||
return ret
|
||||
|
||||
def to_str(self):
|
||||
"""
|
||||
Returns string representation of this social account. This is the
|
||||
unique identifier of the account, such as its username or its email
|
||||
address. It should be meaningful to human beings, which means a numeric
|
||||
ID number is rarely the appropriate representation here.
|
||||
|
||||
Subclasses are meant to override this method.
|
||||
|
||||
Users will see the string representation of their social accounts in
|
||||
the page rendered by the allauth.socialaccount.views.connections view.
|
||||
|
||||
The following code did not use to work in the past due to py2
|
||||
compatibility:
|
||||
|
||||
class GoogleAccount(ProviderAccount):
|
||||
def __str__(self):
|
||||
dflt = super(GoogleAccount, self).__str__()
|
||||
return self.account.extra_data.get('name', dflt)
|
||||
|
||||
So we have this method `to_str` that can be overridden in a conventional
|
||||
fashion, without having to worry about it.
|
||||
"""
|
||||
user_data = self.get_user_data()
|
||||
if user_data:
|
||||
combi_values = {}
|
||||
tbl = [
|
||||
# Prefer username -- it's the most human recognizable & unique.
|
||||
(
|
||||
None,
|
||||
[
|
||||
"username",
|
||||
"userName",
|
||||
"user_name",
|
||||
"login",
|
||||
"handle",
|
||||
],
|
||||
),
|
||||
# Second best is email
|
||||
(None, ["email", "Email", "mail", "email_address"]),
|
||||
(
|
||||
None,
|
||||
[
|
||||
"name",
|
||||
"display_name",
|
||||
"displayName",
|
||||
"Display_Name",
|
||||
"nickname",
|
||||
],
|
||||
),
|
||||
# Use the full name
|
||||
(None, ["full_name", "fullName"]),
|
||||
# Alternatively, try to assemble a full name ourselves.
|
||||
(
|
||||
"first_name",
|
||||
[
|
||||
"first_name",
|
||||
"firstname",
|
||||
"firstName",
|
||||
"First_Name",
|
||||
"given_name",
|
||||
"givenName",
|
||||
],
|
||||
),
|
||||
(
|
||||
"last_name",
|
||||
[
|
||||
"last_name",
|
||||
"lastname",
|
||||
"lastName",
|
||||
"Last_Name",
|
||||
"family_name",
|
||||
"familyName",
|
||||
"surname",
|
||||
],
|
||||
),
|
||||
]
|
||||
for store_as, variants in tbl:
|
||||
for key in variants:
|
||||
value = user_data.get(key)
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if value and not store_as:
|
||||
return value
|
||||
combi_values[store_as] = value
|
||||
first_name = combi_values.get("first_name") or ""
|
||||
last_name = combi_values.get("last_name") or ""
|
||||
if first_name or last_name:
|
||||
return f"{first_name} {last_name}".strip()
|
||||
return self.get_brand()["name"]
|
||||
@@ -0,0 +1,16 @@
|
||||
from django.shortcuts import render
|
||||
|
||||
from allauth.account import app_settings as account_app_settings
|
||||
from allauth.socialaccount import app_settings
|
||||
|
||||
|
||||
def respond_to_login_on_get(request, provider):
|
||||
if (not app_settings.LOGIN_ON_GET) and request.method == "GET":
|
||||
return render(
|
||||
request,
|
||||
"socialaccount/login." + account_app_settings.TEMPLATE_EXTENSION,
|
||||
{
|
||||
"provider": provider,
|
||||
"process": request.GET.get("process"),
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from django.http import Http404
|
||||
from django.views import View
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.base.utils import respond_to_login_on_get
|
||||
|
||||
|
||||
class BaseLoginView(View):
|
||||
provider_id: str # Set in subclasses
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if allauth_settings.HEADLESS_ONLY:
|
||||
raise Http404
|
||||
provider = self.get_provider()
|
||||
resp = respond_to_login_on_get(request, provider)
|
||||
if resp:
|
||||
return resp
|
||||
return provider.redirect_from_request(request)
|
||||
|
||||
def get_provider(self):
|
||||
provider = get_adapter().get_provider(self.request, self.provider_id)
|
||||
return provider
|
||||
@@ -0,0 +1,42 @@
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.basecamp.views import (
|
||||
BasecampOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class BasecampAccount(ProviderAccount):
|
||||
def get_avatar_url(self):
|
||||
return None
|
||||
|
||||
def get_user_data(self):
|
||||
return self.account.extra_data.get("identity", {})
|
||||
|
||||
|
||||
class BasecampProvider(OAuth2Provider):
|
||||
id = "basecamp"
|
||||
name = "Basecamp"
|
||||
account_class = BasecampAccount
|
||||
oauth2_adapter_class = BasecampOAuth2Adapter
|
||||
|
||||
def get_auth_params_from_request(self, request, action):
|
||||
data = super().get_auth_params_from_request(request, action)
|
||||
data["type"] = "web_server"
|
||||
return data
|
||||
|
||||
def extract_uid(self, data):
|
||||
data = data["identity"]
|
||||
return str(data["id"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
data = data["identity"]
|
||||
return dict(
|
||||
email=data.get("email_address"),
|
||||
username=data.get("email_address"),
|
||||
first_name=data.get("first_name"),
|
||||
last_name=data.get("last_name"),
|
||||
name="%s %s" % (data.get("first_name"), data.get("last_name")),
|
||||
)
|
||||
|
||||
|
||||
provider_classes = [BasecampProvider]
|
||||
@@ -0,0 +1,46 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import BasecampProvider
|
||||
|
||||
|
||||
class BasecampTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = BasecampProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{
|
||||
"expires_at": "2012-03-22T16:56:48-05:00",
|
||||
"identity": {
|
||||
"id": 9999999,
|
||||
"first_name": "Jason Fried",
|
||||
"last_name": "Jason Fried",
|
||||
"email_address": "jason@example.com"
|
||||
},
|
||||
"accounts": [
|
||||
{
|
||||
"product": "bcx",
|
||||
"id": 88888888,
|
||||
"name": "Wayne Enterprises, Ltd.",
|
||||
"href": "https://basecamp.com/88888888/api/v1"
|
||||
},
|
||||
{
|
||||
"product": "bcx",
|
||||
"id": 77777777,
|
||||
"name": "Veidt, Inc",
|
||||
"href": "https://basecamp.com/77777777/api/v1"
|
||||
},
|
||||
{
|
||||
"product": "campfire",
|
||||
"id": 44444444,
|
||||
"name": "Acme Shipping Co.",
|
||||
"href": "https://acme4444444.campfirenow.com"
|
||||
}
|
||||
]
|
||||
}""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "jason@example.com"
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import BasecampProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(BasecampProvider)
|
||||
@@ -0,0 +1,27 @@
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class BasecampOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "basecamp"
|
||||
access_token_url = (
|
||||
"https://launchpad.37signals.com/authorization/token?type=web_server" # noqa
|
||||
)
|
||||
authorize_url = "https://launchpad.37signals.com/authorization/new"
|
||||
profile_url = "https://launchpad.37signals.com/authorization.json"
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
headers = {"Authorization": "Bearer {0}".format(token.token)}
|
||||
resp = (
|
||||
get_adapter().get_requests_session().get(self.profile_url, headers=headers)
|
||||
)
|
||||
extra_data = resp.json()
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(BasecampOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(BasecampOAuth2Adapter)
|
||||
@@ -0,0 +1,35 @@
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.battlenet.views import (
|
||||
BattleNetOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class BattleNetAccount(ProviderAccount):
|
||||
def to_str(self):
|
||||
battletag = self.account.extra_data.get("battletag")
|
||||
return battletag or super(BattleNetAccount, self).to_str()
|
||||
|
||||
|
||||
class BattleNetProvider(OAuth2Provider):
|
||||
id = "battlenet"
|
||||
name = "Battle.net"
|
||||
account_class = BattleNetAccount
|
||||
oauth2_adapter_class = BattleNetOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
uid = str(data["id"])
|
||||
if data.get("region") == "cn":
|
||||
# China is on a different account system. UIDs can clash with US.
|
||||
return uid + "-cn"
|
||||
return uid
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return {"username": data.get("battletag")}
|
||||
|
||||
def get_default_scope(self):
|
||||
# Optional scopes: "sc2.profile", "wow.profile"
|
||||
return []
|
||||
|
||||
|
||||
provider_classes = [BattleNetProvider]
|
||||
@@ -0,0 +1,68 @@
|
||||
import json
|
||||
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import BattleNetProvider
|
||||
from .views import _check_errors
|
||||
|
||||
|
||||
class BattleNetTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = BattleNetProvider.id
|
||||
_uid = 123456789
|
||||
_battletag = "LuckyDragon#1953"
|
||||
|
||||
def get_mocked_response(self):
|
||||
data = {"battletag": self._battletag, "id": self._uid}
|
||||
return MockedResponse(200, json.dumps(data))
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return self._battletag
|
||||
|
||||
def test_valid_response_no_battletag(self):
|
||||
data = {"id": 12345}
|
||||
response = MockedResponse(200, json.dumps(data))
|
||||
self.assertEqual(_check_errors(response), data)
|
||||
|
||||
def test_invalid_data(self):
|
||||
response = MockedResponse(200, json.dumps({}))
|
||||
with self.assertRaises(OAuth2Error):
|
||||
# No id, raises
|
||||
_check_errors(response)
|
||||
|
||||
def test_profile_invalid_response(self):
|
||||
data = {"code": 403, "type": "Forbidden", "detail": "Account Inactive"}
|
||||
response = MockedResponse(401, json.dumps(data))
|
||||
|
||||
with self.assertRaises(OAuth2Error):
|
||||
# no id, 4xx code, raises
|
||||
_check_errors(response)
|
||||
|
||||
def test_error_response(self):
|
||||
body = json.dumps({"error": "invalid_token"})
|
||||
response = MockedResponse(400, body)
|
||||
|
||||
with self.assertRaises(OAuth2Error):
|
||||
# no id, 4xx code, raises
|
||||
_check_errors(response)
|
||||
|
||||
def test_service_not_found(self):
|
||||
response = MockedResponse(596, "<h1>596 Service Not Found</h1>")
|
||||
with self.assertRaises(OAuth2Error):
|
||||
# bad json, 5xx code, raises
|
||||
_check_errors(response)
|
||||
|
||||
def test_invalid_response(self):
|
||||
response = MockedResponse(200, "invalid json data")
|
||||
with self.assertRaises(OAuth2Error):
|
||||
# bad json, raises
|
||||
_check_errors(response)
|
||||
|
||||
def test_extra_data(self):
|
||||
self.login(self.get_mocked_response())
|
||||
account = SocialAccount.objects.get(uid=str(self._uid))
|
||||
self.assertEqual(account.extra_data["battletag"], self._battletag)
|
||||
self.assertEqual(account.extra_data["id"], self._uid)
|
||||
self.assertEqual(account.extra_data["region"], "us")
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import BattleNetProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(BattleNetProvider)
|
||||
@@ -0,0 +1,4 @@
|
||||
from django.core.validators import RegexValidator
|
||||
|
||||
|
||||
BattletagUsernameValidator = RegexValidator(r"^[\w.]+#\d+$")
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
OAuth2 Adapter for Battle.net
|
||||
|
||||
Resources:
|
||||
|
||||
* Battle.net OAuth2 documentation:
|
||||
https://dev.battle.net/docs/read/oauth
|
||||
* Battle.net API documentation:
|
||||
https://dev.battle.net/io-docs
|
||||
* Original announcement:
|
||||
https://us.battle.net/en/forum/topic/13979297799
|
||||
* The Battle.net API forum:
|
||||
https://us.battle.net/en/forum/15051532/
|
||||
"""
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class Region:
|
||||
APAC = "apac"
|
||||
CN = "cn"
|
||||
EU = "eu"
|
||||
KR = "kr"
|
||||
SEA = "sea"
|
||||
TW = "tw"
|
||||
US = "us"
|
||||
|
||||
|
||||
def _check_errors(response):
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError: # JSONDecodeError on py3
|
||||
raise OAuth2Error("Invalid JSON from Battle.net API: %r" % (response.text))
|
||||
|
||||
if response.status_code >= 400 or "error" in data:
|
||||
# For errors, we expect the following format:
|
||||
# {"error": "error_name", "error_description": "Oops!"}
|
||||
# For example, if the token is not valid, we will get:
|
||||
# {
|
||||
# "error": "invalid_token",
|
||||
# "error_description": "Invalid access token: abcdef123456"
|
||||
# }
|
||||
# For the profile API, this may also look like the following:
|
||||
# {"code": 403, "type": "Forbidden", "detail": "Account Inactive"}
|
||||
error = data.get("error", "") or data.get("type", "")
|
||||
desc = data.get("error_description", "") or data.get("detail", "")
|
||||
|
||||
raise OAuth2Error("Battle.net error: %s (%s)" % (error, desc))
|
||||
|
||||
# The expected output from the API follows this format:
|
||||
# {"id": 12345, "battletag": "Example#12345"}
|
||||
# The battletag is optional.
|
||||
if "id" not in data:
|
||||
# If the id is not present, the output is not usable (no UID)
|
||||
raise OAuth2Error("Invalid data from Battle.net API: %r" % (data))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class BattleNetOAuth2Adapter(OAuth2Adapter):
|
||||
"""
|
||||
OAuth2 adapter for Battle.net
|
||||
https://dev.battle.net/docs/read/oauth
|
||||
|
||||
Region is set to us by default, but can be overridden with the
|
||||
`region` GET parameter when performing a login.
|
||||
Can be any of eu, us, kr, sea, tw or cn
|
||||
"""
|
||||
|
||||
provider_id = "battlenet"
|
||||
|
||||
valid_regions = (
|
||||
Region.APAC,
|
||||
Region.CN,
|
||||
Region.EU,
|
||||
Region.KR,
|
||||
Region.SEA,
|
||||
Region.TW,
|
||||
Region.US,
|
||||
)
|
||||
|
||||
@property
|
||||
def battlenet_region(self):
|
||||
# Check by URI query parameter first.
|
||||
region = self.request.GET.get("region", "").lower()
|
||||
if region == Region.SEA:
|
||||
# South-East Asia uses the same region as US everywhere
|
||||
return Region.US
|
||||
if region in self.valid_regions:
|
||||
return region
|
||||
|
||||
# Second, check the provider settings.
|
||||
region = (
|
||||
getattr(settings, "SOCIALACCOUNT_PROVIDERS", {})
|
||||
.get("battlenet", {})
|
||||
.get("REGION", "us")
|
||||
)
|
||||
|
||||
if region in self.valid_regions:
|
||||
return region
|
||||
|
||||
return Region.US
|
||||
|
||||
@property
|
||||
def battlenet_base_url(self):
|
||||
region = self.battlenet_region
|
||||
if region == Region.CN:
|
||||
return "https://oauth.battlenet.com.cn"
|
||||
return "https://oauth.battle.net"
|
||||
|
||||
@property
|
||||
def access_token_url(self):
|
||||
return self.battlenet_base_url + "/token"
|
||||
|
||||
@property
|
||||
def authorize_url(self):
|
||||
return self.battlenet_base_url + "/authorize"
|
||||
|
||||
@property
|
||||
def profile_url(self):
|
||||
return self.battlenet_base_url + "/userinfo"
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
response = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(
|
||||
self.profile_url,
|
||||
headers={"authorization": "Bearer %s" % (token.token)},
|
||||
)
|
||||
)
|
||||
data = _check_errors(response)
|
||||
|
||||
# Add the region to the data so that we can have it in `extra_data`.
|
||||
data["region"] = self.battlenet_region
|
||||
|
||||
return self.get_provider().sociallogin_from_response(request, data)
|
||||
|
||||
def get_callback_url(self, request, app):
|
||||
r = super(BattleNetOAuth2Adapter, self).get_callback_url(request, app)
|
||||
region = request.GET.get("region", "").lower()
|
||||
# Pass the region down to the callback URL if we specified it
|
||||
if region and region in self.valid_regions:
|
||||
r += "?region=%s" % (region)
|
||||
return r
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(BattleNetOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(BattleNetOAuth2Adapter)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user