Source code for invenio_records_rest.views

# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2015, 2016 CERN.
#
# Invenio is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; either version 2 of the
# License, or (at your option) any later version.
#
# Invenio is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Invenio; if not, write to the
# Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
# MA 02111-1307, USA.
#
# In applying this license, CERN does not
# waive the privileges and immunities granted to it by virtue of its status
# as an Intergovernmental Organization or submit itself to any jurisdiction.

"""REST API resources."""

from __future__ import absolute_import, print_function

import copy
import uuid
from functools import partial, wraps

from elasticsearch.exceptions import RequestError
from flask import Blueprint, abort, current_app, jsonify, make_response, \
    request, url_for
from flask.views import MethodView
from flask_babelex import gettext as _
from invenio_db import db
from invenio_pidstore import current_pidstore
from invenio_pidstore.models import PersistentIdentifier
from invenio_pidstore.resolver import Resolver
from invenio_records.api import Record
from invenio_rest import ContentNegotiatedMethodView
from invenio_rest.decorators import require_content_types
from invenio_rest.errors import RESTValidationError
from invenio_search import RecordsSearch
from jsonpatch import JsonPatchException, JsonPointerException
from jsonschema.exceptions import ValidationError
from sqlalchemy.exc import SQLAlchemyError

from .errors import InvalidDataRESTError, InvalidQueryRESTError, \
    MaxResultWindowRESTError, PatchJSONFailureRESTError, PIDResolveRESTError, \
    SuggestMissingContextRESTError, SuggestNoCompletionsRESTError, \
    UnsupportedMediaRESTError
from .links import default_links_factory
from .proxies import current_records_rest
from .query import default_search_factory
from .utils import obj_or_import_string


def elasticsearch_query_parsing_exception_handler(error):
    """Handle query parsing exceptions from ElasticSearch."""
    description = _('The syntax of the search query is invalid.')
    return InvalidQueryRESTError(description=description).get_response()


def create_error_handlers(blueprint):
    """Create error handlers on blueprint.

    :params blueprint: Records API blueprint.
    :returns: Configured blueprint.
    """
    # Catch record validation errors
    @blueprint.errorhandler(ValidationError)
    def validation_error(error):
        """Catch validation errors."""
        return RESTValidationError().get_response()

    @blueprint.errorhandler(RequestError)
    def elasticsearch_badrequest_error(error):
        """Catch errors of ElasticSearch."""
        handlers = current_app.config[
            'RECORDS_REST_ELASTICSEARCH_ERROR_HANDLERS']
        cause_types = {c['type'] for c in error.info['error']['root_cause']}

        for cause_type, handler in handlers.items():
            if cause_type in cause_types:
                return handler(error)
        return error

    return blueprint


def create_blueprint(endpoints):
    """Create Invenio-Records-REST blueprint.

    :params endpoints: Dictionary representing the endpoints configuration.
    :returns: Configured blueprint.
    """
    blueprint = Blueprint(
        'invenio_records_rest',
        __name__,
        url_prefix='',
    )

    for endpoint, options in (endpoints or {}).items():
        for rule in create_url_rules(endpoint, **options):
            blueprint.add_url_rule(**rule)
    return create_error_handlers(blueprint)


