okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 01c6004a79
commit 27eb239e97
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,378 @@
import functools
import warnings
from django.core.exceptions import (
ImproperlyConfigured,
MultipleObjectsReturned,
)
from django.db.models import Q
from django.urls import reverse
from django.utils.crypto import get_random_string
from django.utils.translation import gettext_lazy as _
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.utils import user_email, user_field, user_username
from allauth.core.internal.adapter import BaseAdapter
from allauth.utils import (
deserialize_instance,
import_attribute,
serialize_instance,
valid_email_or_none,
)
from . import app_settings
class DefaultSocialAccountAdapter(BaseAdapter):
"""The adapter class allows you to override various functionality of the
``allauth.socialaccount`` app. To do so, point ``settings.SOCIALACCOUNT_ADAPTER`` to
your own class that derives from ``DefaultSocialAccountAdapter`` and override the
behavior by altering the implementation of the methods according to your own
needs.
"""
error_messages = {
"email_taken": _(
"An account already exists with this email address."
" Please sign in to that account first, then connect"
" your %s account."
),
"invalid_token": _("Invalid token."),
"no_password": _("Your account has no password set up."),
"no_verified_email": _("Your account has no verified email address."),
"disconnect_last": _(
"You cannot disconnect your last remaining third-party account."
),
"connected_other": _(
"The third-party account is already connected to a different account."
),
}
def pre_social_login(self, request, sociallogin):
"""
Invoked just after a user successfully authenticates via a
social provider, but before the login is actually processed
(and before the pre_social_login signal is emitted).
You can use this hook to intervene, e.g. abort the login by
raising an ImmediateHttpResponse
Why both an adapter hook and the signal? Intervening in
e.g. the flow from within a signal handler is bad -- multiple
handlers may be active and are executed in undetermined order.
"""
pass
def on_authentication_error(
self,
request,
provider,
error=None,
exception=None,
extra_context=None,
):
"""
Invoked when there is an error in the authentication cycle. In this
case, pre_social_login will not be reached.
You can use this hook to intervene, e.g. redirect to an
educational flow by raising an ImmediateHttpResponse.
"""
if hasattr(self, "authentication_error"):
warnings.warn(
"adapter.authentication_error() is deprecated, use adapter.on_authentication_error()"
)
self.authentication_error(
request,
provider.id,
error=error,
exception=exception,
extra_context=extra_context,
)
def new_user(self, request, sociallogin):
"""
Instantiates a new User instance.
"""
return get_account_adapter().new_user(request)
def save_user(self, request, sociallogin, form=None):
"""
Saves a newly signed up social login. In case of auto-signup,
the signup form is not available.
"""
u = sociallogin.user
u.set_unusable_password()
if form:
get_account_adapter().save_user(request, u, form)
else:
get_account_adapter().populate_username(request, u)
sociallogin.save(request)
return u
def populate_user(self, request, sociallogin, data):
"""
Hook that can be used to further populate the user instance.
For convenience, we populate several common fields.
Note that the user instance being populated represents a
suggested User instance that represents the social user that is
in the process of being logged in.
The User instance need not be completely valid and conflict
free. For example, verifying whether or not the username
already exists, is not a responsibility.
"""
username = data.get("username")
first_name = data.get("first_name")
last_name = data.get("last_name")
email = data.get("email")
name = data.get("name")
user = sociallogin.user
user_username(user, username or "")
user_email(user, valid_email_or_none(email) or "")
name_parts = (name or "").partition(" ")
user_field(user, "first_name", first_name or name_parts[0])
user_field(user, "last_name", last_name or name_parts[2])
return user
def get_connect_redirect_url(self, request, socialaccount):
"""
Returns the default URL to redirect to after successfully
connecting a social account.
"""
url = reverse("socialaccount_connections")
return url
def validate_disconnect(self, account, accounts) -> None:
"""
Validate whether or not the socialaccount account can be
safely disconnected.
"""
pass
def is_auto_signup_allowed(self, request, sociallogin):
# If email is specified, check for duplicate and if so, no auto signup.
auto_signup = app_settings.AUTO_SIGNUP
return auto_signup
def is_open_for_signup(self, request, sociallogin):
"""
Checks whether or not the site is open for signups.
Next to simply returning True/False you can also intervene the
regular flow by raising an ImmediateHttpResponse
"""
return get_account_adapter(request).is_open_for_signup(request)
def get_signup_form_initial_data(self, sociallogin):
user = sociallogin.user
initial = {
"email": user_email(user) or "",
"username": user_username(user) or "",
"first_name": user_field(user, "first_name") or "",
"last_name": user_field(user, "last_name") or "",
}
return initial
def deserialize_instance(self, model, data):
return deserialize_instance(model, data)
def serialize_instance(self, instance):
return serialize_instance(instance)
def list_providers(self, request):
from allauth.socialaccount.providers import registry
ret = []
provider_classes = registry.get_class_list()
apps = self.list_apps(request)
apps_map = {}
for app in apps:
apps_map.setdefault(app.provider, []).append(app)
for provider_class in provider_classes:
provider_apps = apps_map.get(provider_class.id, [])
if not provider_apps:
if provider_class.uses_apps:
continue
provider_apps = [None]
for app in provider_apps:
provider = provider_class(request=request, app=app)
ret.append(provider)
return ret
def get_provider(self, request, provider, client_id=None):
"""Looks up a `provider`, supporting subproviders by looking up by
`provider_id`.
"""
from allauth.socialaccount.providers import registry
provider_class = registry.get_class(provider)
if provider_class is None or provider_class.uses_apps:
app = self.get_app(request, provider=provider, client_id=client_id)
if not provider_class:
# In this case, the `provider` argument passed was a
# `provider_id`.
provider_class = registry.get_class(app.provider)
if not provider_class:
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
return provider_class(request, app=app)
elif provider_class:
assert not provider_class.uses_apps
return provider_class(request, app=None)
else:
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
def list_apps(self, request, provider=None, client_id=None):
"""SocialApp's can be setup in the database, or, via
`settings.SOCIALACCOUNT_PROVIDERS`. This methods returns a uniform list
of all known apps matching the specified criteria, and blends both
(db/settings) sources of data.
"""
# NOTE: Avoid loading models at top due to registry boot...
from allauth.socialaccount.models import SocialApp
# Map provider to the list of apps.
provider_to_apps = {}
# First, populate it with the DB backed apps.
if request:
db_apps = SocialApp.objects.on_site(request)
else:
db_apps = SocialApp.objects.all()
if provider:
db_apps = db_apps.filter(Q(provider=provider) | Q(provider_id=provider))
if client_id:
db_apps = db_apps.filter(client_id=client_id)
for app in db_apps:
apps = provider_to_apps.setdefault(app.provider, [])
apps.append(app)
# Then, extend it with the settings backed apps.
for p, pcfg in app_settings.PROVIDERS.items():
app_configs = pcfg.get("APPS")
if app_configs is None:
app_config = pcfg.get("APP")
if app_config is None:
continue
app_configs = [app_config]
apps = provider_to_apps.setdefault(p, [])
for config in app_configs:
app = SocialApp(provider=p)
for field in [
"name",
"provider_id",
"client_id",
"secret",
"key",
"settings",
]:
if field in config:
setattr(app, field, config[field])
if "certificate_key" in config:
warnings.warn("'certificate_key' should be moved into app.settings")
app.settings["certificate_key"] = config["certificate_key"]
if client_id and app.client_id != client_id:
continue
if (
provider
and app.provider_id != provider
and app.provider != provider
):
continue
apps.append(app)
# Flatten the list of apps.
apps = []
for provider_apps in provider_to_apps.values():
apps.extend(provider_apps)
return apps
def get_app(self, request, provider, client_id=None):
from allauth.socialaccount.models import SocialApp
apps = self.list_apps(request, provider=provider, client_id=client_id)
if len(apps) > 1:
visible_apps = [app for app in apps if not app.settings.get("hidden")]
if len(visible_apps) != 1:
raise MultipleObjectsReturned
apps = visible_apps
elif len(apps) == 0:
raise SocialApp.DoesNotExist()
return apps[0]
def send_notification_mail(self, *args, **kwargs):
return get_account_adapter().send_notification_mail(*args, **kwargs)
def get_requests_session(self):
import requests
session = requests.Session()
session.request = functools.partial(
session.request, timeout=app_settings.REQUESTS_TIMEOUT
)
return session
def is_email_verified(self, provider, email):
"""
Returns ``True`` iff the given email encountered during a social
login for the given provider is to be assumed verified.
This can be configured with a ``"verified_email"`` key in the provider
app settings, or a ``"VERIFIED_EMAIL"`` in the global provider settings
(``SOCIALACCOUNT_PROVIDERS``). Both can be set to ``False`` or
``True``, or, a list of domains to match email addresses against.
"""
verified_email = None
if provider.app:
verified_email = provider.app.settings.get("verified_email")
if verified_email is None:
settings = provider.get_settings()
verified_email = settings.get("VERIFIED_EMAIL", False)
if isinstance(verified_email, bool):
pass
elif isinstance(verified_email, list):
email_domain = email.partition("@")[2].lower()
verified_domains = [d.lower() for d in verified_email]
verified_email = email_domain in verified_domains
else:
raise ImproperlyConfigured("verified_email wrongly configured")
return verified_email
def can_authenticate_by_email(self, login, email):
"""
Returns ``True`` iff authentication by email is active for this login/email.
This can be configured with a ``"email_authentication"`` key in the provider
app settings, or a ``"VERIFIED_EMAIL"`` in the global provider settings
(``SOCIALACCOUNT_PROVIDERS``).
"""
ret = None
provider = login.account.get_provider()
if provider.app:
ret = provider.app.settings.get("email_authentication")
if ret is None:
ret = app_settings.EMAIL_AUTHENTICATION or provider.get_settings().get(
"EMAIL_AUTHENTICATION", False
)
return ret
def generate_state_param(self, state: dict) -> str:
"""
To preserve certain state before the handshake with the provider
takes place, and be able to verify/use that state later on, a `state`
parameter is typically passed to the provider. By default, a random
string sufficies as the state parameter value is actually just a
reference/pointer to the actual state. You can use this adapter method
to alter the generation of the `state` parameter.
"""
from allauth.socialaccount.internal.statekit import STATE_ID_LENGTH
return get_random_string(STATE_ID_LENGTH)
def get_adapter(request=None):
return import_attribute(app_settings.ADAPTER)(request)

View File

