Source code for concurrency.fields

from __future__ import absolute_import, unicode_literals
import copy
import logging
import time
from functools import update_wrapper
from django.db.models.fields import Field
from django.utils.translation import ugettext_lazy as _
from concurrency import forms
from concurrency.api import get_revision_of_object
from concurrency.config import conf
from concurrency.core import ConcurrencyOptions
from concurrency.utils import refetch

try:
    from django.apps import apps

    get_model = apps.get_model
except ImportError:
    from django.db.models.loading import get_model

try:
    from django.db.models.signals import class_prepared, post_migrate
except:
    from django.db.models.signals import class_prepared, post_syncdb as post_migrate

logger = logging.getLogger(__name__)

OFFSET = int(time.mktime((2000, 1, 1, 0, 0, 0, 0, 0, 0)))


def class_prepared_concurrency_handler(sender, **kwargs):
    if hasattr(sender, '_concurrencymeta'):
        if sender != sender._concurrencymeta.base:
            origin = getattr(sender._concurrencymeta.base, '_concurrencymeta')
            local = copy.deepcopy(origin)
            setattr(sender, '_concurrencymeta', local)

        if hasattr(sender, 'ConcurrencyMeta'):
            sender._concurrencymeta.enabled = getattr(sender.ConcurrencyMeta, 'enabled')

        if not (sender._concurrencymeta.manually):
            sender._concurrencymeta.field.wrap_model(sender)

        setattr(sender, 'get_concurrency_version', get_revision_of_object)


def post_syncdb_concurrency_handler(sender, **kwargs):
    from concurrency.triggers import create_triggers
    from django.db import connections
    databases = [alias for alias in connections]
    create_triggers(databases)


class_prepared.connect(class_prepared_concurrency_handler, dispatch_uid='class_prepared_concurrency_handler')


class TriggerRegistry(object):
    # FIXME: this is very bad. it seems required only by tests
    # see
    # https://github.com/pytest-dev/pytest-django/issues/75
    # https://code.djangoproject.com/ticket/22280#comment:20

    _fields = []

    def append(self, field):
        self._fields.append([field.model._meta.app_label, field.model.__name__])

    def __iter__(self):
        return iter([get_model(*i)._concurrencymeta.field for i in self._fields])

    def __contains__(self, field):
        target = [field.model._meta.app_label, field.model.__name__]
        return target in self._fields


_TRIGGERS = TriggerRegistry()
# _TRIGGERS = []

if not conf.MANUAL_TRIGGERS:
    post_migrate.connect(post_syncdb_concurrency_handler, dispatch_uid='post_syncdb_concurrency_handler')


