Skip to main content

Django: Clone a model instance on update

Often times, we need to clone a Django model instance while updating it's fields' values. For example, if we have a price model that needs updating, it needs to be cloned first to keep the current related objects referring to the current pricing values. So while updating, we need to keep the current instance as-is and create a new instance with existing data and update the new instance with updated data instead.

If we don't do the above, when we change any relevant field's value, the related objects will subject to the new values, this will in turn trigger new pricing which would be different from the originally calculated/saved one. This might cause serious reliability problems like the initial price/charge was 20 USD, now it has become 30 USD (for example), and so on.

So to handle this automatic cloning of instances on update, I've written a decorator that can be used as a decorator on the save method of the model. Here it is (all the necesary details are in the docstring):

class clone_on_update(object):
    """Descriptor to clone the model instance when any field
    is updated. This is meant to be set as a decorator on `save`
    method of the model.

    Initialization Args:
        - `condition` (str/callable) [Optional] (default: None)
        - `active_status_field` (str) [Optional] (default: None)
        - `update_one_to_one` (bool) [Optional] (default: True)

    #### condition
    An optional argument `condition` can be passed to do the
    cloning conditionally, the cloning only takes place if
    `condition` is a trythy value. `condition` can also be a
    callable, which in turn is called to get the boolean status.
    A value of `None` indicates the cloning should be done
    unconditionally.

    The model object can also be referred in the condition using
    the string format `self.<attribute_chain>`. The object *must*
    be referred with the string `self`. The final `attribute` can
    also be a callable. For example:

    @clone_on_update(condition='self.foobar.exists')  # callable
    def save(self, *args, **kwargs):
        ...

    @clone_on_update(condition='self.is_valid')  # property
    def save(self, *args, **kwargs):
        ...

    You can't refer to the object in condition while the decorator
    is called because it does not exist at that point, so this
    approach can be used to refer to any object attribute if needed.
    Also note that, only attribute accesses can be used as the
    condition string and the string **must** start with `self.` to
    get this feature (as regular callables can be used directly).

    #### active_status_field
    `active_status_field` can be the name of a BooleanField which
    would be set to False on the old instance and will be set to
    True on new instance.

    #### update_one_to_one
    If `update_one_to_one` is set to True, all related one-to-one
    fields are updated (using `clone_on_update`) and the related
    one-to-one instances are set to the newly created instances.
    If set to False, all related one-to-one fields are set to None,
    so the fields must be NULL-able in that case.

    **NOTE**: Relations that are not defined on the model as
    fields, are not copied to the new instance and hence
    needed to be assigned manually.
    """

    def __init__(
            self,
            condition=None,
            active_status_field=None,
            update_one_to_one=True,
    ):
        self.condition = condition

        if active_status_field is not None:
            if not isinstance(active_status_field, basestring):
                raise TypeError(
                    'active_status_field must be a string.'
                )
        self.active_status_field = active_status_field

        self.update_one_to_one = update_one_to_one

        self.save_func = None
        self.obj = None

        if self._is_save_func(condition):
            self.save_func = condition
            self.condition = None

    def __call__(self, save_func):

        if not self._is_save_func(save_func):
            raise ValueError(
                'save_func must be the `save` method on the model class.'
            )

        self.save_func = save_func
        return self

    def _is_save_func(self, func):
        """Returns whether `func` is a callable named `save`."""

        return (
            isinstance(func, (types.FunctionType, types.MethodType))
            and func.__name__ == 'save'
        )

    @property
    def _condition_is_met(self):
        """Returns whether the condition is met i.e.
        is truthy or evaluates to truthy value.
        """

        if self.condition is None:
            return True

        if isinstance(self.condition, basestring):
            self.condition = self._condition_from_str

        if callable(self.condition):
            return self.condition()

        return self.condition

    @property
    def _condition_from_str(self):
        """If the condition is in format `self.<attribute>`,
        parses that for object attributes and returns the
        eventual object. Otherwise returns the condition
        string as-is.
        """

        condition = self.condition

        if not condition.startswith('self.'):
            return condition

        if self.obj is None:
            return condition

        condition_obj = self.obj
        splitted_attrs = condition.split('.')[1:]
        for attr in splitted_attrs:
            condition_obj = getattr(condition_obj, attr)

        return condition_obj

    def _get_fields(self, obj):
        """Returns a tuple of (one_to_one_fields,
        m2m_fields, all_fields_without_m2m_fields)
        belonging to the model object `obj`.
        """

        m2m_fields = obj._meta.many_to_many

        all_fields_without_m2m_fields = obj._meta.fields
        one_to_one_fields = [
            field
            for field in all_fields_without_m2m_fields
            if isinstance(field, models.OneToOneField)
        ]

        return one_to_one_fields, m2m_fields, all_fields_without_m2m_fields

    def __get__(self, obj, cls=None):

        if self.save_func is None:
            raise ValueError(
                'No `save` method found.'
            )

        if obj is None:
            return self.save_func

        self.obj = obj

        def inner(*args, **kwargs):
            if obj.pk is not None and self._condition_is_met:

                one_to_one_fields, m2m_fields, _ = self._get_fields(obj)

                # PK
                orig_pk = obj.pk
                obj.pk = None

                # ID
                try:
                    _ = obj.id
                except AttributeError:
                    pass
                else:
                    obj.id = None

                if self.active_status_field is not None:
                    setattr(obj, self.active_status_field, True)

                self._update_one_to_one_fields(obj, one_to_one_fields)

                obj.save()

                if self.active_status_field is not None:
                    obj.__class__.objects.filter(pk=orig_pk).update(
                        **{self.active_status_field: False}
                    )

                self._update_many_to_many_fields(orig_pk, obj, m2m_fields)

            self.save_func(obj, *args, **kwargs)
        return inner

    def _update_one_to_one_fields(self, obj, one_to_one_fields):
        """Updates all related one to one fields."""

        if self.update_one_to_one:
            self._full_update_one_to_one_fields(obj, one_to_one_fields)
        else:
            for field in one_to_one_fields:
                setattr(obj, field.name, None)

    def _full_update_one_to_one_fields(self, obj, one_to_one_fields):
        """Run clone_on_update on each related one to one fields."""

        for field in one_to_one_fields:
            orig_one_to_one_obj = getattr(obj, field.name)
            setattr(
                orig_one_to_one_obj.__class__,
                'save',
                self.__class__(orig_one_to_one_obj.__class__.save)
            )
            orig_one_to_one_obj.save(orig_one_to_one_obj)

            # orig_one_to_one_obj is now a new one
            setattr(obj, field.name, orig_one_to_one_obj)

    def _update_many_to_many_fields(self, orig_pk, obj, m2m_fields):
        """Updates all related many-to-many fields."""

        if m2m_fields:
            orig_obj = obj.__class__.objects.get(pk=orig_pk)

        for field in m2m_fields:
            orig_m2m_relation_obj = getattr(orig_obj, field.name)
            new_m2m_relation_obj = getattr(obj, field.name)

            new_m2m_relation_obj.set(orig_m2m_relation_obj.all())

This is also available as a Gist on GitHub.

Comments

Comments powered by Disqus