@@ -0,0 +1,69 @@
from typing import List
from django import forms
from django.contrib import admin
from allauth import app_settings
from allauth.account.adapter import get_adapter
from allauth.socialaccount import providers
from allauth.socialaccount.models import SocialAccount, SocialApp, SocialToken
class SocialAppForm(forms.ModelForm):
class Meta:
model = SocialApp
exclude: List[str] = []
widgets = {
"client_id": forms.TextInput(attrs={"size": "100"}),
"key": forms.TextInput(attrs={"size": "100"}),
"secret": forms.TextInput(attrs={"size": "100"}),
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields["provider"] = forms.ChoiceField(
choices=providers.registry.as_choices()
)
class SocialAppAdmin(admin.ModelAdmin):
form = SocialAppForm
list_display = (
"name",
"provider",
)
filter_horizontal = ("sites",) if app_settings.SITES_ENABLED else ()
class SocialAccountAdmin(admin.ModelAdmin):
search_fields = []
raw_id_fields = ("user",)
list_display = ("user", "uid", "provider")
list_filter = ("provider",)
def get_search_fields(self, request):
base_fields = get_adapter().get_user_search_fields()
return list(map(lambda a: "user__" + a, base_fields))
class SocialTokenAdmin(admin.ModelAdmin):
raw_id_fields = (
"app",
"account",
)
list_display = ("app", "account", "truncated_token", "expires_at")
list_filter = ("app", "app__provider", "expires_at")
def truncated_token(self, token):
max_chars = 40
ret = token.token
if len(ret) > max_chars:
ret = ret[0:max_chars] + "...(truncated)"
return ret
truncated_token.short_description = "Token" # type: ignore[attr-defined]
admin.site.register(SocialApp, SocialAppAdmin)
admin.site.register(SocialToken, SocialTokenAdmin)
admin.site.register(SocialAccount, SocialAccountAdmin)

View File

@@ -0,0 +1,155 @@
class AppSettings:
def __init__(self, prefix):
self.prefix = prefix
def _setting(self, name, dflt):
from allauth.utils import get_setting
return get_setting(self.prefix + name, dflt)
@property
def QUERY_EMAIL(self):
"""
Request email address from 3rd party account provider?
E.g. using OpenID AX
"""
from allauth.account import app_settings as account_settings
return self._setting("QUERY_EMAIL", account_settings.EMAIL_REQUIRED)
@property
def AUTO_SIGNUP(self):
"""
Attempt to bypass the signup form by using fields (e.g. username,
email) retrieved from the social account provider. If a conflict
arises due to a duplicate email signup form will still kick in.
"""
return self._setting("AUTO_SIGNUP", True)
@property
def PROVIDERS(self):
"""
Provider specific settings
"""
ret = self._setting("PROVIDERS", {})
oidc = ret.get("openid_connect")
if oidc:
ret["openid_connect"] = self._migrate_oidc(oidc)
return ret
def _migrate_oidc(self, oidc):
servers = oidc.get("SERVERS")
if servers is None:
return oidc
ret = {}
apps = []
for server in servers:
app = dict(**server["APP"])
app_settings = {}
if "token_auth_method" in server:
app_settings["token_auth_method"] = server["token_auth_method"]
app_settings["server_url"] = server["server_url"]
app.update(
{
"name": server.get("name", ""),
"provider_id": server["id"],
"settings": app_settings,
}
)
assert app["provider_id"]
apps.append(app)
ret["APPS"] = apps
return ret
@property
def EMAIL_REQUIRED(self):
"""
The user is required to hand over an email address when signing up
"""
from allauth.account import app_settings as account_settings
return self._setting("EMAIL_REQUIRED", account_settings.EMAIL_REQUIRED)
@property
def EMAIL_VERIFICATION(self):
"""
See email verification method. When `None`, the default
`allauth.account` logic kicks in.
"""
return self._setting("EMAIL_VERIFICATION", None)
@property
def EMAIL_AUTHENTICATION(self):
"""Consider a scenario where a social login occurs, and the social
account comes with a verified email address (verified by the account
provider), but that email address is already taken by a local user
account. Additionally, assume that the local user account does not have
any social account connected. Now, if the provider can be fully trusted,
you can argue that we should treat this scenario as a login to the
existing local user account even if the local account does not already
have the social account connected, because -- according to the provider
-- the user logging in has ownership of the email address. This is how
this scenario is handled when `EMAIL_AUTHENTICATION` is set to
`True`. As this implies that an untrustworthy provider can login to any
local account by fabricating social account data, this setting defaults
to `False`. Only set it to `True` if you are using providers that can be
fully trusted.
"""
return self._setting("EMAIL_AUTHENTICATION", False)
@property
def EMAIL_AUTHENTICATION_AUTO_CONNECT(self):
"""In case email authentication is applied, this setting controls
whether or not the social account is automatically connected to the
local account. In case of ``False`` (the default) the local account
remains unchanged during the login. In case of ``True``, the social
account for which the email matched, is automatically added to the list
of social accounts connected to the local account. As a result, even if
the user were to change the email address afterwards, social login
would still be possible when using ``True``, but not in case of
``False``.
"""
return self._setting("EMAIL_AUTHENTICATION_AUTO_CONNECT", False)
@property
def ADAPTER(self):
return self._setting(
"ADAPTER",
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter",
)
@property
def FORMS(self):
return self._setting("FORMS", {})
@property
def LOGIN_ON_GET(self):
return self._setting("LOGIN_ON_GET", False)
@property
def STORE_TOKENS(self):
return self._setting("STORE_TOKENS", False)
@property
def UID_MAX_LENGTH(self):
return 191
@property
def SOCIALACCOUNT_STR(self):
return self._setting("SOCIALACCOUNT_STR", None)
@property
def REQUESTS_TIMEOUT(self):
return self._setting("REQUESTS_TIMEOUT", 5)
@property
def OPENID_CONNECT_URL_PREFIX(self):
return self._setting("OPENID_CONNECT_URL_PREFIX", "oidc")
_app_settings = AppSettings("SOCIALACCOUNT_")
def __getattr__(name):
# See https://peps.python.org/pep-0562/
return getattr(_app_settings, name)

View File

@@ -0,0 +1,15 @@
from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
from allauth import app_settings
class SocialAccountConfig(AppConfig):
name = "allauth.socialaccount"
verbose_name = _("Social Accounts")
default_auto_field = app_settings.DEFAULT_AUTO_FIELD or "django.db.models.AutoField"
def ready(self):
from allauth.socialaccount.providers import registry
registry.load()

View File

@@ -0,0 +1,66 @@
from contextlib import contextmanager
from unittest.mock import patch
import pytest
from allauth.account.models import EmailAddress
from allauth.socialaccount.models import (
SocialAccount,
SocialLogin,
SocialToken,
)
@pytest.fixture
def sociallogin_factory(user_factory):
def factory(
email=None,
username=None,
with_email=True,
provider="unittest-server",
uid="123",
email_verified=True,
with_token=False,
):
user = user_factory(
username=username, email=email, commit=False, with_email=with_email
)
account = SocialAccount(provider=provider, uid=uid)
sociallogin = SocialLogin(user=user, account=account)
if with_email:
sociallogin.email_addresses = [
EmailAddress(email=user.email, verified=email_verified, primary=True)
]
if with_token:
sociallogin.token = SocialToken(token="123", token_secret="456")
return sociallogin
return factory
@pytest.fixture
def jwt_decode_bypass():
@contextmanager
def f(jwt_data):
with patch("allauth.socialaccount.internal.jwtkit.verify_and_decode") as m:
data = {
"iss": "https://accounts.google.com",
"aud": "client_id",
"sub": "123sub",
"hd": "example.com",
"email": "raymond@example.com",
"email_verified": True,
"at_hash": "HK6E_P6Dh8Y93mRNtsDB1Q",
"name": "Raymond Penners",
"picture": "https://lh5.googleusercontent.com/photo.jpg",
"given_name": "Raymond",
"family_name": "Penners",
"locale": "en",
"iat": 123,
"exp": 456,
}
data.update(jwt_data)
m.return_value = data
yield
return f

View File

@@ -0,0 +1,62 @@
from django import forms
from allauth.account.forms import BaseSignupForm
from allauth.socialaccount.internal import flows
from . import app_settings
from .adapter import get_adapter
from .models import SocialAccount
class SignupForm(BaseSignupForm):
def __init__(self, *args, **kwargs):
self.sociallogin = kwargs.pop("sociallogin")
initial = get_adapter().get_signup_form_initial_data(self.sociallogin)
kwargs.update(
{
"initial": initial,
"email_required": kwargs.get(
"email_required", app_settings.EMAIL_REQUIRED
),
}
)
super(SignupForm, self).__init__(*args, **kwargs)
def save(self, request):
adapter = get_adapter()
user = adapter.save_user(request, self.sociallogin, form=self)
self.custom_signup(request, user)
return user
def validate_unique_email(self, value):
try:
return super(SignupForm, self).validate_unique_email(value)
except forms.ValidationError:
raise get_adapter().validation_error(
"email_taken", self.sociallogin.account.get_provider().name
)
class DisconnectForm(forms.Form):
account = forms.ModelChoiceField(
queryset=SocialAccount.objects.none(),
widget=forms.RadioSelect,
required=True,
)
def __init__(self, *args, **kwargs):
self.request = kwargs.pop("request")
self.accounts = SocialAccount.objects.filter(user=self.request.user)
super(DisconnectForm, self).__init__(*args, **kwargs)
self.fields["account"].queryset = self.accounts
def clean(self):
cleaned_data = super(DisconnectForm, self).clean()
account = cleaned_data.get("account")
if account:
flows.connect.validate_disconnect(self.request, account)
return cleaned_data
def save(self):
account = self.cleaned_data["account"]
flows.connect.disconnect(self.request, account)

View File

@@ -0,0 +1,74 @@
from django.http import HttpResponseRedirect
from django.shortcuts import render
from django.urls import reverse
from allauth import app_settings as allauth_settings
from allauth.account import app_settings as account_settings
from allauth.account.utils import user_display
from allauth.core.exceptions import ImmediateHttpResponse
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.internal import flows
from allauth.socialaccount.providers.base import AuthError
def render_authentication_error(
request,
provider,
error=AuthError.UNKNOWN,
exception=None,
extra_context=None,
):
try:
if allauth_settings.HEADLESS_ENABLED:
from allauth.headless.socialaccount import internal
internal.on_authentication_error(
request,
provider=provider,
error=error,
exception=exception,
extra_context=extra_context,
)
if extra_context is None:
extra_context = {}
get_adapter().on_authentication_error(
request,
provider,
error=error,
exception=exception,
extra_context=extra_context,
)
except ImmediateHttpResponse as e:
return e.response
if error == AuthError.CANCELLED:
return HttpResponseRedirect(reverse("socialaccount_login_cancelled"))
context = {
"auth_error": {
"provider": provider,
"code": error,
"exception": exception,
}
}
context.update(extra_context)
return render(
request,
"socialaccount/authentication_error." + account_settings.TEMPLATE_EXTENSION,
context,
)
def complete_social_login(request, sociallogin):
if sociallogin.is_headless:
from allauth.headless.socialaccount import internal
return internal.complete_login(request, sociallogin)
return flows.login.complete_login(request, sociallogin)
def socialaccount_user_display(socialaccount):
func = app_settings.SOCIALACCOUNT_STR
if not func:
return user_display(socialaccount.user)
return func(socialaccount)

View File

@@ -0,0 +1,9 @@
from allauth.socialaccount.internal.flows import (
connect,
email_authentication,
login,
signup,
)
__all__ = ["connect", "login", "signup", "email_authentication"]

View File

@@ -0,0 +1,130 @@
from django.contrib import messages
from django.core.exceptions import PermissionDenied, ValidationError
from django.http import HttpResponseRedirect
from allauth import app_settings as allauth_settings
from allauth.account import app_settings as account_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.internal import flows
from allauth.account.models import EmailAddress
from allauth.core.exceptions import ReauthenticationRequired
from allauth.socialaccount import signals
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialAccount, SocialLogin
def validate_disconnect(request, account):
"""
Validate whether or not the socialaccount account can be
safely disconnected.
"""
accounts = SocialAccount.objects.filter(user_id=account.user_id)
is_last = not accounts.exclude(pk=account.pk).exists()
adapter = get_adapter()
if is_last:
if allauth_settings.SOCIALACCOUNT_ONLY:
raise adapter.validation_error("disconnect_last")
# No usable password would render the local account unusable
if not account.user.has_usable_password():
raise adapter.validation_error("no_password")
# No email address, no password reset
if (
account_settings.EMAIL_VERIFICATION
== account_settings.EmailVerificationMethod.MANDATORY
):
if not EmailAddress.objects.filter(
user=account.user, verified=True
).exists():
raise adapter.validation_error("no_verified_email")
adapter.validate_disconnect(account, accounts)
def disconnect(request, account):
if account_settings.REAUTHENTICATION_REQUIRED:
flows.reauthentication.raise_if_reauthentication_required(request)
get_account_adapter().add_message(
request,
messages.INFO,
"socialaccount/messages/account_disconnected.txt",
)
provider = account.get_provider()
account.delete()
signals.social_account_removed.send(
sender=SocialAccount, request=request, socialaccount=account
)
get_adapter().send_notification_mail(
"socialaccount/email/account_disconnected",
request.user,
context={
"account": account,
"provider": provider,
},
)
def resume_connect(request, serialized_state):
sociallogin = SocialLogin.deserialize(serialized_state)
return connect(request, sociallogin)
def connect(request, sociallogin):
try:
ok, action, message = do_connect(request, sociallogin)
except PermissionDenied:
# This should not happen. Simply redirect to the connections
# view (which has a login required)
connect_redirect_url = get_adapter().get_connect_redirect_url(
request, sociallogin.account
)
return HttpResponseRedirect(connect_redirect_url)
except ReauthenticationRequired:
return flows.reauthentication.stash_and_reauthenticate(
request,
sociallogin.serialize(),
"allauth.socialaccount.internal.flows.connect.resume_connect",
)
except ValidationError:
ok, action, message = (
False,
None,
"socialaccount/messages/account_connected_other.txt",
)
level = messages.INFO if ok else messages.ERROR
default_next = get_adapter().get_connect_redirect_url(request, sociallogin.account)
next_url = sociallogin.get_redirect_url(request) or default_next
get_account_adapter(request).add_message(
request,
level,
message,
message_context={"sociallogin": sociallogin, "action": action},
)
return HttpResponseRedirect(next_url)
def do_connect(request, sociallogin):
if request.user.is_anonymous:
raise PermissionDenied()
if account_settings.REAUTHENTICATION_REQUIRED:
flows.reauthentication.raise_if_reauthentication_required(request)
message = "socialaccount/messages/account_connected.txt"
action = None
ok = True
if sociallogin.is_existing:
if sociallogin.user != request.user:
# Social account of other user. For now, this scenario
# is not supported. Issue is that one cannot simply
# remove the social account from the other user, as
# that may render the account unusable.
raise get_adapter().validation_error("connected_other")
elif not sociallogin.account._state.adding:
action = "updated"
message = "socialaccount/messages/account_connected_updated.txt"
else:
action = "added"
sociallogin.connect(request, request.user)
else:
# New account, let's connect
action = "added"
sociallogin.connect(request, request.user)
return ok, action, message

View File

@@ -0,0 +1,34 @@
from allauth import app_settings as allauth_settings
from allauth.account.models import EmailAddress
def wipe_password(request, user, email: str):
"""
Consider a scenario where an attacker signs up for an account using the
email address of a victim. Obviously, the email address cannot be
verified, yet the attacker -- knowing the password -- can wait until the
victim appears. When the victim signs in using email authentication, it
is not obvious that the victim is signing into an account that was not
created by the victim. As a result, both the attacker and the victim now
have access to the account. To prevent this, we wipe the password of the
account in case the email address was not verified, effectively locking
out the attacker.
"""
try:
address = EmailAddress.objects.get_for_user(user, email)
except EmailAddress.DoesNotExist:
address = None
if address and address.verified:
# Verified email address, no reason to worry.
return
if user.has_usable_password():
user.set_unusable_password()
user.save(update_fields=["password"])
# Also wipe any other sessions (upstream integrators may hook up to the
# ending of the sessions to trigger e.g. backchannel logout.
if allauth_settings.USERSESSIONS_ENABLED:
from allauth.usersessions.internal.flows.sessions import (
end_other_sessions,
)
end_other_sessions(request, user)

View File

@@ -0,0 +1,97 @@
from django.http import HttpResponseRedirect
from django.shortcuts import render
from allauth.account import app_settings as account_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.utils import perform_login
from allauth.core.exceptions import (
ImmediateHttpResponse,
SignupClosedException,
)
from allauth.socialaccount import app_settings, signals
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.internal.flows.connect import connect, do_connect
from allauth.socialaccount.internal.flows.signup import (
clear_pending_signup,
process_signup,
)
from allauth.socialaccount.models import SocialLogin
from allauth.socialaccount.providers.base import AuthProcess
def _login(request, sociallogin):
sociallogin._accept_login(request)
record_authentication(request, sociallogin)
return perform_login(
request,
sociallogin.user,
email_verification=app_settings.EMAIL_VERIFICATION,
redirect_url=sociallogin.get_redirect_url(request),
signal_kwargs={"sociallogin": sociallogin},
)
def pre_social_login(request, sociallogin):
clear_pending_signup(request)
assert not sociallogin.is_existing
sociallogin.lookup()
get_adapter().pre_social_login(request, sociallogin)
signals.pre_social_login.send(
sender=SocialLogin, request=request, sociallogin=sociallogin
)
def complete_login(request, sociallogin, raises=False):
try:
pre_social_login(request, sociallogin)
process = sociallogin.state.get("process")
if process == AuthProcess.REDIRECT:
return _redirect(request, sociallogin)
elif process == AuthProcess.CONNECT:
if raises:
do_connect(request, sociallogin)
else:
return connect(request, sociallogin)
else:
return _authenticate(request, sociallogin)
except SignupClosedException:
if raises:
raise
return render(
request,
"account/signup_closed." + account_settings.TEMPLATE_EXTENSION,
)
except ImmediateHttpResponse as e:
if raises:
raise
return e.response
def _redirect(request, sociallogin):
next_url = sociallogin.get_redirect_url(request) or "/"
return HttpResponseRedirect(next_url)
def _authenticate(request, sociallogin):
if request.user.is_authenticated:
get_account_adapter(request).logout(request)
if sociallogin.is_existing:
# Login existing user
ret = _login(request, sociallogin)
else:
# New social user
ret = process_signup(request, sociallogin)
return ret
def record_authentication(request, sociallogin):
from allauth.account.internal.flows.login import record_authentication
record_authentication(
request,
"socialaccount",
**{
"provider": sociallogin.account.provider,
"uid": sociallogin.account.uid,
}
)

View File

@@ -0,0 +1,114 @@
from django.forms import ValidationError
from allauth.account import app_settings as account_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.internal.flows.manage_email import assess_unique_email
from allauth.account.internal.flows.signup import (
complete_signup,
prevent_enumeration,
)
from allauth.account.utils import user_username
from allauth.core.exceptions import SignupClosedException
from allauth.core.internal.httpkit import headed_redirect_response
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialLogin
def get_pending_signup(request):
data = request.session.get("socialaccount_sociallogin")
if data:
return SocialLogin.deserialize(data)
def redirect_to_signup(request, sociallogin):
request.session["socialaccount_sociallogin"] = sociallogin.serialize()
return headed_redirect_response("socialaccount_signup")
def clear_pending_signup(request):
request.session.pop("socialaccount_sociallogin", None)
def signup_by_form(request, sociallogin, form):
clear_pending_signup(request)
user, resp = form.try_save(request)
if not resp:
resp = complete_social_signup(request, sociallogin)
return resp
def process_auto_signup(request, sociallogin):
auto_signup = get_adapter().is_auto_signup_allowed(request, sociallogin)
if not auto_signup:
return False, None
email = None
if sociallogin.email_addresses:
email = sociallogin.email_addresses[0].email
# Let's check if auto_signup is really possible...
if email:
assessment = assess_unique_email(email)
if assessment is True:
# Auto signup is fine.
pass
elif assessment is False:
# Oops, another user already has this address. We cannot simply
# connect this social account to the existing user. Reason is
# that the email address may not be verified, meaning, the user
# may be a hacker that has added your email address to their
# account in the hope that you fall in their trap. We cannot
# check on 'email_address.verified' either, because
# 'email_address' is not guaranteed to be verified.
auto_signup = False
# TODO: We redirect to signup form -- user will see email
# address conflict only after posting whereas we detected it
# here already.
else:
assert assessment is None
# Prevent enumeration is properly turned on, meaning, we cannot
# show the signup form to allow the user to input another email
# address. Instead, we're going to send the user an email that
# the account already exists, and on the outside make it appear
# as if an email verification mail was sent.
resp = prevent_enumeration(request, email)
return False, resp
elif app_settings.EMAIL_REQUIRED:
# Nope, email is required and we don't have it yet...
auto_signup = False
return auto_signup, None
def process_signup(request, sociallogin):
if not get_adapter().is_open_for_signup(request, sociallogin):
raise SignupClosedException()
auto_signup, resp = process_auto_signup(request, sociallogin)
if resp:
return resp
if not auto_signup:
resp = redirect_to_signup(request, sociallogin)
else:
# Ok, auto signup it is, at least the email address is ok.
# We still need to check the username though...
if account_settings.USER_MODEL_USERNAME_FIELD:
username = user_username(sociallogin.user)
try:
get_account_adapter(request).clean_username(username)
except ValidationError:
# This username is no good ...
user_username(sociallogin.user, "")
# TODO: This part contains a lot of duplication of logic
# ("closed" rendering, create user, send email, in active
# etc..)
get_adapter().save_user(request, sociallogin, form=None)
resp = complete_social_signup(request, sociallogin)
return resp
def complete_social_signup(request, sociallogin):
return complete_signup(
request,
user=sociallogin.user,
email_verification=app_settings.EMAIL_VERIFICATION,
redirect_url=sociallogin.get_redirect_url(request),
signal_kwargs={"sociallogin": sociallogin},
)

View File

@@ -0,0 +1,105 @@
import json
import time
from django.core.cache import cache
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
def lookup_kid_pem_x509_certificate(keys_data, kid):
"""
Looks up the key given keys data of the form:
{"<kid>": "-----BEGIN CERTIFICATE-----\nCERTIFICATE"}
"""
key = keys_data.get(kid)
if key:
public_key = load_pem_x509_certificate(
key.encode("utf8"), default_backend()
).public_key()
return public_key
def lookup_kid_jwk(keys_data, kid):
"""
Looks up the key given keys data of the form:
{
"keys": [
{
"kty": "RSA",
"kid": "W6WcOKB",
"use": "sig",
"alg": "RS256",
"n": "2Zc5d0-zk....",
"e": "AQAB"
}]
}
"""
for d in keys_data["keys"]:
if d["kid"] == kid:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(d))
return public_key
def fetch_key(credential, keys_url, lookup):
header = jwt.get_unverified_header(credential)
# {'alg': 'RS256', 'kid': '[AWS-SECRET-REMOVED]', 'typ': 'JWT'}
kid = header["kid"]
alg = header["alg"]
response = get_adapter().get_requests_session().get(keys_url)
response.raise_for_status()
keys_data = response.json()
key = lookup(keys_data, kid)
if not key:
raise OAuth2Error(f"Invalid 'kid': '{kid}'")
return alg, key
def verify_jti(data: dict) -> None:
"""
Put the JWT token on a blacklist to prevent replay attacks.
"""
iss = data.get("iss")
exp = data.get("exp")
jti = data.get("jti")
if iss is None or exp is None or jti is None:
return
timeout = exp - time.time()
key = f"jwt:iss={iss},jti={jti}"
if not cache.add(key=key, value=True, timeout=timeout):
raise OAuth2Error("token already used")
def verify_and_decode(
*, credential, keys_url, issuer, audience, lookup_kid, verify_signature=True
):
try:
if verify_signature:
alg, key = fetch_key(credential, keys_url, lookup_kid)
algorithms = [alg]
else:
key = ""
algorithms = None
data = jwt.decode(
credential,
key=key,
options={
"verify_signature": verify_signature,
"verify_iss": True,
"verify_aud": True,
"verify_exp": True,
},
issuer=issuer,
audience=audience,
algorithms=algorithms,
)
verify_jti(data)
return data
except jwt.PyJWTError as e:
raise OAuth2Error("Invalid id_token") from e

