# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2015, 2016 CERN.
#
# Invenio is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; either version 2 of the
# License, or (at your option) any later version.
#
# Invenio is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Invenio; if not, write to the
# Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
# MA 02111-1307, USA.
#
# In applying this license, CERN does not
# waive the privileges and immunities granted to it by virtue of its status
# as an Intergovernmental Organization or submit itself to any jurisdiction.
"""OAuth2Server models."""
from __future__ import absolute_import, print_function
import six
from flask import current_app
from flask_babelex import lazy_gettext as _
from flask_login import current_user
from invenio_accounts.models import User
from invenio_db import db
from sqlalchemy.schema import Index
from sqlalchemy_utils.types import URLType
from sqlalchemy_utils.types.encrypted import AesEngine, EncryptedType
from werkzeug.security import gen_salt
from wtforms import validators
from .errors import ScopeDoesNotExists
from .proxies import current_oauth2server
from .validators import validate_redirect_uri, validate_scopes
[docs]def secret_key():
"""Return secret key as bytearray."""
return current_app.config['SECRET_KEY'].encode('utf-8')
[docs]class NoneAesEngine(AesEngine):
"""Filter None values from encrypting."""
[docs] def encrypt(self, value):
"""Encrypt a value on the way in."""
if value is not None:
return super(NoneAesEngine, self).encrypt(value)
[docs] def decrypt(self, value):
"""Decrypt value on the way out."""
if value is not None:
return super(NoneAesEngine, self).decrypt(value)
[docs]class OAuthUserProxy(object):
"""Proxy object to an Invenio User."""
def __init__(self, user):
"""Initialize proxy object with user instance."""
self._user = user
def __getattr__(self, name):
"""Pass any undefined attribute to the underlying object."""
return getattr(self._user, name)
def __getstate__(self):
"""Return the id."""
return self.id
def __setstate__(self, state):
"""Set user info."""
self._user = current_app.extensions['security'].datastore.get_user(
state)
@property
def id(self):
"""Return user identifier."""
return self._user.get_id()
[docs] def check_password(self, password):
"""Check user password."""
return self.password == password
@classmethod
[docs] def get_current_user(cls):
"""Return an instance of current user object."""
return cls(current_user._get_current_object())
[docs]class Scope(object):
"""OAuth scope definition."""
def __init__(self, id_, help_text='', group='', internal=False):
"""Initialize scope values."""
self.id = id_
self.group = group
self.help_text = help_text
self.is_internal = internal
[docs]class Client(db.Model):
"""A client is the app which want to use the resource of a user.
It is suggested that the client is registered by a user on your site, but
it is not required.
The client should contain at least these information:
client_id: A random string
client_secret: A random string
client_type: A string represents if it is confidential
redirect_uris: A list of redirect uris
default_redirect_uri: One of the redirect uris
default_scopes: Default scopes of the client
But it could be better, if you implemented:
allowed_grant_types: A list of grant types
allowed_response_types: A list of response types
validate_scopes: A function to validate scopes
"""
__tablename__ = 'oauth2server_client'
name = db.Column(
db.String(40),
info=dict(
label=_('Name'),
description=_('Name of application (displayed to users).'),
validators=[validators.DataRequired()]
)
)
"""Human readable name of the application."""
description = db.Column(
db.Text(),
default=u'',
info=dict(
label=_('Description'),
description=_('Optional. Description of the application'
' (displayed to users).'),
)
)
"""Human readable description."""
website = db.Column(
URLType(),
info=dict(
label=_('Website URL'),
description=_('URL of your application (displayed to users).'),
),
default=u'',
)
user_id = db.Column(db.ForeignKey(User.id), nullable=True)
"""Creator of the client application."""
client_id = db.Column(db.String(255), primary_key=True)
"""Client application ID."""
client_secret = db.Column(
db.String(255), unique=True, index=True, nullable=False
)
"""Client application secret."""
is_confidential = db.Column(
db.Boolean(name='is_confidential'),
default=True
)
"""Determine if client application is public or not."""
is_internal = db.Column(db.Boolean(name='is_internal'), default=False)
"""Determins if client application is an internal application."""
_redirect_uris = db.Column(db.Text)
"""A newline-separated list of redirect URIs. First is the default URI."""
_default_scopes = db.Column(db.Text)
"""A space-separated list of default scopes of the client.
The value of the scope parameter is expressed as a list of space-delimited,
case-sensitive strings.
"""
user = db.relationship(
User,
backref=db.backref(
"oauth2clients",
cascade="all, delete-orphan",
)
)
"""Relationship to user."""
@property
def allowed_grant_types(self):
"""Return allowed grant types."""
return current_app.config['OAUTH2SERVER_ALLOWED_GRANT_TYPES']
@property
def allowed_response_types(self):
"""Return allowed response types."""
return current_app.config['OAUTH2SERVER_ALLOWED_RESPONSE_TYPES']
# def validate_scopes(self, scopes):
# return self._validate_scopes
@property
def client_type(self):
"""Return client type."""
if self.is_confidential:
return 'confidential'
return 'public'
@property
def redirect_uris(self):
"""Return redirect uris."""
if self._redirect_uris:
return self._redirect_uris.splitlines()
return []
@redirect_uris.setter
def redirect_uris(self, value):
"""Validate and store redirect URIs for client."""
if isinstance(value, six.text_type):
value = value.split("\n")
value = [v.strip() for v in value]
for v in value:
validate_redirect_uri(v)
self._redirect_uris = "\n".join(value) or ""
@property
def default_redirect_uri(self):
"""Return default redirect uri."""
try:
return self.redirect_uris[0]
except IndexError:
pass
@property
def default_scopes(self):
"""List of default scopes for client."""
if self._default_scopes:
return self._default_scopes.split(" ")
return []
@default_scopes.setter
def default_scopes(self, scopes):
"""Set default scopes for client."""
validate_scopes(scopes)
self._default_scopes = " ".join(set(scopes)) if scopes else ""
[docs] def validate_scopes(self, scopes):
"""Validate if client is allowed to access scopes."""
try:
validate_scopes(scopes)
return True
except ScopeDoesNotExists:
return False
[docs] def gen_salt(self):
"""Generate salt."""
self.reset_client_id()
self.reset_client_secret()
[docs] def reset_client_id(self):
"""Reset client id."""
self.client_id = gen_salt(
current_app.config.get('OAUTH2SERVER_CLIENT_ID_SALT_LEN')
)
[docs] def reset_client_secret(self):
"""Reset client secret."""
self.client_secret = gen_salt(
current_app.config.get('OAUTH2SERVER_CLIENT_SECRET_SALT_LEN')
)
[docs]class Token(db.Model):
"""A bearer token is the final token that can be used by the client."""
__tablename__ = 'oauth2server_token'
__table_args__ = (
Index('ix_oauth2server_token_access_token',
'access_token',
unique=True,
mysql_length=255),
Index('ix_oauth2server_token_refresh_token',
'refresh_token',
unique=True,
mysql_length=255),
)
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
"""Object ID."""
client_id = db.Column(
db.String(255), db.ForeignKey(Client.client_id),
nullable=False,
)
"""Foreign key to client application."""
client = db.relationship(
'Client',
backref=db.backref(
'oauth2tokens',
cascade="all, delete-orphan"
))
"""SQLAlchemy relationship to client application."""
user_id = db.Column(
db.Integer, db.ForeignKey(User.id), nullable=True
)
"""Foreign key to user."""
user = db.relationship(
User,
backref=db.backref(
"oauth2tokens",
cascade="all, delete-orphan",
)
)
"""SQLAlchemy relationship to user."""
token_type = db.Column(db.String(255), default='bearer')
"""Token type - only bearer is supported at the moment."""
access_token = db.Column(
EncryptedType(
type_in=db.String(255),
key=secret_key,
),
)
refresh_token = db.Column(
EncryptedType(
type_in=db.String(255),
key=secret_key,
engine=NoneAesEngine,
),
nullable=True,
)
expires = db.Column(db.DateTime, nullable=True)
_scopes = db.Column(db.Text)
is_personal = db.Column(db.Boolean(name='is_personal'), default=False)
"""Personal accesss token."""
is_internal = db.Column(db.Boolean(name='is_internal'), default=False)
"""Determines if token is an internally generated token."""
@property
def scopes(self):
"""Return all scopes.
:returns: A list of scopes.
"""
if self._scopes:
return self._scopes.split()
return []
@scopes.setter
def scopes(self, scopes):
"""Set scopes.
:param scopes: The list of scopes.
"""
validate_scopes(scopes)
self._scopes = " ".join(set(scopes)) if scopes else ""
[docs] def get_visible_scopes(self):
"""Get list of non-internal scopes for token.
:returns: A list of scopes.
"""
return [k for k, s in current_oauth2server.scope_choices()
if k in self.scopes]
@classmethod
[docs] def create_personal(cls, name, user_id, scopes=None, is_internal=False):
"""Create a personal access token.
A token that is bound to a specific user and which doesn't expire, i.e.
similar to the concept of an API key.
:param name: Client name.
:param user_id: User ID.
:param scopes: The list of permitted scopes. (Default: ``None``)
:param is_internal: If ``True`` it's a internal access token.
(Default: ``False``)
:returns: A new access token.
"""
with db.session.begin_nested():
scopes = " ".join(scopes) if scopes else ""
c = Client(
name=name,
user_id=user_id,
is_internal=True,
is_confidential=False,
_default_scopes=scopes
)
c.gen_salt()
t = Token(
client_id=c.client_id,
user_id=user_id,
access_token=gen_salt(
current_app.config.get(
'OAUTH2SERVER_TOKEN_PERSONAL_SALT_LEN')
),
expires=None,
_scopes=scopes,
is_personal=True,
is_internal=is_internal,
)
db.session.add(c)
db.session.add(t)
return t