mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-23 16:31:08 -05:00
first commit
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,378 @@
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from django.core.exceptions import (
|
||||
ImproperlyConfigured,
|
||||
MultipleObjectsReturned,
|
||||
)
|
||||
from django.db.models import Q
|
||||
from django.urls import reverse
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.utils import user_email, user_field, user_username
|
||||
from allauth.core.internal.adapter import BaseAdapter
|
||||
from allauth.utils import (
|
||||
deserialize_instance,
|
||||
import_attribute,
|
||||
serialize_instance,
|
||||
valid_email_or_none,
|
||||
)
|
||||
|
||||
from . import app_settings
|
||||
|
||||
|
||||
class DefaultSocialAccountAdapter(BaseAdapter):
|
||||
"""The adapter class allows you to override various functionality of the
|
||||
``allauth.socialaccount`` app. To do so, point ``settings.SOCIALACCOUNT_ADAPTER`` to
|
||||
your own class that derives from ``DefaultSocialAccountAdapter`` and override the
|
||||
behavior by altering the implementation of the methods according to your own
|
||||
needs.
|
||||
"""
|
||||
|
||||
error_messages = {
|
||||
"email_taken": _(
|
||||
"An account already exists with this email address."
|
||||
" Please sign in to that account first, then connect"
|
||||
" your %s account."
|
||||
),
|
||||
"invalid_token": _("Invalid token."),
|
||||
"no_password": _("Your account has no password set up."),
|
||||
"no_verified_email": _("Your account has no verified email address."),
|
||||
"disconnect_last": _(
|
||||
"You cannot disconnect your last remaining third-party account."
|
||||
),
|
||||
"connected_other": _(
|
||||
"The third-party account is already connected to a different account."
|
||||
),
|
||||
}
|
||||
|
||||
def pre_social_login(self, request, sociallogin):
|
||||
"""
|
||||
Invoked just after a user successfully authenticates via a
|
||||
social provider, but before the login is actually processed
|
||||
(and before the pre_social_login signal is emitted).
|
||||
|
||||
You can use this hook to intervene, e.g. abort the login by
|
||||
raising an ImmediateHttpResponse
|
||||
|
||||
Why both an adapter hook and the signal? Intervening in
|
||||
e.g. the flow from within a signal handler is bad -- multiple
|
||||
handlers may be active and are executed in undetermined order.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_authentication_error(
|
||||
self,
|
||||
request,
|
||||
provider,
|
||||
error=None,
|
||||
exception=None,
|
||||
extra_context=None,
|
||||
):
|
||||
"""
|
||||
Invoked when there is an error in the authentication cycle. In this
|
||||
case, pre_social_login will not be reached.
|
||||
|
||||
You can use this hook to intervene, e.g. redirect to an
|
||||
educational flow by raising an ImmediateHttpResponse.
|
||||
"""
|
||||
if hasattr(self, "authentication_error"):
|
||||
warnings.warn(
|
||||
"adapter.authentication_error() is deprecated, use adapter.on_authentication_error()"
|
||||
)
|
||||
|
||||
self.authentication_error(
|
||||
request,
|
||||
provider.id,
|
||||
error=error,
|
||||
exception=exception,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
|
||||
def new_user(self, request, sociallogin):
|
||||
"""
|
||||
Instantiates a new User instance.
|
||||
"""
|
||||
return get_account_adapter().new_user(request)
|
||||
|
||||
def save_user(self, request, sociallogin, form=None):
|
||||
"""
|
||||
Saves a newly signed up social login. In case of auto-signup,
|
||||
the signup form is not available.
|
||||
"""
|
||||
u = sociallogin.user
|
||||
u.set_unusable_password()
|
||||
if form:
|
||||
get_account_adapter().save_user(request, u, form)
|
||||
else:
|
||||
get_account_adapter().populate_username(request, u)
|
||||
sociallogin.save(request)
|
||||
return u
|
||||
|
||||
def populate_user(self, request, sociallogin, data):
|
||||
"""
|
||||
Hook that can be used to further populate the user instance.
|
||||
|
||||
For convenience, we populate several common fields.
|
||||
|
||||
Note that the user instance being populated represents a
|
||||
suggested User instance that represents the social user that is
|
||||
in the process of being logged in.
|
||||
|
||||
The User instance need not be completely valid and conflict
|
||||
free. For example, verifying whether or not the username
|
||||
already exists, is not a responsibility.
|
||||
"""
|
||||
username = data.get("username")
|
||||
first_name = data.get("first_name")
|
||||
last_name = data.get("last_name")
|
||||
email = data.get("email")
|
||||
name = data.get("name")
|
||||
user = sociallogin.user
|
||||
user_username(user, username or "")
|
||||
user_email(user, valid_email_or_none(email) or "")
|
||||
name_parts = (name or "").partition(" ")
|
||||
user_field(user, "first_name", first_name or name_parts[0])
|
||||
user_field(user, "last_name", last_name or name_parts[2])
|
||||
return user
|
||||
|
||||
def get_connect_redirect_url(self, request, socialaccount):
|
||||
"""
|
||||
Returns the default URL to redirect to after successfully
|
||||
connecting a social account.
|
||||
"""
|
||||
url = reverse("socialaccount_connections")
|
||||
return url
|
||||
|
||||
def validate_disconnect(self, account, accounts) -> None:
|
||||
"""
|
||||
Validate whether or not the socialaccount account can be
|
||||
safely disconnected.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_auto_signup_allowed(self, request, sociallogin):
|
||||
# If email is specified, check for duplicate and if so, no auto signup.
|
||||
auto_signup = app_settings.AUTO_SIGNUP
|
||||
return auto_signup
|
||||
|
||||
def is_open_for_signup(self, request, sociallogin):
|
||||
"""
|
||||
Checks whether or not the site is open for signups.
|
||||
|
||||
Next to simply returning True/False you can also intervene the
|
||||
regular flow by raising an ImmediateHttpResponse
|
||||
"""
|
||||
return get_account_adapter(request).is_open_for_signup(request)
|
||||
|
||||
def get_signup_form_initial_data(self, sociallogin):
|
||||
user = sociallogin.user
|
||||
initial = {
|
||||
"email": user_email(user) or "",
|
||||
"username": user_username(user) or "",
|
||||
"first_name": user_field(user, "first_name") or "",
|
||||
"last_name": user_field(user, "last_name") or "",
|
||||
}
|
||||
return initial
|
||||
|
||||
def deserialize_instance(self, model, data):
|
||||
return deserialize_instance(model, data)
|
||||
|
||||
def serialize_instance(self, instance):
|
||||
return serialize_instance(instance)
|
||||
|
||||
def list_providers(self, request):
|
||||
from allauth.socialaccount.providers import registry
|
||||
|
||||
ret = []
|
||||
provider_classes = registry.get_class_list()
|
||||
apps = self.list_apps(request)
|
||||
apps_map = {}
|
||||
for app in apps:
|
||||
apps_map.setdefault(app.provider, []).append(app)
|
||||
for provider_class in provider_classes:
|
||||
provider_apps = apps_map.get(provider_class.id, [])
|
||||
if not provider_apps:
|
||||
if provider_class.uses_apps:
|
||||
continue
|
||||
provider_apps = [None]
|
||||
for app in provider_apps:
|
||||
provider = provider_class(request=request, app=app)
|
||||
ret.append(provider)
|
||||
return ret
|
||||
|
||||
def get_provider(self, request, provider, client_id=None):
|
||||
"""Looks up a `provider`, supporting subproviders by looking up by
|
||||
`provider_id`.
|
||||
"""
|
||||
from allauth.socialaccount.providers import registry
|
||||
|
||||
provider_class = registry.get_class(provider)
|
||||
if provider_class is None or provider_class.uses_apps:
|
||||
app = self.get_app(request, provider=provider, client_id=client_id)
|
||||
if not provider_class:
|
||||
# In this case, the `provider` argument passed was a
|
||||
# `provider_id`.
|
||||
provider_class = registry.get_class(app.provider)
|
||||
if not provider_class:
|
||||
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
|
||||
return provider_class(request, app=app)
|
||||
elif provider_class:
|
||||
assert not provider_class.uses_apps
|
||||
return provider_class(request, app=None)
|
||||
else:
|
||||
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
|
||||
|
||||
def list_apps(self, request, provider=None, client_id=None):
|
||||
"""SocialApp's can be setup in the database, or, via
|
||||
`settings.SOCIALACCOUNT_PROVIDERS`. This methods returns a uniform list
|
||||
of all known apps matching the specified criteria, and blends both
|
||||
(db/settings) sources of data.
|
||||
"""
|
||||
# NOTE: Avoid loading models at top due to registry boot...
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
|
||||
# Map provider to the list of apps.
|
||||
provider_to_apps = {}
|
||||
|
||||
# First, populate it with the DB backed apps.
|
||||
if request:
|
||||
db_apps = SocialApp.objects.on_site(request)
|
||||
else:
|
||||
db_apps = SocialApp.objects.all()
|
||||
if provider:
|
||||
db_apps = db_apps.filter(Q(provider=provider) | Q(provider_id=provider))
|
||||
if client_id:
|
||||
db_apps = db_apps.filter(client_id=client_id)
|
||||
for app in db_apps:
|
||||
apps = provider_to_apps.setdefault(app.provider, [])
|
||||
apps.append(app)
|
||||
|
||||
# Then, extend it with the settings backed apps.
|
||||
for p, pcfg in app_settings.PROVIDERS.items():
|
||||
app_configs = pcfg.get("APPS")
|
||||
if app_configs is None:
|
||||
app_config = pcfg.get("APP")
|
||||
if app_config is None:
|
||||
continue
|
||||
app_configs = [app_config]
|
||||
|
||||
apps = provider_to_apps.setdefault(p, [])
|
||||
for config in app_configs:
|
||||
app = SocialApp(provider=p)
|
||||
for field in [
|
||||
"name",
|
||||
"provider_id",
|
||||
"client_id",
|
||||
"secret",
|
||||
"key",
|
||||
"settings",
|
||||
]:
|
||||
if field in config:
|
||||
setattr(app, field, config[field])
|
||||
if "certificate_key" in config:
|
||||
warnings.warn("'certificate_key' should be moved into app.settings")
|
||||
app.settings["certificate_key"] = config["certificate_key"]
|
||||
if client_id and app.client_id != client_id:
|
||||
continue
|
||||
if (
|
||||
provider
|
||||
and app.provider_id != provider
|
||||
and app.provider != provider
|
||||
):
|
||||
continue
|
||||
apps.append(app)
|
||||
|
||||
# Flatten the list of apps.
|
||||
apps = []
|
||||
for provider_apps in provider_to_apps.values():
|
||||
apps.extend(provider_apps)
|
||||
return apps
|
||||
|
||||
def get_app(self, request, provider, client_id=None):
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
|
||||
apps = self.list_apps(request, provider=provider, client_id=client_id)
|
||||
if len(apps) > 1:
|
||||
visible_apps = [app for app in apps if not app.settings.get("hidden")]
|
||||
if len(visible_apps) != 1:
|
||||
raise MultipleObjectsReturned
|
||||
apps = visible_apps
|
||||
elif len(apps) == 0:
|
||||
raise SocialApp.DoesNotExist()
|
||||
return apps[0]
|
||||
|
||||
def send_notification_mail(self, *args, **kwargs):
|
||||
return get_account_adapter().send_notification_mail(*args, **kwargs)
|
||||
|
||||
def get_requests_session(self):
|
||||
import requests
|
||||
|
||||
session = requests.Session()
|
||||
session.request = functools.partial(
|
||||
session.request, timeout=app_settings.REQUESTS_TIMEOUT
|
||||
)
|
||||
return session
|
||||
|
||||
def is_email_verified(self, provider, email):
|
||||
"""
|
||||
Returns ``True`` iff the given email encountered during a social
|
||||
login for the given provider is to be assumed verified.
|
||||
|
||||
This can be configured with a ``"verified_email"`` key in the provider
|
||||
app settings, or a ``"VERIFIED_EMAIL"`` in the global provider settings
|
||||
(``SOCIALACCOUNT_PROVIDERS``). Both can be set to ``False`` or
|
||||
``True``, or, a list of domains to match email addresses against.
|
||||
"""
|
||||
verified_email = None
|
||||
if provider.app:
|
||||
verified_email = provider.app.settings.get("verified_email")
|
||||
if verified_email is None:
|
||||
settings = provider.get_settings()
|
||||
verified_email = settings.get("VERIFIED_EMAIL", False)
|
||||
if isinstance(verified_email, bool):
|
||||
pass
|
||||
elif isinstance(verified_email, list):
|
||||
email_domain = email.partition("@")[2].lower()
|
||||
verified_domains = [d.lower() for d in verified_email]
|
||||
verified_email = email_domain in verified_domains
|
||||
else:
|
||||
raise ImproperlyConfigured("verified_email wrongly configured")
|
||||
return verified_email
|
||||
|
||||
def can_authenticate_by_email(self, login, email):
|
||||
"""
|
||||
Returns ``True`` iff authentication by email is active for this login/email.
|
||||
|
||||
This can be configured with a ``"email_authentication"`` key in the provider
|
||||
app settings, or a ``"VERIFIED_EMAIL"`` in the global provider settings
|
||||
(``SOCIALACCOUNT_PROVIDERS``).
|
||||
"""
|
||||
ret = None
|
||||
provider = login.account.get_provider()
|
||||
if provider.app:
|
||||
ret = provider.app.settings.get("email_authentication")
|
||||
if ret is None:
|
||||
ret = app_settings.EMAIL_AUTHENTICATION or provider.get_settings().get(
|
||||
"EMAIL_AUTHENTICATION", False
|
||||
)
|
||||
return ret
|
||||
|
||||
def generate_state_param(self, state: dict) -> str:
|
||||
"""
|
||||
To preserve certain state before the handshake with the provider
|
||||
takes place, and be able to verify/use that state later on, a `state`
|
||||
parameter is typically passed to the provider. By default, a random
|
||||
string sufficies as the state parameter value is actually just a
|
||||
reference/pointer to the actual state. You can use this adapter method
|
||||
to alter the generation of the `state` parameter.
|
||||
"""
|
||||
from allauth.socialaccount.internal.statekit import STATE_ID_LENGTH
|
||||
|
||||
return get_random_string(STATE_ID_LENGTH)
|
||||
|
||||
|
||||
def get_adapter(request=None):
|
||||
return import_attribute(app_settings.ADAPTER)(request)
|
||||
@@ -0,0 +1,69 @@
|
||||
from typing import List
|
||||
|
||||
from django import forms
|
||||
from django.contrib import admin
|
||||
|
||||
from allauth import app_settings
|
||||
from allauth.account.adapter import get_adapter
|
||||
from allauth.socialaccount import providers
|
||||
from allauth.socialaccount.models import SocialAccount, SocialApp, SocialToken
|
||||
|
||||
|
||||
class SocialAppForm(forms.ModelForm):
|
||||
class Meta:
|
||||
model = SocialApp
|
||||
exclude: List[str] = []
|
||||
widgets = {
|
||||
"client_id": forms.TextInput(attrs={"size": "100"}),
|
||||
"key": forms.TextInput(attrs={"size": "100"}),
|
||||
"secret": forms.TextInput(attrs={"size": "100"}),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["provider"] = forms.ChoiceField(
|
||||
choices=providers.registry.as_choices()
|
||||
)
|
||||
|
||||
|
||||
class SocialAppAdmin(admin.ModelAdmin):
|
||||
form = SocialAppForm
|
||||
list_display = (
|
||||
"name",
|
||||
"provider",
|
||||
)
|
||||
filter_horizontal = ("sites",) if app_settings.SITES_ENABLED else ()
|
||||
|
||||
|
||||
class SocialAccountAdmin(admin.ModelAdmin):
|
||||
search_fields = []
|
||||
raw_id_fields = ("user",)
|
||||
list_display = ("user", "uid", "provider")
|
||||
list_filter = ("provider",)
|
||||
|
||||
def get_search_fields(self, request):
|
||||
base_fields = get_adapter().get_user_search_fields()
|
||||
return list(map(lambda a: "user__" + a, base_fields))
|
||||
|
||||
|
||||
class SocialTokenAdmin(admin.ModelAdmin):
|
||||
raw_id_fields = (
|
||||
"app",
|
||||
"account",
|
||||
)
|
||||
list_display = ("app", "account", "truncated_token", "expires_at")
|
||||
list_filter = ("app", "app__provider", "expires_at")
|
||||
|
||||
def truncated_token(self, token):
|
||||
max_chars = 40
|
||||
ret = token.token
|
||||
if len(ret) > max_chars:
|
||||
ret = ret[0:max_chars] + "...(truncated)"
|
||||
return ret
|
||||
|
||||
truncated_token.short_description = "Token" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
admin.site.register(SocialApp, SocialAppAdmin)
|
||||
admin.site.register(SocialToken, SocialTokenAdmin)
|
||||
admin.site.register(SocialAccount, SocialAccountAdmin)
|
||||
@@ -0,0 +1,155 @@
|
||||
class AppSettings:
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
|
||||
def _setting(self, name, dflt):
|
||||
from allauth.utils import get_setting
|
||||
|
||||
return get_setting(self.prefix + name, dflt)
|
||||
|
||||
@property
|
||||
def QUERY_EMAIL(self):
|
||||
"""
|
||||
Request email address from 3rd party account provider?
|
||||
E.g. using OpenID AX
|
||||
"""
|
||||
from allauth.account import app_settings as account_settings
|
||||
|
||||
return self._setting("QUERY_EMAIL", account_settings.EMAIL_REQUIRED)
|
||||
|
||||
@property
|
||||
def AUTO_SIGNUP(self):
|
||||
"""
|
||||
Attempt to bypass the signup form by using fields (e.g. username,
|
||||
email) retrieved from the social account provider. If a conflict
|
||||
arises due to a duplicate email signup form will still kick in.
|
||||
"""
|
||||
return self._setting("AUTO_SIGNUP", True)
|
||||
|
||||
@property
|
||||
def PROVIDERS(self):
|
||||
"""
|
||||
Provider specific settings
|
||||
"""
|
||||
ret = self._setting("PROVIDERS", {})
|
||||
oidc = ret.get("openid_connect")
|
||||
if oidc:
|
||||
ret["openid_connect"] = self._migrate_oidc(oidc)
|
||||
return ret
|
||||
|
||||
def _migrate_oidc(self, oidc):
|
||||
servers = oidc.get("SERVERS")
|
||||
if servers is None:
|
||||
return oidc
|
||||
ret = {}
|
||||
apps = []
|
||||
for server in servers:
|
||||
app = dict(**server["APP"])
|
||||
app_settings = {}
|
||||
if "token_auth_method" in server:
|
||||
app_settings["token_auth_method"] = server["token_auth_method"]
|
||||
app_settings["server_url"] = server["server_url"]
|
||||
app.update(
|
||||
{
|
||||
"name": server.get("name", ""),
|
||||
"provider_id": server["id"],
|
||||
"settings": app_settings,
|
||||
}
|
||||
)
|
||||
assert app["provider_id"]
|
||||
apps.append(app)
|
||||
ret["APPS"] = apps
|
||||
return ret
|
||||
|
||||
@property
|
||||
def EMAIL_REQUIRED(self):
|
||||
"""
|
||||
The user is required to hand over an email address when signing up
|
||||
"""
|
||||
from allauth.account import app_settings as account_settings
|
||||
|
||||
return self._setting("EMAIL_REQUIRED", account_settings.EMAIL_REQUIRED)
|
||||
|
||||
@property
|
||||
def EMAIL_VERIFICATION(self):
|
||||
"""
|
||||
See email verification method. When `None`, the default
|
||||
`allauth.account` logic kicks in.
|
||||
"""
|
||||
return self._setting("EMAIL_VERIFICATION", None)
|
||||
|
||||
@property
|
||||
def EMAIL_AUTHENTICATION(self):
|
||||
"""Consider a scenario where a social login occurs, and the social
|
||||
account comes with a verified email address (verified by the account
|
||||
provider), but that email address is already taken by a local user
|
||||
account. Additionally, assume that the local user account does not have
|
||||
any social account connected. Now, if the provider can be fully trusted,
|
||||
you can argue that we should treat this scenario as a login to the
|
||||
existing local user account even if the local account does not already
|
||||
have the social account connected, because -- according to the provider
|
||||
-- the user logging in has ownership of the email address. This is how
|
||||
this scenario is handled when `EMAIL_AUTHENTICATION` is set to
|
||||
`True`. As this implies that an untrustworthy provider can login to any
|
||||
local account by fabricating social account data, this setting defaults
|
||||
to `False`. Only set it to `True` if you are using providers that can be
|
||||
fully trusted.
|
||||
"""
|
||||
return self._setting("EMAIL_AUTHENTICATION", False)
|
||||
|
||||
@property
|
||||
def EMAIL_AUTHENTICATION_AUTO_CONNECT(self):
|
||||
"""In case email authentication is applied, this setting controls
|
||||
whether or not the social account is automatically connected to the
|
||||
local account. In case of ``False`` (the default) the local account
|
||||
remains unchanged during the login. In case of ``True``, the social
|
||||
account for which the email matched, is automatically added to the list
|
||||
of social accounts connected to the local account. As a result, even if
|
||||
the user were to change the email address afterwards, social login
|
||||
would still be possible when using ``True``, but not in case of
|
||||
``False``.
|
||||
"""
|
||||
return self._setting("EMAIL_AUTHENTICATION_AUTO_CONNECT", False)
|
||||
|
||||
@property
|
||||
def ADAPTER(self):
|
||||
return self._setting(
|
||||
"ADAPTER",
|
||||
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter",
|
||||
)
|
||||
|
||||
@property
|
||||
def FORMS(self):
|
||||
return self._setting("FORMS", {})
|
||||
|
||||
@property
|
||||
def LOGIN_ON_GET(self):
|
||||
return self._setting("LOGIN_ON_GET", False)
|
||||
|
||||
@property
|
||||
def STORE_TOKENS(self):
|
||||
return self._setting("STORE_TOKENS", False)
|
||||
|
||||
@property
|
||||
def UID_MAX_LENGTH(self):
|
||||
return 191
|
||||
|
||||
@property
|
||||
def SOCIALACCOUNT_STR(self):
|
||||
return self._setting("SOCIALACCOUNT_STR", None)
|
||||
|
||||
@property
|
||||
def REQUESTS_TIMEOUT(self):
|
||||
return self._setting("REQUESTS_TIMEOUT", 5)
|
||||
|
||||
@property
|
||||
def OPENID_CONNECT_URL_PREFIX(self):
|
||||
return self._setting("OPENID_CONNECT_URL_PREFIX", "oidc")
|
||||
|
||||
|
||||
_app_settings = AppSettings("SOCIALACCOUNT_")
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
# See https://peps.python.org/pep-0562/
|
||||
return getattr(_app_settings, name)
|
||||
@@ -0,0 +1,15 @@
|
||||
from django.apps import AppConfig
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from allauth import app_settings
|
||||
|
||||
|
||||
class SocialAccountConfig(AppConfig):
|
||||
name = "allauth.socialaccount"
|
||||
verbose_name = _("Social Accounts")
|
||||
default_auto_field = app_settings.DEFAULT_AUTO_FIELD or "django.db.models.AutoField"
|
||||
|
||||
def ready(self):
|
||||
from allauth.socialaccount.providers import registry
|
||||
|
||||
registry.load()
|
||||
@@ -0,0 +1,66 @@
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.models import (
|
||||
SocialAccount,
|
||||
SocialLogin,
|
||||
SocialToken,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sociallogin_factory(user_factory):
|
||||
def factory(
|
||||
email=None,
|
||||
username=None,
|
||||
with_email=True,
|
||||
provider="unittest-server",
|
||||
uid="123",
|
||||
email_verified=True,
|
||||
with_token=False,
|
||||
):
|
||||
user = user_factory(
|
||||
username=username, email=email, commit=False, with_email=with_email
|
||||
)
|
||||
account = SocialAccount(provider=provider, uid=uid)
|
||||
sociallogin = SocialLogin(user=user, account=account)
|
||||
if with_email:
|
||||
sociallogin.email_addresses = [
|
||||
EmailAddress(email=user.email, verified=email_verified, primary=True)
|
||||
]
|
||||
if with_token:
|
||||
sociallogin.token = SocialToken(token="123", token_secret="456")
|
||||
return sociallogin
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_decode_bypass():
|
||||
@contextmanager
|
||||
def f(jwt_data):
|
||||
with patch("allauth.socialaccount.internal.jwtkit.verify_and_decode") as m:
|
||||
data = {
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "client_id",
|
||||
"sub": "123sub",
|
||||
"hd": "example.com",
|
||||
"email": "raymond@example.com",
|
||||
"email_verified": True,
|
||||
"at_hash": "HK6E_P6Dh8Y93mRNtsDB1Q",
|
||||
"name": "Raymond Penners",
|
||||
"picture": "https://lh5.googleusercontent.com/photo.jpg",
|
||||
"given_name": "Raymond",
|
||||
"family_name": "Penners",
|
||||
"locale": "en",
|
||||
"iat": 123,
|
||||
"exp": 456,
|
||||
}
|
||||
data.update(jwt_data)
|
||||
m.return_value = data
|
||||
yield
|
||||
|
||||
return f
|
||||
@@ -0,0 +1,62 @@
|
||||
from django import forms
|
||||
|
||||
from allauth.account.forms import BaseSignupForm
|
||||
from allauth.socialaccount.internal import flows
|
||||
|
||||
from . import app_settings
|
||||
from .adapter import get_adapter
|
||||
from .models import SocialAccount
|
||||
|
||||
|
||||
class SignupForm(BaseSignupForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.sociallogin = kwargs.pop("sociallogin")
|
||||
initial = get_adapter().get_signup_form_initial_data(self.sociallogin)
|
||||
kwargs.update(
|
||||
{
|
||||
"initial": initial,
|
||||
"email_required": kwargs.get(
|
||||
"email_required", app_settings.EMAIL_REQUIRED
|
||||
),
|
||||
}
|
||||
)
|
||||
super(SignupForm, self).__init__(*args, **kwargs)
|
||||
|
||||
def save(self, request):
|
||||
adapter = get_adapter()
|
||||
user = adapter.save_user(request, self.sociallogin, form=self)
|
||||
self.custom_signup(request, user)
|
||||
return user
|
||||
|
||||
def validate_unique_email(self, value):
|
||||
try:
|
||||
return super(SignupForm, self).validate_unique_email(value)
|
||||
except forms.ValidationError:
|
||||
raise get_adapter().validation_error(
|
||||
"email_taken", self.sociallogin.account.get_provider().name
|
||||
)
|
||||
|
||||
|
||||
class DisconnectForm(forms.Form):
|
||||
account = forms.ModelChoiceField(
|
||||
queryset=SocialAccount.objects.none(),
|
||||
widget=forms.RadioSelect,
|
||||
required=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.request = kwargs.pop("request")
|
||||
self.accounts = SocialAccount.objects.filter(user=self.request.user)
|
||||
super(DisconnectForm, self).__init__(*args, **kwargs)
|
||||
self.fields["account"].queryset = self.accounts
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super(DisconnectForm, self).clean()
|
||||
account = cleaned_data.get("account")
|
||||
if account:
|
||||
flows.connect.validate_disconnect(self.request, account)
|
||||
return cleaned_data
|
||||
|
||||
def save(self):
|
||||
account = self.cleaned_data["account"]
|
||||
flows.connect.disconnect(self.request, account)
|
||||
@@ -0,0 +1,74 @@
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.shortcuts import render
|
||||
from django.urls import reverse
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.utils import user_display
|
||||
from allauth.core.exceptions import ImmediateHttpResponse
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.internal import flows
|
||||
from allauth.socialaccount.providers.base import AuthError
|
||||
|
||||
|
||||
def render_authentication_error(
|
||||
request,
|
||||
provider,
|
||||
error=AuthError.UNKNOWN,
|
||||
exception=None,
|
||||
extra_context=None,
|
||||
):
|
||||
try:
|
||||
if allauth_settings.HEADLESS_ENABLED:
|
||||
from allauth.headless.socialaccount import internal
|
||||
|
||||
internal.on_authentication_error(
|
||||
request,
|
||||
provider=provider,
|
||||
error=error,
|
||||
exception=exception,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
|
||||
if extra_context is None:
|
||||
extra_context = {}
|
||||
get_adapter().on_authentication_error(
|
||||
request,
|
||||
provider,
|
||||
error=error,
|
||||
exception=exception,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
except ImmediateHttpResponse as e:
|
||||
return e.response
|
||||
if error == AuthError.CANCELLED:
|
||||
return HttpResponseRedirect(reverse("socialaccount_login_cancelled"))
|
||||
context = {
|
||||
"auth_error": {
|
||||
"provider": provider,
|
||||
"code": error,
|
||||
"exception": exception,
|
||||
}
|
||||
}
|
||||
context.update(extra_context)
|
||||
return render(
|
||||
request,
|
||||
"socialaccount/authentication_error." + account_settings.TEMPLATE_EXTENSION,
|
||||
context,
|
||||
)
|
||||
|
||||
|
||||
def complete_social_login(request, sociallogin):
|
||||
if sociallogin.is_headless:
|
||||
from allauth.headless.socialaccount import internal
|
||||
|
||||
return internal.complete_login(request, sociallogin)
|
||||
return flows.login.complete_login(request, sociallogin)
|
||||
|
||||
|
||||
def socialaccount_user_display(socialaccount):
|
||||
func = app_settings.SOCIALACCOUNT_STR
|
||||
if not func:
|
||||
return user_display(socialaccount.user)
|
||||
return func(socialaccount)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,9 @@
|
||||
from allauth.socialaccount.internal.flows import (
|
||||
connect,
|
||||
email_authentication,
|
||||
login,
|
||||
signup,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["connect", "login", "signup", "email_authentication"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,130 @@
|
||||
from django.contrib import messages
|
||||
from django.core.exceptions import PermissionDenied, ValidationError
|
||||
from django.http import HttpResponseRedirect
|
||||
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.internal import flows
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.core.exceptions import ReauthenticationRequired
|
||||
from allauth.socialaccount import signals
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialAccount, SocialLogin
|
||||
|
||||
|
||||
def validate_disconnect(request, account):
|
||||
"""
|
||||
Validate whether or not the socialaccount account can be
|
||||
safely disconnected.
|
||||
"""
|
||||
accounts = SocialAccount.objects.filter(user_id=account.user_id)
|
||||
is_last = not accounts.exclude(pk=account.pk).exists()
|
||||
adapter = get_adapter()
|
||||
if is_last:
|
||||
if allauth_settings.SOCIALACCOUNT_ONLY:
|
||||
raise adapter.validation_error("disconnect_last")
|
||||
# No usable password would render the local account unusable
|
||||
if not account.user.has_usable_password():
|
||||
raise adapter.validation_error("no_password")
|
||||
# No email address, no password reset
|
||||
if (
|
||||
account_settings.EMAIL_VERIFICATION
|
||||
== account_settings.EmailVerificationMethod.MANDATORY
|
||||
):
|
||||
if not EmailAddress.objects.filter(
|
||||
user=account.user, verified=True
|
||||
).exists():
|
||||
raise adapter.validation_error("no_verified_email")
|
||||
adapter.validate_disconnect(account, accounts)
|
||||
|
||||
|
||||
def disconnect(request, account):
|
||||
if account_settings.REAUTHENTICATION_REQUIRED:
|
||||
flows.reauthentication.raise_if_reauthentication_required(request)
|
||||
|
||||
get_account_adapter().add_message(
|
||||
request,
|
||||
messages.INFO,
|
||||
"socialaccount/messages/account_disconnected.txt",
|
||||
)
|
||||
provider = account.get_provider()
|
||||
account.delete()
|
||||
signals.social_account_removed.send(
|
||||
sender=SocialAccount, request=request, socialaccount=account
|
||||
)
|
||||
get_adapter().send_notification_mail(
|
||||
"socialaccount/email/account_disconnected",
|
||||
request.user,
|
||||
context={
|
||||
"account": account,
|
||||
"provider": provider,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resume_connect(request, serialized_state):
|
||||
sociallogin = SocialLogin.deserialize(serialized_state)
|
||||
return connect(request, sociallogin)
|
||||
|
||||
|
||||
def connect(request, sociallogin):
|
||||
try:
|
||||
ok, action, message = do_connect(request, sociallogin)
|
||||
except PermissionDenied:
|
||||
# This should not happen. Simply redirect to the connections
|
||||
# view (which has a login required)
|
||||
connect_redirect_url = get_adapter().get_connect_redirect_url(
|
||||
request, sociallogin.account
|
||||
)
|
||||
return HttpResponseRedirect(connect_redirect_url)
|
||||
except ReauthenticationRequired:
|
||||
return flows.reauthentication.stash_and_reauthenticate(
|
||||
request,
|
||||
sociallogin.serialize(),
|
||||
"allauth.socialaccount.internal.flows.connect.resume_connect",
|
||||
)
|
||||
except ValidationError:
|
||||
ok, action, message = (
|
||||
False,
|
||||
None,
|
||||
"socialaccount/messages/account_connected_other.txt",
|
||||
)
|
||||
level = messages.INFO if ok else messages.ERROR
|
||||
default_next = get_adapter().get_connect_redirect_url(request, sociallogin.account)
|
||||
next_url = sociallogin.get_redirect_url(request) or default_next
|
||||
get_account_adapter(request).add_message(
|
||||
request,
|
||||
level,
|
||||
message,
|
||||
message_context={"sociallogin": sociallogin, "action": action},
|
||||
)
|
||||
return HttpResponseRedirect(next_url)
|
||||
|
||||
|
||||
def do_connect(request, sociallogin):
|
||||
if request.user.is_anonymous:
|
||||
raise PermissionDenied()
|
||||
if account_settings.REAUTHENTICATION_REQUIRED:
|
||||
flows.reauthentication.raise_if_reauthentication_required(request)
|
||||
message = "socialaccount/messages/account_connected.txt"
|
||||
action = None
|
||||
ok = True
|
||||
if sociallogin.is_existing:
|
||||
if sociallogin.user != request.user:
|
||||
# Social account of other user. For now, this scenario
|
||||
# is not supported. Issue is that one cannot simply
|
||||
# remove the social account from the other user, as
|
||||
# that may render the account unusable.
|
||||
raise get_adapter().validation_error("connected_other")
|
||||
elif not sociallogin.account._state.adding:
|
||||
action = "updated"
|
||||
message = "socialaccount/messages/account_connected_updated.txt"
|
||||
else:
|
||||
action = "added"
|
||||
sociallogin.connect(request, request.user)
|
||||
else:
|
||||
# New account, let's connect
|
||||
action = "added"
|
||||
sociallogin.connect(request, request.user)
|
||||
return ok, action, message
|
||||
@@ -0,0 +1,34 @@
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account.models import EmailAddress
|
||||
|
||||
|
||||
def wipe_password(request, user, email: str):
|
||||
"""
|
||||
Consider a scenario where an attacker signs up for an account using the
|
||||
email address of a victim. Obviously, the email address cannot be
|
||||
verified, yet the attacker -- knowing the password -- can wait until the
|
||||
victim appears. When the victim signs in using email authentication, it
|
||||
is not obvious that the victim is signing into an account that was not
|
||||
created by the victim. As a result, both the attacker and the victim now
|
||||
have access to the account. To prevent this, we wipe the password of the
|
||||
account in case the email address was not verified, effectively locking
|
||||
out the attacker.
|
||||
"""
|
||||
try:
|
||||
address = EmailAddress.objects.get_for_user(user, email)
|
||||
except EmailAddress.DoesNotExist:
|
||||
address = None
|
||||
if address and address.verified:
|
||||
# Verified email address, no reason to worry.
|
||||
return
|
||||
if user.has_usable_password():
|
||||
user.set_unusable_password()
|
||||
user.save(update_fields=["password"])
|
||||
# Also wipe any other sessions (upstream integrators may hook up to the
|
||||
# ending of the sessions to trigger e.g. backchannel logout.
|
||||
if allauth_settings.USERSESSIONS_ENABLED:
|
||||
from allauth.usersessions.internal.flows.sessions import (
|
||||
end_other_sessions,
|
||||
)
|
||||
|
||||
end_other_sessions(request, user)
|
||||
@@ -0,0 +1,97 @@
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.shortcuts import render
|
||||
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.utils import perform_login
|
||||
from allauth.core.exceptions import (
|
||||
ImmediateHttpResponse,
|
||||
SignupClosedException,
|
||||
)
|
||||
from allauth.socialaccount import app_settings, signals
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.internal.flows.connect import connect, do_connect
|
||||
from allauth.socialaccount.internal.flows.signup import (
|
||||
clear_pending_signup,
|
||||
process_signup,
|
||||
)
|
||||
from allauth.socialaccount.models import SocialLogin
|
||||
from allauth.socialaccount.providers.base import AuthProcess
|
||||
|
||||
|
||||
def _login(request, sociallogin):
|
||||
sociallogin._accept_login(request)
|
||||
record_authentication(request, sociallogin)
|
||||
return perform_login(
|
||||
request,
|
||||
sociallogin.user,
|
||||
email_verification=app_settings.EMAIL_VERIFICATION,
|
||||
redirect_url=sociallogin.get_redirect_url(request),
|
||||
signal_kwargs={"sociallogin": sociallogin},
|
||||
)
|
||||
|
||||
|
||||
def pre_social_login(request, sociallogin):
|
||||
clear_pending_signup(request)
|
||||
assert not sociallogin.is_existing
|
||||
sociallogin.lookup()
|
||||
get_adapter().pre_social_login(request, sociallogin)
|
||||
signals.pre_social_login.send(
|
||||
sender=SocialLogin, request=request, sociallogin=sociallogin
|
||||
)
|
||||
|
||||
|
||||
def complete_login(request, sociallogin, raises=False):
|
||||
try:
|
||||
pre_social_login(request, sociallogin)
|
||||
process = sociallogin.state.get("process")
|
||||
if process == AuthProcess.REDIRECT:
|
||||
return _redirect(request, sociallogin)
|
||||
elif process == AuthProcess.CONNECT:
|
||||
if raises:
|
||||
do_connect(request, sociallogin)
|
||||
else:
|
||||
return connect(request, sociallogin)
|
||||
else:
|
||||
return _authenticate(request, sociallogin)
|
||||
except SignupClosedException:
|
||||
if raises:
|
||||
raise
|
||||
return render(
|
||||
request,
|
||||
"account/signup_closed." + account_settings.TEMPLATE_EXTENSION,
|
||||
)
|
||||
except ImmediateHttpResponse as e:
|
||||
if raises:
|
||||
raise
|
||||
return e.response
|
||||
|
||||
|
||||
def _redirect(request, sociallogin):
|
||||
next_url = sociallogin.get_redirect_url(request) or "/"
|
||||
return HttpResponseRedirect(next_url)
|
||||
|
||||
|
||||
def _authenticate(request, sociallogin):
|
||||
if request.user.is_authenticated:
|
||||
get_account_adapter(request).logout(request)
|
||||
if sociallogin.is_existing:
|
||||
# Login existing user
|
||||
ret = _login(request, sociallogin)
|
||||
else:
|
||||
# New social user
|
||||
ret = process_signup(request, sociallogin)
|
||||
return ret
|
||||
|
||||
|
||||
def record_authentication(request, sociallogin):
|
||||
from allauth.account.internal.flows.login import record_authentication
|
||||
|
||||
record_authentication(
|
||||
request,
|
||||
"socialaccount",
|
||||
**{
|
||||
"provider": sociallogin.account.provider,
|
||||
"uid": sociallogin.account.uid,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,114 @@
|
||||
from django.forms import ValidationError
|
||||
|
||||
from allauth.account import app_settings as account_settings
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.internal.flows.manage_email import assess_unique_email
|
||||
from allauth.account.internal.flows.signup import (
|
||||
complete_signup,
|
||||
prevent_enumeration,
|
||||
)
|
||||
from allauth.account.utils import user_username
|
||||
from allauth.core.exceptions import SignupClosedException
|
||||
from allauth.core.internal.httpkit import headed_redirect_response
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialLogin
|
||||
|
||||
|
||||
def get_pending_signup(request):
|
||||
data = request.session.get("socialaccount_sociallogin")
|
||||
if data:
|
||||
return SocialLogin.deserialize(data)
|
||||
|
||||
|
||||
def redirect_to_signup(request, sociallogin):
|
||||
request.session["socialaccount_sociallogin"] = sociallogin.serialize()
|
||||
return headed_redirect_response("socialaccount_signup")
|
||||
|
||||
|
||||
def clear_pending_signup(request):
|
||||
request.session.pop("socialaccount_sociallogin", None)
|
||||
|
||||
|
||||
def signup_by_form(request, sociallogin, form):
|
||||
clear_pending_signup(request)
|
||||
user, resp = form.try_save(request)
|
||||
if not resp:
|
||||
resp = complete_social_signup(request, sociallogin)
|
||||
return resp
|
||||
|
||||
|
||||
def process_auto_signup(request, sociallogin):
|
||||
auto_signup = get_adapter().is_auto_signup_allowed(request, sociallogin)
|
||||
if not auto_signup:
|
||||
return False, None
|
||||
email = None
|
||||
if sociallogin.email_addresses:
|
||||
email = sociallogin.email_addresses[0].email
|
||||
# Let's check if auto_signup is really possible...
|
||||
if email:
|
||||
assessment = assess_unique_email(email)
|
||||
if assessment is True:
|
||||
# Auto signup is fine.
|
||||
pass
|
||||
elif assessment is False:
|
||||
# Oops, another user already has this address. We cannot simply
|
||||
# connect this social account to the existing user. Reason is
|
||||
# that the email address may not be verified, meaning, the user
|
||||
# may be a hacker that has added your email address to their
|
||||
# account in the hope that you fall in their trap. We cannot
|
||||
# check on 'email_address.verified' either, because
|
||||
# 'email_address' is not guaranteed to be verified.
|
||||
auto_signup = False
|
||||
# TODO: We redirect to signup form -- user will see email
|
||||
# address conflict only after posting whereas we detected it
|
||||
# here already.
|
||||
else:
|
||||
assert assessment is None
|
||||
# Prevent enumeration is properly turned on, meaning, we cannot
|
||||
# show the signup form to allow the user to input another email
|
||||
# address. Instead, we're going to send the user an email that
|
||||
# the account already exists, and on the outside make it appear
|
||||
# as if an email verification mail was sent.
|
||||
resp = prevent_enumeration(request, email)
|
||||
return False, resp
|
||||
elif app_settings.EMAIL_REQUIRED:
|
||||
# Nope, email is required and we don't have it yet...
|
||||
auto_signup = False
|
||||
return auto_signup, None
|
||||
|
||||
|
||||
def process_signup(request, sociallogin):
|
||||
if not get_adapter().is_open_for_signup(request, sociallogin):
|
||||
raise SignupClosedException()
|
||||
auto_signup, resp = process_auto_signup(request, sociallogin)
|
||||
if resp:
|
||||
return resp
|
||||
if not auto_signup:
|
||||
resp = redirect_to_signup(request, sociallogin)
|
||||
else:
|
||||
# Ok, auto signup it is, at least the email address is ok.
|
||||
# We still need to check the username though...
|
||||
if account_settings.USER_MODEL_USERNAME_FIELD:
|
||||
username = user_username(sociallogin.user)
|
||||
try:
|
||||
get_account_adapter(request).clean_username(username)
|
||||
except ValidationError:
|
||||
# This username is no good ...
|
||||
user_username(sociallogin.user, "")
|
||||
# TODO: This part contains a lot of duplication of logic
|
||||
# ("closed" rendering, create user, send email, in active
|
||||
# etc..)
|
||||
get_adapter().save_user(request, sociallogin, form=None)
|
||||
resp = complete_social_signup(request, sociallogin)
|
||||
return resp
|
||||
|
||||
|
||||
def complete_social_signup(request, sociallogin):
|
||||
return complete_signup(
|
||||
request,
|
||||
user=sociallogin.user,
|
||||
email_verification=app_settings.EMAIL_VERIFICATION,
|
||||
redirect_url=sociallogin.get_redirect_url(request),
|
||||
signal_kwargs={"sociallogin": sociallogin},
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from django.core.cache import cache
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.x509 import load_pem_x509_certificate
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
|
||||
|
||||
|
||||
def lookup_kid_pem_x509_certificate(keys_data, kid):
|
||||
"""
|
||||
Looks up the key given keys data of the form:
|
||||
|
||||
{"<kid>": "-----BEGIN CERTIFICATE-----\nCERTIFICATE"}
|
||||
"""
|
||||
key = keys_data.get(kid)
|
||||
if key:
|
||||
public_key = load_pem_x509_certificate(
|
||||
key.encode("utf8"), default_backend()
|
||||
).public_key()
|
||||
return public_key
|
||||
|
||||
|
||||
def lookup_kid_jwk(keys_data, kid):
|
||||
"""
|
||||
Looks up the key given keys data of the form:
|
||||
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"kid": "W6WcOKB",
|
||||
"use": "sig",
|
||||
"alg": "RS256",
|
||||
"n": "2Zc5d0-zk....",
|
||||
"e": "AQAB"
|
||||
}]
|
||||
}
|
||||
"""
|
||||
for d in keys_data["keys"]:
|
||||
if d["kid"] == kid:
|
||||
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(d))
|
||||
return public_key
|
||||
|
||||
|
||||
def fetch_key(credential, keys_url, lookup):
|
||||
header = jwt.get_unverified_header(credential)
|
||||
# {'alg': 'RS256', 'kid': '[AWS-SECRET-REMOVED]', 'typ': 'JWT'}
|
||||
kid = header["kid"]
|
||||
alg = header["alg"]
|
||||
response = get_adapter().get_requests_session().get(keys_url)
|
||||
response.raise_for_status()
|
||||
keys_data = response.json()
|
||||
key = lookup(keys_data, kid)
|
||||
if not key:
|
||||
raise OAuth2Error(f"Invalid 'kid': '{kid}'")
|
||||
return alg, key
|
||||
|
||||
|
||||
def verify_jti(data: dict) -> None:
|
||||
"""
|
||||
Put the JWT token on a blacklist to prevent replay attacks.
|
||||
"""
|
||||
iss = data.get("iss")
|
||||
exp = data.get("exp")
|
||||
jti = data.get("jti")
|
||||
if iss is None or exp is None or jti is None:
|
||||
return
|
||||
timeout = exp - time.time()
|
||||
key = f"jwt:iss={iss},jti={jti}"
|
||||
if not cache.add(key=key, value=True, timeout=timeout):
|
||||
raise OAuth2Error("token already used")
|
||||
|
||||
|
||||
def verify_and_decode(
|
||||
*, credential, keys_url, issuer, audience, lookup_kid, verify_signature=True
|
||||
):
|
||||
try:
|
||||
if verify_signature:
|
||||
alg, key = fetch_key(credential, keys_url, lookup_kid)
|
||||
algorithms = [alg]
|
||||
else:
|
||||
key = ""
|
||||
algorithms = None
|
||||
data = jwt.decode(
|
||||
credential,
|
||||
key=key,
|
||||
options={
|
||||
"verify_signature": verify_signature,
|
||||
"verify_iss": True,
|
||||
"verify_aud": True,
|
||||
"verify_exp": True,
|
||||
},
|
||||
issuer=issuer,
|
||||
audience=audience,
|
||||
algorithms=algorithms,
|
||||
)
|
||||
verify_jti(data)
|
||||
return data
|
||||
except jwt.PyJWTError as e:
|
||||
raise OAuth2Error("Invalid id_token") from e
|
||||
@@ -0,0 +1,69 @@
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
|
||||
|
||||
STATE_ID_LENGTH = 16
|
||||
MAX_STATES = 10
|
||||
STATES_SESSION_KEY = "socialaccount_states"
|
||||
|
||||
|
||||
def get_oldest_state(
|
||||
states: Dict[str, Tuple[Dict[str, Any], float]], rev: bool = False
|
||||
) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
|
||||
oldest_ts = None
|
||||
oldest_id = None
|
||||
oldest = None
|
||||
for state_id, state_ts in states.items():
|
||||
ts = state_ts[1]
|
||||
if oldest_ts is None or (
|
||||
(rev and ts > oldest_ts) or ((not rev) and oldest_ts > ts)
|
||||
):
|
||||
oldest_ts = ts
|
||||
oldest_id = state_id
|
||||
oldest = state_ts[0]
|
||||
return oldest_id, oldest
|
||||
|
||||
|
||||
def gc_states(states: Dict[str, Tuple[Dict[str, Any], float]]):
|
||||
if len(states) > MAX_STATES:
|
||||
oldest_id, oldest = get_oldest_state(states)
|
||||
if oldest_id:
|
||||
del states[oldest_id]
|
||||
|
||||
|
||||
def get_states(request) -> Dict[str, Tuple[Dict[str, Any], float]]:
|
||||
states = request.session.get(STATES_SESSION_KEY)
|
||||
if not isinstance(states, dict):
|
||||
states = {}
|
||||
return states
|
||||
|
||||
|
||||
def stash_state(request, state: Dict[str, Any], state_id: Optional[str] = None) -> str:
|
||||
states = get_states(request)
|
||||
gc_states(states)
|
||||
if state_id is None:
|
||||
state_id = get_adapter().generate_state_param(state)
|
||||
states[state_id] = (state, time.time())
|
||||
request.session[STATES_SESSION_KEY] = states
|
||||
return state_id
|
||||
|
||||
|
||||
def unstash_state(request, state_id: str) -> Optional[Dict[str, Any]]:
|
||||
state: Optional[Dict[str, Any]] = None
|
||||
states = get_states(request)
|
||||
state_ts = states.get(state_id)
|
||||
if state_ts is not None:
|
||||
state = state_ts[0]
|
||||
del states[state_id]
|
||||
request.session[STATES_SESSION_KEY] = states
|
||||
return state
|
||||
|
||||
|
||||
def unstash_last_state(request) -> Optional[Dict[str, Any]]:
|
||||
states = get_states(request)
|
||||
state_id, state = get_oldest_state(states, rev=True)
|
||||
if state_id:
|
||||
unstash_state(request, state_id)
|
||||
return state
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,36 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from allauth.socialaccount.internal.jwtkit import verify_and_decode
|
||||
from allauth.socialaccount.providers.apple.client import jwt_encode
|
||||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
|
||||
|
||||
|
||||
def test_verify_and_decode(enable_cache):
|
||||
now = timezone.now()
|
||||
payload = {
|
||||
"iss": "https://accounts.google.com",
|
||||
"azp": "client_id",
|
||||
"aud": "client_id",
|
||||
"sub": "108204268033311374519",
|
||||
"hd": "example.com",
|
||||
"locale": "en",
|
||||
"iat": now,
|
||||
"jti": "[AWS-SECRET-REMOVED]",
|
||||
"exp": now + timedelta(hours=1),
|
||||
}
|
||||
id_token = jwt_encode(payload, "secret")
|
||||
for attempt in range(2):
|
||||
try:
|
||||
verify_and_decode(
|
||||
credential=id_token,
|
||||
keys_url="/",
|
||||
issuer=payload["iss"],
|
||||
audience=payload["aud"],
|
||||
lookup_kid=False,
|
||||
verify_signature=False,
|
||||
)
|
||||
assert attempt == 0
|
||||
except OAuth2Error:
|
||||
assert attempt == 1
|
||||
@@ -0,0 +1,48 @@
|
||||
from allauth.socialaccount.internal import statekit
|
||||
|
||||
|
||||
def test_get_oldest_state():
|
||||
states = {
|
||||
"new": [{"id": "new"}, 300],
|
||||
"mid": [{"id": "mid"}, 200],
|
||||
"old": [{"id": "old"}, 100],
|
||||
}
|
||||
state_id, state = statekit.get_oldest_state(states)
|
||||
assert state_id == "old"
|
||||
assert state["id"] == "old"
|
||||
|
||||
|
||||
def test_get_oldest_state_empty():
|
||||
state_id, state = statekit.get_oldest_state({})
|
||||
assert state_id is None
|
||||
assert state is None
|
||||
|
||||
|
||||
def test_gc_states():
|
||||
states = {}
|
||||
for i in range(statekit.MAX_STATES + 1):
|
||||
states[f"state-{i}"] = [{"i": i}, 1000 + i]
|
||||
assert len(states) == statekit.MAX_STATES + 1
|
||||
statekit.gc_states(states)
|
||||
assert len(states) == statekit.MAX_STATES
|
||||
assert "state-0" not in states
|
||||
|
||||
|
||||
def test_stashing(rf):
|
||||
request = rf.get("/")
|
||||
request.session = {}
|
||||
state_id = statekit.stash_state(request, {"foo": "bar"})
|
||||
state2_id = statekit.stash_state(request, {"foo2": "bar2"})
|
||||
state3_id = statekit.stash_state(request, {"foo3": "bar3"})
|
||||
state = statekit.unstash_last_state(request)
|
||||
assert state == {"foo3": "bar3"}
|
||||
state = statekit.unstash_state(request, state3_id)
|
||||
assert state is None
|
||||
state = statekit.unstash_state(request, state2_id)
|
||||
assert state == {"foo2": "bar2"}
|
||||
state = statekit.unstash_state(request, state2_id)
|
||||
assert state is None
|
||||
state = statekit.unstash_state(request, state_id)
|
||||
assert state == {"foo": "bar"}
|
||||
state = statekit.unstash_state(request, state_id)
|
||||
assert state is None
|
||||
@@ -0,0 +1,192 @@
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
from allauth import app_settings
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = (
|
||||
[
|
||||
("sites", "0001_initial"),
|
||||
]
|
||||
if app_settings.SITES_ENABLED
|
||||
else []
|
||||
) + [
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="SocialAccount",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
verbose_name="ID",
|
||||
serialize=False,
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.CharField(
|
||||
max_length=30,
|
||||
verbose_name="provider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"uid",
|
||||
models.CharField(
|
||||
max_length=getattr(
|
||||
settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191
|
||||
),
|
||||
verbose_name="uid",
|
||||
),
|
||||
),
|
||||
(
|
||||
"last_login",
|
||||
models.DateTimeField(auto_now=True, verbose_name="last login"),
|
||||
),
|
||||
(
|
||||
"date_joined",
|
||||
models.DateTimeField(auto_now_add=True, verbose_name="date joined"),
|
||||
),
|
||||
(
|
||||
"extra_data",
|
||||
models.TextField(default="{}", verbose_name="extra data"),
|
||||
),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "social account",
|
||||
"verbose_name_plural": "social accounts",
|
||||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="SocialApp",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
verbose_name="ID",
|
||||
serialize=False,
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.CharField(
|
||||
max_length=30,
|
||||
verbose_name="provider",
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=40, verbose_name="name")),
|
||||
(
|
||||
"client_id",
|
||||
models.CharField(
|
||||
help_text="App ID, or consumer key",
|
||||
max_length=100,
|
||||
verbose_name="client id",
|
||||
),
|
||||
),
|
||||
(
|
||||
"secret",
|
||||
models.CharField(
|
||||
help_text="API secret, client secret, or consumer secret",
|
||||
max_length=100,
|
||||
verbose_name="secret key",
|
||||
),
|
||||
),
|
||||
(
|
||||
"key",
|
||||
models.CharField(
|
||||
help_text="Key",
|
||||
max_length=100,
|
||||
verbose_name="key",
|
||||
blank=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
+ (
|
||||
[
|
||||
("sites", models.ManyToManyField(to="sites.Site", blank=True)),
|
||||
]
|
||||
if app_settings.SITES_ENABLED
|
||||
else []
|
||||
),
|
||||
options={
|
||||
"verbose_name": "social application",
|
||||
"verbose_name_plural": "social applications",
|
||||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="SocialToken",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
verbose_name="ID",
|
||||
serialize=False,
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"token",
|
||||
models.TextField(
|
||||
help_text='"oauth_token" (OAuth1) or access token (OAuth2)',
|
||||
verbose_name="token",
|
||||
),
|
||||
),
|
||||
(
|
||||
"token_secret",
|
||||
models.TextField(
|
||||
help_text='"oauth_token_secret" (OAuth1) or refresh token (OAuth2)',
|
||||
verbose_name="token secret",
|
||||
blank=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"expires_at",
|
||||
models.DateTimeField(
|
||||
null=True, verbose_name="expires at", blank=True
|
||||
),
|
||||
),
|
||||
(
|
||||
"account",
|
||||
models.ForeignKey(
|
||||
to="socialaccount.SocialAccount",
|
||||
on_delete=models.CASCADE,
|
||||
),
|
||||
),
|
||||
(
|
||||
"app",
|
||||
models.ForeignKey(
|
||||
to="socialaccount.SocialApp", on_delete=models.CASCADE
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "social application token",
|
||||
"verbose_name_plural": "social application tokens",
|
||||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="socialtoken",
|
||||
unique_together=set([("app", "account")]),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="socialaccount",
|
||||
unique_together=set([("provider", "uid")]),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,45 @@
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="uid",
|
||||
field=models.CharField(
|
||||
max_length=getattr(settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191),
|
||||
verbose_name="uid",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialapp",
|
||||
name="client_id",
|
||||
field=models.CharField(
|
||||
help_text="App ID, or consumer key",
|
||||
max_length=191,
|
||||
verbose_name="client id",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialapp",
|
||||
name="key",
|
||||
field=models.CharField(
|
||||
help_text="Key", max_length=191, verbose_name="key", blank=True
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialapp",
|
||||
name="secret",
|
||||
field=models.CharField(
|
||||
help_text="API secret, client secret, or consumer secret",
|
||||
max_length=191,
|
||||
verbose_name="secret key",
|
||||
blank=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,16 @@
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0002_token_max_lengths"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="extra_data",
|
||||
field=models.TextField(default="{}", verbose_name="extra data"),
|
||||
preserve_default=True,
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,29 @@
|
||||
# Generated by Django 3.2.19 on 2023-06-30 13:16
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0003_extra_data_default_dict"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="socialapp",
|
||||
name="provider_id",
|
||||
field=models.CharField(
|
||||
blank=True, max_length=200, verbose_name="provider ID"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="socialapp",
|
||||
name="settings",
|
||||
field=models.JSONField(blank=True, default=dict),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="provider",
|
||||
field=models.CharField(max_length=200, verbose_name="provider"),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 3.2.20 on 2023-09-03 19:46
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0004_app_provider_id_settings"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialtoken",
|
||||
name="app",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="socialaccount.socialapp",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
# Generated by Django 3.2.20 on 2023-10-11 09:23
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("socialaccount", "0005_socialtoken_nullable_app"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="socialaccount",
|
||||
name="extra_data",
|
||||
field=models.JSONField(default=dict, verbose_name="extra data"),
|
||||
),
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,411 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import authenticate, get_user_model
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
import allauth.app_settings
|
||||
from allauth import app_settings as allauth_settings
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.account.utils import (
|
||||
filter_users_by_email,
|
||||
get_next_redirect_url,
|
||||
setup_user_email,
|
||||
)
|
||||
from allauth.core import context
|
||||
from allauth.socialaccount import app_settings, providers, signals
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.internal import statekit
|
||||
from allauth.utils import get_request_param
|
||||
|
||||
|
||||
if not allauth_settings.SOCIALACCOUNT_ENABLED:
|
||||
raise ImproperlyConfigured(
|
||||
"allauth.socialaccount not installed, yet its models are imported."
|
||||
)
|
||||
|
||||
|
||||
class SocialAppManager(models.Manager):
|
||||
def on_site(self, request):
|
||||
if allauth.app_settings.SITES_ENABLED:
|
||||
site = get_current_site(request)
|
||||
return self.filter(sites__id=site.id)
|
||||
return self.all()
|
||||
|
||||
|
||||
class SocialApp(models.Model):
|
||||
objects = SocialAppManager()
|
||||
|
||||
# The provider type, e.g. "google", "telegram", "saml".
|
||||
provider = models.CharField(
|
||||
verbose_name=_("provider"),
|
||||
max_length=30,
|
||||
)
|
||||
# For providers that support subproviders, such as OpenID Connect and SAML,
|
||||
# this ID identifies that instance. SocialAccount's originating from app
|
||||
# will have their `provider` field set to the `provider_id` if available,
|
||||
# else `provider`.
|
||||
provider_id = models.CharField(
|
||||
verbose_name=_("provider ID"),
|
||||
max_length=200,
|
||||
blank=True,
|
||||
)
|
||||
name = models.CharField(verbose_name=_("name"), max_length=40)
|
||||
client_id = models.CharField(
|
||||
verbose_name=_("client id"),
|
||||
max_length=191,
|
||||
help_text=_("App ID, or consumer key"),
|
||||
)
|
||||
secret = models.CharField(
|
||||
verbose_name=_("secret key"),
|
||||
max_length=191,
|
||||
blank=True,
|
||||
help_text=_("API secret, client secret, or consumer secret"),
|
||||
)
|
||||
key = models.CharField(
|
||||
verbose_name=_("key"), max_length=191, blank=True, help_text=_("Key")
|
||||
)
|
||||
settings = models.JSONField(default=dict, blank=True)
|
||||
|
||||
if allauth.app_settings.SITES_ENABLED:
|
||||
# Most apps can be used across multiple domains, therefore we use
|
||||
# a ManyToManyField. Note that Facebook requires an app per domain
|
||||
# (unless the domains share a common base name).
|
||||
# blank=True allows for disabling apps without removing them
|
||||
sites = models.ManyToManyField("sites.Site", blank=True) # type: ignore[var-annotated]
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("social application")
|
||||
verbose_name_plural = _("social applications")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def get_provider(self, request):
|
||||
provider_class = providers.registry.get_class(self.provider)
|
||||
return provider_class(request=request, app=self)
|
||||
|
||||
|
||||
class SocialAccount(models.Model):
|
||||
user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
|
||||
# Given a `SocialApp` from which this account originates, this field equals
|
||||
# the app's `app.provider_id` if available, `app.provider` otherwise.
|
||||
provider = models.CharField(
|
||||
verbose_name=_("provider"),
|
||||
max_length=200,
|
||||
)
|
||||
# Just in case you're wondering if an OpenID identity URL is going
|
||||
# to fit in a 'uid':
|
||||
#
|
||||
# Ideally, URLField(max_length=1024, unique=True) would be used
|
||||
# for identity. However, MySQL has a max_length limitation of 191
|
||||
# for URLField (in case of utf8mb4). How about
|
||||
# models.TextField(unique=True) then? Well, that won't work
|
||||
# either for MySQL due to another bug[1]. So the only way out
|
||||
# would be to drop the unique constraint, or switch to shorter
|
||||
# identity URLs. Opted for the latter, as [2] suggests that
|
||||
# identity URLs are supposed to be short anyway, at least for the
|
||||
# old spec.
|
||||
#
|
||||
# [1] http://code.djangoproject.com/ticket/2495.
|
||||
# [2] http://openid.net/specs/openid-authentication-1_1.html#limits
|
||||
|
||||
uid = models.CharField(
|
||||
verbose_name=_("uid"), max_length=app_settings.UID_MAX_LENGTH
|
||||
)
|
||||
last_login = models.DateTimeField(verbose_name=_("last login"), auto_now=True)
|
||||
date_joined = models.DateTimeField(verbose_name=_("date joined"), auto_now_add=True)
|
||||
extra_data = models.JSONField(verbose_name=_("extra data"), default=dict)
|
||||
|
||||
class Meta:
|
||||
unique_together = ("provider", "uid")
|
||||
verbose_name = _("social account")
|
||||
verbose_name_plural = _("social accounts")
|
||||
|
||||
def authenticate(self):
|
||||
return authenticate(account=self)
|
||||
|
||||
def __str__(self):
|
||||
from .helpers import socialaccount_user_display
|
||||
|
||||
return socialaccount_user_display(self)
|
||||
|
||||
def get_profile_url(self):
|
||||
return self.get_provider_account().get_profile_url()
|
||||
|
||||
def get_avatar_url(self):
|
||||
return self.get_provider_account().get_avatar_url()
|
||||
|
||||
def get_provider(self, request=None):
|
||||
provider = getattr(self, "_provider", None)
|
||||
if provider:
|
||||
return provider
|
||||
adapter = get_adapter()
|
||||
provider = self._provider = adapter.get_provider(
|
||||
request or context.request, provider=self.provider
|
||||
)
|
||||
return provider
|
||||
|
||||
def get_provider_account(self):
|
||||
return self.get_provider().wrap_account(self)
|
||||
|
||||
|
||||
class SocialToken(models.Model):
|
||||
app = models.ForeignKey(SocialApp, on_delete=models.SET_NULL, blank=True, null=True)
|
||||
account = models.ForeignKey(SocialAccount, on_delete=models.CASCADE)
|
||||
token = models.TextField(
|
||||
verbose_name=_("token"),
|
||||
help_text=_('"oauth_token" (OAuth1) or access token (OAuth2)'),
|
||||
)
|
||||
token_secret = models.TextField(
|
||||
blank=True,
|
||||
verbose_name=_("token secret"),
|
||||
help_text=_('"oauth_token_secret" (OAuth1) or refresh token (OAuth2)'),
|
||||
)
|
||||
expires_at = models.DateTimeField(
|
||||
blank=True, null=True, verbose_name=_("expires at")
|
||||
)
|
||||
|
||||
class Meta:
|
||||
unique_together = ("app", "account")
|
||||
verbose_name = _("social application token")
|
||||
verbose_name_plural = _("social application tokens")
|
||||
|
||||
def __str__(self):
|
||||
return "%s (%s)" % (self._meta.verbose_name, self.pk)
|
||||
|
||||
|
||||
class SocialLogin:
|
||||
"""
|
||||
Represents a social user that is in the process of being logged
|
||||
in. This consists of the following information:
|
||||
|
||||
`account` (`SocialAccount` instance): The social account being
|
||||
logged in. Providers are not responsible for checking whether or
|
||||
not an account already exists or not. Therefore, a provider
|
||||
typically creates a new (unsaved) `SocialAccount` instance. The
|
||||
`User` instance pointed to by the account (`account.user`) may be
|
||||
prefilled by the provider for use as a starting point later on
|
||||
during the signup process.
|
||||
|
||||
`token` (`SocialToken` instance): An optional access token token
|
||||
that results from performing a successful authentication
|
||||
handshake.
|
||||
|
||||
`state` (`dict`): The state to be preserved during the
|
||||
authentication handshake. Note that this state may end up in the
|
||||
url -- do not put any secrets in here. It currently only contains
|
||||
the url to redirect to after login.
|
||||
|
||||
`email_addresses` (list of `EmailAddress`): Optional list of
|
||||
email addresses retrieved from the provider.
|
||||
"""
|
||||
|
||||
account: SocialAccount
|
||||
token: Optional[SocialToken]
|
||||
email_addresses: List[EmailAddress]
|
||||
state: Dict
|
||||
_did_authenticate_by_email: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user=None,
|
||||
account: Optional[SocialAccount] = None,
|
||||
token: Optional[SocialToken] = None,
|
||||
email_addresses: Optional[List[EmailAddress]] = None,
|
||||
):
|
||||
if token:
|
||||
assert token.account is None or token.account == account
|
||||
self.token = token
|
||||
self.user = user
|
||||
if account:
|
||||
self.account = account
|
||||
self.email_addresses = email_addresses if email_addresses else []
|
||||
self.state = {}
|
||||
|
||||
def connect(self, request, user) -> None:
|
||||
self.user = user
|
||||
self.save(request, connect=True)
|
||||
signals.social_account_added.send(
|
||||
sender=SocialLogin, request=request, sociallogin=self
|
||||
)
|
||||
|
||||
get_adapter().send_notification_mail(
|
||||
"socialaccount/email/account_connected",
|
||||
self.user,
|
||||
context={
|
||||
"account": self.account,
|
||||
"provider": self.account.get_provider(),
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def is_headless(self) -> bool:
|
||||
return bool(self.state.get("headless"))
|
||||
|
||||
def serialize(self) -> Dict[str, Any]:
|
||||
serialize_instance = get_adapter().serialize_instance
|
||||
ret = dict(
|
||||
account=serialize_instance(self.account),
|
||||
user=serialize_instance(self.user),
|
||||
state=self.state,
|
||||
email_addresses=[serialize_instance(ea) for ea in self.email_addresses],
|
||||
)
|
||||
if self.token:
|
||||
ret["token"] = serialize_instance(self.token)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: Dict[str, Any]) -> "SocialLogin":
|
||||
deserialize_instance = get_adapter().deserialize_instance
|
||||
account = deserialize_instance(SocialAccount, data["account"])
|
||||
user = deserialize_instance(get_user_model(), data["user"])
|
||||
if "token" in data:
|
||||
token = deserialize_instance(SocialToken, data["token"])
|
||||
else:
|
||||
token = None
|
||||
email_addresses = []
|
||||
for ea in data["email_addresses"]:
|
||||
email_address = deserialize_instance(EmailAddress, ea)
|
||||
email_addresses.append(email_address)
|
||||
ret = cls()
|
||||
ret.token = token
|
||||
ret.account = account
|
||||
ret.user = user
|
||||
ret.email_addresses = email_addresses
|
||||
ret.state = data["state"]
|
||||
return ret
|
||||
|
||||
def save(self, request, connect: bool = False) -> None:
|
||||
"""
|
||||
Saves a new account. Note that while the account is new,
|
||||
the user may be an existing one (when connecting accounts)
|
||||
"""
|
||||
user = self.user
|
||||
user.save()
|
||||
self.account.user = user
|
||||
self.account.save()
|
||||
if app_settings.STORE_TOKENS and self.token:
|
||||
self.token.account = self.account
|
||||
self.token.save()
|
||||
if connect:
|
||||
# TODO: Add any new email addresses automatically?
|
||||
pass
|
||||
else:
|
||||
setup_user_email(request, user, self.email_addresses)
|
||||
|
||||
@property
|
||||
def is_existing(self) -> bool:
|
||||
"""When `False`, this social login represents a temporary account, not
|
||||
yet backed by a database record.
|
||||
"""
|
||||
if self.user.pk is None:
|
||||
return False
|
||||
return get_user_model().objects.filter(pk=self.user.pk).exists()
|
||||
|
||||
def lookup(self) -> None:
|
||||
"""Look up the existing local user account to which this social login
|
||||
points, if any.
|
||||
"""
|
||||
self._did_authenticate_by_email = None
|
||||
if not self._lookup_by_socialaccount():
|
||||
self._lookup_by_email()
|
||||
|
||||
def _lookup_by_socialaccount(self) -> bool:
|
||||
assert not self.is_existing
|
||||
try:
|
||||
a = SocialAccount.objects.get(
|
||||
provider=self.account.provider, uid=self.account.uid
|
||||
)
|
||||
# Update account
|
||||
a.extra_data = self.account.extra_data
|
||||
self.account = a
|
||||
self.user = self.account.user
|
||||
a.save()
|
||||
signals.social_account_updated.send(
|
||||
sender=SocialLogin, request=context.request, sociallogin=self
|
||||
)
|
||||
self._store_token()
|
||||
return True
|
||||
except SocialAccount.DoesNotExist:
|
||||
return False
|
||||
|
||||
def _store_token(self) -> None:
|
||||
# Update token
|
||||
if not app_settings.STORE_TOKENS or not self.token:
|
||||
return
|
||||
assert not self.token.pk
|
||||
app = self.token.app
|
||||
if app and not app.pk:
|
||||
# If the app is not stored in the db, leave the FK empty.
|
||||
app = None
|
||||
try:
|
||||
t = SocialToken.objects.get(account=self.account, app=app)
|
||||
t.token = self.token.token
|
||||
if self.token.token_secret:
|
||||
# only update the refresh token if we got one
|
||||
# many oauth2 providers do not resend the refresh token
|
||||
t.token_secret = self.token.token_secret
|
||||
t.expires_at = self.token.expires_at
|
||||
t.save()
|
||||
self.token = t
|
||||
except SocialToken.DoesNotExist:
|
||||
self.token.account = self.account
|
||||
self.token.app = app
|
||||
self.token.save()
|
||||
|
||||
def _lookup_by_email(self) -> None:
|
||||
emails = [e.email for e in self.email_addresses if e.verified]
|
||||
for email in emails:
|
||||
if not get_adapter().can_authenticate_by_email(self, email):
|
||||
continue
|
||||
users = filter_users_by_email(email, prefer_verified=True)
|
||||
if users:
|
||||
self.user = users[0]
|
||||
self._did_authenticate_by_email = email
|
||||
return
|
||||
|
||||
def _accept_login(self, request) -> None:
|
||||
from allauth.socialaccount.internal.flows.email_authentication import (
|
||||
wipe_password,
|
||||
)
|
||||
|
||||
if self._did_authenticate_by_email:
|
||||
wipe_password(request, self.user, self._did_authenticate_by_email)
|
||||
if app_settings.EMAIL_AUTHENTICATION_AUTO_CONNECT:
|
||||
self.connect(context.request, self.user)
|
||||
|
||||
def get_redirect_url(self, request) -> Optional[str]:
|
||||
url = self.state.get("next")
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def state_from_request(cls, request) -> Dict[str, Any]:
|
||||
"""
|
||||
TODO: Deprecated! To be integrated with provider.redirect()
|
||||
"""
|
||||
state = {}
|
||||
next_url = get_next_redirect_url(request)
|
||||
if next_url:
|
||||
state["next"] = next_url
|
||||
state["process"] = get_request_param(request, "process", "login")
|
||||
state["scope"] = get_request_param(request, "scope", "")
|
||||
state["auth_params"] = get_request_param(request, "auth_params", "")
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def stash_state(cls, request, state: Optional[Dict[str, Any]] = None) -> str:
|
||||
if state is None:
|
||||
# Only for providers that don't support redirect() yet.
|
||||
state = cls.state_from_request(request)
|
||||
return statekit.stash_state(request, state)
|
||||
|
||||
@classmethod
|
||||
def unstash_state(cls, request) -> Optional[Dict[str, Any]]:
|
||||
state = statekit.unstash_last_state(request)
|
||||
if state is None:
|
||||
raise PermissionDenied()
|
||||
return state
|
||||
@@ -0,0 +1,57 @@
|
||||
import importlib
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
|
||||
from allauth.utils import import_attribute
|
||||
|
||||
|
||||
class ProviderRegistry:
|
||||
def __init__(self):
|
||||
self.provider_map = OrderedDict()
|
||||
self.loaded = False
|
||||
|
||||
def get_class_list(self):
|
||||
self.load()
|
||||
return list(self.provider_map.values())
|
||||
|
||||
def register(self, cls):
|
||||
self.provider_map[cls.id] = cls
|
||||
|
||||
def get_class(self, id):
|
||||
return self.provider_map.get(id)
|
||||
|
||||
def as_choices(self):
|
||||
self.load()
|
||||
for provider_cls in self.provider_map.values():
|
||||
yield (provider_cls.id, provider_cls.name)
|
||||
|
||||
def load(self):
|
||||
# TODO: Providers register with the provider registry when
|
||||
# loaded. Here, we build the URLs for all registered providers. So, we
|
||||
# really need to be sure all providers did register, which is why we're
|
||||
# forcefully importing the `provider` modules here. The overall
|
||||
# mechanism is way to magical and depends on the import order et al, so
|
||||
# all of this really needs to be revisited.
|
||||
if not self.loaded:
|
||||
for app_config in apps.get_app_configs():
|
||||
try:
|
||||
module_name = app_config.name + ".provider"
|
||||
provider_module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
if e.name != module_name:
|
||||
raise
|
||||
else:
|
||||
provider_settings = getattr(settings, "SOCIALACCOUNT_PROVIDERS", {})
|
||||
for cls in getattr(provider_module, "provider_classes", []):
|
||||
provider_class = provider_settings.get(cls.id, {}).get(
|
||||
"provider_class"
|
||||
)
|
||||
if provider_class:
|
||||
cls = import_attribute(provider_class)
|
||||
self.register(cls)
|
||||
self.loaded = True
|
||||
|
||||
|
||||
registry = ProviderRegistry()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,37 @@
|
||||
from allauth.socialaccount.providers.agave.views import AgaveAdapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AgaveAccount(ProviderAccount):
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("web_url", "dflt")
|
||||
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("avatar_url", "dflt")
|
||||
|
||||
|
||||
class AgaveProvider(OAuth2Provider):
|
||||
id = "agave"
|
||||
name = "Agave"
|
||||
account_class = AgaveAccount
|
||||
oauth2_adapter_class = AgaveAdapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data.get("create_time"))
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(
|
||||
email=data.get("email"),
|
||||
username=data.get("username", ""),
|
||||
name=(
|
||||
(data.get("first_name", "") + " " + data.get("last_name", "")).strip()
|
||||
),
|
||||
)
|
||||
|
||||
def get_default_scope(self):
|
||||
scope = ["PRODUCTION"]
|
||||
return scope
|
||||
|
||||
|
||||
provider_classes = [AgaveProvider]
|
||||
@@ -0,0 +1,34 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AgaveProvider
|
||||
|
||||
|
||||
class AgaveTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AgaveProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{
|
||||
"status": "success",
|
||||
"message": "User details retrieved successfully.",
|
||||
"version": "2.0.0-SNAPSHOT-rc3fad",
|
||||
"result": {
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"full_name": "John Doe",
|
||||
"email": "jon@doe.edu",
|
||||
"phone": "",
|
||||
"mobile_phone": "",
|
||||
"status": "Active",
|
||||
"create_time": "20180322043812Z",
|
||||
"username": "jdoe"
|
||||
}
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "jdoe"
|
||||
@@ -0,0 +1,5 @@
|
||||
from allauth.socialaccount.providers.agave.provider import AgaveProvider
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AgaveProvider)
|
||||
@@ -0,0 +1,41 @@
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AgaveAdapter(OAuth2Adapter):
|
||||
provider_id = "agave"
|
||||
|
||||
settings = app_settings.PROVIDERS.get(provider_id, {})
|
||||
provider_base_url = settings.get("API_URL", "https://public.agaveapi.co")
|
||||
|
||||
access_token_url = "{0}/token".format(provider_base_url)
|
||||
authorize_url = "{0}/authorize".format(provider_base_url)
|
||||
profile_url = "{0}/profiles/v2/me".format(provider_base_url)
|
||||
|
||||
def complete_login(self, request, app, token, response):
|
||||
extra_data = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(
|
||||
self.profile_url,
|
||||
params={"access_token": token.token},
|
||||
headers={
|
||||
"Authorization": "Bearer " + token.token,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
user_profile = (
|
||||
extra_data.json()["result"] if "result" in extra_data.json() else {}
|
||||
)
|
||||
|
||||
return self.get_provider().sociallogin_from_response(request, user_profile)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AgaveAdapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AgaveAdapter)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,34 @@
|
||||
from allauth.socialaccount.providers.amazon.views import AmazonOAuth2Adapter
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AmazonAccount(ProviderAccount):
|
||||
pass
|
||||
|
||||
|
||||
class AmazonProvider(OAuth2Provider):
|
||||
id = "amazon"
|
||||
name = "Amazon"
|
||||
account_class = AmazonAccount
|
||||
oauth2_adapter_class = AmazonOAuth2Adapter
|
||||
|
||||
def get_default_scope(self):
|
||||
return ["profile"]
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["user_id"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
# Hackish way of splitting the fullname.
|
||||
# Assumes no middlenames.
|
||||
name = data.get("name", "")
|
||||
first_name, last_name = name, ""
|
||||
if name and " " in name:
|
||||
first_name, last_name = name.split(" ", 1)
|
||||
return dict(
|
||||
email=data.get("email", ""), last_name=last_name, first_name=first_name
|
||||
)
|
||||
|
||||
|
||||
provider_classes = [AmazonProvider]
|
||||
@@ -0,0 +1,24 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AmazonProvider
|
||||
|
||||
|
||||
class AmazonTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AmazonProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{
|
||||
"Profile":{
|
||||
"CustomerId":"amzn1.account.K2LI23KL2LK2",
|
||||
"Name":"John Doe",
|
||||
"PrimaryEmail":"johndoe@example.com"
|
||||
}
|
||||
}""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "johndoe@example.com"
|
||||
@@ -0,0 +1,6 @@
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
from .provider import AmazonProvider
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AmazonProvider)
|
||||
@@ -0,0 +1,33 @@
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AmazonOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "amazon"
|
||||
access_token_url = "https://api.amazon.com/auth/o2/token"
|
||||
authorize_url = "http://www.amazon.com/ap/oa"
|
||||
profile_url = "https://api.amazon.com/user/profile"
|
||||
|
||||
def complete_login(self, request, app, token, **kwargs):
|
||||
response = (
|
||||
get_adapter()
|
||||
.get_requests_session()
|
||||
.get(self.profile_url, params={"access_token": token.token})
|
||||
)
|
||||
response.raise_for_status()
|
||||
extra_data = response.json()
|
||||
if "Profile" in extra_data:
|
||||
extra_data = {
|
||||
"user_id": extra_data["Profile"]["CustomerId"],
|
||||
"name": extra_data["Profile"]["Name"],
|
||||
"email": extra_data["Profile"]["PrimaryEmail"],
|
||||
}
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data)
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AmazonOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonOAuth2Adapter)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,69 @@
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.providers.amazon_cognito.utils import (
|
||||
convert_to_python_bool_if_value_is_json_string_bool,
|
||||
)
|
||||
from allauth.socialaccount.providers.amazon_cognito.views import (
|
||||
AmazonCognitoOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AmazonCognitoAccount(ProviderAccount):
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("picture")
|
||||
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("profile")
|
||||
|
||||
|
||||
class AmazonCognitoProvider(OAuth2Provider):
|
||||
id = "amazon_cognito"
|
||||
name = "Amazon Cognito"
|
||||
account_class = AmazonCognitoAccount
|
||||
oauth2_adapter_class = AmazonCognitoOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["sub"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return {
|
||||
"email": data.get("email"),
|
||||
"first_name": data.get("given_name"),
|
||||
"last_name": data.get("family_name"),
|
||||
}
|
||||
|
||||
def get_default_scope(self):
|
||||
return ["openid", "profile", "email"]
|
||||
|
||||
def extract_email_addresses(self, data):
|
||||
email = data.get("email")
|
||||
verified = convert_to_python_bool_if_value_is_json_string_bool(
|
||||
data.get("email_verified", False)
|
||||
)
|
||||
|
||||
return (
|
||||
[EmailAddress(email=email, verified=verified, primary=True)]
|
||||
if email
|
||||
else []
|
||||
)
|
||||
|
||||
def extract_extra_data(self, data):
|
||||
ret = dict(data)
|
||||
phone_number_verified = data.get("phone_number_verified")
|
||||
if phone_number_verified is not None:
|
||||
ret["phone_number_verified"] = (
|
||||
convert_to_python_bool_if_value_is_json_string_bool(
|
||||
"phone_number_verified"
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def get_slug(cls):
|
||||
# IMPORTANT: Amazon Cognito does not support `_` characters
|
||||
# as part of their redirect URI.
|
||||
return super(AmazonCognitoProvider, cls).get_slug().replace("_", "-")
|
||||
|
||||
|
||||
provider_classes = [AmazonCognitoProvider]
|
||||
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
import pytest
|
||||
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from allauth.socialaccount.providers.amazon_cognito.provider import (
|
||||
AmazonCognitoProvider,
|
||||
)
|
||||
from allauth.socialaccount.providers.amazon_cognito.utils import (
|
||||
convert_to_python_bool_if_value_is_json_string_bool,
|
||||
)
|
||||
from allauth.socialaccount.providers.amazon_cognito.views import (
|
||||
AmazonCognitoOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
|
||||
def _get_mocked_claims():
|
||||
return {
|
||||
"sub": "[HEROKU-API-KEY-REMOVED]",
|
||||
"given_name": "John",
|
||||
"family_name": "Doe",
|
||||
"email": "jdoe@example.com",
|
||||
"username": "johndoe",
|
||||
}
|
||||
|
||||
|
||||
@override_settings(
|
||||
SOCIALACCOUNT_PROVIDERS={
|
||||
"amazon_cognito": {"DOMAIN": "https://domain.auth.us-east-1.amazoncognito.com"}
|
||||
}
|
||||
)
|
||||
class AmazonCognitoTestCase(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AmazonCognitoProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
mocked_payload = json.dumps(_get_mocked_claims())
|
||||
return MockedResponse(status_code=200, content=mocked_payload)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "johndoe"
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={"amazon_cognito": {}})
|
||||
def test_oauth2_adapter_raises_if_domain_settings_is_missing(
|
||||
self,
|
||||
):
|
||||
mocked_response = self.get_mocked_response()
|
||||
|
||||
with self.assertRaises(
|
||||
ValueError,
|
||||
msg=AmazonCognitoOAuth2Adapter.DOMAIN_KEY_MISSING_ERROR,
|
||||
):
|
||||
self.login(mocked_response)
|
||||
|
||||
def test_saves_email_as_verified_if_email_is_verified_in_cognito(
|
||||
self,
|
||||
):
|
||||
mocked_claims = _get_mocked_claims()
|
||||
mocked_claims["email_verified"] = True
|
||||
mocked_payload = json.dumps(mocked_claims)
|
||||
mocked_response = MockedResponse(status_code=200, content=mocked_payload)
|
||||
|
||||
self.login(mocked_response)
|
||||
|
||||
user_id = SocialAccount.objects.get(uid=mocked_claims["sub"]).user_id
|
||||
email_address = EmailAddress.objects.get(user_id=user_id)
|
||||
|
||||
self.assertEqual(email_address.email, mocked_claims["email"])
|
||||
self.assertTrue(email_address.verified)
|
||||
|
||||
def test_provider_slug_replaces_underscores_with_hyphens(self):
|
||||
self.assertTrue("_" not in self.provider.get_slug())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,output",
|
||||
[
|
||||
(True, True),
|
||||
("true", True),
|
||||
("false", False),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_convert_bool(input, output):
|
||||
assert convert_to_python_bool_if_value_is_json_string_bool(input) == output
|
||||
@@ -0,0 +1,7 @@
|
||||
from allauth.socialaccount.providers.amazon_cognito.provider import (
|
||||
AmazonCognitoProvider,
|
||||
)
|
||||
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
|
||||
|
||||
|
||||
urlpatterns = default_urlpatterns(AmazonCognitoProvider)
|
||||
@@ -0,0 +1,7 @@
|
||||
def convert_to_python_bool_if_value_is_json_string_bool(s):
|
||||
if s == "true":
|
||||
return True
|
||||
elif s == "false":
|
||||
return False
|
||||
|
||||
return s
|
||||
@@ -0,0 +1,56 @@
|
||||
from allauth.socialaccount import app_settings
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialToken
|
||||
from allauth.socialaccount.providers.oauth2.views import (
|
||||
OAuth2Adapter,
|
||||
OAuth2CallbackView,
|
||||
OAuth2LoginView,
|
||||
)
|
||||
|
||||
|
||||
class AmazonCognitoOAuth2Adapter(OAuth2Adapter):
|
||||
provider_id = "amazon_cognito"
|
||||
|
||||
DOMAIN_KEY_MISSING_ERROR = (
|
||||
'"DOMAIN" key is missing in Amazon Cognito configuration.'
|
||||
)
|
||||
|
||||
@property
|
||||
def settings(self):
|
||||
return app_settings.PROVIDERS.get(self.provider_id, {})
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
domain = self.settings.get("DOMAIN")
|
||||
|
||||
if domain is None:
|
||||
raise ValueError(self.DOMAIN_KEY_MISSING_ERROR)
|
||||
|
||||
return domain
|
||||
|
||||
@property
|
||||
def access_token_url(self):
|
||||
return "{}/oauth2/token".format(self.domain)
|
||||
|
||||
@property
|
||||
def authorize_url(self):
|
||||
return "{}/oauth2/authorize".format(self.domain)
|
||||
|
||||
@property
|
||||
def profile_url(self):
|
||||
return "{}/oauth2/userInfo".format(self.domain)
|
||||
|
||||
def complete_login(self, request, app, token: SocialToken, **kwargs):
|
||||
headers = {
|
||||
"Authorization": "Bearer {}".format(token.token),
|
||||
}
|
||||
extra_data = (
|
||||
get_adapter().get_requests_session().get(self.profile_url, headers=headers)
|
||||
)
|
||||
extra_data.raise_for_status()
|
||||
|
||||
return self.get_provider().sociallogin_from_response(request, extra_data.json())
|
||||
|
||||
|
||||
oauth2_login = OAuth2LoginView.adapter_view(AmazonCognitoOAuth2Adapter)
|
||||
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonCognitoOAuth2Adapter)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,33 @@
|
||||
from allauth.socialaccount.providers.angellist.views import (
|
||||
AngelListOAuth2Adapter,
|
||||
)
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
|
||||
|
||||
|
||||
class AngelListAccount(ProviderAccount):
|
||||
def get_profile_url(self):
|
||||
return self.account.extra_data.get("angellist_url")
|
||||
|
||||
def get_avatar_url(self):
|
||||
return self.account.extra_data.get("image")
|
||||
|
||||
|
||||
class AngelListProvider(OAuth2Provider):
|
||||
id = "angellist"
|
||||
name = "AngelList"
|
||||
account_class = AngelListAccount
|
||||
oauth2_adapter_class = AngelListOAuth2Adapter
|
||||
|
||||
def extract_uid(self, data):
|
||||
return str(data["id"])
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
return dict(
|
||||
email=data.get("email"),
|
||||
username=data.get("angellist_url").split("/")[-1],
|
||||
name=data.get("name"),
|
||||
)
|
||||
|
||||
|
||||
provider_classes = [AngelListProvider]
|
||||
@@ -0,0 +1,28 @@
|
||||
from allauth.socialaccount.tests import OAuth2TestsMixin
|
||||
from allauth.tests import MockedResponse, TestCase
|
||||
|
||||
from .provider import AngelListProvider
|
||||
|
||||
|
||||
class AngelListTests(OAuth2TestsMixin, TestCase):
|
||||
provider_id = AngelListProvider.id
|
||||
|
||||
def get_mocked_response(self):
|
||||
return MockedResponse(
|
||||
200,
|
||||
"""
|
||||
{"name":"pennersr","id":424732,"bio":"","follower_count":0,
|
||||
"angellist_url":"https://angel.co/dsxtst",
|
||||
"image":"https://angel.co/images/shared/nopic.png",
|
||||
"email":"raymond.penners@example.com","blog_url":null,
|
||||
"online_bio_url":null,"twitter_url":"https://twitter.com/dsxtst",
|
||||
"facebook_url":null,"linkedin_url":null,"aboutme_url":null,
|
||||
"github_url":null,"dribbble_url":null,"behance_url":null,
|
||||
"what_ive_built":null,"locations":[],"roles":[],"skills":[],
|
||||
"investor":false,"scopes":["message","talent","dealflow","comment",
|
||||
"email"]}
|
||||
""",
|
||||
)
|
||||
|
||||
def get_expected_to_str(self):
|
||||
return "raymond.penners@example.com"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user