View File

@@ -0,0 +1,69 @@
import time
from typing import Any, Dict, Optional, Tuple
from allauth.socialaccount.adapter import get_adapter
STATE_ID_LENGTH = 16
MAX_STATES = 10
STATES_SESSION_KEY = "socialaccount_states"
def get_oldest_state(
states: Dict[str, Tuple[Dict[str, Any], float]], rev: bool = False
) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
oldest_ts = None
oldest_id = None
oldest = None
for state_id, state_ts in states.items():
ts = state_ts[1]
if oldest_ts is None or (
(rev and ts > oldest_ts) or ((not rev) and oldest_ts > ts)
):
oldest_ts = ts
oldest_id = state_id
oldest = state_ts[0]
return oldest_id, oldest
def gc_states(states: Dict[str, Tuple[Dict[str, Any], float]]):
if len(states) > MAX_STATES:
oldest_id, oldest = get_oldest_state(states)
if oldest_id:
del states[oldest_id]
def get_states(request) -> Dict[str, Tuple[Dict[str, Any], float]]:
states = request.session.get(STATES_SESSION_KEY)
if not isinstance(states, dict):
states = {}
return states
def stash_state(request, state: Dict[str, Any], state_id: Optional[str] = None) -> str:
states = get_states(request)
gc_states(states)
if state_id is None:
state_id = get_adapter().generate_state_param(state)
states[state_id] = (state, time.time())
request.session[STATES_SESSION_KEY] = states
return state_id
def unstash_state(request, state_id: str) -> Optional[Dict[str, Any]]:
state: Optional[Dict[str, Any]] = None
states = get_states(request)
state_ts = states.get(state_id)
if state_ts is not None:
state = state_ts[0]
del states[state_id]
request.session[STATES_SESSION_KEY] = states
return state
def unstash_last_state(request) -> Optional[Dict[str, Any]]:
states = get_states(request)
state_id, state = get_oldest_state(states, rev=True)
if state_id:
unstash_state(request, state_id)
return state

View File

@@ -0,0 +1,36 @@
from datetime import timedelta
from django.utils import timezone
from allauth.socialaccount.internal.jwtkit import verify_and_decode
from allauth.socialaccount.providers.apple.client import jwt_encode
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
def test_verify_and_decode(enable_cache):
now = timezone.now()
payload = {
"iss": "https://accounts.google.com",
"azp": "client_id",
"aud": "client_id",
"sub": "108204268033311374519",
"hd": "example.com",
"locale": "en",
"iat": now,
"jti": "[AWS-SECRET-REMOVED]",
"exp": now + timedelta(hours=1),
}
id_token = jwt_encode(payload, "secret")
for attempt in range(2):
try:
verify_and_decode(
credential=id_token,
keys_url="/",
issuer=payload["iss"],
audience=payload["aud"],
lookup_kid=False,
verify_signature=False,
)
assert attempt == 0
except OAuth2Error:
assert attempt == 1

View File

@@ -0,0 +1,48 @@
from allauth.socialaccount.internal import statekit
def test_get_oldest_state():
states = {
"new": [{"id": "new"}, 300],
"mid": [{"id": "mid"}, 200],
"old": [{"id": "old"}, 100],
}
state_id, state = statekit.get_oldest_state(states)
assert state_id == "old"
assert state["id"] == "old"
def test_get_oldest_state_empty():
state_id, state = statekit.get_oldest_state({})
assert state_id is None
assert state is None
def test_gc_states():
states = {}
for i in range(statekit.MAX_STATES + 1):
states[f"state-{i}"] = [{"i": i}, 1000 + i]
assert len(states) == statekit.MAX_STATES + 1
statekit.gc_states(states)
assert len(states) == statekit.MAX_STATES
assert "state-0" not in states
def test_stashing(rf):
request = rf.get("/")
request.session = {}
state_id = statekit.stash_state(request, {"foo": "bar"})
state2_id = statekit.stash_state(request, {"foo2": "bar2"})
state3_id = statekit.stash_state(request, {"foo3": "bar3"})
state = statekit.unstash_last_state(request)
assert state == {"foo3": "bar3"}
state = statekit.unstash_state(request, state3_id)
assert state is None
state = statekit.unstash_state(request, state2_id)
assert state == {"foo2": "bar2"}
state = statekit.unstash_state(request, state2_id)
assert state is None
state = statekit.unstash_state(request, state_id)
assert state == {"foo": "bar"}
state = statekit.unstash_state(request, state_id)
assert state is None

View File

@@ -0,0 +1,192 @@
from django.conf import settings
from django.db import migrations, models
from allauth import app_settings
class Migration(migrations.Migration):
dependencies = (
[
("sites", "0001_initial"),
]
if app_settings.SITES_ENABLED
else []
) + [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="SocialAccount",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"provider",
models.CharField(
max_length=30,
verbose_name="provider",
),
),
(
"uid",
models.CharField(
max_length=getattr(
settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191
),
verbose_name="uid",
),
),
(
"last_login",
models.DateTimeField(auto_now=True, verbose_name="last login"),
),
(
"date_joined",
models.DateTimeField(auto_now_add=True, verbose_name="date joined"),
),
(
"extra_data",
models.TextField(default="{}", verbose_name="extra data"),
),
(
"user",
models.ForeignKey(
to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE
),
),
],
options={
"verbose_name": "social account",
"verbose_name_plural": "social accounts",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="SocialApp",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"provider",
models.CharField(
max_length=30,
verbose_name="provider",
),
),
("name", models.CharField(max_length=40, verbose_name="name")),
(
"client_id",
models.CharField(
help_text="App ID, or consumer key",
max_length=100,
verbose_name="client id",
),
),
(
"secret",
models.CharField(
help_text="API secret, client secret, or consumer secret",
max_length=100,
verbose_name="secret key",
),
),
(
"key",
models.CharField(
help_text="Key",
max_length=100,
verbose_name="key",
blank=True,
),
),
]
+ (
[
("sites", models.ManyToManyField(to="sites.Site", blank=True)),
]
if app_settings.SITES_ENABLED
else []
),
options={
"verbose_name": "social application",
"verbose_name_plural": "social applications",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="SocialToken",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"token",
models.TextField(
help_text='"oauth_token" (OAuth1) or access token (OAuth2)',
verbose_name="token",
),
),
(
"token_secret",
models.TextField(
help_text='"oauth_token_secret" (OAuth1) or refresh token (OAuth2)',
verbose_name="token secret",
blank=True,
),
),
(
"expires_at",
models.DateTimeField(
null=True, verbose_name="expires at", blank=True
),
),
(
"account",
models.ForeignKey(
to="socialaccount.SocialAccount",
on_delete=models.CASCADE,
),
),
(
"app",
models.ForeignKey(
to="socialaccount.SocialApp", on_delete=models.CASCADE
),
),
],
options={
"verbose_name": "social application token",
"verbose_name_plural": "social application tokens",
},
bases=(models.Model,),
),
migrations.AlterUniqueTogether(
name="socialtoken",
unique_together=set([("app", "account")]),
),
migrations.AlterUniqueTogether(
name="socialaccount",
unique_together=set([("provider", "uid")]),
),
]

View File

@@ -0,0 +1,45 @@
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0001_initial"),
]
operations = [
migrations.AlterField(
model_name="socialaccount",
name="uid",
field=models.CharField(
max_length=getattr(settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191),
verbose_name="uid",
),
),
migrations.AlterField(
model_name="socialapp",
name="client_id",
field=models.CharField(
help_text="App ID, or consumer key",
max_length=191,
verbose_name="client id",
),
),
migrations.AlterField(
model_name="socialapp",
name="key",
field=models.CharField(
help_text="Key", max_length=191, verbose_name="key", blank=True
),
),
migrations.AlterField(
model_name="socialapp",
name="secret",
field=models.CharField(
help_text="API secret, client secret, or consumer secret",
max_length=191,
verbose_name="secret key",
blank=True,
),
),
]

View File

@@ -0,0 +1,16 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0002_token_max_lengths"),
]
operations = [
migrations.AlterField(
model_name="socialaccount",
name="extra_data",
field=models.TextField(default="{}", verbose_name="extra data"),
preserve_default=True,
),
]

View File

@@ -0,0 +1,29 @@
# Generated by Django 3.2.19 on 2023-06-30 13:16
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0003_extra_data_default_dict"),
]
operations = [
migrations.AddField(
model_name="socialapp",
name="provider_id",
field=models.CharField(
blank=True, max_length=200, verbose_name="provider ID"
),
),
migrations.AddField(
model_name="socialapp",
name="settings",
field=models.JSONField(blank=True, default=dict),
),
migrations.AlterField(
model_name="socialaccount",
name="provider",
field=models.CharField(max_length=200, verbose_name="provider"),
),
]

View File

@@ -0,0 +1,23 @@
# Generated by Django 3.2.20 on 2023-09-03 19:46
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0004_app_provider_id_settings"),
]
operations = [
migrations.AlterField(
model_name="socialtoken",
name="app",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="socialaccount.socialapp",
),
),
]

View File

@@ -0,0 +1,17 @@
# Generated by Django 3.2.20 on 2023-10-11 09:23
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0005_socialtoken_nullable_app"),
]
operations = [
migrations.AlterField(
model_name="socialaccount",
name="extra_data",
field=models.JSONField(default=dict, verbose_name="extra data"),
),
]

View File