[docs]def create_url_rules(endpoint, list_route=None, item_route=None, pid_type=None, pid_minter=None, pid_fetcher=None, read_permission_factory_imp=None, create_permission_factory_imp=None, update_permission_factory_imp=None, delete_permission_factory_imp=None, record_class=None, record_serializers=None, record_loaders=None, search_class=None, search_serializers=None, search_index=None, search_type=None, default_media_type=None, max_result_window=None, use_options_view=True, search_factory_imp=None, links_factory_imp=None, suggesters=None, default_endpoint_prefix=None): """Create Werkzeug URL rules. :param endpoint: Name of endpoint. :param list_route: Record listing URL route. Required. :param item_route: Record URL route (must include ``<pid_value>`` pattern). Required. :param pid_type: Persistent identifier type for endpoint. Required. :param pid_minter: It identifies the registered minter name. :param pid_fetcher: It identifies the registered fetcher name. :param read_permission_factory_imp: Import path to factory that creates a read permission object for a given record. :param create_permission_factory_imp: Import path to factory that creates a create permission object for a given record. :param update_permission_factory_imp: Import path to factory that creates a update permission object for a given record. :param delete_permission_factory_imp: Import path to factory that creates a delete permission object for a given record. :param default_endpoint_prefix: ignored. :param record_class: A record API class or importable string. :param record_serializers: Serializers used for records. :param record_loaders: It contains the list of record deserializers for supperted formats. :param search_class: Import path or class object for the object in charge of execute the search queries. The default search class is :class:`invenio_search.api.RecordsSearch`. For more information about resource loading, see the Search of ElasticSearch DSL library. :param search_serializers: Serializers used for search results. :param search_index: Name of the search index used when searching records. :param search_type: Name of the search type used when searching records. :param default_media_type: Default media type for both records and search. :param max_result_window: Maximum number of results that Elasticsearch can provide for the given search index without use of scroll. This value should correspond to Elasticsearch ``index.max_result_window`` value for the index. :param use_options_view: Determines if a special option view should be installed. :param search_factory_imp: Factory to parse quieries. :param links_factory_imp: Factory for record links generation. :param suggesters: Suggester fields configuration. :returns: a list of dictionaries with can each be passed as keywords arguments to ``Blueprint.add_url_rule``. """ assert list_route assert item_route assert pid_type assert search_serializers assert record_serializers read_permission_factory = obj_or_import_string( read_permission_factory_imp ) create_permission_factory = obj_or_import_string( create_permission_factory_imp ) update_permission_factory = obj_or_import_string( update_permission_factory_imp ) delete_permission_factory = obj_or_import_string( delete_permission_factory_imp ) links_factory = obj_or_import_string( links_factory_imp, default=default_links_factory ) record_class = obj_or_import_string( record_class, default=Record ) search_class = obj_or_import_string( search_class, default=RecordsSearch ) search_class_kwargs = {} if search_index: search_class_kwargs['index'] = search_index else: search_index = search_class.Meta.index if search_type: search_class_kwargs['doc_type'] = search_type else: search_type = search_class.Meta.doc_types if search_class_kwargs: search_class = partial(search_class, **search_class_kwargs) if record_loaders: record_loaders = {mime: obj_or_import_string(func) for mime, func in record_loaders.items()} record_serializers = {mime: obj_or_import_string(func) for mime, func in record_serializers.items()} search_serializers = {mime: obj_or_import_string(func) for mime, func in search_serializers.items()} resolver = Resolver(pid_type=pid_type, object_type='rec', getter=partial(record_class.get_record, with_deleted=True)) list_view = RecordsListResource.as_view( RecordsListResource.view_name.format(endpoint), resolver=resolver, minter_name=pid_minter, pid_type=pid_type, pid_fetcher=pid_fetcher, read_permission_factory=read_permission_factory, create_permission_factory=create_permission_factory, record_serializers=record_serializers, record_loaders=record_loaders, search_serializers=search_serializers, search_class=search_class, default_media_type=default_media_type, max_result_window=max_result_window, search_factory=(obj_or_import_string( search_factory_imp, default=default_search_factory )), item_links_factory=links_factory, record_class=record_class, ) item_view = RecordResource.as_view( RecordResource.view_name.format(endpoint), resolver=resolver, read_permission_factory=read_permission_factory, update_permission_factory=update_permission_factory, delete_permission_factory=delete_permission_factory, serializers=record_serializers, loaders=record_loaders, search_class=search_class, links_factory=links_factory, default_media_type=default_media_type) views = [ dict(rule=list_route, view_func=list_view), dict(rule=item_route, view_func=item_view), ] if suggesters: suggest_view = SuggestResource.as_view( SuggestResource.view_name.format(endpoint), suggesters=suggesters, search_class=search_class, ) views.append(dict( rule=list_route + '_suggest', view_func=suggest_view )) if use_options_view: options_view = RecordsListOptionsResource.as_view( RecordsListOptionsResource.view_name.format(endpoint), search_index=search_index, max_result_window=max_result_window, default_media_type=default_media_type, search_media_types=search_serializers.keys(), item_media_types=record_serializers.keys(), ) return [ dict(rule="{0}_options".format(list_route), view_func=options_view) ] + views return views
def pass_record(f): """Decorator to retrieve persistent identifier and record.""" @wraps(f) def inner(self, pid_value, *args, **kwargs): try: pid, record = request.view_args['pid_value'].data return f(self, pid=pid, record=record, *args, **kwargs) except SQLAlchemyError: raise PIDResolveRESTError(pid) return inner def verify_record_permission(permission_factory, record): """Check that the current user has the required permissions on record. In case the permission check fails, an Flask abort is launched. If the user was previously logged-in, a HTTP error 403 is returned. Otherwise, is returned a HTTP error 401. :param permission_factory: permission factory used to check permissions. :param record: record whose access is limited. """ # Note, cannot be done in one line due overloading of boolean # operations permission object. if not permission_factory(record=record).can(): from flask_login import current_user if not current_user.is_authenticated: abort(401) abort(403) def need_record_permission(factory_name): """Decorator checking that the user has the required permissions on record. :param factory_name: name of the factory to retrieve. """ def need_record_permission_builder(f): @wraps(f) def need_record_permission_decorator(self, record=None, *args, **kwargs): permission_factory = ( getattr(self, factory_name) or getattr(current_records_rest, factory_name) ) # FIXME use context instead request._methodview = self if permission_factory: verify_record_permission(permission_factory, record) return f(self, record=record, *args, **kwargs) return need_record_permission_decorator return need_record_permission_builder class RecordsListOptionsResource(MethodView): """Resource for displaying options about records list/item views.""" view_name = '{0}_list_options' def __init__(self, search_index=None, max_result_window=None, default_media_type=None, search_media_types=None, item_media_types=None): """Initialize method view.""" self.search_index = search_index self.max_result_window = max_result_window or 10000 self.default_media_type = default_media_type self.item_media_types = item_media_types self.search_media_types = search_media_types def get(self): """Get options.""" opts = current_app.config['RECORDS_REST_SORT_OPTIONS'].get( self.search_index) sort_fields = [] if opts: for key, item in sorted(opts.items(), key=lambda x: x[1]['order']): sort_fields.append( {key: dict( title=item['title'], default_order=item.get('default_order', 'asc'))} ) return jsonify(dict( sort_fields=sort_fields, max_result_window=self.max_result_window, default_media_type=self.default_media_type, search_media_types=sorted(self.search_media_types), item_media_types=sorted(self.item_media_types), )) class RecordsListResource(ContentNegotiatedMethodView): """Resource for records listing.""" view_name = '{0}_list' def __init__(self, resolver=None, minter_name=None, pid_type=None, pid_fetcher=None, read_permission_factory=None, create_permission_factory=None, search_class=None, record_serializers=None, record_loaders=None, search_serializers=None, default_media_type=None, max_result_window=None, search_factory=None, item_links_factory=None, record_class=None, **kwargs): """Constructor.""" super(RecordsListResource, self).__init__( method_serializers={ 'GET': search_serializers, 'POST': record_serializers, }, default_method_media_type={ 'GET': default_media_type, 'POST': default_media_type, }, default_media_type=default_media_type, **kwargs) self.resolver = resolver self.pid_type = pid_type self.minter = current_pidstore.minters[minter_name] self.pid_fetcher = current_pidstore.fetchers[pid_fetcher] self.read_permission_factory = read_permission_factory self.create_permission_factory = create_permission_factory or \ current_records_rest.create_permission_factory self.search_class = search_class self.max_result_window = max_result_window or 10000 self.search_factory = partial(search_factory, self) self.item_links_factory = item_links_factory self.loaders = record_loaders or \ current_records_rest.loaders self.record_class = record_class or Record def get(self, **kwargs): """Search records. :returns: the search result containing hits and aggregations as returned by invenio-search. """ page = request.values.get('page', 1, type=int) size = request.values.get('size', 10, type=int) if page * size >= self.max_result_window: raise MaxResultWindowRESTError() # Arguments that must be added in prev/next links urlkwargs = dict() search = self.search_class().params(version=True) search = search[(page - 1) * size:page * size] search, qs_kwargs = self.search_factory(search) urlkwargs.update(qs_kwargs) # Execute search search_result = search.execute() # Generate links for prev/next urlkwargs.update( size=size, _external=True, ) endpoint = '.{0}_list'.format( current_records_rest.default_endpoint_prefixes[self.pid_type]) links = dict(self=url_for(endpoint, page=page, **urlkwargs)) if page > 1: links['prev'] = url_for(endpoint, page=page - 1, **urlkwargs) if size * page < search_result.hits.total and \ size * page < self.max_result_window: links['next'] = url_for(endpoint, page=page + 1, **urlkwargs) return self.make_response( pid_fetcher=self.pid_fetcher, search_result=search_result.to_dict(), links=links, item_links_factory=self.item_links_factory, ) @need_record_permission('create_permission_factory') def post(self, **kwargs): """Create a record. Procedure description: #. First of all, the `create_permission_factory` permissions are checked. #. Then, the record is deserialized by the proper loader. #. A second call to the `create_permission_factory` factory is done: it differs from the previous call because this time the record is passed as parameter. #. A `uuid` is generated for the record and the minter is called. #. The record class is called to create the record. #. The HTTP response is built with the help of the item link factory. :returns: The created record. """ if request.mimetype not in self.loaders: raise UnsupportedMediaRESTError(request.mimetype) data = self.loaders[request.mimetype]() if data is None: raise InvalidDataRESTError() # Check permissions permission_factory = self.create_permission_factory if permission_factory: verify_record_permission(permission_factory, data) # Create uuid for record record_uuid = uuid.uuid4() # Create persistent identifier pid = self.minter(record_uuid, data=data) # Create record record = self.record_class.create(data, id_=record_uuid) db.session.commit() response = self.make_response( pid, record, 201, links_factory=self.item_links_factory) # Add location headers endpoint = '.{0}_item'.format( current_records_rest.default_endpoint_prefixes[pid.pid_type]) location = url_for(endpoint, pid_value=pid.pid_value, _external=True) response.headers.extend(dict(location=location)) return response class RecordResource(ContentNegotiatedMethodView): """Resource for record items.""" view_name = '{0}_item' def __init__(self, resolver=None, read_permission_factory=None, update_permission_factory=None, delete_permission_factory=None, default_media_type=None, links_factory=None, loaders=None, search_class=None, **kwargs): """Constructor. :param resolver: Persistent identifier resolver instance. """ super(RecordResource, self).__init__( method_serializers={ 'DELETE': {'*/*': lambda *args: make_response(*args), }, }, default_method_media_type={ 'GET': default_media_type, 'PUT': default_media_type, 'DELETE': '*/*', 'PATCH': default_media_type, }, default_media_type=default_media_type, **kwargs) self.resolver = resolver self.search_class = search_class self.read_permission_factory = read_permission_factory self.update_permission_factory = update_permission_factory self.delete_permission_factory = delete_permission_factory self.links_factory = links_factory self.loaders = loaders or current_records_rest.loaders @pass_record @need_record_permission('delete_permission_factory') def delete(self, pid, record, **kwargs): """Delete a record. Procedure description: #. The record is resolved reading the pid value from the url. #. The ETag is checked. #. The record is deleted. #. All PIDs are marked as DELETED. :param pid: Persistent identifier for record. :param record: Record object. """ self.check_etag(str(record.model.version_id)) record.delete() # mark all PIDs as DELETED all_pids = PersistentIdentifier.query.filter( PersistentIdentifier.object_type == pid.object_type, PersistentIdentifier.object_uuid == pid.object_uuid, ).all() for rec_pid in all_pids: if not rec_pid.is_deleted(): rec_pid.delete() db.session.commit() return '', 204 @pass_record @need_record_permission('read_permission_factory') def get(self, pid, record, **kwargs): """Get a record. Procedure description: #. The record is resolved reading the pid value from the url. #. The ETag and If-Modifed-Since is checked. #. The HTTP response is built with the help of the link factory. :param pid: Persistent identifier for record. :param record: Record object. :returns: The requested record. """ etag = str(record.revision_id) self.check_etag(str(record.revision_id)) self.check_if_modified_since(record.updated, etag=etag) return self.make_response( pid, record, links_factory=self.links_factory ) @require_content_types('application/json-patch+json') @pass_record @need_record_permission('update_permission_factory') def patch(self, pid, record, **kwargs): """Modify a record. The data should be a JSON-patch, which will be applied to the record. Procedure description: #. The record is deserialized using the proper loader. #. The ETag is checked. #. The record is patched. #. The HTTP response is built with the help of the link factory. :param pid: Persistent identifier for record. :param record: Record object. :returns: The modified record. """ data = self.loaders[request.mimetype]() if data is None: raise InvalidDataRESTError() self.check_etag(str(record.revision_id)) try: record = record.patch(data) except (JsonPatchException, JsonPointerException): raise PatchJSONFailureRESTError() record.commit() db.session.commit() return self.make_response( pid, record, links_factory=self.links_factory) @pass_record @need_record_permission('update_permission_factory') def put(self, pid, record, **kwargs): """Replace a record. The body should be a JSON object, which will fully replace the current record metadata. Procedure description: #. The ETag is checked. #. The record is updated by calling the record API `clear()`, `update()` and then `commit()`. #. The HTTP response is built with the help of the link factory. :param pid: Persistent identifier for record. :param record: Record object. :returns: The modified record. """ if request.mimetype not in self.loaders: raise UnsupportedMediaRESTError(request.mimetype) data = self.loaders[request.mimetype]() if data is None: raise InvalidDataRESTError() self.check_etag(str(record.revision_id)) record.clear() record.update(data) record.commit() db.session.commit() return self.make_response( pid, record, links_factory=self.links_factory) class SuggestResource(MethodView): """Resource for records suggests.""" view_name = '{0}_suggest' def __init__(self, suggesters, search_class=None, **kwargs): """Constructor.""" self.suggesters = suggesters self.search_class = search_class def get(self, **kwargs): """Get suggestions.""" completions = [] size = request.values.get('size', type=int) for k in self.suggesters.keys(): val = request.values.get(k, type=str) if val: # Get completion suggestions opts = copy.deepcopy(self.suggesters[k]) if 'context' in opts.get('completion', {}): ctx_field = opts['completion']['context'] ctx_val = request.values.get(ctx_field, type=str) if not ctx_val: raise SuggestMissingContextRESTError # raise SuggestMissingContextRESTError(ctx_field) opts['completion']['context'] = { ctx_field: ctx_val } if size: opts['completion']['size'] = size completions.append((k, val, opts)) if not completions: # raise SuggestNoCompletionsRESTError raise SuggestNoCompletionsRESTError( ', '.join(sorted(self.suggesters.keys()))) # Add completions s = self.search_class() for field, val, opts in completions: s = s.suggest(field, val, **opts) # Execute search response = s.execute_suggest().to_dict() result = dict() for field, val, opts in completions: result[field] = response[field] return make_response(jsonify(result))