Source code for django_celery_results.backends.database

import binascii
import json

from celery import maybe_signature
from celery.backends.base import BaseDictBackend
from celery.exceptions import ChordError
from celery.result import allow_join_result, result_from_tuple, GroupResult
from celery.utils.serialization import b64encode, b64decode
from celery.utils.log import get_logger
from kombu.exceptions import DecodeError
from django.db import transaction

from ..models import TaskResult, ChordCounter, GroupResult as GroupResultModel


logger = get_logger(__name__)


[docs]class DatabaseBackend(BaseDictBackend): """The Django database backend, using models to store task state.""" TaskModel = TaskResult GroupModel = GroupResultModel subpolling_interval = 0.5 def _store_result( self, task_id, result, status, traceback=None, request=None, using=None ): """Store return value and status of an executed task.""" content_type, content_encoding, result = self.encode_content(result) _, _, meta = self.encode_content( {'children': self.current_task_children(request)} ) task_name = getattr(request, 'task', None) worker = getattr(request, 'hostname', None) # Get input arguments if getattr(request, 'argsrepr', None) is not None: # task protocol 2 task_args = request.argsrepr else: # task protocol 1 task_args = getattr(request, 'args', None) if getattr(request, 'kwargsrepr', None) is not None: # task protocol 2 task_kwargs = request.kwargsrepr else: # task protocol 1 task_kwargs = getattr(request, 'kwargs', None) # Encode input arguments if task_args is not None: _, _, task_args = self.encode_content(task_args) if task_kwargs is not None: _, _, task_kwargs = self.encode_content(task_kwargs) self.TaskModel._default_manager.store_result( content_type, content_encoding, task_id, result, status, traceback=traceback, meta=meta, task_name=task_name, task_args=task_args, task_kwargs=task_kwargs, worker=worker, using=using, ) return result def _get_task_meta_for(self, task_id): """Get task metadata for a task by id.""" obj = self.TaskModel._default_manager.get_task(task_id) res = obj.as_dict() meta = self.decode_content(obj, res.pop('meta', None)) or {} result = self.decode_content(obj, res.get('result')) task_args = res.get('task_args') task_kwargs = res.get('task_kwargs') try: task_args = self.decode_content(obj, task_args) task_kwargs = self.decode_content(obj, task_kwargs) except (DecodeError, binascii.Error): pass # the right names are args/kwargs, not task_args/task_kwargs, # keep both for backward compatibility res.update( meta, result=result, task_args=task_args, task_kwargs=task_kwargs, args=task_args, kwargs=task_kwargs, ) return self.meta_from_decoded(res)
[docs] def encode_content(self, data): content_type, content_encoding, content = self._encode(data) if content_encoding == 'binary': content = b64encode(content) return content_type, content_encoding, content
[docs] def decode_content(self, obj, content): if content: if obj.content_encoding == 'binary': content = b64decode(content) return self.decode(content)
def _forget(self, task_id): try: self.TaskModel._default_manager.get(task_id=task_id).delete() except self.TaskModel.DoesNotExist: pass
[docs] def cleanup(self): """Delete expired metadata.""" self.TaskModel._default_manager.delete_expired(self.expires) self.GroupModel._default_manager.delete_expired(self.expires)
def _restore_group(self, group_id): """return result value for a group by id.""" group_result = self.GroupModel._default_manager.get_group(group_id) if group_result: res = group_result.as_dict() decoded_result = self.decode_content(group_result, res["result"]) res["result"] = None if decoded_result: res["result"] = result_from_tuple(decoded_result, app=self.app) return res def _save_group(self, group_id, group_result): """Store return value of group""" content_type, content_encoding, result = self.encode_content( group_result.as_tuple() ) self.GroupModel._default_manager.store_group_result( content_type, content_encoding, group_id, result ) return group_result def _delete_group(self, group_id): try: self.GroupModel._default_manager.get_group(group_id).delete() except self.TaskModel.DoesNotExist: pass
[docs] def apply_chord(self, header_result_args, body, **kwargs): """Add a ChordCounter with the expected number of results""" if not isinstance(header_result_args, GroupResult): # Celery 5.1 provides the GroupResult args header_result = self.app.GroupResult(*header_result_args) else: # celery <5.1 will pass a GroupResult object header_result = header_result_args results = [r.as_tuple() for r in header_result] chord_size = body.get("chord_size", None) or len(results) data = json.dumps(results) ChordCounter.objects.create( group_id=header_result.id, sub_tasks=data, count=chord_size )
[docs] def on_chord_part_return(self, request, state, result, **kwargs): """Called on finishing each part of a Chord header""" tid, gid = request.id, request.group if not gid or not tid: return call_callback = False with transaction.atomic(): # We need to know if `count` hits 0. # wrap the update in a transaction # with a `select_for_update` lock to prevent race conditions. # SELECT FOR UPDATE is not supported on all databases chord_counter = ( ChordCounter.objects.select_for_update() .filter(group_id=gid).first() ) if chord_counter is None: logger.warning("Can't find ChordCounter for Group %s", gid) return chord_counter.count -= 1 if chord_counter.count != 0: chord_counter.save() else: # Last task in the chord header has finished call_callback = True chord_counter.delete() if call_callback: deps = chord_counter.group_result(app=self.app) if deps.ready(): callback = maybe_signature(request.chord, app=self.app) trigger_callback( app=self.app, callback=callback, group_result=deps )
[docs]def trigger_callback(app, callback, group_result): """Add the callback to the queue or mark the callback as failed Implementation borrowed from `celery.app.builtins.unlock_chord` """ if group_result.supports_native_join: j = group_result.join_native else: j = group_result.join try: with allow_join_result(): ret = j(timeout=app.conf.result_chord_join_timeout, propagate=True) except Exception as exc: # pylint: disable=broad-except try: culprit = next(group_result._failed_join_report()) reason = "Dependency {0.id} raised {1!r}".format(culprit, exc) except StopIteration: reason = repr(exc) logger.exception("Chord %r raised: %r", group_result.id, exc) app.backend.chord_error_from_stack(callback, ChordError(reason)) else: try: callback.delay(ret) except Exception as exc: # pylint: disable=broad-except logger.exception("Chord %r raised: %r", group_result.id, exc) app.backend.chord_error_from_stack( callback, exc=ChordError("Callback error: {0!r}".format(exc)) )