@@ -0,0 +1,411 @@
from typing import Any, Dict, List, Optional
from django.conf import settings
from django.contrib.auth import authenticate, get_user_model
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
from django.db import models
from django.utils.translation import gettext_lazy as _
import allauth.app_settings
from allauth import app_settings as allauth_settings
from allauth.account.models import EmailAddress
from allauth.account.utils import (
filter_users_by_email,
get_next_redirect_url,
setup_user_email,
)
from allauth.core import context
from allauth.socialaccount import app_settings, providers, signals
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.internal import statekit
from allauth.utils import get_request_param
if not allauth_settings.SOCIALACCOUNT_ENABLED:
raise ImproperlyConfigured(
"allauth.socialaccount not installed, yet its models are imported."
)
class SocialAppManager(models.Manager):
def on_site(self, request):
if allauth.app_settings.SITES_ENABLED:
site = get_current_site(request)
return self.filter(sites__id=site.id)
return self.all()
class SocialApp(models.Model):
objects = SocialAppManager()
# The provider type, e.g. "google", "telegram", "saml".
provider = models.CharField(
verbose_name=_("provider"),
max_length=30,
)
# For providers that support subproviders, such as OpenID Connect and SAML,
# this ID identifies that instance. SocialAccount's originating from app
# will have their `provider` field set to the `provider_id` if available,
# else `provider`.
provider_id = models.CharField(
verbose_name=_("provider ID"),
max_length=200,
blank=True,
)
name = models.CharField(verbose_name=_("name"), max_length=40)
client_id = models.CharField(
verbose_name=_("client id"),
max_length=191,
help_text=_("App ID, or consumer key"),
)
secret = models.CharField(
verbose_name=_("secret key"),
max_length=191,
blank=True,
help_text=_("API secret, client secret, or consumer secret"),
)
key = models.CharField(
verbose_name=_("key"), max_length=191, blank=True, help_text=_("Key")
)
settings = models.JSONField(default=dict, blank=True)
if allauth.app_settings.SITES_ENABLED:
# Most apps can be used across multiple domains, therefore we use
# a ManyToManyField. Note that Facebook requires an app per domain
# (unless the domains share a common base name).
# blank=True allows for disabling apps without removing them
sites = models.ManyToManyField("sites.Site", blank=True) # type: ignore[var-annotated]
class Meta:
verbose_name = _("social application")
verbose_name_plural = _("social applications")
def __str__(self):
return self.name
def get_provider(self, request):
provider_class = providers.registry.get_class(self.provider)
return provider_class(request=request, app=self)
class SocialAccount(models.Model):
user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
# Given a `SocialApp` from which this account originates, this field equals
# the app's `app.provider_id` if available, `app.provider` otherwise.
provider = models.CharField(
verbose_name=_("provider"),
max_length=200,
)
# Just in case you're wondering if an OpenID identity URL is going
# to fit in a 'uid':
#
# Ideally, URLField(max_length=1024, unique=True) would be used
# for identity. However, MySQL has a max_length limitation of 191
# for URLField (in case of utf8mb4). How about
# models.TextField(unique=True) then? Well, that won't work
# either for MySQL due to another bug[1]. So the only way out
# would be to drop the unique constraint, or switch to shorter
# identity URLs. Opted for the latter, as [2] suggests that
# identity URLs are supposed to be short anyway, at least for the
# old spec.
#
# [1] http://code.djangoproject.com/ticket/2495.
# [2] http://openid.net/specs/openid-authentication-1_1.html#limits
uid = models.CharField(
verbose_name=_("uid"), max_length=app_settings.UID_MAX_LENGTH
)
last_login = models.DateTimeField(verbose_name=_("last login"), auto_now=True)
date_joined = models.DateTimeField(verbose_name=_("date joined"), auto_now_add=True)
extra_data = models.JSONField(verbose_name=_("extra data"), default=dict)
class Meta:
unique_together = ("provider", "uid")
verbose_name = _("social account")
verbose_name_plural = _("social accounts")
def authenticate(self):
return authenticate(account=self)
def __str__(self):
from .helpers import socialaccount_user_display
return socialaccount_user_display(self)
def get_profile_url(self):
return self.get_provider_account().get_profile_url()
def get_avatar_url(self):
return self.get_provider_account().get_avatar_url()
def get_provider(self, request=None):
provider = getattr(self, "_provider", None)
if provider:
return provider
adapter = get_adapter()
provider = self._provider = adapter.get_provider(
request or context.request, provider=self.provider
)
return provider
def get_provider_account(self):
return self.get_provider().wrap_account(self)
class SocialToken(models.Model):
app = models.ForeignKey(SocialApp, on_delete=models.SET_NULL, blank=True, null=True)
account = models.ForeignKey(SocialAccount, on_delete=models.CASCADE)
token = models.TextField(
verbose_name=_("token"),
help_text=_('"oauth_token" (OAuth1) or access token (OAuth2)'),
)
token_secret = models.TextField(
blank=True,
verbose_name=_("token secret"),
help_text=_('"oauth_token_secret" (OAuth1) or refresh token (OAuth2)'),
)
expires_at = models.DateTimeField(
blank=True, null=True, verbose_name=_("expires at")
)
class Meta:
unique_together = ("app", "account")
verbose_name = _("social application token")
verbose_name_plural = _("social application tokens")
def __str__(self):
return "%s (%s)" % (self._meta.verbose_name, self.pk)
class SocialLogin:
"""
Represents a social user that is in the process of being logged
in. This consists of the following information:
`account` (`SocialAccount` instance): The social account being
logged in. Providers are not responsible for checking whether or
not an account already exists or not. Therefore, a provider
typically creates a new (unsaved) `SocialAccount` instance. The
`User` instance pointed to by the account (`account.user`) may be
prefilled by the provider for use as a starting point later on
during the signup process.
`token` (`SocialToken` instance): An optional access token token
that results from performing a successful authentication
handshake.
`state` (`dict`): The state to be preserved during the
authentication handshake. Note that this state may end up in the
url -- do not put any secrets in here. It currently only contains
the url to redirect to after login.
`email_addresses` (list of `EmailAddress`): Optional list of
email addresses retrieved from the provider.
"""
account: SocialAccount
token: Optional[SocialToken]
email_addresses: List[EmailAddress]
state: Dict
_did_authenticate_by_email: Optional[str]
def __init__(
self,
user=None,
account: Optional[SocialAccount] = None,
token: Optional[SocialToken] = None,
email_addresses: Optional[List[EmailAddress]] = None,
):
if token:
assert token.account is None or token.account == account
self.token = token
self.user = user
if account:
self.account = account
self.email_addresses = email_addresses if email_addresses else []
self.state = {}
def connect(self, request, user) -> None:
self.user = user
self.save(request, connect=True)
signals.social_account_added.send(
sender=SocialLogin, request=request, sociallogin=self
)
get_adapter().send_notification_mail(
"socialaccount/email/account_connected",
self.user,
context={
"account": self.account,
"provider": self.account.get_provider(),
},
)
@property
def is_headless(self) -> bool:
return bool(self.state.get("headless"))
def serialize(self) -> Dict[str, Any]:
serialize_instance = get_adapter().serialize_instance
ret = dict(
account=serialize_instance(self.account),
user=serialize_instance(self.user),
state=self.state,
email_addresses=[serialize_instance(ea) for ea in self.email_addresses],
)
if self.token:
ret["token"] = serialize_instance(self.token)
return ret
@classmethod
def deserialize(cls, data: Dict[str, Any]) -> "SocialLogin":
deserialize_instance = get_adapter().deserialize_instance
account = deserialize_instance(SocialAccount, data["account"])
user = deserialize_instance(get_user_model(), data["user"])
if "token" in data:
token = deserialize_instance(SocialToken, data["token"])
else:
token = None
email_addresses = []
for ea in data["email_addresses"]:
email_address = deserialize_instance(EmailAddress, ea)
email_addresses.append(email_address)
ret = cls()
ret.token = token
ret.account = account
ret.user = user
ret.email_addresses = email_addresses
ret.state = data["state"]
return ret
def save(self, request, connect: bool = False) -> None:
"""
Saves a new account. Note that while the account is new,
the user may be an existing one (when connecting accounts)
"""
user = self.user
user.save()
self.account.user = user
self.account.save()
if app_settings.STORE_TOKENS and self.token:
self.token.account = self.account
self.token.save()
if connect:
# TODO: Add any new email addresses automatically?
pass
else:
setup_user_email(request, user, self.email_addresses)
@property
def is_existing(self) -> bool:
"""When `False`, this social login represents a temporary account, not
yet backed by a database record.
"""
if self.user.pk is None:
return False
return get_user_model().objects.filter(pk=self.user.pk).exists()
def lookup(self) -> None:
"""Look up the existing local user account to which this social login
points, if any.
"""
self._did_authenticate_by_email = None
if not self._lookup_by_socialaccount():
self._lookup_by_email()
def _lookup_by_socialaccount(self) -> bool:
assert not self.is_existing
try:
a = SocialAccount.objects.get(
provider=self.account.provider, uid=self.account.uid
)
# Update account
a.extra_data = self.account.extra_data
self.account = a
self.user = self.account.user
a.save()
signals.social_account_updated.send(
sender=SocialLogin, request=context.request, sociallogin=self
)
self._store_token()
return True
except SocialAccount.DoesNotExist:
return False
def _store_token(self) -> None:
# Update token
if not app_settings.STORE_TOKENS or not self.token:
return
assert not self.token.pk
app = self.token.app
if app and not app.pk:
# If the app is not stored in the db, leave the FK empty.
app = None
try:
t = SocialToken.objects.get(account=self.account, app=app)
t.token = self.token.token
if self.token.token_secret:
# only update the refresh token if we got one
# many oauth2 providers do not resend the refresh token
t.token_secret = self.token.token_secret
t.expires_at = self.token.expires_at
t.save()
self.token = t
except SocialToken.DoesNotExist:
self.token.account = self.account
self.token.app = app
self.token.save()
def _lookup_by_email(self) -> None:
emails = [e.email for e in self.email_addresses if e.verified]
for email in emails:
if not get_adapter().can_authenticate_by_email(self, email):
continue
users = filter_users_by_email(email, prefer_verified=True)
if users:
self.user = users[0]
self._did_authenticate_by_email = email
return
def _accept_login(self, request) -> None:
from allauth.socialaccount.internal.flows.email_authentication import (
wipe_password,
)
if self._did_authenticate_by_email:
wipe_password(request, self.user, self._did_authenticate_by_email)
if app_settings.EMAIL_AUTHENTICATION_AUTO_CONNECT:
self.connect(context.request, self.user)
def get_redirect_url(self, request) -> Optional[str]:
url = self.state.get("next")
return url
@classmethod
def state_from_request(cls, request) -> Dict[str, Any]:
"""
TODO: Deprecated! To be integrated with provider.redirect()
"""
state = {}
next_url = get_next_redirect_url(request)
if next_url:
state["next"] = next_url
state["process"] = get_request_param(request, "process", "login")
state["scope"] = get_request_param(request, "scope", "")
state["auth_params"] = get_request_param(request, "auth_params", "")
return state
@classmethod
def stash_state(cls, request, state: Optional[Dict[str, Any]] = None) -> str:
if state is None:
# Only for providers that don't support redirect() yet.
state = cls.state_from_request(request)
return statekit.stash_state(request, state)
@classmethod
def unstash_state(cls, request) -> Optional[Dict[str, Any]]:
state = statekit.unstash_last_state(request)
if state is None:
raise PermissionDenied()
return state

View File

@@ -0,0 +1,57 @@
import importlib
from collections import OrderedDict
from django.apps import apps
from django.conf import settings
from allauth.utils import import_attribute
class ProviderRegistry:
def __init__(self):
self.provider_map = OrderedDict()
self.loaded = False
def get_class_list(self):
self.load()
return list(self.provider_map.values())
def register(self, cls):
self.provider_map[cls.id] = cls
def get_class(self, id):
return self.provider_map.get(id)
def as_choices(self):
self.load()
for provider_cls in self.provider_map.values():
yield (provider_cls.id, provider_cls.name)
def load(self):
# TODO: Providers register with the provider registry when
# loaded. Here, we build the URLs for all registered providers. So, we
# really need to be sure all providers did register, which is why we're
# forcefully importing the `provider` modules here. The overall
# mechanism is way to magical and depends on the import order et al, so
# all of this really needs to be revisited.
if not self.loaded:
for app_config in apps.get_app_configs():
try:
module_name = app_config.name + ".provider"
provider_module = importlib.import_module(module_name)
except ImportError as e:
if e.name != module_name:
raise
else:
provider_settings = getattr(settings, "SOCIALACCOUNT_PROVIDERS", {})
for cls in getattr(provider_module, "provider_classes", []):
provider_class = provider_settings.get(cls.id, {}).get(
"provider_class"
)
if provider_class:
cls = import_attribute(provider_class)
self.register(cls)
self.loaded = True
registry = ProviderRegistry()

View File

@@ -0,0 +1,37 @@
from allauth.socialaccount.providers.agave.views import AgaveAdapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AgaveAccount(ProviderAccount):
def get_profile_url(self):
return self.account.extra_data.get("web_url", "dflt")
def get_avatar_url(self):
return self.account.extra_data.get("avatar_url", "dflt")
class AgaveProvider(OAuth2Provider):
id = "agave"
name = "Agave"
account_class = AgaveAccount
oauth2_adapter_class = AgaveAdapter
def extract_uid(self, data):
return str(data.get("create_time"))
def extract_common_fields(self, data):
return dict(
email=data.get("email"),
username=data.get("username", ""),
name=(
(data.get("first_name", "") + " " + data.get("last_name", "")).strip()
),
)
def get_default_scope(self):
scope = ["PRODUCTION"]
return scope
provider_classes = [AgaveProvider]

View File

@@ -0,0 +1,34 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AgaveProvider
class AgaveTests(OAuth2TestsMixin, TestCase):
provider_id = AgaveProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{
"status": "success",
"message": "User details retrieved successfully.",
"version": "2.0.0-SNAPSHOT-rc3fad",
"result": {
"first_name": "John",
"last_name": "Doe",
"full_name": "John Doe",
"email": "jon@doe.edu",
"phone": "",
"mobile_phone": "",
"status": "Active",
"create_time": "20180322043812Z",
"username": "jdoe"
}
}
""",
)
def get_expected_to_str(self):
return "jdoe"

View File

@@ -0,0 +1,5 @@
from allauth.socialaccount.providers.agave.provider import AgaveProvider
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
urlpatterns = default_urlpatterns(AgaveProvider)

View File

@@ -0,0 +1,41 @@
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AgaveAdapter(OAuth2Adapter):
provider_id = "agave"
settings = app_settings.PROVIDERS.get(provider_id, {})
provider_base_url = settings.get("API_URL", "https://public.agaveapi.co")
access_token_url = "{0}/token".format(provider_base_url)
authorize_url = "{0}/authorize".format(provider_base_url)
profile_url = "{0}/profiles/v2/me".format(provider_base_url)
def complete_login(self, request, app, token, response):
extra_data = (
get_adapter()
.get_requests_session()
.get(
self.profile_url,
params={"access_token": token.token},
headers={
"Authorization": "Bearer " + token.token,
},
)
)
user_profile = (
extra_data.json()["result"] if "result" in extra_data.json() else {}
)
return self.get_provider().sociallogin_from_response(request, user_profile)
oauth2_login = OAuth2LoginView.adapter_view(AgaveAdapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AgaveAdapter)

View File

@@ -0,0 +1,34 @@
from allauth.socialaccount.providers.amazon.views import AmazonOAuth2Adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AmazonAccount(ProviderAccount):
pass
class AmazonProvider(OAuth2Provider):
id = "amazon"
name = "Amazon"
account_class = AmazonAccount
oauth2_adapter_class = AmazonOAuth2Adapter
def get_default_scope(self):
return ["profile"]
def extract_uid(self, data):
return str(data["user_id"])
def extract_common_fields(self, data):
# Hackish way of splitting the fullname.
# Assumes no middlenames.
name = data.get("name", "")
first_name, last_name = name, ""
if name and " " in name:
first_name, last_name = name.split(" ", 1)
return dict(
email=data.get("email", ""), last_name=last_name, first_name=first_name
)
provider_classes = [AmazonProvider]

View File

@@ -0,0 +1,24 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AmazonProvider
class AmazonTests(OAuth2TestsMixin, TestCase):
provider_id = AmazonProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{
"Profile":{
"CustomerId":"amzn1.account.K2LI23KL2LK2",
"Name":"John Doe",
"PrimaryEmail":"johndoe@example.com"
}
}""",
)
def get_expected_to_str(self):
return "johndoe@example.com"

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AmazonProvider
urlpatterns = default_urlpatterns(AmazonProvider)

View File

@@ -0,0 +1,33 @@
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AmazonOAuth2Adapter(OAuth2Adapter):
provider_id = "amazon"
access_token_url = "https://api.amazon.com/auth/o2/token"
authorize_url = "http://www.amazon.com/ap/oa"
profile_url = "https://api.amazon.com/user/profile"
def complete_login(self, request, app, token, **kwargs):
response = (
get_adapter()
.get_requests_session()
.get(self.profile_url, params={"access_token": token.token})
)
response.raise_for_status()
extra_data = response.json()
if "Profile" in extra_data:
extra_data = {
"user_id": extra_data["Profile"]["CustomerId"],
"name": extra_data["Profile"]["Name"],
"email": extra_data["Profile"]["PrimaryEmail"],
}
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(AmazonOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonOAuth2Adapter)

View File