[docs]class VersionField(Field): """ Base class """ def __init__(self, *args, **kwargs): verbose_name = kwargs.get('verbose_name', None) name = kwargs.get('name', None) db_tablespace = kwargs.get('db_tablespace', None) db_column = kwargs.get('db_column', None) help_text = kwargs.get('help_text', _('record revision number')) super(VersionField, self).__init__(verbose_name, name, help_text=help_text, default=0, db_tablespace=db_tablespace, db_column=db_column) def deconstruct(self): name, path, args, kwargs = super(VersionField, self).deconstruct() kwargs['default'] = 1 return name, path, args, kwargs def get_default(self): return 0 def get_internal_type(self): return "BigIntegerField" def to_python(self, value): return int(value) def validate(self, value, model_instance): pass def formfield(self, **kwargs): kwargs['form_class'] = self.form_class kwargs['widget'] = forms.VersionField.widget return super(VersionField, self).formfield(**kwargs) def contribute_to_class(self, cls, name, virtual_only=False): super(VersionField, self).contribute_to_class(cls, name) if hasattr(cls, '_concurrencymeta') or cls._meta.abstract: return setattr(cls, '_concurrencymeta', ConcurrencyOptions()) cls._concurrencymeta.field = self cls._concurrencymeta.base = cls def _set_version_value(self, model_instance, value): setattr(model_instance, self.attname, int(value)) def pre_save(self, model_instance, add): if add: value = self._get_next_version(model_instance) self._set_version_value(model_instance, value) return getattr(model_instance, self.attname) @classmethod def wrap_model(cls, model, force=False): if not force and model._concurrencymeta.versioned_save: return cls._wrap_model_methods(model) model._concurrencymeta.versioned_save = True @staticmethod def _wrap_model_methods(model): old_do_update = getattr(model, '_do_update') setattr(model, '_do_update', model._concurrencymeta.field._wrap_do_update(old_do_update)) def _wrap_do_update(self, func): def _do_update(model_instance, base_qs, using, pk_val, values, update_fields, forced_update): version_field = model_instance._concurrencymeta.field old_version = get_revision_of_object(model_instance) if not version_field.model._meta.abstract: if version_field.model is not base_qs.model: return func(model_instance, base_qs, using, pk_val, values, update_fields, forced_update) for i, (field, _1, value) in enumerate(values): if field == version_field: new_version = field._get_next_version(model_instance) values[i] = (field, _1, new_version) field._set_version_value(model_instance, new_version) break if values: if (model_instance._concurrencymeta.enabled and conf.ENABLED and not getattr(model_instance, '_concurrency_disabled', False) and old_version): filter_kwargs = {'pk': pk_val, version_field.attname: old_version} updated = base_qs.filter(**filter_kwargs)._update(values) >= 1 if not updated: version_field._set_version_value(model_instance, old_version) updated = conf._callback(model_instance) else: filter_kwargs = {'pk': pk_val} updated = base_qs.filter(**filter_kwargs)._update(values) >= 1 else: updated = base_qs.filter(pk=pk_val).exists() return updated return update_wrapper(_do_update, func)
[docs]class IntegerVersionField(VersionField): """ Version Field that returns a "unique" version number for the record. The version number is produced using time.time() * 1000000, to get the benefits of microsecond if the system clock provides them. """ form_class = forms.VersionField def _get_next_version(self, model_instance): old_value = getattr(model_instance, self.attname, 0) return max(int(old_value) + 1, (int(time.time() * 1000000) - OFFSET))
[docs]class AutoIncVersionField(VersionField): """ Version Field increment the revision number each commit """ form_class = forms.VersionField def _get_next_version(self, model_instance): return int(getattr(model_instance, self.attname, 0)) + 1
class TriggerVersionField(VersionField): """ Version Field increment the revision number each commit """ form_class = forms.VersionField def __init__(self, *args, **kwargs): self._trigger_name = kwargs.pop('trigger_name', None) self._trigger_exists = False super(TriggerVersionField, self).__init__(*args, **kwargs) def contribute_to_class(self, cls, name, virtual_only=False): super(TriggerVersionField, self).contribute_to_class(cls, name) if not cls._meta.abstract or cls._meta.proxy: if self not in _TRIGGERS: _TRIGGERS.append(self) def check(self, **kwargs): errors = [] model = self.model from django.db import router, connections from concurrency.triggers import factory from django.core.checks import Warning alias = router.db_for_write(model) connection = connections[alias] f = factory(connection) if not f.get_trigger(self): errors.append( Warning( 'Missed trigger for field {}'.format(self), hint=None, obj=None, id='concurrency.W001', ) ) return errors @property def trigger_name(self): from concurrency.triggers import get_trigger_name return get_trigger_name(self) def _get_next_version(self, model_instance): # always returns the same value return int(getattr(model_instance, self.attname, 1)) def pre_save(self, model_instance, add): # always returns the same value return 1 @staticmethod def _increment_version_number(obj): old_value = get_revision_of_object(obj) setattr(obj, obj._concurrencymeta.field.attname, int(old_value) + 1) @staticmethod def _wrap_model_methods(model): super(TriggerVersionField, TriggerVersionField)._wrap_model_methods(model) old_save = getattr(model, 'save') setattr(model, 'save', model._concurrencymeta.field._wrap_save(old_save)) @staticmethod def _wrap_save(func): def inner(self, force_insert=False, force_update=False, using=None, **kwargs): reload = kwargs.pop('refetch', False) ret = func(self, force_insert, force_update, using, **kwargs) TriggerVersionField._increment_version_number(self) if reload: ret = refetch(self) setattr(self, self._concurrencymeta.field.attname, get_revision_of_object(ret)) return ret return update_wrapper(inner, func) try: from south.modelsinspector import add_introspection_rules rules = [ ( (IntegerVersionField, AutoIncVersionField, TriggerVersionField), [], {"verbose_name": ["verbose_name", {"default": None}], "name": ["name", {"default": None}], "help_text": ["help_text", {"default": ''}], "db_column": ["db_column", {"default": None}], "db_tablespace": ["db_tablespace", {"default": None}], "default": ["default", {"default": 1}], "manually": ["manually", {"default": False}]}) ] add_introspection_rules(rules, [r"^concurrency\.fields\.IntegerVersionField", r"^concurrency\.fields\.AutoIncVersionField"]) except ImportError as e: from django.conf import settings if 'south' in settings.INSTALLED_APPS: raise e