@@ -0,0 +1,69 @@
from allauth.account.models import EmailAddress
from allauth.socialaccount.providers.amazon_cognito.utils import (
convert_to_python_bool_if_value_is_json_string_bool,
)
from allauth.socialaccount.providers.amazon_cognito.views import (
AmazonCognitoOAuth2Adapter,
)
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AmazonCognitoAccount(ProviderAccount):
def get_avatar_url(self):
return self.account.extra_data.get("picture")
def get_profile_url(self):
return self.account.extra_data.get("profile")
class AmazonCognitoProvider(OAuth2Provider):
id = "amazon_cognito"
name = "Amazon Cognito"
account_class = AmazonCognitoAccount
oauth2_adapter_class = AmazonCognitoOAuth2Adapter
def extract_uid(self, data):
return str(data["sub"])
def extract_common_fields(self, data):
return {
"email": data.get("email"),
"first_name": data.get("given_name"),
"last_name": data.get("family_name"),
}
def get_default_scope(self):
return ["openid", "profile", "email"]
def extract_email_addresses(self, data):
email = data.get("email")
verified = convert_to_python_bool_if_value_is_json_string_bool(
data.get("email_verified", False)
)
return (
[EmailAddress(email=email, verified=verified, primary=True)]
if email
else []
)
def extract_extra_data(self, data):
ret = dict(data)
phone_number_verified = data.get("phone_number_verified")
if phone_number_verified is not None:
ret["phone_number_verified"] = (
convert_to_python_bool_if_value_is_json_string_bool(
"phone_number_verified"
)
)
return ret
@classmethod
def get_slug(cls):
# IMPORTANT: Amazon Cognito does not support `_` characters
# as part of their redirect URI.
return super(AmazonCognitoProvider, cls).get_slug().replace("_", "-")
provider_classes = [AmazonCognitoProvider]

View File

@@ -0,0 +1,89 @@
import json
from django.test import override_settings
import pytest
from allauth.account.models import EmailAddress
from allauth.socialaccount.models import SocialAccount
from allauth.socialaccount.providers.amazon_cognito.provider import (
AmazonCognitoProvider,
)
from allauth.socialaccount.providers.amazon_cognito.utils import (
convert_to_python_bool_if_value_is_json_string_bool,
)
from allauth.socialaccount.providers.amazon_cognito.views import (
AmazonCognitoOAuth2Adapter,
)
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
def _get_mocked_claims():
return {
"sub": "[HEROKU-API-KEY-REMOVED]",
"given_name": "John",
"family_name": "Doe",
"email": "jdoe@example.com",
"username": "johndoe",
}
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"amazon_cognito": {"DOMAIN": "https://domain.auth.us-east-1.amazoncognito.com"}
}
)
class AmazonCognitoTestCase(OAuth2TestsMixin, TestCase):
provider_id = AmazonCognitoProvider.id
def get_mocked_response(self):
mocked_payload = json.dumps(_get_mocked_claims())
return MockedResponse(status_code=200, content=mocked_payload)
def get_expected_to_str(self):
return "johndoe"
@override_settings(SOCIALACCOUNT_PROVIDERS={"amazon_cognito": {}})
def test_oauth2_adapter_raises_if_domain_settings_is_missing(
self,
):
mocked_response = self.get_mocked_response()
with self.assertRaises(
ValueError,
msg=AmazonCognitoOAuth2Adapter.DOMAIN_KEY_MISSING_ERROR,
):
self.login(mocked_response)
def test_saves_email_as_verified_if_email_is_verified_in_cognito(
self,
):
mocked_claims = _get_mocked_claims()
mocked_claims["email_verified"] = True
mocked_payload = json.dumps(mocked_claims)
mocked_response = MockedResponse(status_code=200, content=mocked_payload)
self.login(mocked_response)
user_id = SocialAccount.objects.get(uid=mocked_claims["sub"]).user_id
email_address = EmailAddress.objects.get(user_id=user_id)
self.assertEqual(email_address.email, mocked_claims["email"])
self.assertTrue(email_address.verified)
def test_provider_slug_replaces_underscores_with_hyphens(self):
self.assertTrue("_" not in self.provider.get_slug())
@pytest.mark.parametrize(
"input,output",
[
(True, True),
("true", True),
("false", False),
(False, False),
],
)
def test_convert_bool(input, output):
assert convert_to_python_bool_if_value_is_json_string_bool(input) == output

View File

@@ -0,0 +1,7 @@
from allauth.socialaccount.providers.amazon_cognito.provider import (
AmazonCognitoProvider,
)
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
urlpatterns = default_urlpatterns(AmazonCognitoProvider)

View File

@@ -0,0 +1,7 @@
def convert_to_python_bool_if_value_is_json_string_bool(s):
if s == "true":
return True
elif s == "false":
return False
return s

View File

@@ -0,0 +1,56 @@
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialToken
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AmazonCognitoOAuth2Adapter(OAuth2Adapter):
provider_id = "amazon_cognito"
DOMAIN_KEY_MISSING_ERROR = (
'"DOMAIN" key is missing in Amazon Cognito configuration.'
)
@property
def settings(self):
return app_settings.PROVIDERS.get(self.provider_id, {})
@property
def domain(self):
domain = self.settings.get("DOMAIN")
if domain is None:
raise ValueError(self.DOMAIN_KEY_MISSING_ERROR)
return domain
@property
def access_token_url(self):
return "{}/oauth2/token".format(self.domain)
@property
def authorize_url(self):
return "{}/oauth2/authorize".format(self.domain)
@property
def profile_url(self):
return "{}/oauth2/userInfo".format(self.domain)
def complete_login(self, request, app, token: SocialToken, **kwargs):
headers = {
"Authorization": "Bearer {}".format(token.token),
}
extra_data = (
get_adapter().get_requests_session().get(self.profile_url, headers=headers)
)
extra_data.raise_for_status()
return self.get_provider().sociallogin_from_response(request, extra_data.json())
oauth2_login = OAuth2LoginView.adapter_view(AmazonCognitoOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonCognitoOAuth2Adapter)

View File

@@ -0,0 +1,33 @@
from allauth.socialaccount.providers.angellist.views import (
AngelListOAuth2Adapter,
)
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AngelListAccount(ProviderAccount):
def get_profile_url(self):
return self.account.extra_data.get("angellist_url")
def get_avatar_url(self):
return self.account.extra_data.get("image")
class AngelListProvider(OAuth2Provider):
id = "angellist"
name = "AngelList"
account_class = AngelListAccount
oauth2_adapter_class = AngelListOAuth2Adapter
def extract_uid(self, data):
return str(data["id"])
def extract_common_fields(self, data):
return dict(
email=data.get("email"),
username=data.get("angellist_url").split("/")[-1],
name=data.get("name"),
)
provider_classes = [AngelListProvider]

View File

@@ -0,0 +1,28 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AngelListProvider
class AngelListTests(OAuth2TestsMixin, TestCase):
provider_id = AngelListProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{"name":"pennersr","id":424732,"bio":"","follower_count":0,
"angellist_url":"https://angel.co/dsxtst",
"image":"https://angel.co/images/shared/nopic.png",
"email":"raymond.penners@example.com","blog_url":null,
"online_bio_url":null,"twitter_url":"https://twitter.com/dsxtst",
"facebook_url":null,"linkedin_url":null,"aboutme_url":null,
"github_url":null,"dribbble_url":null,"behance_url":null,
"what_ive_built":null,"locations":[],"roles":[],"skills":[],
"investor":false,"scopes":["message","talent","dealflow","comment",
"email"]}
""",
)
def get_expected_to_str(self):
return "raymond.penners@example.com"

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AngelListProvider
urlpatterns = default_urlpatterns(AngelListProvider)

View File

@@ -0,0 +1,27 @@
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AngelListOAuth2Adapter(OAuth2Adapter):
provider_id = "angellist"
access_token_url = "https://angel.co/api/oauth/token/"
authorize_url = "https://angel.co/api/oauth/authorize/"
profile_url = "https://api.angel.co/1/me/"
supports_state = False
def complete_login(self, request, app, token, **kwargs):
resp = (
get_adapter()
.get_requests_session()
.get(self.profile_url, params={"access_token": token.token})
)
extra_data = resp.json()
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(AngelListOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AngelListOAuth2Adapter)

View File

@@ -0,0 +1,8 @@
from allauth.socialaccount.sessions import LoginSession
APPLE_SESSION_COOKIE_NAME = "apple-login-session"
def get_apple_session(request):
return LoginSession(request, "apple_login_session", APPLE_SESSION_COOKIE_NAME)

View File

@@ -0,0 +1,101 @@
import time
from urllib.parse import parse_qsl, quote, urlencode
from django.core.exceptions import ImproperlyConfigured
import jwt
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.client import (
OAuth2Client,
OAuth2Error,
)
def jwt_encode(*args, **kwargs):
resp = jwt.encode(*args, **kwargs)
if isinstance(resp, bytes):
# For PyJWT <2
resp = resp.decode("utf-8")
return resp
class Scope:
EMAIL = "email"
NAME = "name"
class AppleOAuth2Client(OAuth2Client):
"""
Custom client because `Sign In With Apple`:
* requires `response_mode` field in redirect_url
* requires special `client_secret` as JWT
"""
def generate_client_secret(self):
"""Create a JWT signed with an apple provided private key"""
now = int(time.time())
app = get_adapter(self.request).get_app(self.request, "apple")
if not app.key:
raise ImproperlyConfigured("Apple 'key' missing")
certificate_key = app.settings.get("certificate_key")
if not certificate_key:
raise ImproperlyConfigured("Apple 'certificate_key' missing")
claims = {
"iss": app.key,
"aud": "https://appleid.apple.com",
"sub": self.get_client_id(),
"iat": now,
"exp": now + 60 * 60,
}
headers = {"kid": self.consumer_secret, "alg": "ES256"}
client_secret = jwt_encode(
payload=claims, key=certificate_key, algorithm="ES256", headers=headers
)
return client_secret
def get_client_id(self):
"""We support multiple client_ids, but use the first one for api calls"""
return self.consumer_key.split(",")[0]
def get_access_token(self, code, pkce_code_verifier=None):
url = self.access_token_url
client_secret = self.generate_client_secret()
data = {
"client_id": self.get_client_id(),
"code": code,
"grant_type": "authorization_code",
"redirect_uri": self.callback_url,
"client_secret": client_secret,
}
if pkce_code_verifier:
data["code_verifier"] = pkce_code_verifier
self._strip_empty_keys(data)
resp = (
get_adapter()
.get_requests_session()
.request(self.access_token_method, url, data=data, headers=self.headers)
)
access_token = None
if resp.status_code in [200, 201]:
try:
access_token = resp.json()
except ValueError:
access_token = dict(parse_qsl(resp.text))
if not access_token or "access_token" not in access_token:
raise OAuth2Error("Error retrieving access token: %s" % resp.content)
return access_token
def get_redirect_url(self, authorization_url, scope, extra_params):
scope = self.scope_delimiter.join(set(scope))
params = {
"client_id": self.get_client_id(),
"redirect_uri": self.callback_url,
"response_mode": "form_post",
"scope": scope,
"response_type": "code id_token",
}
if self.state:
params["state"] = self.state
params.update(extra_params)
return "%s?%s" % (authorization_url, urlencode(params, quote_via=quote))

View File

@@ -0,0 +1,92 @@
import requests
from allauth.account.models import EmailAddress
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.app_settings import QUERY_EMAIL
from allauth.socialaccount.providers.apple.views import AppleOAuth2Adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AppleAccount(ProviderAccount):
def to_str(self):
email = self.account.extra_data.get("email")
if email and not email.lower().endswith("@privaterelay.appleid.com"):
return email
name = self.account.extra_data.get("name") or {}
if name.get("firstName") or name.get("lastName"):
full_name = f"{name['firstName'] or ''} {name['lastName'] or ''}"
full_name = full_name.strip()
if full_name:
return full_name
return super().to_str()
class AppleProvider(OAuth2Provider):
id = "apple"
name = "Apple"
account_class = AppleAccount
oauth2_adapter_class = AppleOAuth2Adapter
supports_token_authentication = True
def extract_uid(self, data):
return str(data["sub"])
def extract_common_fields(self, data):
fields = {"email": data.get("email")}
# If the name was provided
name = data.get("name")
if name:
fields["first_name"] = name.get("firstName", "")
fields["last_name"] = name.get("lastName", "")
return fields
def extract_email_addresses(self, data):
ret = []
email = data.get("email")
verified = data.get("email_verified")
if isinstance(verified, str):
verified = verified.lower() == "true"
if email:
ret.append(
EmailAddress(
email=email,
verified=verified,
primary=True,
)
)
return ret
def get_default_scope(self):
scopes = ["name"]
if QUERY_EMAIL:
scopes.append("email")
return scopes
def verify_token(self, request, token):
from allauth.socialaccount.providers.apple.views import (
AppleOAuth2Adapter,
)
id_token = token.get("id_token")
if not id_token:
raise get_adapter().validation_error("invalid_token")
try:
identity_data = AppleOAuth2Adapter.get_verified_identity_data(
self, id_token
)
except (OAuth2Error, requests.RequestException) as e:
raise get_adapter().validation_error("invalid_token") from e
login = self.sociallogin_from_response(request, identity_data)
return login
def get_auds(self):
return [aud.strip() for aud in self.app.client_id.split(",")]
provider_classes = [AppleProvider]

View File

@@ -0,0 +1,264 @@
import json
import time
from importlib import import_module
from urllib.parse import parse_qs, urlparse
from django.conf import settings
from django.test.utils import override_settings
from django.urls import reverse
from django.utils.http import urlencode
import jwt
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase, mocked_response
from .apple_session import APPLE_SESSION_COOKIE_NAME
from .client import jwt_encode
from .provider import AppleProvider
# Generated on https://mkjwk.org/, used to sign and verify the apple id_token
TESTING_JWT_KEYSET = {
"p": (
"4ADzS5jKx_kdQihyOocVS0Qwwo7m0f7Ow56EadySJ-cmnwoHHF3AxgRaq-h-KwybSphv"
"dc-X7NbS79-b9dumHKyt1MeVLAsDZD1a-uQCEneY1g9LsQkscNr7OggcpvMg5UUFwv6A"
"kavu8cB0iyhNdha5_AWX27K5lNebvpaXEJ8"
),
"kty": "RSA",
"q": (
"yy5UvMjrvZyO1Os_nxXIugCa3NyWOkC8oMppPvr1Bl5AnF_xwXN2n9ozPd9Nb3Q3n-om"
"NgLayyUxhwIjWDlI67Vbx-ESuff8ZEBKuTK0Gdmr4C_QU_j0gvvNMNJweSPxDdRmIUgO"
"njTVNWmdqFTZs43jXAT4J519rgveNLAkGNE"
),
"d": (
"riPuGIDde88WS03CVbo_mZ9toFWPyTxvuz8VInJ9S1ZxULo-hQWDBohWGYwvg8cgfXck"
"cqWt5OBqNvPYdLgwb84uVi2JeEHmhcQSc_x0zfRTau5HVE2KdR-gWxQjPWoaBHeDVqwo"
"PKaU2XYxa-[AWS-SECRET-REMOVED]EHwyWXJbTpoar7AARW"
"kz76qtngDkk4t9gk_Q0L1y1qf1GeWiAL7xWb-bdptma4-1ui-R2219-1ONEZ41v_jsIS"
"_[AWS-SECRET-REMOVED]sXjaIwkdItbDmL1jFUgefwfO91Y"
"YQ"
),
"e": "AQAB",
"use": "sig",
"kid": "testkey",
"qi": (
"R0Hu4YmpHzw3SKWGYuAcAo6B97-[AWS-SECRET-REMOVED]r"
"[AWS-SECRET-REMOVED]t1c4tTotFDdw8WFptDOw4ow31Tml"
"BPExLqzzGjJeQSNULB1bExuuhYMWx6wBXo8"
),
"dp": (
"WBaHlnbjZ3hDVTzqjrGIYizSr-_[AWS-SECRET-REMOVED]G"
"wuF78RsZoFLi1fAmhqgxQ7eopcU-[AWS-SECRET-REMOVED]"
"szhVoqP4MLEMpR-Sy9S3PyItcKbJDE3T4ik"
),
"alg": "RS256",
"dq": (
"[AWS-SECRET-REMOVED]zRk2vCXbiOY7Qttad8ptLEUgfytV"
"SsNtGvMsoQsZWRak8nHnhGJ4s0QzB1OK7sdNgU_cL1HV-VxSSPaHhdJBrJEcrzggDPEB"
"KYfDHU6Iz34d1nvjBxoWE8rfqJsGbCW4xxE"
),
"n": (
"sclLPioUv4VOcOZWAKoRhcvwIH2jOhoHhSI_Cj5c5zSp7qaK8jCU6T7-GObsgrhpty-k"
"26ZuqRdgu9d-62WO8OBGt1e0wxbTh128-nTTrOESHUlV_K1wpJmXOxNpJiybcgzZNbAm"
"ACmsHfxZvN9bt7gKPXxf3-_zFAf12PbYMrOionAJ1N_4HxL7fz3xkr5C87Av06QNilIC"
"-mA-4n9Eqw_R2DYNpE3RYMdWtwKqBwJC8qs3677RpG9vcc-yZ_97pEiytd2FBJ8uoTwH"
"[AWS-SECRET-REMOVED]79LrsfOzrXF4enkfKJYI40-uwT95"
"zw"
),
}
# Mocked version of the test data from https://appleid.apple.com/auth/keys
KEY_SERVER_RESP_JSON = json.dumps(
{
"keys": [
{
"kty": TESTING_JWT_KEYSET["kty"],
"kid": TESTING_JWT_KEYSET["kid"],
"use": TESTING_JWT_KEYSET["use"],
"alg": TESTING_JWT_KEYSET["alg"],
"n": TESTING_JWT_KEYSET["n"],
"e": TESTING_JWT_KEYSET["e"],
}
]
}
)
def sign_id_token(payload):
"""
Sign a payload as apple normally would for the id_token.
"""
signing_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(TESTING_JWT_KEYSET))
return jwt_encode(
payload,
signing_key,
algorithm="RS256",
headers={"kid": TESTING_JWT_KEYSET["kid"]},
)
@override_settings(
SOCIALACCOUNT_STORE_TOKENS=False,
SOCIALACCOUNT_PROVIDERS={
"apple": {
"APP": {
"client_id": "app123id",
"key": "apple",
"secret": "dummy",
"settings": {
"certificate_key": """[PRIVATE-KEY-REMOVED]
[AWS-SECRET-REMOVED]awIBAQQg2+Eybl8ojH4wB30C
[AWS-SECRET-REMOVED]Q+EpNgQQyQVs/F27dkq3gvAI
[AWS-SECRET-REMOVED]Ru3XGyqy3mdb8gMy
-----END PRIVATE KEY-----
""",
},
}
}
},
)
class AppleTests(OAuth2TestsMixin, TestCase):
provider_id = AppleProvider.id
def get_apple_id_token_payload(self):
now = int(time.time())
return {
"iss": "https://appleid.apple.com",
"aud": "app123id", # Matches `setup_app`
"exp": now + 60 * 60,
"iat": now,
"sub": "000313.c9720f41e9434e18987a.1218",
"at_hash": "CkaUPjk4MJinaAq6Z0tGUA",
"email": "test@privaterelay.appleid.com",
"email_verified": "true",
"is_private_email": "true",
"auth_time": 1234345345, # not converted automatically by pyjwt
}
def test_verify_token(self):
id_token = sign_id_token(self.get_apple_id_token_payload())
with mocked_response(self.get_mocked_response()):
sociallogin = self.provider.verify_token(None, {"id_token": id_token})
assert sociallogin.user.email == "test@privaterelay.appleid.com"
def get_login_response_json(self, with_refresh_token=True):
"""
`with_refresh_token` is not optional for apple, so it's ignored.
"""
id_token = sign_id_token(self.get_apple_id_token_payload())
return json.dumps(
{
"access_token": "testac", # Matches OAuth2TestsMixin value
"expires_in": 3600,
"id_token": id_token,
"refresh_token": "testrt", # Matches OAuth2TestsMixin value
"token_type": "Bearer",
}
)
def get_mocked_response(self):
"""
Apple is unusual in that the `id_token` contains all the user info
so no profile info request is made. However, it does need the
public key verification, so this mocked response is the public
key request in order to verify the authenticity of the id_token.
"""
return MockedResponse(
200, KEY_SERVER_RESP_JSON, {"content-type": "application/json"}
)
def get_expected_to_str(self):
return "A B"
def get_complete_parameters(self, auth_request_params):
"""
Add apple specific response parameters which they include in the
form_post response.
https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms
"""
params = super().get_complete_parameters(auth_request_params)
params.update(
{
"id_token": sign_id_token(self.get_apple_id_token_payload()),
"user": json.dumps(
{
"email": "private@appleid.apple.com",
"name": {
"firstName": "A",
"lastName": "B",
},
}
),
}
)
return params
def login(self, resp_mock, process="login", with_refresh_token=True):
resp = self.client.post(
reverse(self.provider.id + "_login")
+ "?"
+ urlencode(dict(process=process))
)
p = urlparse(resp["location"])
q = parse_qs(p.query)
complete_url = reverse(self.provider.id + "_callback")
self.assertGreater(q["redirect_uri"][0].find(complete_url), 0)
response_json = self.get_login_response_json(
with_refresh_token=with_refresh_token
)
with mocked_response(
MockedResponse(200, response_json, {"content-type": "application/json"}),
resp_mock,
):
resp = self.client.post(
complete_url,
data=self.get_complete_parameters(q),
)
assert reverse("apple_finish_callback") in resp.url
# Follow the redirect
resp = self.client.get(resp.url)
return resp
def test_authentication_error(self):
"""Override base test because apple posts errors"""
resp = self.client.post(
reverse(self.provider.id + "_callback"),
data={"error": "misc", "state": "testingstate123"},
)
assert reverse("apple_finish_callback") in resp.url
# Follow the redirect
resp = self.client.get(resp.url)
self.assertTemplateUsed(
resp,
"socialaccount/authentication_error.%s"
% getattr(settings, "ACCOUNT_TEMPLATE_EXTENSION", "html"),
)
def test_apple_finish(self):
resp = self.login(self.get_mocked_response())
# Check request generating the response
finish_url = reverse("apple_finish_callback")
self.assertEqual(resp.request["PATH_INFO"], finish_url)
self.assertTrue("state" in resp.request["QUERY_STRING"])
self.assertTrue("code" in resp.request["QUERY_STRING"])
# Check have cookie containing apple session
self.assertTrue(APPLE_SESSION_COOKIE_NAME in self.client.cookies)
# Session should have been cleared
apple_session_cookie = self.client.cookies.get(APPLE_SESSION_COOKIE_NAME)
engine = import_module(settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
apple_login_session = SessionStore(apple_session_cookie.value)
self.assertEqual(len(apple_login_session.keys()), 0)
# Check cookie path was correctly set
self.assertEqual(apple_session_cookie.get("path"), finish_url)

View File

@@ -0,0 +1,16 @@
from django.urls import path
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AppleProvider
from .views import oauth2_finish_login
urlpatterns = default_urlpatterns(AppleProvider)
urlpatterns += [
path(
AppleProvider.get_slug() + "/login/callback/finish/",
oauth2_finish_login,
name="apple_finish_callback",
),
]

View File

@@ -0,0 +1,148 @@
import json
from datetime import timedelta
from django.http import HttpResponseNotAllowed, HttpResponseRedirect
from django.urls import reverse
from django.utils import timezone
from django.utils.http import urlencode
from django.views.decorators.csrf import csrf_exempt
from allauth.account.internal.decorators import login_not_required
from allauth.socialaccount.internal import jwtkit
from allauth.socialaccount.models import SocialToken
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
from allauth.utils import build_absolute_uri, get_request_param
from .apple_session import get_apple_session
from .client import AppleOAuth2Client
class AppleOAuth2Adapter(OAuth2Adapter):
client_class = AppleOAuth2Client
provider_id = "apple"
access_token_url = "https://appleid.apple.com/auth/token"
authorize_url = "https://appleid.apple.com/auth/authorize"
public_key_url = "https://appleid.apple.com/auth/keys"
@classmethod
def get_verified_identity_data(cls, provider, id_token):
data = jwtkit.verify_and_decode(
credential=id_token,
keys_url=cls.public_key_url,
issuer="https://appleid.apple.com",
audience=provider.get_auds(),
lookup_kid=jwtkit.lookup_kid_jwk,
)
return data
def parse_token(self, data):
token = SocialToken(
token=data["access_token"],
)
token.token_secret = data.get("refresh_token", "")
expires_in = data.get(self.expires_in_key)
if expires_in:
token.expires_at = timezone.now() + timedelta(seconds=int(expires_in))
# `user_data` is a big flat dictionary with the parsed JWT claims
# access_tokens, and user info from the apple post.
identity_data = AppleOAuth2Adapter.get_verified_identity_data(
self.get_provider(), data["id_token"]
)
token.user_data = {**data, **identity_data}
return token
def complete_login(self, request, app, token, **kwargs):
extra_data = token.user_data
login = self.get_provider().sociallogin_from_response(
request=request, response=extra_data
)
login.state["id_token"] = token.user_data
# We can safely remove the apple login session now
# Note: The cookie will remain, but it's set to delete on browser close
get_apple_session(request).delete()
return login
def get_user_scope_data(self, request):
user_scope_data = request.apple_login_session.get("user", "")
try:
return json.loads(user_scope_data)
except json.JSONDecodeError:
# We do not care much about user scope data as it maybe blank
# so return blank dictionary instead
return {}
def get_access_token_data(self, request, app, client, pkce_code_verifier=None):
"""We need to gather the info from the apple specific login"""
apple_session = get_apple_session(request)
# Exchange `code`
code = get_request_param(request, "code")
access_token_data = client.get_access_token(
code, pkce_code_verifier=pkce_code_verifier
)
id_token = access_token_data.get("id_token", None)
# In case of missing id_token in access_token_data
if id_token is None:
id_token = apple_session.store.get("id_token")
return {
**access_token_data,
**self.get_user_scope_data(request),
"id_token": id_token,
}
@csrf_exempt
@login_not_required
def apple_post_callback(request, finish_endpoint_name="apple_finish_callback"):
"""
Apple uses a `form_post` response type, which due to
CORS/Samesite-cookie rules means this request cannot access
the request since the session cookie is unavailable.
We work around this by storing the apple response in a
separate, temporary session and redirecting to a more normal
oauth flow.
args:
finish_endpoint_name (str): The name of a defined URL, which can be
overridden in your url configuration if you have more than one
callback endpoint.
"""
if request.method != "POST":
return HttpResponseNotAllowed(["POST"])
apple_session = get_apple_session(request)
# Add regular OAuth2 params to the URL - reduces the overrides required
keys_to_put_in_url = ["code", "state", "error"]
url_params = {}
for key in keys_to_put_in_url:
value = get_request_param(request, key, "")
if value:
url_params[key] = value
# Add other params to the apple_login_session
keys_to_save_to_session = ["user", "id_token"]
for key in keys_to_save_to_session:
apple_session.store[key] = get_request_param(request, key, "")
url = build_absolute_uri(request, reverse(finish_endpoint_name))
response = HttpResponseRedirect(
"{url}?{query}".format(url=url, query=urlencode(url_params))
)
apple_session.save(response)
return response
oauth2_login = OAuth2LoginView.adapter_view(AppleOAuth2Adapter)
oauth2_callback = apple_post_callback
oauth2_finish_login = OAuth2CallbackView.adapter_view(AppleOAuth2Adapter)

View File

@@ -0,0 +1,23 @@
from allauth.socialaccount.providers.asana.views import AsanaOAuth2Adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AsanaAccount(ProviderAccount):
pass
class AsanaProvider(OAuth2Provider):
id = "asana"
name = "Asana"
account_class = AsanaAccount
oauth2_adapter_class = AsanaOAuth2Adapter
def extract_uid(self, data):
return str(data["id"])
def extract_common_fields(self, data):
return dict(email=data.get("email"), name=data.get("name"))
provider_classes = [AsanaProvider]

View File

@@ -0,0 +1,20 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AsanaProvider
class AsanaTests(OAuth2TestsMixin, TestCase):
provider_id = AsanaProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{"data": {"photo": null, "workspaces": [{"id": 31337, "name": "example.com"},
{"id": 3133777, "name": "Personal Projects"}], "email": "test@example.com",
"name": "Test Name", "id": 43748387}}""",
)
def get_expected_to_str(self):
return "test@example.com"

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AsanaProvider
urlpatterns = default_urlpatterns(AsanaProvider)

View File

@@ -0,0 +1,26 @@
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AsanaOAuth2Adapter(OAuth2Adapter):
provider_id = "asana"
access_token_url = "https://app.asana.com/-/oauth_token"
authorize_url = "https://app.asana.com/-/oauth_authorize"
profile_url = "https://app.asana.com/api/1.0/users/me"
def complete_login(self, request, app, token, **kwargs):
resp = (
get_adapter()
.get_requests_session()
.get(self.profile_url, params={"access_token": token.token})
)
extra_data = resp.json()["data"]
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(AsanaOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AsanaOAuth2Adapter)

View File

@@ -0,0 +1,38 @@
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
from .views import AtlassianOAuth2Adapter
class AtlassianAccount(ProviderAccount):
def get_profile_url(self):
return self.account.extra_data.get("picture")
class AtlassianProvider(OAuth2Provider):
id = "atlassian"
name = "Atlassian"
account_class = AtlassianAccount
oauth2_adapter_class = AtlassianOAuth2Adapter
def extract_uid(self, data):
return data["account_id"]
def extract_common_fields(self, data):
return {
"email": data.get("email"),
"name": data.get("name"),
"username": data.get("nickname"),
"email_verified": data.get("email_verified"),
}
def get_default_scope(self):
return ["read:me"]
def get_auth_params(self):
params = super().get_auth_params()
params.update({"audience": "api.atlassian.com", "prompt": "consent"})
return params
provider_classes = [AtlassianProvider]

View File

@@ -0,0 +1,33 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AtlassianProvider
class AtlassianTests(OAuth2TestsMixin, TestCase):
provider_id = AtlassianProvider.id
def get_mocked_response(self):
response_data = """
{
"account_type": "atlassian",
"account_id": "[HEROKU-API-KEY-REMOVED]",
"email": "mia@example.com",
"email_verified": true,
"name": "Mia Krystof",
"picture": "https://avatar-management--avatars.us-west-2.prod.public.atl-paas.net/[HEROKU-API-KEY-REMOVED]/1234abcd-9876-54aa-33aa-1234dfsade9487ds",
"account_status": "active",
"nickname": "mkrystof",
"zoneinfo": "Australia/Sydney",
"locale": "en-US",
"extended_profile": {
"job_title": "Designer",
"organization": "mia@example.com",
"department": "Design team",
"location": "Sydney"
}
}"""
return MockedResponse(200, response_data)
def get_expected_to_str(self):
return "mia@example.com"

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AtlassianProvider
urlpatterns = default_urlpatterns(AtlassianProvider)

View File

@@ -0,0 +1,30 @@
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialToken
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AtlassianOAuth2Adapter(OAuth2Adapter):
provider_id = "atlassian"
access_token_url = "https://api.atlassian.com/oauth/token"
authorize_url = "https://auth.atlassian.com/authorize"
profile_url = "https://api.atlassian.com/me"
def complete_login(self, request, app, token: SocialToken, **kwargs):
headers = {
"Authorization": f"Bearer {token.token}",
"Accept": "application/json",
}
response = (
get_adapter().get_requests_session().get(self.profile_url, headers=headers)
)
response.raise_for_status()
data = response.json()
return self.get_provider().sociallogin_from_response(request, data)
oauth2_login = OAuth2LoginView.adapter_view(AtlassianOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AtlassianOAuth2Adapter)

View File

@@ -0,0 +1,31 @@
from allauth.socialaccount.providers.auth0.views import Auth0OAuth2Adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class Auth0Account(ProviderAccount):
def get_avatar_url(self):
return self.account.extra_data.get("picture")
class Auth0Provider(OAuth2Provider):
id = "auth0"
name = "Auth0"
account_class = Auth0Account
oauth2_adapter_class = Auth0OAuth2Adapter
def get_default_scope(self):
return ["openid", "profile", "email"]
def extract_uid(self, data):
return str(data["sub"])
def extract_common_fields(self, data):
return dict(
email=data.get("email"),
username=data.get("username"),
name=data.get("name"),
)
provider_classes = [Auth0Provider]

View File

@@ -0,0 +1,25 @@
from allauth.socialaccount.providers.auth0.provider import Auth0Provider
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
class Auth0Tests(OAuth2TestsMixin, TestCase):
provider_id = Auth0Provider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{
"picture": "https://secure.gravatar.com/avatar/123",
"email": "mr.bob@your.Auth0.server.example.com",
"id": 2,
"sub": 2,
"identities": [],
"name": "Mr Bob"
}
""",
)
def get_expected_to_str(self):
return "mr.bob@your.Auth0.server.example.com"

View File

@@ -0,0 +1,5 @@
from allauth.socialaccount.providers.auth0.provider import Auth0Provider
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
urlpatterns = default_urlpatterns(Auth0Provider)

View File

@@ -0,0 +1,31 @@
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class Auth0OAuth2Adapter(OAuth2Adapter):
provider_id = "auth0"
settings = app_settings.PROVIDERS.get(provider_id, {})
provider_base_url = settings.get("AUTH0_URL")
access_token_url = "{0}/oauth/token".format(provider_base_url)
authorize_url = "{0}/authorize".format(provider_base_url)
profile_url = "{0}/userinfo".format(provider_base_url)
def complete_login(self, request, app, token, response):
extra_data = (
get_adapter()
.get_requests_session()
.get(self.profile_url, params={"access_token": token.token})
.json()
)
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(Auth0OAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(Auth0OAuth2Adapter)

View File

@@ -0,0 +1,110 @@
from allauth.account.models import EmailAddress
from allauth.socialaccount import app_settings
from allauth.socialaccount.providers.authentiq.views import (
AuthentiqOAuth2Adapter,
)
from allauth.socialaccount.providers.base import AuthAction, ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class Scope:
NAME = "aq:name"
EMAIL = "email"
PHONE = "phone"
ADDRESS = "address"
LOCATION = "aq:location"
PUSH = "aq:push"
IDENTITY_CLAIMS = frozenset(
[
"sub",
"name",
"given_name",
"family_name",
"middle_name",
"nickname",
"preferred_username",
"profile",
"picture",
"website",
"email",
"email_verified",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"phone_number_verified",
"address",
"updated_at",
"aq:location",
]
)
class AuthentiqAccount(ProviderAccount):
def get_profile_url(self):
return self.account.extra_data.get("profile")
def get_avatar_url(self):
return self.account.extra_data.get("picture")
class AuthentiqProvider(OAuth2Provider):
id = "authentiq"
name = "Authentiq"
account_class = AuthentiqAccount
oauth2_adapter_class = AuthentiqOAuth2Adapter
def get_scope_from_request(self, request):
scope = set(super().get_scope_from_request(request))
scope.add("openid")
if Scope.EMAIL in scope:
modifiers = ""
if app_settings.EMAIL_REQUIRED:
modifiers += "r"
if app_settings.EMAIL_VERIFICATION:
modifiers += "s"
if modifiers:
scope.add(Scope.EMAIL + "~" + modifiers)
scope.remove(Scope.EMAIL)
return list(scope)
def get_default_scope(self):
scope = [Scope.NAME, Scope.PUSH]
if app_settings.QUERY_EMAIL:
scope.append(Scope.EMAIL)
return scope
def get_auth_params_from_request(self, request, action):
ret = super().get_auth_params_from_request(request, action)
if action == AuthAction.REAUTHENTICATE:
ret["prompt"] = "select_account"
return ret
def extract_uid(self, data):
return str(data["sub"])
def extract_common_fields(self, data):
return dict(
username=data.get("preferred_username", data.get("given_name")),
email=data.get("email"),
name=data.get("name"),
first_name=data.get("given_name"),
last_name=data.get("family_name"),
)
def extract_extra_data(self, data):
return {k: v for k, v in data.items() if k in IDENTITY_CLAIMS}
def extract_email_addresses(self, data):
ret = []
email = data.get("email")
if email and data.get("email_verified"):
ret.append(EmailAddress(email=email, verified=True, primary=True))
return ret
provider_classes = [AuthentiqProvider]

View File

@@ -0,0 +1,105 @@
import json
from django.test.client import RequestFactory
from django.test.utils import override_settings
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AuthentiqProvider
from .views import AuthentiqOAuth2Adapter
class AuthentiqTests(OAuth2TestsMixin, TestCase):
provider_id = AuthentiqProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
json.dumps(
{
"sub": "ZLARGMFT1M",
"email": "jane@email.invalid",
"email_verified": True,
"given_name": "Jane",
"family_name": "Doe",
}
),
)
def get_expected_to_str(self):
return "jane@email.invalid"
@override_settings(
SOCIALACCOUNT_QUERY_EMAIL=False,
)
def test_default_scopes_no_email(self):
scopes = self.provider.get_default_scope()
self.assertIn("aq:name", scopes)
self.assertNotIn("email", scopes)
@override_settings(
SOCIALACCOUNT_QUERY_EMAIL=True,
)
def test_default_scopes_email(self):
scopes = self.provider.get_default_scope()
self.assertIn("aq:name", scopes)
self.assertIn("email", scopes)
def test_scopes(self):
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
scopes = self.provider.get_scope_from_request(request)
self.assertIn("openid", scopes)
self.assertIn("aq:name", scopes)
def test_dynamic_scopes(self):
request = RequestFactory().get(
AuthentiqOAuth2Adapter.authorize_url, dict(scope="foo")
)
scopes = self.provider.get_scope_from_request(request)
self.assertIn("openid", scopes)
self.assertIn("aq:name", scopes)
self.assertIn("foo", scopes)
@override_settings(
SOCIALACCOUNT_QUERY_EMAIL=True,
SOCIALACCOUNT_EMAIL_REQUIRED=True,
SOCIALACCOUNT_EMAIL_VERIFICATION=True,
)
def test_scopes_required_verified_email(self):
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
scopes = self.provider.get_scope_from_request(request)
self.assertIn("email~rs", scopes)
self.assertNotIn("email", scopes)
@override_settings(
SOCIALACCOUNT_QUERY_EMAIL=True,
SOCIALACCOUNT_EMAIL_REQUIRED=False,
SOCIALACCOUNT_EMAIL_VERIFICATION=True,
)
def test_scopes_optional_verified_email(self):
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
scopes = self.provider.get_scope_from_request(request)
self.assertIn("email~s", scopes)
self.assertNotIn("email", scopes)
@override_settings(
SOCIALACCOUNT_QUERY_EMAIL=True,
SOCIALACCOUNT_EMAIL_REQUIRED=True,
SOCIALACCOUNT_EMAIL_VERIFICATION=False,
)
def test_scopes_required_email(self):
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
scopes = self.provider.get_scope_from_request(request)
self.assertIn("email~r", scopes)
self.assertNotIn("email", scopes)
@override_settings(
SOCIALACCOUNT_QUERY_EMAIL=True,
SOCIALACCOUNT_EMAIL_REQUIRED=False,
SOCIALACCOUNT_EMAIL_VERIFICATION=False,
)
def test_scopes_optional_email(self):
request = RequestFactory().get(AuthentiqOAuth2Adapter.authorize_url)
scopes = self.provider.get_scope_from_request(request)
self.assertIn("email", scopes)

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AuthentiqProvider
urlpatterns = default_urlpatterns(AuthentiqProvider)

View File

@@ -0,0 +1,35 @@
from urllib.parse import urljoin
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AuthentiqOAuth2Adapter(OAuth2Adapter):
provider_id = "authentiq"
settings = app_settings.PROVIDERS.get(provider_id, {})
provider_url = settings.get("PROVIDER_URL", "https://connect.authentiq.io/")
if not provider_url.endswith("/"):
provider_url += "/"
access_token_url = urljoin(provider_url, "token")
authorize_url = urljoin(provider_url, "authorize")
profile_url = urljoin(provider_url, "userinfo")
def complete_login(self, request, app, token, **kwargs):
auth = {"Authorization": "Bearer " + token.token}
resp = get_adapter().get_requests_session().get(self.profile_url, headers=auth)
resp.raise_for_status()
extra_data = resp.json()
login = self.get_provider().sociallogin_from_response(request, extra_data)
return login
oauth2_login = OAuth2LoginView.adapter_view(AuthentiqOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AuthentiqOAuth2Adapter)

View File

@@ -0,0 +1,34 @@
from allauth.socialaccount.providers.baidu.views import BaiduOAuth2Adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class BaiduAccount(ProviderAccount):
def get_profile_url(self):
return "http://www.baidu.com/p/" + self.account.extra_data.get("uname")
def get_avatar_url(self):
return (
"http://tb.himg.baidu.com/sys/portraitn/item/"
+ self.account.extra_data.get("portrait")
)
def to_str(self):
dflt = super(BaiduAccount, self).to_str()
return self.account.extra_data.get("uname", dflt)
class BaiduProvider(OAuth2Provider):
id = "baidu"
name = "Baidu"
account_class = BaiduAccount
oauth2_adapter_class = BaiduOAuth2Adapter
def extract_uid(self, data):
return data["uid"]
def extract_common_fields(self, data):
return dict(username=data.get("uid"), name=data.get("uname"))
provider_classes = [BaiduProvider]

View File

@@ -0,0 +1,19 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import BaiduProvider
class BaiduTests(OAuth2TestsMixin, TestCase):
provider_id = BaiduProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{"portrait": "78c0e9839de59bbde7859ccf43",
"uname": "\u90dd\u56fd\u715c", "uid": "3225892368"}""",
)
def get_expected_to_str(self):
return "\u90dd\u56fd\u715c"

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import BaiduProvider
urlpatterns = default_urlpatterns(BaiduProvider)

View File

@@ -0,0 +1,28 @@
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class BaiduOAuth2Adapter(OAuth2Adapter):
provider_id = "baidu"
access_token_url = "https://openapi.baidu.com/oauth/2.0/token"
authorize_url = "https://openapi.baidu.com/oauth/2.0/authorize"
profile_url = (
"https://openapi.baidu.com/rest/2.0/passport/users/getLoggedInUser" # noqa
)
def complete_login(self, request, app, token, **kwargs):
resp = (
get_adapter()
.get_requests_session()
.get(self.profile_url, params={"access_token": token.token})
)
extra_data = resp.json()
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(BaiduOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(BaiduOAuth2Adapter)

View File

@@ -0,0 +1,2 @@
from .constants import AuthAction, AuthError, AuthProcess # noqa
from .provider import Provider, ProviderAccount, ProviderException # noqa

View File

@@ -0,0 +1,16 @@
class AuthProcess:
LOGIN = "login"
CONNECT = "connect"
REDIRECT = "redirect"
class AuthAction:
AUTHENTICATE = "authenticate"
REAUTHENTICATE = "reauthenticate"
REREQUEST = "rerequest"
class AuthError:
UNKNOWN = "unknown"
CANCELLED = "cancelled" # Cancelled on request of user
DENIED = "denied" # Denied by server

View File

@@ -0,0 +1,363 @@
from typing import Dict, Optional
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
from allauth.account.utils import get_next_redirect_url, get_request_param
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.internal import statekit
from allauth.socialaccount.providers.base.constants import AuthProcess
class ProviderException(Exception):
pass
class Provider:
name: str # Provided by subclasses
id: str # Provided by subclasses
slug: Optional[str] = None # Provided by subclasses
uses_apps = True
supports_redirect = False
# Indicates whether or not this provider supports logging in by posting an
# access/id-token.
supports_token_authentication = False
def __init__(self, request, app=None):
self.request = request
if self.uses_apps and app is None:
raise ValueError("missing: app")
self.app = app
def __str__(self):
return self.name
@classmethod
def get_slug(cls):
return cls.slug or cls.id
def get_login_url(self, request, next=None, **kwargs):
"""
Builds the URL to redirect to when initiating a login for this
provider.
"""
raise NotImplementedError("get_login_url() for " + self.name)
def redirect_from_request(self, request):
kwargs = self.get_redirect_from_request_kwargs(request)
return self.redirect(request, **kwargs)
def get_redirect_from_request_kwargs(self, request):
kwargs = {}
next_url = get_next_redirect_url(request)
if next_url:
kwargs["next_url"] = next_url
kwargs["process"] = get_request_param(request, "process", AuthProcess.LOGIN)
return kwargs
def redirect(self, request, process, next_url=None, data=None, **kwargs):
"""
Initiate a redirect to the provider.
"""
raise NotImplementedError()
def verify_token(self, request, token):
"""
Verifies the token, returning a `SocialLogin` instance when valid.
Raises a `ValidationError` otherwise.
"""
raise NotImplementedError()
def media_js(self, request):
"""
Some providers may require extra scripts (e.g. a Facebook connect)
"""
return ""
def wrap_account(self, social_account):
return self.account_class(social_account)
def get_settings(self):
return app_settings.PROVIDERS.get(self.id, {})
def sociallogin_from_response(self, request, response):
"""
Instantiates and populates a `SocialLogin` model based on the data
retrieved in `response`. The method does NOT save the model to the
DB.
Data for `SocialLogin` will be extracted from `response` with the
help of the `.extract_uid()`, `.extract_extra_data()`,
`.extract_common_fields()`, and `.extract_email_addresses()`
methods.
:param request: a Django `HttpRequest` object.
:param response: object retrieved via the callback response of the
social auth provider.
:return: A populated instance of the `SocialLogin` model (unsaved).
"""
# NOTE: Avoid loading models at top due to registry boot...
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialAccount, SocialLogin
adapter = get_adapter()
uid = self.extract_uid(response)
if not isinstance(uid, str):
raise ValueError(f"uid must be a string: {repr(uid)}")
if len(uid) > app_settings.UID_MAX_LENGTH:
raise ImproperlyConfigured(
f"SOCIALACCOUNT_UID_MAX_LENGTH too small (<{len(uid)})"
)
if not uid:
raise ValueError("uid must be a non-empty string")
extra_data = self.extract_extra_data(response)
common_fields = self.extract_common_fields(response)
socialaccount = SocialAccount(
extra_data=extra_data,
uid=uid,
provider=self.sub_id,
)
email_addresses = self.extract_email_addresses(response)
email = self.cleanup_email_addresses(
common_fields.get("email"),
email_addresses,
email_verified=common_fields.get("email_verified"),
)
if email:
common_fields["email"] = email
sociallogin = SocialLogin(
account=socialaccount, email_addresses=email_addresses
)
user = sociallogin.user = adapter.new_user(request, sociallogin)
user.set_unusable_password()
adapter.populate_user(request, sociallogin, common_fields)
return sociallogin
def extract_uid(self, data):
"""
Extracts the unique user ID from `data`
"""
raise NotImplementedError(
"The provider must implement the `extract_uid()` method"
)
def extract_extra_data(self, data):
"""
Extracts fields from `data` that will be stored in
`SocialAccount`'s `extra_data` JSONField, such as email address, first
name, last name, and phone number.
:return: any JSON-serializable Python structure.
"""
return data
def extract_common_fields(self, data):
"""
Extracts fields from `data` that will be used to populate the
`User` model in the `SOCIALACCOUNT_ADAPTER`'s `populate_user()`
method.
For example:
{'first_name': 'John'}
:return: dictionary of key-value pairs.
"""
return {}
def cleanup_email_addresses(
self, email: Optional[str], addresses: list, email_verified: bool = False
) -> Optional[str]:
# Avoid loading models before adapters have been registered.
from allauth.account.models import EmailAddress
# Move user.email over to EmailAddress
if email and email.lower() not in [a.email.lower() for a in addresses]:
addresses.insert(
0,
EmailAddress(email=email, verified=bool(email_verified), primary=True),
)
# Force verified emails
adapter = get_adapter()
for address in addresses:
if adapter.is_email_verified(self, address.email):
address.verified = True
# Sort in order of importance (primary, verified...)
addresses.sort(key=lambda a: (a.primary, a.verified, a.email), reverse=True)
if not email and addresses:
email = addresses[0].email
return email
def extract_email_addresses(self, data):
"""
For example:
[EmailAddress(email='john@example.com',
verified=True,
primary=True)]
"""
return []
@classmethod
def get_package(cls):
pkg = getattr(cls, "package", None)
if not pkg:
pkg = cls.__module__.rpartition(".")[0]
return pkg
def stash_redirect_state(
self, request, process, next_url=None, data=None, state_id=None, **kwargs
):
"""
Stashes state, returning a (random) state ID using which the state
can be looked up later. Application specific state is stored separately
from (core) allauth state such as `process` and `**kwargs`.
"""
state = {"process": process, "data": data, **kwargs}
if next_url:
state["next"] = next_url
return statekit.stash_state(request, state, state_id=state_id)
def unstash_redirect_state(self, request, state_id):
state = statekit.unstash_state(request, state_id)
if state is None:
raise PermissionDenied()
return state
@property
def sub_id(self) -> str:
return (
(self.app.provider_id or self.app.provider) if self.uses_apps else self.id
)
class ProviderAccount:
def __init__(self, social_account):
self.account = social_account
def get_profile_url(self):
return None
def get_avatar_url(self):
return None
def get_brand(self):
"""
Returns a dict containing an id and name identifying the
brand. Useful when displaying logos next to accounts in
templates.
For most providers, these are identical to the provider. For
OpenID however, the brand can derived from the OpenID identity
url.
"""
provider = self.account.get_provider()
return dict(id=provider.id, name=provider.name)
def __str__(self):
return self.to_str()
def get_user_data(self) -> Optional[Dict]:
"""Typically, the ``extra_data`` directly contains user related keys.
For some providers, however, they are nested below a different key. In
that case, you can override this method so that the base ``__str__()``
will still be able to find the data.
"""
ret = self.account.extra_data
if not isinstance(ret, dict):
ret = None
return ret
def to_str(self):
"""
Returns string representation of this social account. This is the
unique identifier of the account, such as its username or its email
address. It should be meaningful to human beings, which means a numeric
ID number is rarely the appropriate representation here.
Subclasses are meant to override this method.
Users will see the string representation of their social accounts in
the page rendered by the allauth.socialaccount.views.connections view.
The following code did not use to work in the past due to py2
compatibility:
class GoogleAccount(ProviderAccount):
def __str__(self):
dflt = super(GoogleAccount, self).__str__()
return self.account.extra_data.get('name', dflt)
So we have this method `to_str` that can be overridden in a conventional
fashion, without having to worry about it.
"""
user_data = self.get_user_data()
if user_data:
combi_values = {}
tbl = [
# Prefer username -- it's the most human recognizable & unique.
(
None,
[
"username",
"userName",
"user_name",
"login",
"handle",
],
),
# Second best is email
(None, ["email", "Email", "mail", "email_address"]),
(
None,
[
"name",
"display_name",
"displayName",
"Display_Name",
"nickname",
],
),
# Use the full name
(None, ["full_name", "fullName"]),
# Alternatively, try to assemble a full name ourselves.
(
"first_name",
[
"first_name",
"firstname",
"firstName",
"First_Name",
"given_name",
"givenName",
],
),
(
"last_name",
[
"last_name",
"lastname",
"lastName",
"Last_Name",
"family_name",
"familyName",
"surname",
],
),
]
for store_as, variants in tbl:
for key in variants:
value = user_data.get(key)
if isinstance(value, str):
value = value.strip()
if value and not store_as:
return value
combi_values[store_as] = value
first_name = combi_values.get("first_name") or ""
last_name = combi_values.get("last_name") or ""
if first_name or last_name:
return f"{first_name} {last_name}".strip()
return self.get_brand()["name"]

View File

@@ -0,0 +1,16 @@
from django.shortcuts import render
from allauth.account import app_settings as account_app_settings
from allauth.socialaccount import app_settings
def respond_to_login_on_get(request, provider):
if (not app_settings.LOGIN_ON_GET) and request.method == "GET":
return render(
request,
"socialaccount/login." + account_app_settings.TEMPLATE_EXTENSION,
{
"provider": provider,
"process": request.GET.get("process"),
},
)

View File

@@ -0,0 +1,23 @@
from django.http import Http404
from django.views import View
from allauth import app_settings as allauth_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.base.utils import respond_to_login_on_get
class BaseLoginView(View):
provider_id: str # Set in subclasses
def dispatch(self, request, *args, **kwargs):
if allauth_settings.HEADLESS_ONLY:
raise Http404
provider = self.get_provider()
resp = respond_to_login_on_get(request, provider)
if resp:
return resp
return provider.redirect_from_request(request)
def get_provider(self):
provider = get_adapter().get_provider(self.request, self.provider_id)
return provider

View File

@@ -0,0 +1,42 @@
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.basecamp.views import (
BasecampOAuth2Adapter,
)
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class BasecampAccount(ProviderAccount):
def get_avatar_url(self):
return None
def get_user_data(self):
return self.account.extra_data.get("identity", {})
class BasecampProvider(OAuth2Provider):
id = "basecamp"
name = "Basecamp"
account_class = BasecampAccount
oauth2_adapter_class = BasecampOAuth2Adapter
def get_auth_params_from_request(self, request, action):
data = super().get_auth_params_from_request(request, action)
data["type"] = "web_server"
return data
def extract_uid(self, data):
data = data["identity"]
return str(data["id"])
def extract_common_fields(self, data):
data = data["identity"]
return dict(
email=data.get("email_address"),
username=data.get("email_address"),
first_name=data.get("first_name"),
last_name=data.get("last_name"),
name="%s %s" % (data.get("first_name"), data.get("last_name")),
)
provider_classes = [BasecampProvider]

View File

@@ -0,0 +1,46 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import BasecampProvider
class BasecampTests(OAuth2TestsMixin, TestCase):
provider_id = BasecampProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{
"expires_at": "2012-03-22T16:56:48-05:00",
"identity": {
"id": 9999999,
"first_name": "Jason Fried",
"last_name": "Jason Fried",
"email_address": "jason@example.com"
},
"accounts": [
{
"product": "bcx",
"id": 88888888,
"name": "Wayne Enterprises, Ltd.",
"href": "https://basecamp.com/88888888/api/v1"
},
{
"product": "bcx",
"id": 77777777,
"name": "Veidt, Inc",
"href": "https://basecamp.com/77777777/api/v1"
},
{
"product": "campfire",
"id": 44444444,
"name": "Acme Shipping Co.",
"href": "https://acme4444444.campfirenow.com"
}
]
}""",
)
def get_expected_to_str(self):
return "jason@example.com"

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import BasecampProvider
urlpatterns = default_urlpatterns(BasecampProvider)

View File

@@ -0,0 +1,27 @@
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class BasecampOAuth2Adapter(OAuth2Adapter):
provider_id = "basecamp"
access_token_url = (
"https://launchpad.37signals.com/authorization/token?type=web_server" # noqa
)
authorize_url = "https://launchpad.37signals.com/authorization/new"
profile_url = "https://launchpad.37signals.com/authorization.json"
def complete_login(self, request, app, token, **kwargs):
headers = {"Authorization": "Bearer {0}".format(token.token)}
resp = (
get_adapter().get_requests_session().get(self.profile_url, headers=headers)
)
extra_data = resp.json()
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(BasecampOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(BasecampOAuth2Adapter)

View File

@@ -0,0 +1,35 @@
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.battlenet.views import (
BattleNetOAuth2Adapter,
)
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class BattleNetAccount(ProviderAccount):
def to_str(self):
battletag = self.account.extra_data.get("battletag")
return battletag or super(BattleNetAccount, self).to_str()
class BattleNetProvider(OAuth2Provider):
id = "battlenet"
name = "Battle.net"
account_class = BattleNetAccount
oauth2_adapter_class = BattleNetOAuth2Adapter
def extract_uid(self, data):
uid = str(data["id"])
if data.get("region") == "cn":
# China is on a different account system. UIDs can clash with US.
return uid + "-cn"
return uid
def extract_common_fields(self, data):
return {"username": data.get("battletag")}
def get_default_scope(self):
# Optional scopes: "sc2.profile", "wow.profile"
return []
provider_classes = [BattleNetProvider]

View File

@@ -0,0 +1,68 @@
import json
from allauth.socialaccount.models import SocialAccount
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import BattleNetProvider
from .views import _check_errors
class BattleNetTests(OAuth2TestsMixin, TestCase):
provider_id = BattleNetProvider.id
_uid = 123456789
_battletag = "LuckyDragon#1953"
def get_mocked_response(self):
data = {"battletag": self._battletag, "id": self._uid}
return MockedResponse(200, json.dumps(data))
def get_expected_to_str(self):
return self._battletag
def test_valid_response_no_battletag(self):
data = {"id": 12345}
response = MockedResponse(200, json.dumps(data))
self.assertEqual(_check_errors(response), data)
def test_invalid_data(self):
response = MockedResponse(200, json.dumps({}))
with self.assertRaises(OAuth2Error):
# No id, raises
_check_errors(response)
def test_profile_invalid_response(self):
data = {"code": 403, "type": "Forbidden", "detail": "Account Inactive"}
response = MockedResponse(401, json.dumps(data))
with self.assertRaises(OAuth2Error):
# no id, 4xx code, raises
_check_errors(response)
def test_error_response(self):
body = json.dumps({"error": "invalid_token"})
response = MockedResponse(400, body)
with self.assertRaises(OAuth2Error):
# no id, 4xx code, raises
_check_errors(response)
def test_service_not_found(self):
response = MockedResponse(596, "<h1>596 Service Not Found</h1>")
with self.assertRaises(OAuth2Error):
# bad json, 5xx code, raises
_check_errors(response)
def test_invalid_response(self):
response = MockedResponse(200, "invalid json data")
with self.assertRaises(OAuth2Error):
# bad json, raises
_check_errors(response)
def test_extra_data(self):
self.login(self.get_mocked_response())
account = SocialAccount.objects.get(uid=str(self._uid))
self.assertEqual(account.extra_data["battletag"], self._battletag)
self.assertEqual(account.extra_data["id"], self._uid)
self.assertEqual(account.extra_data["region"], "us")

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import BattleNetProvider
urlpatterns = default_urlpatterns(BattleNetProvider)

View File

@@ -0,0 +1,4 @@
from django.core.validators import RegexValidator
BattletagUsernameValidator = RegexValidator(r"^[\w.]+#\d+$")

View File

@@ -0,0 +1,157 @@
"""
OAuth2 Adapter for Battle.net
Resources:
* Battle.net OAuth2 documentation:
https://dev.battle.net/docs/read/oauth
* Battle.net API documentation:
https://dev.battle.net/io-docs
* Original announcement:
https://us.battle.net/en/forum/topic/13979297799
* The Battle.net API forum:
https://us.battle.net/en/forum/15051532/
"""
from django.conf import settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class Region:
APAC = "apac"
CN = "cn"
EU = "eu"
KR = "kr"
SEA = "sea"
TW = "tw"
US = "us"
def _check_errors(response):
try:
data = response.json()
except ValueError: # JSONDecodeError on py3
raise OAuth2Error("Invalid JSON from Battle.net API: %r" % (response.text))
if response.status_code >= 400 or "error" in data:
# For errors, we expect the following format:
# {"error": "error_name", "error_description": "Oops!"}
# For example, if the token is not valid, we will get:
# {
# "error": "invalid_token",
# "error_description": "Invalid access token: abcdef123456"
# }
# For the profile API, this may also look like the following:
# {"code": 403, "type": "Forbidden", "detail": "Account Inactive"}
error = data.get("error", "") or data.get("type", "")
desc = data.get("error_description", "") or data.get("detail", "")
raise OAuth2Error("Battle.net error: %s (%s)" % (error, desc))
# The expected output from the API follows this format:
# {"id": 12345, "battletag": "Example#12345"}
# The battletag is optional.
if "id" not in data:
# If the id is not present, the output is not usable (no UID)
raise OAuth2Error("Invalid data from Battle.net API: %r" % (data))
return data
class BattleNetOAuth2Adapter(OAuth2Adapter):
"""
OAuth2 adapter for Battle.net
https://dev.battle.net/docs/read/oauth
Region is set to us by default, but can be overridden with the
`region` GET parameter when performing a login.
Can be any of eu, us, kr, sea, tw or cn
"""
provider_id = "battlenet"
valid_regions = (
Region.APAC,
Region.CN,
Region.EU,
Region.KR,
Region.SEA,
Region.TW,
Region.US,
)
@property
def battlenet_region(self):
# Check by URI query parameter first.
region = self.request.GET.get("region", "").lower()
if region == Region.SEA:
# South-East Asia uses the same region as US everywhere
return Region.US
if region in self.valid_regions:
return region
# Second, check the provider settings.
region = (
getattr(settings, "SOCIALACCOUNT_PROVIDERS", {})
.get("battlenet", {})
.get("REGION", "us")
)
if region in self.valid_regions:
return region
return Region.US
@property
def battlenet_base_url(self):
region = self.battlenet_region
if region == Region.CN:
return "https://oauth.battlenet.com.cn"
return "https://oauth.battle.net"
@property
def access_token_url(self):
return self.battlenet_base_url + "/token"
@property
def authorize_url(self):
return self.battlenet_base_url + "/authorize"
@property
def profile_url(self):
return self.battlenet_base_url + "/userinfo"
def complete_login(self, request, app, token, **kwargs):
response = (
get_adapter()
.get_requests_session()
.get(
self.profile_url,
headers={"authorization": "Bearer %s" % (token.token)},
)
)
data = _check_errors(response)
# Add the region to the data so that we can have it in `extra_data`.
data["region"] = self.battlenet_region
return self.get_provider().sociallogin_from_response(request, data)
def get_callback_url(self, request, app):
r = super(BattleNetOAuth2Adapter, self).get_callback_url(request, app)
region = request.GET.get("region", "").lower()
# Pass the region down to the callback URL if we specified it
if region and region in self.valid_regions:
r += "?region=%s" % (region)
return r
oauth2_login = OAuth2LoginView.adapter_view(BattleNetOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(BattleNetOAuth2Adapter)

Some files were not shown because too many files have changed in this diff Show More