django.db.transaction.atomic

Here are the examples of the python api django.db.transaction.atomic taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 51

Project: nodewatcher
Source File: tasks.py
View license
@celery_task(bind=True)
@transaction.atomic
def background_build(self, result_uuid):
    """
    A task for deferred building of a firmware image.

    :param result_uuid: Destination build result UUID
    """

    result = generator_models.BuildResult.objects.get(pk=result_uuid)
    if result.status != generator_models.BuildResult.PENDING:
        return

    # Try to lock the builder for building
    try:
        list(generator_models.Builder.objects.select_for_update(nowait=True).filter(pk=result.builder.pk))
    except db.DatabaseError:
        # Retry the build task again in 15 seconds
        raise self.retry(countdown=15)

    result.status = generator_models.BuildResult.BUILDING
    result.save()

    # Ensure that all CGMs are loaded before doing processing
    loader.load_modules('cgm')
    platform = cgm_base.get_platform(result.builder.platform)

    # Dispatch pre-build signal
    signals.pre_firmware_build.send(sender=None, result=result)

    # Build the firmware and obtain firmware files
    try:
        files = platform.build(result)
    except exceptions.BuildError, e:
        if len(e.args) > 0:
            error_message = 'ERROR: %s' % e.args[0]

            if result.build_log:
                result.build_log += '\n' + error_message
            else:
                result.build_log = error_message

        result.status = generator_models.BuildResult.FAILED
        result.save()

        # Dispatch error signal
        signals.fail_firmware_build.send(sender=None, result=result)
        # Dispatch the result failed event
        generator_events.BuildResultFailed(result).post()
        return
    except:
        result.build_log = 'An internal build error has occurred.\n\n'
        result.build_log += traceback.format_exc()
        result.status = generator_models.BuildResult.FAILED
        result.save()

        # Dispatch error signal
        signals.fail_firmware_build.send(sender=None, result=result)
        # Dispatch the result failed event
        generator_events.BuildResultFailed(result).post()
        return

    # By default, prepend node name and version before firmware filenames.
    node_name = unidecode.unidecode(result.node.config.core.general().name)
    fw_version = result.builder.version.name.replace('.', '')

    for index, (fw_name, fw_file) in enumerate(files[:]):
        files[index] = ('%s-v%s-%s' % (node_name, fw_version, fw_name), fw_file)

    # Dispatch signal that can be used to modify files
    signals.post_firmware_build.send(sender=None, result=result, files=files)

    # Store resulting files and generate the file manifest.
    manifest = {
        'node': {
            'uuid': str(result.node.uuid),
            'name': node_name,
        },
        'firmware': {
            'version': result.builder.version.name,
        },
        'files': []
    }

    for fw_name, fw_file in files:
        r_file = generator_models.BuildResultFile(
            result=result,
            file=uploadedfile.InMemoryUploadedFile(
                io.BytesIO(fw_file),
                None,
                os.path.basename(fw_name),
                'application/octet-stream',
                len(fw_file),
                None
            ),
            checksum_md5=hashlib.md5(fw_file).hexdigest(),
            checksum_sha256=hashlib.sha256(fw_file).hexdigest(),
        )

        manifest_entry = r_file.to_manifest()
        if manifest_entry is not None:
            manifest['files'].append(manifest_entry)

        r_file.save()

    # Store the manifest.
    manifest = json.dumps(manifest)
    generator_models.BuildResultFile(
        result=result,
        file=uploadedfile.InMemoryUploadedFile(
            io.BytesIO(manifest),
            None,
            'manifest.json',
            'text/json',
            len(manifest),
            None
        ),
        checksum_md5=hashlib.md5(manifest).hexdigest(),
        checksum_sha256=hashlib.sha256(manifest).hexdigest(),
        hidden=True,
    ).save()

    result.status = generator_models.BuildResult.OK
    result.save()

    # Dispatch finalize signal
    signals.finalize_firmware_build.send(sender=None, result=result)
    # Dispatch the result ready event
    generator_events.BuildResultReady(result).post()

Example 52

Project: nodewatcher
Source File: base.py
View license
@transaction.atomic(savepoint=False)
def prepare_root_forms(regpoint, request, root=None, data=None, save=False, form_state=None, flags=0):
    """
    Prepares a list of configuration forms for use on a regpoint root's
    configuration page.

    :param regpoint: Registration point name or instance
    :param request: Request instance
    :param root: Registration point root instance for which to generate forms
    :param data: User-supplied POST data
    :param save: Are we performing a save or rendering an initial form
    """

    # Ensure that all registry forms, form processors and CGMs are registered.
    loader.load_modules('forms', 'formprocessors', 'cgm')

    if save and flags & FORM_ONLY_DEFAULTS:
        raise ValueError("You cannot use save and FORM_ONLY_DEFAULTS at the same time!")

    if isinstance(regpoint, basestring):
        regpoint = registration.point(regpoint)

    # Transform data into a mutable dictionary in case an immutable one is passed
    data = copy.copy(data)

    # Prepare context
    context = RegistryFormContext(
        regpoint=regpoint,
        request=request,
        root=root,
        data=data,
        save=save,
        validation_errors=False,
        pending_save_forms={},
        pending_save_foreign_keys={},
        form_state=form_state,
        flags=flags,
    )

    # Parse form actions.
    if data:
        form_actions = json.loads(data.get('ACTIONS', '{}'))
    else:
        form_actions = {}

    for action, options in form_actions.items():
        if action == 'defaults':
            context.form_state.set_using_defaults(options['value'])
        elif action == 'simple_mode':
            # Simple mode should also automatically enable defaults.
            if options['value']:
                context.form_state.set_using_defaults(True)

    if flags & FORM_SET_DEFAULTS:
        context.form_state.set_using_defaults(flags & FORM_DEFAULTS_ENABLED)

    if flags & FORM_INITIAL and flags & FORM_ROOT_CREATE and context.form_state.is_using_defaults():
        # Set simple mode to its configured default value.
        context.form_state.set_using_simple_mode(
            getattr(settings, 'REGISTRY_SIMPLE_MODE', {}).get(regpoint.name, {}).get('default', False)
        )

    # Prepare form processors.
    form_processors = []
    for form_processor in regpoint.get_form_processors():
        form_processor = form_processor()
        form_processor.preprocess(root)
        form_processors.append(form_processor)

    try:
        sid = transaction.savepoint()
        forms = RootRegistryRenderItem(context, prepare_forms(context))

        if flags & (FORM_DEFAULTS | FORM_ONLY_DEFAULTS):
            # Apply form actions before applying defaults.
            for action, options in form_actions.items():
                if action == 'append':
                    context.form_state.append_default_item(options['registry_id'], options['parent_id'])
                elif action == 'remove':
                    context.form_state.remove_item(options['index'])
                elif action == 'simple_mode':
                    context.form_state.set_using_simple_mode(options['value'])

            # Apply form defaults.
            context.form_state.apply_form_defaults(regpoint, flags & FORM_ROOT_CREATE)

            if flags & FORM_ONLY_DEFAULTS:
                # If only defaults application is requested, we should set defaults and then rollback
                # the savepoint in any case; all validation errors are ignored.
                transaction.savepoint_rollback(sid)
                return context.form_state

        # Process forms when saving and there are no validation errors
        if save and root is not None and not context.validation_errors:
            # Resolve form dependencies and save all forms
            for layer, linear_forms in enumerate(toposort.topological_sort(context.pending_save_forms)):
                for info in linear_forms:
                    form = info['form']

                    # Before saving the form perform the validation again so dependent
                    # fields can be recalculated
                    form._clean_fields()
                    form._clean_form()
                    form._post_clean()

                    if form.is_valid():
                        # Save the form and store the instance into partial configuration so
                        # dependent objects can reference the new instance. Before we save,
                        # we also store the form's index into the display_order attribute of
                        # the instances, so that we preserve order when loading back from db.
                        form.instance.display_order = info['index']
                        instance = form.save()
                        # Only overwrite instances at the top layer (forms which have no dependencies
                        # on anything else). Models with dependencies will already be updated when
                        # calling save.
                        if layer == 0 and info['registry_id'] in context.form_state:
                            context.form_state[info['registry_id']][info['index']] = instance

                        for form_id, field in context.pending_save_foreign_keys.get(info['form_id'], []):
                            setattr(
                                context.pending_save_forms[form_id]['form'].instance,
                                field,
                                instance
                            )
                    else:
                        context.validation_errors = True

            # Execute any validation hooks.
            for processor in form_processors:
                try:
                    processor.postprocess(root)
                except RegistryValidationError, e:
                    context.validation_errors = True
                    forms.add_error(e.message)

        if not context.validation_errors:
            # Persist metadata.
            regpoint.set_root_metadata(root, context.form_state.get_metadata())
            root.save()

            transaction.savepoint_commit(sid)
            if flags & FORM_CLEAR_STATE:
                context.form_state.clear_session()
        else:
            transaction.savepoint_rollback(sid)
    except RegistryValidationError:
        transaction.savepoint_rollback(sid)
    except (transaction.TransactionManagementError, django_db.DatabaseError):
        # Do not perform a rollback in case of a database error as this will just raise another
        # database error exception as the transaction has been aborted.
        raise
    except:
        transaction.savepoint_rollback(sid)
        raise

    return forms if not save else (context.validation_errors, forms)

Example 53

Project: nodewatcher
Source File: policy.py
View license
@transaction.atomic(savepoint=False)
def verify_identity(node, mechanism, data):
    """
    Performs identity verification and takes actions based on the configured
    policy. If the verification succeeds, this function returns True.

    :param node: Node instance
    :param mechanism: A subclass of IdentityMechanismConfig that should be used
    :param data: Mechanism-specific data
    :return: True if verification succeeded, False otherwise
    """

    if not issubclass(mechanism, models.IdentityMechanismConfig):
        raise TypeError("Passed identity mechanism class must be a subclass of IdentityMechanismConfig.")

    # If node is not defined, verification fails.
    if not node:
        return False

    # Fetch the identity configuration so we know what is the policy.
    config = node.config.core.identity()
    if not config:
        return False

    # Go through the list of trusted identities and try to match one to the passed data.
    matched_trusted = False
    matched_untrusted = False
    for identity in node.config.core.identity.mechanisms(onlyclass=mechanism):
        if identity.is_match(data):
            # Update last seen timestamp.
            identity.last_seen = timezone.now()
            identity.save()

            if identity.trusted:
                matched_trusted = True
            else:
                matched_untrusted = True

    if not matched_trusted:
        # If nothing matched, check whether we should store the identity.
        is_first_use = False
        if config.trust_policy == 'first':
            # When the configured policy is trust on first use, we check whether there are
            # currently no other identities configured.
            is_first_use = not node.config.core.identity.mechanisms().exists()

        if not matched_untrusted and (config.store_unknown or is_first_use):
            kwargs = mechanism.from_data(data)
            if kwargs is not None:
                kwargs.update({
                    'identity': config,
                    'trusted': is_first_use,
                    'automatically_added': True,
                    'last_seen': timezone.now(),
                })

                identity = node.config.core.identity.mechanisms(create=mechanism, **kwargs)
                identity.save()

                # When the configured policy is trust on first use, we trust this identity.
                if is_first_use:
                    return True

                # Ensure that only one untrusted automatically added identity is stored.
                node.config.core.identity.mechanisms(onlyclass=mechanism).filter(
                    trusted=False,
                    automatically_added=True,
                ).exclude(pk=identity.pk).delete()
            else:
                # Failed to decode data with the target mechanism.
                pass

    # We trust any identity in case of the 'any' policy.
    if config.trust_policy == 'any':
        return True

    return matched_trusted

Example 54

Project: Wooey
Source File: utils.py
View license
def create_job_fileinfo(job):
    parameters = job.get_parameters()
    from ..models import WooeyFile, UserFile
    # first, create a reference to things the script explicitly created that is a parameter
    files = []
    local_storage = get_storage(local=True)
    for field in parameters:
        try:
            if field.parameter.form_field == 'FileField':
                value = field.value
                if value is None:
                    continue
                if isinstance(value, six.string_types):
                    # check if this was ever created and make a fileobject if so
                    if local_storage.exists(value):
                        if not get_storage(local=False).exists(value):
                            get_storage(local=False).save(value, File(local_storage.open(value)))
                        value = field.value
                    else:
                        field.force_value(None)
                        try:
                            with transaction.atomic():
                                field.save()
                        except:
                            sys.stderr.write('{}\n'.format(traceback.format_exc()))
                        continue
                d = {'parameter': field, 'file': value}
                if field.parameter.is_output:
                    full_path = os.path.join(job.save_path, os.path.split(local_storage.path(value))[1])
                    checksum = get_checksum(value, extra=[job.pk, full_path, 'output'])
                    d['checksum'] = checksum
                files.append(d)
        except ValueError:
            continue

    known_files = {i['file'].name for i in files}
    # add the user_output files, these are things which may be missed by the model fields because the script
    # generated them without an explicit arguments reference in the script
    file_groups = {'archives': []}
    absbase = os.path.join(settings.MEDIA_ROOT, job.save_path)
    for root, dirs, dir_files in os.walk(absbase):
        for filename in dir_files:
            new_name = os.path.join(job.save_path, filename)
            if any([i.endswith(new_name) for i in known_files]):
                continue
            try:
                filepath = os.path.join(root, filename)
                if os.path.isdir(filepath):
                    continue
                full_path = os.path.join(job.save_path, filename)
                # this is to make the job output have a unique checksum. If this file is then re-uploaded, it will create
                # a new file to reference in the uploads directory and not link back to the job output.
                checksum = get_checksum(filepath, extra=[job.pk, full_path, 'output'])
                try:
                    storage_file = get_storage_object(full_path)
                except:
                    sys.stderr.write('Error in accessing stored file {}:\n{}'.format(full_path, traceback.format_exc()))
                    continue
                d = {'name': filename, 'file': storage_file, 'size_bytes': storage_file.size, 'checksum': checksum}
                if filename.endswith('.tar.gz') or filename.endswith('.zip'):
                    file_groups['archives'].append(d)
                else:
                    files.append(d)
            except IOError:
                sys.stderr.write('{}'.format(traceback.format_exc()))
                continue

    # establish grouping by inferring common things
    file_groups['all'] = files
    file_groups['image'] = []
    file_groups['tabular'] = []
    file_groups['fasta'] = []

    for filemodel in files:
        fileinfo = get_file_info(filemodel['file'].path)
        filetype = fileinfo.get('type')
        if filetype is not None:
            file_groups[filetype].append(dict(filemodel, **{'preview': fileinfo.get('preview')}))
        else:
            filemodel['preview'] = json.dumps(None)

    # Create our WooeyFile models

    # mark things that are in groups so we don't add this to the 'all' category too to reduce redundancy
    grouped = set([i['file'].path for file_type, groups in six.iteritems(file_groups) for i in groups if file_type != 'all'])
    for file_type, group_files in six.iteritems(file_groups):
        for group_file in group_files:
            if file_type == 'all' and group_file['file'].path in grouped:
                continue
            try:
                preview = group_file.get('preview')
                size_bytes = group_file.get('size_bytes')

                filepath = group_file['file'].path
                save_path = job.get_relative_path(filepath)
                parameter = group_file.get('parameter')

                # get the checksum of the file to see if we need to save it
                checksum = group_file.get('checksum', get_checksum(filepath))
                try:
                    wooey_file = WooeyFile.objects.get(checksum=checksum)
                    file_created = False
                except ObjectDoesNotExist:
                    wooey_file = WooeyFile(
                        checksum=checksum,
                        filetype=file_type,
                        filepreview=preview,
                        size_bytes=size_bytes,
                        filepath=save_path
                    )
                    file_created = True
                userfile_kwargs = {
                    'job': job,
                    'parameter': parameter,
                    'system_file': wooey_file,
                    'filename': os.path.split(filepath)[1]
                }
                try:
                    with transaction.atomic():
                        if file_created:
                            wooey_file.save()
                        job.save()
                        UserFile.objects.get_or_create(**userfile_kwargs)
                except:
                    sys.stderr.write('Error in saving DJFile: {}\n'.format(traceback.format_exc()))
            except:
                sys.stderr.write('Error in saving DJFile: {}\n'.format(traceback.format_exc()))
                continue

Example 55

Project: btb
Source File: views.py
View license
    @permission_required_or_deny("correspondence.manage_correspondence")
    @args_method_decorator(transaction.atomic)
    def post(self, request, mailing_id=None):
        """
        Create a mailing, and autogenerated letters, reflecting the params given.
        """
        params = self.clean_params(request)
        if 'types' not in params:
            raise Http404

        mailing = Mailing.objects.create(editor=request.user)
        to_send = []
        types = set(params['types'])

        kw = {'auto_generated': True, 'sender': request.user}
        if "enqueued" in types:
            to_send += list(Letter.objects.unsent().mail_filter(request.user).filter(
                    recipient__profile__lost_contact=False,
                    mailing__isnull=True,
                    auto_generated=False
            ))
        if "waitlist" in types:
            to_send += list(Letter.objects.create(
                    recipient=p.user, type="waitlist", is_postcard=True, 
                    org=p.user.organization_set.get(), **kw 
                ) for p in Profile.objects.waitlistable().mail_filter(request.user).filter(
                    lost_contact=False
                ).distinct())
        if "consent_form" in types:
            cutoff = params.get("consent_cutoff", "") or datetime.datetime.now()
            to_send += list(Letter.objects.create(
                    recipient=p.user, type="consent_form",
                    org=p.user.organization_set.get(), **kw 
                ) for p in Profile.objects.invitable().mail_filter(request.user).filter(
                    user__date_joined__lte=cutoff,
                    lost_contact=False
                ).distinct())
        if "signup_complete" in types:
            to_send += list(Letter.objects.create(
                    recipient=p.user, type="signup_complete",
                    org=p.user.organization_set.get(), **kw 
                ) for p in Profile.objects.needs_signup_complete_letter().mail_filter(
                    request.user
                ).filter(lost_contact=False).distinct())
        if "first_post" in types:
            to_send += list(Letter.objects.create(
                    recipient=p.user, type="first_post",
                    org=p.user.organization_set.get(), **kw 
                ) for p in Profile.objects.needs_first_post_letter().mail_filter(request.user).filter(lost_contact=False).distinct())
        if "comments" in types:
            comments = list(Comment.objects.unmailed().mail_filter(request.user).filter(
                document__author__profile__lost_contact=False
            ).order_by(
                'document__author', 'document'
            ))
            author = None
            letter = None
            for c in comments:
                if c.document.author != author:
                    doc = c.document
                    author = c.document.author
                    letter = Letter.objects.create(recipient=c.document.author,
                            type="comments",
                            org=c.document.author.organization_set.get(),
                            **kw
                    )
                    to_send.append(letter)
                letter.comments.add(c)
        if "comment_removal" in types:
            removals = list(CommentRemoval.objects.needing_letters().mail_filter(request.user).filter(
                comment__document__author__profile__lost_contact=False
            ).order_by(
                'comment__document__author', 'comment__document'
            ))
            for removal in removals:
                letter = Letter.objects.create(
                    type="comment_removal",
                    recipient=removal.comment.document.author,
                    org=removal.comment.document.author.organization_set.get(),
                    body=removal.post_author_message,
                    send_anonymously=True,
                    **kw)
                letter.comments.add(removal.comment)
                to_send.append(letter)

        mailing.letters.add(*to_send)
        return self.json_response(mailing.light_dict())

Example 56

Project: huxley
Source File: models.py
View license
    @classmethod
    def update_assignments(cls, new_assignments):
        '''
        Atomically update the set of country assignments in a transaction.

        For each assignment in the updated list, either update the existing
        one (and delete its delegates), or create a new one if it doesn't
        exist.
        '''
        assignments = cls.objects.all().values()
        assignment_dict = {(a['committee_id'], a['country_id']): a
                           for a in assignments}
        additions = []
        deletions = []
        assigned = set()
        failed_assignments = []

        def add(committee, country, school, rejected):
            additions.append(
                cls(committee_id=committee.id,
                    country_id=country.id,
                    school_id=school.id,
                    rejected=rejected, ))

        def remove(assignment_data):
            deletions.append(assignment_data['id'])

        for committee, country, school, rejected in new_assignments:
            key = (committee, country)
            if key in assigned:
                # Make sure that the same committee/country pair is not being
                # given to more than one school in the upload
                committee = str(committee.name)
                country = str(country.name)
                failed_assignments.append(
                    str((committee, country)) +
                    ' - ASSIGNED TO MORE THAN ONE SCHOOL')
                continue

            # If the assignemnt contains no bad cells, then each value should
            # have the type of its corresponding model.
            is_invalid = False
            if type(committee) is not Committee:
                committee = Committee(name=committee + ' - DOES NOT EXIST')
                is_invalid = True
            if type(country) is not Country:
                country = Country(name=country + ' - DOES NOT EXIST')
                is_invalid = True
            if type(school) is not School:
                school = School(name=school + ' - DOES NOT EXIST')
                is_invalid = True
            if is_invalid:
                failed_assignments.append(
                    str((str(school.name), str(committee.name), str(
                        country.name))))
                continue

            assigned.add(key)
            old_assignment = assignment_dict.get(key)

            if not old_assignment:
                add(committee, country, school, rejected)
                continue

            if old_assignment['school_id'] != school:
                # Remove the old assignment instead of just updating it
                # so that its delegates are deleted by cascade.
                remove(old_assignment)
                add(committee, country, school, rejected)

            del assignment_dict[key]

        if not failed_assignments:
            # Only update assignments if there were no issues
            for old_assignment in assignment_dict.values():
                remove(old_assignment)

            with transaction.atomic():
                Assignment.objects.filter(id__in=deletions).delete()
                Assignment.objects.bulk_create(additions)

        return failed_assignments

Example 57

Project: django-simple-import
Source File: views.py
View license
@staff_member_required
def do_import(request, import_log_id):
    """ Import the data! """
    import_log = get_object_or_404(ImportLog, id=import_log_id)
    if import_log.import_type == "N" and 'undo' in request.GET and request.GET['undo'] == "True":
        import_log.undo()
        return HttpResponseRedirect(reverse(
                    do_import,
                    kwargs={'import_log_id': import_log.id}) + '?success_undo=True')

    if 'success_undo' in request.GET and request.GET['success_undo'] == "True":
        success_undo = True
    else:
        success_undo = False

    model_class = import_log.import_setting.content_type.model_class()
    import_data = import_log.get_import_file_as_list()
    header_row = import_data.pop(0)
    header_row_field_names = []
    header_row_default = []
    header_row_null_on_empty = []
    error_data = [header_row + ['Error Type', 'Error Details']]
    create_count = 0
    update_count = 0
    fail_count = 0
    if 'commit' in request.GET and request.GET['commit'] == "True":
        commit = True
    else:
        commit = False

    key_column_name = None
    if import_log.update_key and import_log.import_type in ["U", "O"]:
        key_match = import_log.import_setting.columnmatch_set.get(column_name=import_log.update_key)
        key_column_name = key_match.column_name
        key_field_name = key_match.field_name
    for i, cell in enumerate(header_row):
        match = import_log.import_setting.columnmatch_set.get(column_name=cell)
        header_row_field_names += [match.field_name]
        header_row_default += [match.default_value]
        header_row_null_on_empty += [match.null_on_empty]
        if key_column_name != None and key_column_name.lower() == cell.lower():
            key_index = i

    with transaction.atomic():
        sid = transaction.savepoint()
        for row in import_data:
            try:
                with transaction.atomic():
                    is_created = True
                    if import_log.import_type == "N":
                        new_object = model_class()
                    elif import_log.import_type == "O":
                        filters = {key_field_name: row[key_index]}
                        new_object = model_class.objects.get(**filters)
                        is_created = False
                    elif import_log.import_type == "U":
                        filters = {key_field_name: row[key_index]}
                        new_object = model_class.objects.filter(**filters).first()
                        if new_object == None:
                            new_object = model_class()
                            is_created = False

                    new_object.simple_import_m2ms = {} # Need to deal with these after saving
                    for i, cell in enumerate(row):
                        if header_row_field_names[i]: # skip blank
                            if not import_log.is_empty(cell) or header_row_null_on_empty[i]:
                                set_field_from_cell(import_log, new_object, header_row_field_names[i], cell)
                            elif header_row_default[i]:
                                set_field_from_cell(import_log, new_object, header_row_field_names[i], header_row_default[i])
                    new_object.save()

                    for i, cell in enumerate(row):
                        if header_row_field_names[i]: # skip blank
                            if not import_log.is_empty(cell) or header_row_null_on_empty[i]:
                                set_method_from_cell(import_log, new_object, header_row_field_names[i], cell)
                            elif header_row_default[i]:
                                set_method_from_cell(import_log, new_object, header_row_field_names[i], header_row_default[i])
                    new_object.save()

                    for key in new_object.simple_import_m2ms.keys():
                        value = new_object.simple_import_m2ms[key]
                        m2m = getattr(new_object, key)
                        m2m_model = type(m2m.model())
                        related_field_name = RelationalMatch.objects.get(import_log=import_log, field_name=key).related_field_name
                        m2m_object = m2m_model.objects.get(**{related_field_name:value})
                        m2m.add(m2m_object)

                    if is_created:
                        LogEntry.objects.log_action(
                            user_id         = request.user.pk,
                            content_type_id = ContentType.objects.get_for_model(new_object).pk,
                            object_id       = new_object.pk,
                            object_repr     = smart_text(new_object),
                            action_flag     = ADDITION
                        )
                        create_count += 1
                    else:
                        LogEntry.objects.log_action(
                            user_id         = request.user.pk,
                            content_type_id = ContentType.objects.get_for_model(new_object).pk,
                            object_id       = new_object.pk,
                            object_repr     = smart_text(new_object),
                            action_flag     = CHANGE
                        )
                        update_count += 1
                    ImportedObject.objects.create(
                        import_log = import_log,
                        object_id = new_object.pk,
                        content_type = import_log.import_setting.content_type)
            except IntegrityError:
                exc = sys.exc_info()
                error_data += [row + ["Integrity Error", smart_text(exc[1])]]
                fail_count += 1
            except ObjectDoesNotExist:
                exc = sys.exc_info()
                error_data += [row + ["No Record Found to Update", smart_text(exc[1])]]
                fail_count += 1
            except ValueError:
                exc = sys.exc_info()
                if str(exc[1]).startswith('invalid literal for int() with base 10'):
                    error_data += [row + ["Incompatible Data - A number was expected, but a character was used", smart_text(exc[1])]]
                else:
                    error_data += [row + ["Value Error", smart_text(exc[1])]]
                fail_count += 1
            except:
                error_data += [row + ["Unknown Error"]]
                fail_count += 1
        if not commit:
            transaction.savepoint_rollback(sid)


    if fail_count:
        from io import StringIO
        from django.core.files.base import ContentFile
        from openpyxl.workbook import Workbook
        from openpyxl.writer.excel import save_virtual_workbook

        wb = Workbook()
        ws = wb.worksheets[0]
        ws.title = "Errors"
        filename = 'Errors.xlsx'
        for row in error_data:
            ws.append(row)
        buf = StringIO()
        # Not Python 3 compatible
        #buf.write(str(save_virtual_workbook(wb)))
        import_log.error_file.save(filename, ContentFile(save_virtual_workbook(wb)))
        import_log.save()

    return render(
        request,
        'simple_import/do_import.html',
        {
            'error_data': error_data,
            'create_count': create_count,
            'update_count': update_count,
            'fail_count': fail_count,
            'import_log': import_log,
            'commit': commit,
            'success_undo': success_undo,},
    )

Example 58

Project: cadasta-platform
Source File: managers.py
View license
    def create_from_form(self, xls_form=None, original_file=None,
                         project=None):
        try:
            with transaction.atomic():
                errors = []
                instance = self.model(
                    xls_form=xls_form,
                    original_file=original_file,
                    project=project
                )
                json = parse_file_to_json(instance.xls_form.file.name)
                has_default_language = (
                    'default_language' in json and
                    json['default_language'] != 'default'
                )
                if (has_default_language and
                   not check_for_language(json['default_language'])):
                    raise InvalidXLSForm(
                        ["Default language code '{}' unknown".format(
                            json['default_language']
                        )]
                    )
                is_multilingual = multilingual_label_check(json['children'])
                if is_multilingual and not has_default_language:
                    raise InvalidXLSForm(["Multilingual XLS forms must have "
                                          "a default_language setting"])
                instance.default_language = json['default_language']
                if instance.default_language == 'default':
                    instance.default_language = ''
                instance.filename = json.get('name')
                instance.title = json.get('title')
                instance.id_string = json.get('id_string')
                instance.version = int(
                    datetime.utcnow().strftime('%Y%m%d%H%M%S%f')[:-4]
                )
                instance.md5_hash = self.get_hash(
                    instance.filename, instance.id_string, instance.version
                )

                survey = create_survey_element_from_dict(json)
                xml_form = survey.xml()
                fix_languages(xml_form)
                xml_form = xml_form.toxml()
                # insert version attr into the xform instance root node
                xml = self.insert_version_attribute(
                    xml_form, instance.filename, instance.version
                )
                name = os.path.join(instance.xml_form.field.upload_to,
                                    os.path.basename(instance.filename))
                url = instance.xml_form.storage.save(
                    '{}.xml'.format(name), xml)
                instance.xml_form = url

                instance.save()

                project.current_questionnaire = instance.id
                project.save()

                create_children(
                    children=json.get('children'),
                    errors=errors,
                    project=project,
                    default_language=instance.default_language,
                    kwargs={'questionnaire': instance}
                )

                # all these errors handled by PyXForm so turning off for now
                # if errors:
                #     raise InvalidXLSForm(errors)

                return instance

        except PyXFormError as e:
            raise InvalidXLSForm([str(e)])

Example 59

Project: cadasta-platform
Source File: model_helper.py
View license
    def upload_submission_data(self, request):
        if 'xml_submission_file' not in request.data.keys():
            raise InvalidXMLSubmission(_('XML submission not found'))

        xml_submission_file = request.data['xml_submission_file'].read()
        full_submission = XFormToDict(
            xml_submission_file.decode('utf-8')).get_dict()

        submission = full_submission[list(full_submission.keys())[0]]

        with transaction.atomic():
            (questionnaire,
             parties, party_resources,
             locations, location_resources,
             tenure_relationships, tenure_resources
             ) = self.create_models(submission)

            party_submissions = [submission]
            location_submissions = [submission]
            tenure_submissions = [submission]

            if 'party_repeat' in submission:
                party_submissions = self._format_repeat(submission, ['party'])
                if 'tenure_type' in party_submissions[0]:
                    tenure_submissions = party_submissions

            elif 'location_repeat' in submission:
                location_submissions = self._format_repeat(
                    submission, ['location']
                )
                if 'tenure_type' in location_submissions[0]:
                    tenure_submissions = location_submissions

            party_resource_files = []
            location_resource_files = []
            tenure_resource_files = []

            for group in party_submissions:
                party_resource_files.extend(
                    self._get_resource_files(group, 'party')
                )

            for group in location_submissions:
                location_resource_files.extend(
                    self._get_resource_files(group, 'location')
                )

            for group in tenure_submissions:
                tenure_resource_files.extend(
                    self._get_resource_files(group, 'tenure')
                )

            resource_data = {
                'project': questionnaire.project,
                'location_resources': location_resource_files,
                'locations': location_resources,
                'party_resources': party_resource_files,
                'parties': party_resources,
                'tenure_resources': tenure_resource_files,
                'tenures': tenure_resources,
            }
            self.upload_resource_files(request, resource_data)

        if XFormSubmission.objects.filter(
                instanceID=submission['meta']['instanceID']).exists():
            return XFormSubmission.objects.get(
                instanceID=submission['meta']['instanceID'])

        xform_submission = XFormSubmission(
            json_submission=full_submission,
            user=request.user,
            questionnaire=questionnaire,
            instanceID=submission['meta']['instanceID']
            )
        return xform_submission, parties, locations, tenure_relationships

Example 60

Project: django-timepiece
Source File: views.py
View license
@login_required
@transaction.atomic
def create_invoice(request):
    pk = request.GET.get('project', None)
    to_date = request.GET.get('to_date', None)
    if not (pk and to_date):
        raise Http404
    from_date = request.GET.get('from_date', None)
    if not request.user.has_perm('crm.generate_project_invoice'):
        return HttpResponseForbidden('Forbidden')
    try:
        to_date = utils.add_timezone(
            datetime.datetime.strptime(to_date, '%Y-%m-%d'))
        if from_date:
            from_date = utils.add_timezone(
                datetime.datetime.strptime(from_date, '%Y-%m-%d'))
    except (ValueError, OverflowError):
        raise Http404
    project = get_object_or_404(Project, pk=pk)
    initial = {
        'project': project,
        'user': request.user,
        'from_date': from_date,
        'to_date': to_date,
    }
    entries_query = {
        'status': Entry.APPROVED,
        'end_time__lt': to_date + relativedelta(days=1),
        'project__id': project.id
    }
    if from_date:
        entries_query.update({'end_time__gte': from_date})
    invoice_form = InvoiceForm(request.POST or None, initial=initial)
    if request.POST and invoice_form.is_valid():
        entries = Entry.no_join.filter(**entries_query)
        if entries.exists():
            # LOCK the entries until our transaction completes - nobody
            # else will be able to lock or change them - see
            # https://docs.djangoproject.com/en/1.4/ref/models/querysets/#select-for-update
            # (This feature requires Django 1.4.)
            # If more than one request is trying to create an invoice from
            # these same entries, then the second one to get to this line will
            # throw a DatabaseError.  That can happen if someone double-clicks
            # the Create Invoice button.
            try:
                entries.select_for_update(nowait=True)
            except DatabaseError:
                # Whoops, we lost the race
                messages.add_message(request, messages.ERROR,
                                     "Lock error trying to get entries")
            else:
                # We got the lock, we can carry on
                invoice = invoice_form.save()
                Entry.no_join.filter(pk__in=entries).update(
                    status=invoice.status, entry_group=invoice)
                messages.add_message(request, messages.INFO,
                                     "Invoice created")
                return HttpResponseRedirect(reverse('view_invoice',
                                                    args=[invoice.pk]))
        else:
            messages.add_message(request, messages.ERROR,
                                 "No entries for invoice")
    else:
        entries = Entry.objects.filter(**entries_query)
        entries = entries.order_by('start_time')
        if not entries:
            raise Http404

    billable_entries = entries.filter(activity__billable=True) \
        .select_related()
    nonbillable_entries = entries.filter(activity__billable=False) \
        .select_related()
    return render(request, 'timepiece/invoice/create.html', {
        'invoice_form': invoice_form,
        'billable_entries': billable_entries,
        'nonbillable_entries': nonbillable_entries,
        'project': project,
        'billable_totals': HourGroup.objects.summaries(billable_entries),
        'nonbillable_totals': HourGroup.objects.summaries(nonbillable_entries),
        'from_date': from_date,
        'to_date': to_date,
    })

Example 61

Project: django-mysql
Source File: test_operations.py
View license
    def set_up_test_model(
            self, app_label, second_model=False, third_model=False,
            related_model=False, mti_model=False, proxy_model=False,
            unique_together=False, options=False, db_table=None,
            index_together=False):
        """
        Creates a test model state and database table.
        """
        # Delete the tables if they already exist
        table_names = [
            # Start with ManyToMany tables
            '_pony_stables', '_pony_vans',
            # Then standard model tables
            '_pony', '_stable', '_van',
        ]
        tables = [(app_label + table_name) for table_name in table_names]
        with connection.cursor() as cursor:
            table_names = connection.introspection.table_names(cursor)
            connection.disable_constraint_checking()
            sql_delete_table = connection.schema_editor().sql_delete_table
            with transaction.atomic():
                for table in tables:
                    if table in table_names:
                        cursor.execute(sql_delete_table % {
                            "table": connection.ops.quote_name(table),
                        })
            connection.enable_constraint_checking()

        # Make the "current" state
        model_options = {
            "swappable": "TEST_SWAP_MODEL",
            "index_together": [["weight", "pink"]] if index_together else [],
            "unique_together": [["pink", "weight"]] if unique_together else [],
        }
        if options:
            model_options["permissions"] = [("can_groom", "Can groom")]
        if db_table:
            model_options["db_table"] = db_table
        operations = [migrations.CreateModel(
            "Pony",
            [
                ("id", models.AutoField(primary_key=True)),
                ("pink", models.IntegerField(default=3)),
                ("weight", models.FloatField()),
            ],
            options=model_options,
        )]
        if second_model:
            operations.append(migrations.CreateModel(
                "Stable",
                [
                    ("id", models.AutoField(primary_key=True)),
                ]
            ))
        if third_model:
            operations.append(migrations.CreateModel(
                "Van",
                [
                    ("id", models.AutoField(primary_key=True)),
                ]
            ))
        if related_model:
            operations.append(migrations.CreateModel(
                "Rider",
                [
                    ("id", models.AutoField(primary_key=True)),
                    ("pony", models.ForeignKey("Pony")),
                    ("friend", models.ForeignKey("self"))
                ],
            ))
        if mti_model:
            operations.append(migrations.CreateModel(
                "ShetlandPony",
                fields=[
                    ('pony_ptr', models.OneToOneField(
                        auto_created=True,
                        primary_key=True,
                        to_field='id',
                        serialize=False,
                        to='Pony',
                    )),
                    ("cuteness", models.IntegerField(default=1)),
                ],
                bases=['%s.Pony' % app_label],
            ))
        if proxy_model:
            operations.append(migrations.CreateModel(
                "ProxyPony",
                fields=[],
                options={"proxy": True},
                bases=['%s.Pony' % app_label],
            ))

        return self.apply_operations(app_label, ProjectState(), operations)

Example 62

Project: plan
Source File: views.py
View license
def select_course(request, year, semester_type, slug, add=False):
    '''Handle selecting of courses from course list, change of names and
       removeall of courses'''

    # FIXME split ut three sub functions into seperate functions?

    try:
        semester = Semester.objects.get(year=year, type=semester_type)
    except Semester.DoesNotExist:
        return shortcuts.redirect(
            'schedule', year, Semester.localize(semester_type), slug)

    if request.method == 'POST':
        if 'submit_add' in request.POST or add:
            lookup = []

            for l in request.POST.getlist('course_add'):
                lookup.extend(l.replace(',', '').split())

            subscriptions = set(Subscription.objects.get_subscriptions(year,
                semester_type, slug).values_list('course__code', flat=True))

            if not lookup:
                localized_semester = Semester.localize(semester_type)
                return shortcuts.redirect(
                    'schedule-advanced', year, localized_semester, slug)

            errors = []
            to_many_subscriptions = False

            student, created = Student.objects.get_or_create(slug=slug)

            for l in lookup:
                try:
                    if len(subscriptions) > settings.TIMETABLE_MAX_COURSES:
                        to_many_subscriptions = True
                        break

                    course = Course.objects.get(
                            code__iexact=l.strip(),
                            semester__year__exact=year,
                            semester__type__exact=semester_type,
                        )

                    Subscription.objects.get_or_create(
                            student=student,
                            course=course,
                        )
                    subscriptions.add(course.code)

                except Course.DoesNotExist:
                    errors.append(l)

            if errors or to_many_subscriptions:
                return shortcuts.render(request, 'error.html', {
                        'courses': errors,
                        'max': settings.TIMETABLE_MAX_COURSES,
                        'slug': slug,
                        'year': year,
                        'type': semester_type,
                        'to_many_subscriptions': to_many_subscriptions,
                    })

            return shortcuts.redirect(
                'change-groups', year, Semester.localize(semester_type), slug)

        elif 'submit_remove' in request.POST:
            with transaction.atomic():
                courses = []
                for c in request.POST.getlist('course_remove'):
                    if c.strip():
                        courses.append(c.strip())

                Subscription.objects.get_subscriptions(year, semester_type, slug). \
                        filter(course__id__in=courses).delete()

                if Subscription.objects.filter(student__slug=slug).count() == 0:
                    Student.objects.filter(slug=slug).delete()

        elif 'submit_name' in request.POST:
            subscriptions = Subscription.objects.get_subscriptions(year, semester_type, slug)

            for u in subscriptions:
                form = forms.CourseAliasForm(request.POST, prefix=u.course_id)

                if form.is_valid():
                    alias = form.cleaned_data['alias'].strip()

                    if alias.upper() == u.course.code.upper() or alias == "":
                        # Leave as blank if we match the current course name
                        alias = ""

                    u.alias = alias
                    u.save()

    return shortcuts.redirect(
        'schedule-advanced', year, Semester.localize(semester_type), slug)

Example 63

Project: SHARE
Source File: util.py
View license
def fetch_abstractcreativework(pks):
    if connection.connection is None:
        connection.cursor()

    with transaction.atomic():
        with connection.connection.cursor(str(uuid.uuid4())) as c:
            c.execute('''
                SELECT json_build_object(
                'id', creativework.id
                , 'type', creativework.type
                , 'title', creativework.title
                , 'description', creativework.description
                , 'is_deleted', creativework.is_deleted
                , 'language', creativework.language
                , 'date_created', creativework.date_created
                , 'date_modified', creativework.date_modified
                , 'date_updated', creativework.date_updated
                , 'date_published', creativework.date_published
                , 'tags', COALESCE(tags, '{}')
                , 'links', COALESCE(links, '{}')
                , 'sources', sources
                , 'subjects', COALESCE(subjects, '{}')
                , 'associations', COALESCE(associations, '{}')
                , 'contributors', COALESCE(contributors, '{}'))
                FROM share_abstractcreativework AS creativework
                LEFT JOIN LATERAL(
                    SELECT json_agg(json_build_object('id', entity.id, 'type', entity.type, 'name', entity.name)) as associations
                    FROM share_association AS association
                    JOIN share_entity AS entity ON association.entity_id = entity.id
                    WHERE association.creative_work_id = creativework.id
                ) AS associations ON true
                LEFT JOIN LATERAL (
                    SELECT json_agg(json_build_object('type', link.type, 'url', link.url)) as links
                    FROM share_throughlinks AS throughlink
                    JOIN share_link AS link ON throughlink.link_id = link.id
                    WHERE throughlink.creative_work_id = creativework.id
                ) AS links ON true
                LEFT JOIN LATERAL (
                    SELECT array_agg(source.long_title) AS sources
                    FROM share_abstractcreativework_sources AS throughsources
                    JOIN share_shareuser AS source ON throughsources.shareuser_id = source.id
                    WHERE throughsources.abstractcreativework_id = creativework.id
                ) AS sources ON true
                LEFT JOIN LATERAL (
                    SELECT array_agg(tag.name) AS tags
                    FROM share_throughtags AS throughtag
                    JOIN share_tag AS tag ON throughtag.tag_id = tag.id
                    WHERE throughtag.creative_work_id = creativework.id
                ) AS tags ON true
                LEFT JOIN LATERAL (
                    SELECT array_agg(subject.name) AS subjects
                    FROM share_throughsubjects AS throughsubject
                    JOIN share_subject AS subject ON throughsubject.subject_id = subject.id
                    WHERE throughsubject.creative_work_id = creativework.id
                ) AS subjects ON true
                LEFT JOIN LATERAL (
                    SELECT json_agg(json_build_object(
                        'id', person.id
                        , 'order_cited', contributor.order_cited
                        , 'bibliographic', contributor.bibliographic
                        , 'cited_name', contributor.cited_name
                        , 'given_name', person.given_name
                        , 'family_name', person.family_name
                        , 'additional_name', person.additional_name
                        , 'suffix', person.suffix
                        , 'identifiers', COALESCE(identifiers, '[]'::json)
                    )) AS contributors
                    FROM share_contributor AS contributor
                    JOIN share_person AS person ON contributor.person_id = person.id
                    LEFT JOIN LATERAL (
                        SELECT json_agg(json_build_object('url', identifier.url, 'base_url', identifier.base_url)) AS identifiers
                        FROM share_throughidentifiers AS throughidentifier
                        JOIN share_identifier as identifier ON throughidentifier.identifier_id = identifier.id
                        WHERE throughidentifier.person_id = person.id
                    ) AS identifiers ON true
                    WHERE contributor.creative_work_id = creativework.id
                ) AS contributors ON true
                WHERE creativework.id IN %s
            ''', (tuple(pks), ))

            while True:
                data = c.fetchone()

                if not data:
                    return

                data = data[0]

                associations = {
                    k + 's': [{**e, 'type': k} for e in v]
                    for k, v in
                    itertools.groupby(data.pop('associations'), lambda x: x['type'].rpartition('.')[-1])
                }

                data['type'] = data['type'].rpartition('.')[-1]
                data['date'] = (data['date_published'] or data['date_updated'] or data['date_created'])

                data['lists'] = {
                    **associations,
                    'links': data.pop('links', []),
                    'contributors': sorted(data.pop('contributors', []), key=lambda x: x['order_cited']),
                }

                data['contributors'] = [
                    ' '.join(x for x in (p['given_name'], p['family_name'], p['additional_name'], p['suffix']) if x)
                    for p in data['lists']['contributors']
                ]

                yield {**data, **{k: [e['name'] for e in v] for k, v in associations.items()}}

Example 64

Project: cgstudiomap
Source File: fields.py
View license
def create_generic_related_manager(superclass, rel):
    """
    Factory function to create a manager that subclasses another manager
    (generally the default manager of a given model) and adds behaviors
    specific to generic relations.
    """

    class GenericRelatedObjectManager(superclass):
        def __init__(self, instance=None):
            super(GenericRelatedObjectManager, self).__init__()

            self.instance = instance

            self.model = rel.model

            content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
                instance, for_concrete_model=rel.field.for_concrete_model)
            self.content_type = content_type
            self.content_type_field_name = rel.field.content_type_field_name
            self.object_id_field_name = rel.field.object_id_field_name
            self.prefetch_cache_name = rel.field.attname
            self.pk_val = instance._get_pk_val()

            self.core_filters = {
                '%s__pk' % self.content_type_field_name: content_type.id,
                self.object_id_field_name: self.pk_val,
            }

        def __call__(self, **kwargs):
            # We use **kwargs rather than a kwarg argument to enforce the
            # `manager='manager_name'` syntax.
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_generic_related_manager(manager.__class__, rel)
            return manager_class(instance=self.instance)
        do_not_call_in_templates = True

        def __str__(self):
            return repr(self)

        def get_queryset(self):
            try:
                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
            except (AttributeError, KeyError):
                db = self._db or router.db_for_read(self.model, instance=self.instance)
                return super(GenericRelatedObjectManager, self).get_queryset().using(db).filter(**self.core_filters)

        def get_prefetch_queryset(self, instances, queryset=None):
            if queryset is None:
                queryset = super(GenericRelatedObjectManager, self).get_queryset()

            queryset._add_hints(instance=instances[0])
            queryset = queryset.using(queryset._db or self._db)

            query = {
                '%s__pk' % self.content_type_field_name: self.content_type.id,
                '%s__in' % self.object_id_field_name: set(obj._get_pk_val() for obj in instances)
            }

            # We (possibly) need to convert object IDs to the type of the
            # instances' PK in order to match up instances:
            object_id_converter = instances[0]._meta.pk.to_python
            return (queryset.filter(**query),
                    lambda relobj: object_id_converter(getattr(relobj, self.object_id_field_name)),
                    lambda obj: obj._get_pk_val(),
                    False,
                    self.prefetch_cache_name)

        def add(self, *objs, **kwargs):
            bulk = kwargs.pop('bulk', True)
            db = router.db_for_write(self.model, instance=self.instance)

            def check_and_update_obj(obj):
                if not isinstance(obj, self.model):
                    raise TypeError("'%s' instance expected, got %r" % (
                        self.model._meta.object_name, obj
                    ))
                setattr(obj, self.content_type_field_name, self.content_type)
                setattr(obj, self.object_id_field_name, self.pk_val)

            if bulk:
                pks = []
                for obj in objs:
                    if obj._state.adding or obj._state.db != db:
                        raise ValueError(
                            "%r instance isn't saved. Use bulk=False or save "
                            "the object first. but must be." % obj
                        )
                    check_and_update_obj(obj)
                    pks.append(obj.pk)

                self.model._base_manager.using(db).filter(pk__in=pks).update(**{
                    self.content_type_field_name: self.content_type,
                    self.object_id_field_name: self.pk_val,
                })
            else:
                with transaction.atomic(using=db, savepoint=False):
                    for obj in objs:
                        check_and_update_obj(obj)
                        obj.save()
        add.alters_data = True

        def remove(self, *objs, **kwargs):
            if not objs:
                return
            bulk = kwargs.pop('bulk', True)
            self._clear(self.filter(pk__in=[o.pk for o in objs]), bulk)
        remove.alters_data = True

        def clear(self, **kwargs):
            bulk = kwargs.pop('bulk', True)
            self._clear(self, bulk)
        clear.alters_data = True

        def _clear(self, queryset, bulk):
            db = router.db_for_write(self.model, instance=self.instance)
            queryset = queryset.using(db)
            if bulk:
                # `QuerySet.delete()` creates its own atomic block which
                # contains the `pre_delete` and `post_delete` signal handlers.
                queryset.delete()
            else:
                with transaction.atomic(using=db, savepoint=False):
                    for obj in queryset:
                        obj.delete()
        _clear.alters_data = True

        def set(self, objs, **kwargs):
            # Force evaluation of `objs` in case it's a queryset whose value
            # could be affected by `manager.clear()`. Refs #19816.
            objs = tuple(objs)

            bulk = kwargs.pop('bulk', True)
            clear = kwargs.pop('clear', False)

            db = router.db_for_write(self.model, instance=self.instance)
            with transaction.atomic(using=db, savepoint=False):
                if clear:
                    self.clear()
                    self.add(*objs, bulk=bulk)
                else:
                    old_objs = set(self.using(db).all())
                    new_objs = []
                    for obj in objs:
                        if obj in old_objs:
                            old_objs.remove(obj)
                        else:
                            new_objs.append(obj)

                    self.remove(*old_objs)
                    self.add(*new_objs, bulk=bulk)
        set.alters_data = True

        def create(self, **kwargs):
            kwargs[self.content_type_field_name] = self.content_type
            kwargs[self.object_id_field_name] = self.pk_val
            db = router.db_for_write(self.model, instance=self.instance)
            return super(GenericRelatedObjectManager, self).using(db).create(**kwargs)
        create.alters_data = True

        def get_or_create(self, **kwargs):
            kwargs[self.content_type_field_name] = self.content_type
            kwargs[self.object_id_field_name] = self.pk_val
            db = router.db_for_write(self.model, instance=self.instance)
            return super(GenericRelatedObjectManager, self).using(db).get_or_create(**kwargs)
        get_or_create.alters_data = True

        def update_or_create(self, **kwargs):
            kwargs[self.content_type_field_name] = self.content_type
            kwargs[self.object_id_field_name] = self.pk_val
            db = router.db_for_write(self.model, instance=self.instance)
            return super(GenericRelatedObjectManager, self).using(db).update_or_create(**kwargs)
        update_or_create.alters_data = True

    return GenericRelatedObjectManager

Example 65

Project: cgstudiomap
Source File: related_descriptors.py
View license
def create_reverse_many_to_one_manager(superclass, rel):
    """
    Create a manager for the reverse side of a many-to-one relation.

    This manager subclasses another manager, generally the default manager of
    the related model, and adds behaviors specific to many-to-one relations.
    """

    class RelatedManager(superclass):
        def __init__(self, instance):
            super(RelatedManager, self).__init__()

            self.instance = instance
            self.model = rel.related_model
            self.field = rel.field

            self.core_filters = {self.field.name: instance}

        def __call__(self, **kwargs):
            # We use **kwargs rather than a kwarg argument to enforce the
            # `manager='manager_name'` syntax.
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_reverse_many_to_one_manager(manager.__class__, rel)
            return manager_class(self.instance)
        do_not_call_in_templates = True

        def get_queryset(self):
            try:
                return self.instance._prefetched_objects_cache[self.field.related_query_name()]
            except (AttributeError, KeyError):
                db = self._db or router.db_for_read(self.model, instance=self.instance)
                empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
                qs = super(RelatedManager, self).get_queryset()
                qs._add_hints(instance=self.instance)
                if self._db:
                    qs = qs.using(self._db)
                qs = qs.filter(**self.core_filters)
                for field in self.field.foreign_related_fields:
                    val = getattr(self.instance, field.attname)
                    if val is None or (val == '' and empty_strings_as_null):
                        return qs.none()
                qs._known_related_objects = {self.field: {self.instance.pk: self.instance}}
                return qs

        def get_prefetch_queryset(self, instances, queryset=None):
            if queryset is None:
                queryset = super(RelatedManager, self).get_queryset()

            queryset._add_hints(instance=instances[0])
            queryset = queryset.using(queryset._db or self._db)

            rel_obj_attr = self.field.get_local_related_value
            instance_attr = self.field.get_foreign_related_value
            instances_dict = {instance_attr(inst): inst for inst in instances}
            query = {'%s__in' % self.field.name: instances}
            queryset = queryset.filter(**query)

            # Since we just bypassed this class' get_queryset(), we must manage
            # the reverse relation manually.
            for rel_obj in queryset:
                instance = instances_dict[rel_obj_attr(rel_obj)]
                setattr(rel_obj, self.field.name, instance)
            cache_name = self.field.related_query_name()
            return queryset, rel_obj_attr, instance_attr, False, cache_name

        def add(self, *objs, **kwargs):
            bulk = kwargs.pop('bulk', True)
            objs = list(objs)
            db = router.db_for_write(self.model, instance=self.instance)

            def check_and_update_obj(obj):
                if not isinstance(obj, self.model):
                    raise TypeError("'%s' instance expected, got %r" % (
                        self.model._meta.object_name, obj,
                    ))
                setattr(obj, self.field.name, self.instance)

            if bulk:
                pks = []
                for obj in objs:
                    check_and_update_obj(obj)
                    if obj._state.adding or obj._state.db != db:
                        raise ValueError(
                            "%r instance isn't saved. Use bulk=False or save "
                            "the object first." % obj
                        )
                    pks.append(obj.pk)
                self.model._base_manager.using(db).filter(pk__in=pks).update(**{
                    self.field.name: self.instance,
                })
            else:
                with transaction.atomic(using=db, savepoint=False):
                    for obj in objs:
                        check_and_update_obj(obj)
                        obj.save()
        add.alters_data = True

        def create(self, **kwargs):
            kwargs[self.field.name] = self.instance
            db = router.db_for_write(self.model, instance=self.instance)
            return super(RelatedManager, self.db_manager(db)).create(**kwargs)
        create.alters_data = True

        def get_or_create(self, **kwargs):
            kwargs[self.field.name] = self.instance
            db = router.db_for_write(self.model, instance=self.instance)
            return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
        get_or_create.alters_data = True

        def update_or_create(self, **kwargs):
            kwargs[self.field.name] = self.instance
            db = router.db_for_write(self.model, instance=self.instance)
            return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs)
        update_or_create.alters_data = True

        # remove() and clear() are only provided if the ForeignKey can have a value of null.
        if rel.field.null:
            def remove(self, *objs, **kwargs):
                if not objs:
                    return
                bulk = kwargs.pop('bulk', True)
                val = self.field.get_foreign_related_value(self.instance)
                old_ids = set()
                for obj in objs:
                    # Is obj actually part of this descriptor set?
                    if self.field.get_local_related_value(obj) == val:
                        old_ids.add(obj.pk)
                    else:
                        raise self.field.remote_field.model.DoesNotExist(
                            "%r is not related to %r." % (obj, self.instance)
                        )
                self._clear(self.filter(pk__in=old_ids), bulk)
            remove.alters_data = True

            def clear(self, **kwargs):
                bulk = kwargs.pop('bulk', True)
                self._clear(self, bulk)
            clear.alters_data = True

            def _clear(self, queryset, bulk):
                db = router.db_for_write(self.model, instance=self.instance)
                queryset = queryset.using(db)
                if bulk:
                    # `QuerySet.update()` is intrinsically atomic.
                    queryset.update(**{self.field.name: None})
                else:
                    with transaction.atomic(using=db, savepoint=False):
                        for obj in queryset:
                            setattr(obj, self.field.name, None)
                            obj.save(update_fields=[self.field.name])
            _clear.alters_data = True

        def set(self, objs, **kwargs):
            # Force evaluation of `objs` in case it's a queryset whose value
            # could be affected by `manager.clear()`. Refs #19816.
            objs = tuple(objs)

            bulk = kwargs.pop('bulk', True)
            clear = kwargs.pop('clear', False)

            if self.field.null:
                db = router.db_for_write(self.model, instance=self.instance)
                with transaction.atomic(using=db, savepoint=False):
                    if clear:
                        self.clear()
                        self.add(*objs, bulk=bulk)
                    else:
                        old_objs = set(self.using(db).all())
                        new_objs = []
                        for obj in objs:
                            if obj in old_objs:
                                old_objs.remove(obj)
                            else:
                                new_objs.append(obj)

                        self.remove(*old_objs, bulk=bulk)
                        self.add(*new_objs, bulk=bulk)
            else:
                self.add(*objs, bulk=bulk)
        set.alters_data = True

    return RelatedManager

Example 66

Project: cgstudiomap
Source File: related_descriptors.py
View license
def create_forward_many_to_many_manager(superclass, rel, reverse):
    """
    Create a manager for the either side of a many-to-many relation.

    This manager subclasses another manager, generally the default manager of
    the related model, and adds behaviors specific to many-to-many relations.
    """

    class ManyRelatedManager(superclass):
        def __init__(self, instance=None):
            super(ManyRelatedManager, self).__init__()

            self.instance = instance

            if not reverse:
                self.model = rel.model
                self.query_field_name = rel.field.related_query_name()
                self.prefetch_cache_name = rel.field.name
                self.source_field_name = rel.field.m2m_field_name()
                self.target_field_name = rel.field.m2m_reverse_field_name()
                self.symmetrical = rel.symmetrical
            else:
                self.model = rel.related_model
                self.query_field_name = rel.field.name
                self.prefetch_cache_name = rel.field.related_query_name()
                self.source_field_name = rel.field.m2m_reverse_field_name()
                self.target_field_name = rel.field.m2m_field_name()
                self.symmetrical = False

            self.through = rel.through
            self.reverse = reverse

            self.source_field = self.through._meta.get_field(self.source_field_name)
            self.target_field = self.through._meta.get_field(self.target_field_name)

            self.core_filters = {}
            for lh_field, rh_field in self.source_field.related_fields:
                core_filter_key = '%s__%s' % (self.query_field_name, rh_field.name)
                self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)

            self.related_val = self.source_field.get_foreign_related_value(instance)
            if None in self.related_val:
                raise ValueError('"%r" needs to have a value for field "%s" before '
                                 'this many-to-many relationship can be used.' %
                                 (instance, self.source_field_name))
            # Even if this relation is not to pk, we require still pk value.
            # The wish is that the instance has been already saved to DB,
            # although having a pk value isn't a guarantee of that.
            if instance.pk is None:
                raise ValueError("%r instance needs to have a primary key value before "
                                 "a many-to-many relationship can be used." %
                                 instance.__class__.__name__)

        def __call__(self, **kwargs):
            # We use **kwargs rather than a kwarg argument to enforce the
            # `manager='manager_name'` syntax.
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_forward_many_to_many_manager(manager.__class__, rel, reverse)
            return manager_class(instance=self.instance)
        do_not_call_in_templates = True

        def _build_remove_filters(self, removed_vals):
            filters = Q(**{self.source_field_name: self.related_val})
            # No need to add a subquery condition if removed_vals is a QuerySet without
            # filters.
            removed_vals_filters = (not isinstance(removed_vals, QuerySet) or
                                    removed_vals._has_filters())
            if removed_vals_filters:
                filters &= Q(**{'%s__in' % self.target_field_name: removed_vals})
            if self.symmetrical:
                symmetrical_filters = Q(**{self.target_field_name: self.related_val})
                if removed_vals_filters:
                    symmetrical_filters &= Q(
                        **{'%s__in' % self.source_field_name: removed_vals})
                filters |= symmetrical_filters
            return filters

        def get_queryset(self):
            try:
                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
            except (AttributeError, KeyError):
                qs = super(ManyRelatedManager, self).get_queryset()
                qs._add_hints(instance=self.instance)
                if self._db:
                    qs = qs.using(self._db)
                return qs._next_is_sticky().filter(**self.core_filters)

        def get_prefetch_queryset(self, instances, queryset=None):
            if queryset is None:
                queryset = super(ManyRelatedManager, self).get_queryset()

            queryset._add_hints(instance=instances[0])
            queryset = queryset.using(queryset._db or self._db)

            query = {'%s__in' % self.query_field_name: instances}
            queryset = queryset._next_is_sticky().filter(**query)

            # M2M: need to annotate the query in order to get the primary model
            # that the secondary model was actually related to. We know that
            # there will already be a join on the join table, so we can just add
            # the select.

            # For non-autocreated 'through' models, can't assume we are
            # dealing with PK values.
            fk = self.through._meta.get_field(self.source_field_name)
            join_table = self.through._meta.db_table
            connection = connections[queryset.db]
            qn = connection.ops.quote_name
            queryset = queryset.extra(select={
                '_prefetch_related_val_%s' % f.attname:
                '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
            return (
                queryset,
                lambda result: tuple(
                    getattr(result, '_prefetch_related_val_%s' % f.attname)
                    for f in fk.local_related_fields
                ),
                lambda inst: tuple(
                    f.get_db_prep_value(getattr(inst, f.attname), connection)
                    for f in fk.foreign_related_fields
                ),
                False,
                self.prefetch_cache_name,
            )

        def add(self, *objs):
            if not rel.through._meta.auto_created:
                opts = self.through._meta
                raise AttributeError(
                    "Cannot use add() on a ManyToManyField which specifies an "
                    "intermediary model. Use %s.%s's Manager instead." %
                    (opts.app_label, opts.object_name)
                )

            db = router.db_for_write(self.through, instance=self.instance)
            with transaction.atomic(using=db, savepoint=False):
                self._add_items(self.source_field_name, self.target_field_name, *objs)

                # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table
                if self.symmetrical:
                    self._add_items(self.target_field_name, self.source_field_name, *objs)
        add.alters_data = True

        def remove(self, *objs):
            if not rel.through._meta.auto_created:
                opts = self.through._meta
                raise AttributeError(
                    "Cannot use remove() on a ManyToManyField which specifies "
                    "an intermediary model. Use %s.%s's Manager instead." %
                    (opts.app_label, opts.object_name)
                )
            self._remove_items(self.source_field_name, self.target_field_name, *objs)
        remove.alters_data = True

        def clear(self):
            db = router.db_for_write(self.through, instance=self.instance)
            with transaction.atomic(using=db, savepoint=False):
                signals.m2m_changed.send(sender=self.through, action="pre_clear",
                    instance=self.instance, reverse=self.reverse,
                    model=self.model, pk_set=None, using=db)

                filters = self._build_remove_filters(super(ManyRelatedManager, self).get_queryset().using(db))
                self.through._default_manager.using(db).filter(filters).delete()

                signals.m2m_changed.send(sender=self.through, action="post_clear",
                    instance=self.instance, reverse=self.reverse,
                    model=self.model, pk_set=None, using=db)
        clear.alters_data = True

        def set(self, objs, **kwargs):
            if not rel.through._meta.auto_created:
                opts = self.through._meta
                raise AttributeError(
                    "Cannot set values on a ManyToManyField which specifies an "
                    "intermediary model. Use %s.%s's Manager instead." %
                    (opts.app_label, opts.object_name)
                )

            # Force evaluation of `objs` in case it's a queryset whose value
            # could be affected by `manager.clear()`. Refs #19816.
            objs = tuple(objs)

            clear = kwargs.pop('clear', False)

            db = router.db_for_write(self.through, instance=self.instance)
            with transaction.atomic(using=db, savepoint=False):
                if clear:
                    self.clear()
                    self.add(*objs)
                else:
                    old_ids = set(self.using(db).values_list(self.target_field.target_field.attname, flat=True))

                    new_objs = []
                    for obj in objs:
                        fk_val = (self.target_field.get_foreign_related_value(obj)[0]
                            if isinstance(obj, self.model) else obj)

                        if fk_val in old_ids:
                            old_ids.remove(fk_val)
                        else:
                            new_objs.append(obj)

                    self.remove(*old_ids)
                    self.add(*new_objs)
        set.alters_data = True

        def create(self, **kwargs):
            # This check needs to be done here, since we can't later remove this
            # from the method lookup table, as we do with add and remove.
            if not self.through._meta.auto_created:
                opts = self.through._meta
                raise AttributeError(
                    "Cannot use create() on a ManyToManyField which specifies "
                    "an intermediary model. Use %s.%s's Manager instead." %
                    (opts.app_label, opts.object_name)
                )
            db = router.db_for_write(self.instance.__class__, instance=self.instance)
            new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
            self.add(new_obj)
            return new_obj
        create.alters_data = True

        def get_or_create(self, **kwargs):
            db = router.db_for_write(self.instance.__class__, instance=self.instance)
            obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs)
            # We only need to add() if created because if we got an object back
            # from get() then the relationship already exists.
            if created:
                self.add(obj)
            return obj, created
        get_or_create.alters_data = True

        def update_or_create(self, **kwargs):
            db = router.db_for_write(self.instance.__class__, instance=self.instance)
            obj, created = super(ManyRelatedManager, self.db_manager(db)).update_or_create(**kwargs)
            # We only need to add() if created because if we got an object back
            # from get() then the relationship already exists.
            if created:
                self.add(obj)
            return obj, created
        update_or_create.alters_data = True

        def _add_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Model
            if objs:
                new_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                (obj, self.instance._state.db, obj._state.db)
                            )
                        fk_val = self.through._meta.get_field(
                            target_field_name).get_foreign_related_value(obj)[0]
                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None' %
                                (obj, target_field_name)
                            )
                        new_ids.add(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError(
                            "'%s' instance expected, got %r" %
                            (self.model._meta.object_name, obj)
                        )
                    else:
                        new_ids.add(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                vals = (self.through._default_manager.using(db)
                        .values_list(target_field_name, flat=True)
                        .filter(**{
                            source_field_name: self.related_val[0],
                            '%s__in' % target_field_name: new_ids,
                        }))
                new_ids = new_ids - set(vals)

                with transaction.atomic(using=db, savepoint=False):
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through, action='pre_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db)

                    # Add the ones that aren't there already
                    self.through._default_manager.using(db).bulk_create([
                        self.through(**{
                            '%s_id' % source_field_name: self.related_val[0],
                            '%s_id' % target_field_name: obj_id,
                        })
                        for obj_id in new_ids
                    ])

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through, action='post_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db)

        def _remove_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK colname in join table for the source object
            # target_field_name: the PK colname in join table for the target object
            # *objs - objects to remove
            if not objs:
                return

            # Check that all the objects are of the right type
            old_ids = set()
            for obj in objs:
                if isinstance(obj, self.model):
                    fk_val = self.target_field.get_foreign_related_value(obj)[0]
                    old_ids.add(fk_val)
                else:
                    old_ids.add(obj)

            db = router.db_for_write(self.through, instance=self.instance)
            with transaction.atomic(using=db, savepoint=False):
                # Send a signal to the other end if need be.
                signals.m2m_changed.send(sender=self.through, action="pre_remove",
                    instance=self.instance, reverse=self.reverse,
                    model=self.model, pk_set=old_ids, using=db)
                target_model_qs = super(ManyRelatedManager, self).get_queryset()
                if target_model_qs._has_filters():
                    old_vals = target_model_qs.using(db).filter(**{
                        '%s__in' % self.target_field.target_field.attname: old_ids})
                else:
                    old_vals = old_ids
                filters = self._build_remove_filters(old_vals)
                self.through._default_manager.using(db).filter(filters).delete()

                signals.m2m_changed.send(sender=self.through, action="post_remove",
                    instance=self.instance, reverse=self.reverse,
                    model=self.model, pk_set=old_ids, using=db)

    return ManyRelatedManager

Example 67

Project: django-sellmo
Source File: views.py
View license
@chaining.define(provides=['order', 'process'])
def checkout(request, step=None, **kwargs):

    order = order_from_request(request)
    if not order:
        raise Http404("Nothing to order.")

    data = request.POST if request.method == 'POST' else None

    process = CheckoutProcess(request, order)

    # If a step is given, attempt to go to the given step
    if step:
        try:
            process.step_to(step)
        except ProcessStepNotFound:
            # If the step can not be found, fall back to
            # the latest step.
            step = None
        except ProcessError as error:
            raise Http404(error)

    # Go to the latest step
    if not step:
        try:
            process.step_to_latest()
        except ProcessError as error:
            raise Http404(error)

        yield redirect(
            reverse(
                'checkout:checkout',
                kwargs={'step': process.current_step.key},
                current_app=resolve(request.path).namespace
            )
        )

    elif data:
        # Perform atomic transactions at this point
        with transaction.atomic():
            success = process.feed(data)
            if order.may_change:
                order.calculate()

            if success:
                # Update order
                order.save()

                # Process step was successfull, redirect to next
                # step
                if not process.is_completed():
                    # Go to the next step
                    yield redirect(
                        reverse(
                            'checkout:checkout',
                            kwargs={'step': process.current_step.key},
                            current_app=resolve(request.path).namespace
                        )
                    )

        # See if we completed the process
        if process.is_completed():
            # Redirect away from this view
            request.session['completed_order'] = order.pk
            yield redirect(
                reverse(
                    'checkout:complete',
                    current_app=resolve(request.path).namespace
                )
            )


    yield chaining.update(
        order=order,
        process=process)

    if (yield chaining.forward).result is None:
        context = {'order': order}
        try:
            result = process.render(request, context=context)
        except ProcessError as error:
            raise Http404(error)

        if result is None:
            raise ViewNotImplemented

Example 68

Project: django-sellmo
Source File: views.py
View license
@chaining.define(provides=['customer', 'process'])
def registration(request, step=None, **kwargs):

    customer = customer_from_request(request)
    if request.user.is_authenticated():
        user = request.user
    else:
        user = apps.get_model(settings.AUTH_USER_MODEL)()
    customer.user = user

    data = request.POST if request.method == 'POST' else None

    process = get_registration_process(request, customer)

    # If a step is given, attempt to go to the given step
    if step:
        try:
            process.step_to(step)
        except ProcessStepNotFound:
            # If the step can not be found, fall back to
            # the latest step.
            step = None
        except ProcessError as error:
            raise Http404(error)

    # Go to the latest step
    if not step:
        try:
            process.step_to_latest()
        except ProcessError as error:
            raise Http404(error)

        # Create the redirection
        yield redirect(
            reverse(
                'account:registration',
                kwargs={'step': process.current_step.key},
                current_app=resolve(request.path).namespace
            )
        )

    elif data:
        # Perform atomic transactions at this point
        with transaction.atomic():
            if process.feed(data):
                # Update customer
                user = customer.user
                user.save()
                customer.user = user
                customer.save()

                # Process step was successfull, redirect to next
                # step
                if not process.is_completed():
                    # Go to the next step
                    yield redirect(
                        reverse(
                            'account:registration',
                            kwargs={'step': process.current_step.key},
                            current_app=resolve(request.path).namespace
                        )
                    )

        # See if we completed the process
        if process.is_completed():
            # Assign last completed order
            order = completed_order_from_request(request)
            if order is not None:
                order.customer = customer
                order.save()

            # Redirect to login view
            yield redirect(
                reverse(
                    request.POST.get(
                        'next',
                        request.GET.get('next', LOGIN_VIEW))
                )
            )

    yield chaining.update(
        customer=customer,
        process=process)

    if (yield chaining.forward).result is None:
        context = {'customer': customer}
        try:
            result = process.render(request, context=context)
        except ProcessError as error:
            raise Http404(error)

        if result is None:
            raise ViewNotImplemented

Example 69

Project: django-sellmo
Source File: models.py
View license
    @transaction.atomic
    def sync(self, product):

        product = product.downcast()

        logger.info("Syncing variations for %s" % (product))

        # Get all variating attributes for this product
        attributes = _attribute.models.Attribute.objects.which_variate_product(
            product)

        # Do we have an attribute which groups ?
        group = attributes.filter(groups=True).first()

        # CONVERT TO LIST FOR PERFORMANCE AND INDEXING
        attributes = list(attributes)

        if attributes:
            # Create all possible variation combinations
            combinations = itertools.product(
                *[
                    (
                        row['value']
                        for row in
                        _attribute.models.Value.objects.which_variate_product(
                            product).for_attribute(attribute)
                        .smart_values('value')
                        .distinct()
                        .smart_sort(attribute, product)
                    )
                    for attribute in attributes
                ]
            )
        else:
            combinations = []

        # Create variations
        sort_order = 0
        created = []

        for combination in combinations:
            # Find all values which could be in this combination.
            # Values can be explicitly assigned to a variant, or
            # they could variate this product.
            q = Q()
            for attribute, value in zip(attributes, combination):
                q |= value_q(attribute, value)

            all_values = (
                _attribute.models.Value.objects.which_variate_product(product)
                .filter(q)
            )

            # Filter out implicits and explicits
            implicits = all_values.filter(variates=True)
            explicits = all_values.filter(variates=False)

            # Find all remaining variants which can be matched
            # against one ore more explicit values in this combination.
            variants = product.variants.filter(
                pk__in=(
                    explicits.order_by('product').distinct().values('product')
                )
            )

            # Find most explicit value combination
            grouped_explicitly = None
            most_explicit = _attribute.models.Value.objects.none()

            for variant in variants:
                current = explicits.filter(product=variant)

                # Make sure this variant does not belong to a different combination
                if current.count() != variant.values.which_variate(
                    product
                ).filter(variates=False).count():
                    continue

                if current.count() == len(combination):
                    # This must be the most explicit combination
                    most_explicit = current
                    break
                else:
                    # If variations are grouped, ignore the grouping
                    # attribute when determening most explicit variant.
                    a = current.exclude(attribute=group).count()
                    b = most_explicit.exclude(attribute=group).count()

                    # Try keep track of an explicit grouped value
                    # We only allow value combinations containing a single
                    # grouped value.
                    if group and not grouped_explicitly and a == 0:
                        try:
                            grouped_explicitly = current.get(attribute=group)
                        except _attribute.models.Value.DoesNotExist:
                            pass

                    if most_explicit.count() == 0 or a > b:
                        # Found more explicit match.. override
                        most_explicit = current

            explicits = most_explicit

            if explicits.count() > 0:
                # A variant did match
                variant = explicits[0].product.downcast()
            else:
                # If no variants are matched, variant equals product
                variant = product

            values = _attribute.models.Value.objects.none()
            # Resolve actual values
            if variant == product:
                # All values are implicit  since we don't have a variant
                values = implicits
            else:
                # Collect all values by combining implicit values and values
                # in the explicits queryset
                implicits = implicits.exclude(
                    attribute__in=explicits.values('attribute')
                )
                values = all_values.filter(
                    Q(pk__in=implicits) | Q(
                        pk__in=explicits
                    )
                )

                # Now account for grouping behaviour
                if group and grouped_explicitly and values.filter(
                    attribute=group
                ).count() == 0:
                    # Seem like the grouped value was left out due to a more
                    # explicit value combination, filter again
                    values = all_values.filter(
                        Q(pk__in=implicits) | Q(
                            pk__in=explicits
                        ) | Q(pk=grouped_explicitly.pk)
                    )

            # Make sure this combination actually exists
            if values.count() != len(combination):
                continue

            # Generate a unique key (uses slug format) for this variation
            variation_key = values_slug(
                values,
                prefix=variant.slug,
                full=True
            )

            # Generate description (does not use all values only implicits).
            # This because unicode(variant) already includes explicit values.
            variation_description = values_description(
                implicits,
                prefix=unicode(variant)
            )

            try:
                # See if variation already exists
                variation = _variation.models.Variation.objects.get(
                    id=variation_key,
                    product=product
                )
            except _variation.models.Variation.DoesNotExist:
                # Create
                variation = _variation.models.Variation.objects.create(
                    id=variation_key,
                    description=variation_description,
                    product=product,
                    variant=variant,
                    sort_order=sort_order
                )
            else:
                # Update
                variation.variant = variant
                variation.sort_order = sort_order
                variation.description = variation_description
                variation.save()

            variation.values.add(*values)
            sort_order += 1

            # Make sure this variation does not get deleted
            created.append(variation_key)

        # Handle grouping
        if group:
            # We need to find a single variant which is common across all
            # variations in this group
            for value in [
                row['value']
                for row in
                _attribute.models.Value.objects.which_variate_product(
                    product
                ).for_attribute(attribute=group)
                .smart_values('value').distinct().order_by()
            ]:

                # Get variations for this grouped attribute / value combination
                qargs = {
                    'values__attribute': group,
                    'values__{0}'.format(
                        group.get_type().get_value_field_name(
                        )
                    ): value,
                }
                variations = _variation.models.Variation.objects.filter(
                    product=product
                ).filter(**qargs)

                # Get variant
                qargs = [product_q(group, value)]
                if variations.count() > 1:
                    # Get single variant common across all variations
                    for attribute in attributes:
                        if attribute != group:
                            qargs.append(~product_q(attribute))

                try:
                    variant = product.variants.get(*qargs)
                except _product.models.Product.DoesNotExist:
                    variant = product
                except _product.models.Product.MultipleObjectsReturned:
                    variant = product
                    logger.warning(
                        "Product {product} has multiple variants "
                        "which conflict with the following "
                        "attribute/value combinations: "
                        "{attribute}/{value}".format(
                            product=product,
                            attribute=group,
                            value=value
                        )
                    )

                variations.update(group_variant=variant)

        # Delete any stale variations
        stale = self.filter(product=product) \
                    .exclude(pk__in=created)
        stale.delete()

        # Finally update product
        product.variations_synced = True
        product.save()

        variations_synced.send(self, product=product)

Example 70

Project: django-djangui
Source File: utils.py
View license
@transaction.atomic
def create_job_fileinfo(job):
    parameters = job.get_parameters()
    from ..models import DjanguiFile
    # first, create a reference to things the script explicitly created that is a parameter
    files = []
    for field in parameters:
        try:
            if field.parameter.form_field == 'FileField':
                value = field.value
                if value is None:
                    continue
                if isinstance(value, six.string_types):
                    # check if this was ever created and make a fileobject if so
                    if get_storage(local=True).exists(value):
                        if not get_storage(local=False).exists(value):
                            get_storage(local=False).save(value, File(get_storage(local=True).open(value)))
                        value = field.value
                    else:
                        field.force_value(None)
                        field.save()
                        continue
                d = {'parameter': field, 'file': value}
                files.append(d)
        except ValueError:
            continue

    known_files = {i['file'].name for i in files}
    # add the user_output files, these are things which may be missed by the model fields because the script
    # generated them without an explicit arguments reference in the script
    file_groups = {'archives': []}
    absbase = os.path.join(settings.MEDIA_ROOT, job.save_path)
    for filename in os.listdir(absbase):
        new_name = os.path.join(job.save_path, filename)
        if any([i.endswith(new_name) for i in known_files]):
            continue
        try:
            filepath = os.path.join(absbase, filename)
            if os.path.isdir(filepath):
                continue
            d = {'name': filename, 'file': get_storage_object(os.path.join(job.save_path, filename))}
            if filename.endswith('.tar.gz') or filename.endswith('.zip'):
                file_groups['archives'].append(d)
            else:
                files.append(d)
        except IOError:
            sys.stderr.format('{}'.format(traceback.format_exc()))
            continue

    # establish grouping by inferring common things
    file_groups['all'] = files
    import imghdr
    file_groups['images'] = []
    for filemodel in files:
        if imghdr.what(filemodel['file'].path):
            file_groups['images'].append(filemodel)
    file_groups['tabular'] = []
    file_groups['fasta'] = []

    for filemodel in files:
        fileinfo = get_file_info(filemodel['file'].path)
        filetype = fileinfo.get('type')
        if filetype is not None:
            file_groups[filetype].append(dict(filemodel, **{'preview': fileinfo.get('preview')}))
        else:
            filemodel['preview'] = json.dumps(None)

    # Create our DjanguiFile models

    # mark things that are in groups so we don't add this to the 'all' category too to reduce redundancy
    grouped = set([i['file'].path for file_type, groups in six.iteritems(file_groups) for i in groups if file_type != 'all'])
    for file_type, group_files in six.iteritems(file_groups):
        for group_file in group_files:
            if file_type == 'all' and group_file['file'].path in grouped:
                continue
            try:
                preview = group_file.get('preview')
                dj_file = DjanguiFile(job=job, filetype=file_type, filepreview=preview,
                                    parameter=group_file.get('parameter'))
                filepath = group_file['file'].path
                save_path = job.get_relative_path(filepath)
                dj_file.filepath.name = save_path
                dj_file.save()
            except:
                sys.stderr.write('Error in saving DJFile: {}\n'.format(traceback.format_exc()))
                continue

Example 71

Project: django-djangui
Source File: tasks.py
View license
@celery_app.task(base=DjanguiTask)
def submit_script(**kwargs):
    job_id = kwargs.pop('djangui_job')
    resubmit = kwargs.pop('djangui_resubmit', False)
    rerun = kwargs.pop('rerun', False)
    from .backend import utils
    from .models import DjanguiJob, DjanguiFile
    job = DjanguiJob.objects.get(pk=job_id)

    command = utils.get_job_commands(job=job)
    if resubmit:
        # clone ourselves, setting pk=None seems hackish but it works
        job.pk = None

    # This is where the script works from -- it is what is after the media_root since that may change between
    # setups/where our user uploads are stored.
    cwd = job.get_output_path()

    abscwd = os.path.abspath(os.path.join(settings.MEDIA_ROOT, cwd))
    job.command = ' '.join(command)
    job.save_path = cwd

    if rerun:
        # cleanup the old files, we need to be somewhat aggressive here.
        local_storage = utils.get_storage(local=True)
        remote_storage = utils.get_storage(local=False)
        to_delete = []
        with atomic():
            for dj_file in DjanguiFile.objects.filter(job=job):
                if dj_file.parameter is None or dj_file.parameter.parameter.is_output:
                    to_delete.append(dj_file)
                    path = local_storage.path(dj_file.filepath.name)
                    dj_file.filepath.delete(False)
                    if local_storage.exists(path):
                        local_storage.delete(path)
                    # TODO: This needs to be tested to make sure it's being nuked
                    if remote_storage.exists(path):
                        remote_storage.delete(path)
            [i.delete() for i in to_delete]

    utils.mkdirs(abscwd)
    # make sure we have the script, otherwise download it. This can happen if we have an ephemeral file system or are
    # executing jobs on a worker node.
    script_path = job.script.script_path
    if not utils.get_storage(local=True).exists(script_path.path):
        utils.get_storage(local=True).save(script_path.path, script_path.file)

    job.status = DjanguiJob.RUNNING
    job.save()

    proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=abscwd)

    stdout, stderr = proc.communicate()
    # tar/zip up the generated content for bulk downloads
    def get_valid_file(cwd, name, ext):
        out = os.path.join(cwd, name)
        index = 0
        while os.path.exists(six.u('{}.{}').format(out, ext)):
            index += 1
            out = os.path.join(cwd, six.u('{}_{}').format(name, index))
        return six.u('{}.{}').format(out, ext)

    # fetch the job again in case the database connection was lost during the job or something else changed.
    job = DjanguiJob.objects.get(pk=job_id)

    # if there are files generated, make zip/tar files for download
    if len(os.listdir(abscwd)):
        tar_out = get_valid_file(abscwd, get_valid_filename(job.job_name), 'tar.gz')
        tar = tarfile.open(tar_out, "w:gz")
        tar_name = os.path.splitext(os.path.splitext(os.path.split(tar_out)[1])[0])[0]
        tar.add(abscwd, arcname=tar_name)
        tar.close()

        zip_out = get_valid_file(abscwd, get_valid_filename(job.job_name), 'zip')
        zip = zipfile.ZipFile(zip_out, "w")
        arcname = os.path.splitext(os.path.split(zip_out)[1])[0]
        zip.write(abscwd, arcname=arcname)
        for root, folders, filenames in os.walk(os.path.split(zip_out)[0]):
            for filename in filenames:
                path = os.path.join(root, filename)
                if path == tar_out:
                    continue
                if path == zip_out:
                    continue
                zip.write(path, arcname=os.path.join(arcname, filename))
        zip.close()

        # save all the files generated as well to our default storage for ephemeral storage setups
        if djangui_settings.DJANGUI_EPHEMERAL_FILES:
            for root, folders, files in os.walk(abscwd):
                for filename in files:
                    filepath = os.path.join(root, filename)
                    s3path = os.path.join(root[root.find(cwd):], filename)
                    remote = utils.get_storage(local=False)
                    exists = remote.exists(s3path)
                    filesize = remote.size(s3path)
                    if not exists or (exists and filesize == 0):
                        if exists:
                            remote.delete(s3path)
                        remote.save(s3path, File(open(filepath, 'rb')))

    utils.create_job_fileinfo(job)


    job.stdout = stdout
    job.stderr = stderr
    job.status = DjanguiJob.COMPLETED
    job.save()

    return (stdout, stderr)

Example 72

Project: ADL_LRS
Source File: views.py
View license
@transaction.atomic
def handle_request(request, more_id=None):
    validators = {
        reverse('lrs:statements').lower(): {
            "POST": req_validate.statements_post,
            "GET": req_validate.statements_get,
            "PUT": req_validate.statements_put,
            "HEAD": req_validate.statements_get
        },
        reverse('lrs:statements_more_placeholder').lower(): {
            "GET": req_validate.statements_more_get,
            "HEAD": req_validate.statements_more_get
        },
        reverse('lrs:activity_state').lower(): {
            "POST": req_validate.activity_state_post,
            "PUT": req_validate.activity_state_put,
            "GET": req_validate.activity_state_get,
            "HEAD": req_validate.activity_state_get,
            "DELETE": req_validate.activity_state_delete
        },
        reverse('lrs:activity_profile').lower(): {
            "POST": req_validate.activity_profile_post,
            "PUT": req_validate.activity_profile_put,
            "GET": req_validate.activity_profile_get,
            "HEAD": req_validate.activity_profile_get,
            "DELETE": req_validate.activity_profile_delete
        },
        reverse('lrs:activities').lower(): {
            "GET": req_validate.activities_get,
            "HEAD": req_validate.activities_get
        },
        reverse('lrs:agent_profile').lower(): {
            "POST": req_validate.agent_profile_post,
            "PUT": req_validate.agent_profile_put,
            "GET": req_validate.agent_profile_get,
            "HEAD": req_validate.agent_profile_get,
            "DELETE": req_validate.agent_profile_delete
        },
        reverse('lrs:agents').lower(): {
            "GET": req_validate.agents_get,
            "HEAD": req_validate.agents_get
        }
    }
    processors = {
        reverse('lrs:statements').lower(): {
            "POST": req_process.statements_post,
            "GET": req_process.statements_get,
            "HEAD": req_process.statements_get,
            "PUT": req_process.statements_put
        },
        reverse('lrs:statements_more_placeholder').lower(): {
            "GET": req_process.statements_more_get,
            "HEAD": req_process.statements_more_get
        },
        reverse('lrs:activity_state').lower(): {
            "POST": req_process.activity_state_post,
            "PUT": req_process.activity_state_put,
            "GET": req_process.activity_state_get,
            "HEAD": req_process.activity_state_get,
            "DELETE": req_process.activity_state_delete
        },
        reverse('lrs:activity_profile').lower(): {
            "POST": req_process.activity_profile_post,
            "PUT": req_process.activity_profile_put,
            "GET": req_process.activity_profile_get,
            "HEAD": req_process.activity_profile_get,
            "DELETE": req_process.activity_profile_delete
        },
        reverse('lrs:activities').lower(): {
            "GET": req_process.activities_get,
            "HEAD": req_process.activities_get
        },
        reverse('lrs:agent_profile').lower(): {
            "POST": req_process.agent_profile_post,
            "PUT": req_process.agent_profile_put,
            "GET": req_process.agent_profile_get,
            "HEAD": req_process.agent_profile_get,
            "DELETE": req_process.agent_profile_delete
        },
        reverse('lrs:agents').lower(): {
            "GET": req_process.agents_get,
            "HEAD": req_process.agents_get
        }
    }

    try:
        r_dict = req_parse.parse(request, more_id)
        path = request.path.lower()
        if path.endswith('/'):
            path = path.rstrip('/')
        # Cutoff more_id
        if 'more' in path:
            path = "%s/%s" % (reverse('lrs:statements').lower(), "more")
        req_dict = validators[path][r_dict['method']](r_dict)
        return processors[path][req_dict['method']](req_dict)
    except (BadRequest, OauthBadRequest, HttpResponseBadRequest)  as err:
        log_exception(request.path, err)
        response = HttpResponse(err.message, status=400)
    except (Unauthorized, OauthUnauthorized) as autherr:
        log_exception(request.path, autherr)
        response = HttpResponse(autherr, status=401)
        response['WWW-Authenticate'] = 'Basic realm="ADLLRS"'
    except Forbidden as forb:
        log_exception(request.path, forb)
        response = HttpResponse(forb.message, status=403)
    except NotFound as nf:
        log_exception(request.path, nf)
        response = HttpResponse(nf.message, status=404)
    except Conflict as c:
        log_exception(request.path, c)
        response = HttpResponse(c.message, status=409)
    except PreconditionFail as pf:
        log_exception(request.path, pf)
        response = HttpResponse(pf.message, status=412)
    except Exception as err:
        log_exception(request.path, err)
        response = HttpResponse(err.message, status=500)   
    return response

Example 73

Project: hue
Source File: layermapping.py
View license
    def __init__(self, model, data, mapping, layer=0,
                 source_srs=None, encoding='utf-8',
                 transaction_mode='commit_on_success',
                 transform=True, unique=None, using=None):
        """
        A LayerMapping object is initialized using the given Model (not an instance),
        a DataSource (or string path to an OGR-supported data file), and a mapping
        dictionary.  See the module level docstring for more details and keyword
        argument usage.
        """
        # Getting the DataSource and the associated Layer.
        if isinstance(data, six.string_types):
            self.ds = DataSource(data, encoding=encoding)
        else:
            self.ds = data
        self.layer = self.ds[layer]

        self.using = using if using is not None else router.db_for_write(model)
        self.spatial_backend = connections[self.using].ops

        # Setting the mapping & model attributes.
        self.mapping = mapping
        self.model = model

        # Checking the layer -- intitialization of the object will fail if
        # things don't check out before hand.
        self.check_layer()

        # Getting the geometry column associated with the model (an
        # exception will be raised if there is no geometry column).
        if self.spatial_backend.mysql:
            transform = False
        else:
            self.geo_field = self.geometry_field()

        # Checking the source spatial reference system, and getting
        # the coordinate transformation object (unless the `transform`
        # keyword is set to False)
        if transform:
            self.source_srs = self.check_srs(source_srs)
            self.transform = self.coord_transform()
        else:
            self.transform = transform

        # Setting the encoding for OFTString fields, if specified.
        if encoding:
            # Making sure the encoding exists, if not a LookupError
            # exception will be thrown.
            from codecs import lookup
            lookup(encoding)
            self.encoding = encoding
        else:
            self.encoding = None

        if unique:
            self.check_unique(unique)
            transaction_mode = 'autocommit' # Has to be set to autocommit.
            self.unique = unique
        else:
            self.unique = None

        # Setting the transaction decorator with the function in the
        # transaction modes dictionary.
        self.transaction_mode = transaction_mode
        if transaction_mode == 'autocommit':
            self.transaction_decorator = None
        elif transaction_mode == 'commit_on_success':
            self.transaction_decorator = transaction.atomic
        else:
            raise LayerMapError('Unrecognized transaction mode: %s' % transaction_mode)

Example 74

Project: hue
Source File: tests.py
View license
    def test_filter(self):
        Company.objects.create(
            name="Example Inc.", num_employees=2300, num_chairs=5,
            ceo=Employee.objects.create(firstname="Joe", lastname="Smith")
        )
        Company.objects.create(
            name="Foobar Ltd.", num_employees=3, num_chairs=4,
            ceo=Employee.objects.create(firstname="Frank", lastname="Meyer")
        )
        Company.objects.create(
            name="Test GmbH", num_employees=32, num_chairs=1,
            ceo=Employee.objects.create(firstname="Max", lastname="Mustermann")
        )

        company_query = Company.objects.values(
            "name", "num_employees", "num_chairs"
        ).order_by(
            "name", "num_employees", "num_chairs"
        )

        # We can filter for companies where the number of employees is greater
        # than the number of chairs.
        self.assertQuerysetEqual(
            company_query.filter(num_employees__gt=F("num_chairs")), [
                {
                    "num_chairs": 5,
                    "name": "Example Inc.",
                    "num_employees": 2300,
                },
                {
                    "num_chairs": 1,
                    "name": "Test GmbH",
                    "num_employees": 32
                },
            ],
            lambda o: o
        )

        # We can set one field to have the value of another field
        # Make sure we have enough chairs
        company_query.update(num_chairs=F("num_employees"))
        self.assertQuerysetEqual(
            company_query, [
                {
                    "num_chairs": 2300,
                    "name": "Example Inc.",
                    "num_employees": 2300
                },
                {
                    "num_chairs": 3,
                    "name": "Foobar Ltd.",
                    "num_employees": 3
                },
                {
                    "num_chairs": 32,
                    "name": "Test GmbH",
                    "num_employees": 32
                }
            ],
            lambda o: o
        )

        # We can perform arithmetic operations in expressions
        # Make sure we have 2 spare chairs
        company_query.update(num_chairs=F("num_employees")+2)
        self.assertQuerysetEqual(
            company_query, [
                {
                    'num_chairs': 2302,
                    'name': 'Example Inc.',
                    'num_employees': 2300
                },
                {
                    'num_chairs': 5,
                    'name': 'Foobar Ltd.',
                    'num_employees': 3
                },
                {
                    'num_chairs': 34,
                    'name': 'Test GmbH',
                    'num_employees': 32
                }
            ],
            lambda o: o,
        )

        # Law of order of operations is followed
        company_query.update(
            num_chairs=F('num_employees') + 2 * F('num_employees')
        )
        self.assertQuerysetEqual(
            company_query, [
                {
                    'num_chairs': 6900,
                    'name': 'Example Inc.',
                    'num_employees': 2300
                },
                {
                    'num_chairs': 9,
                    'name': 'Foobar Ltd.',
                    'num_employees': 3
                },
                {
                    'num_chairs': 96,
                    'name': 'Test GmbH',
                    'num_employees': 32
                }
            ],
            lambda o: o,
        )

        # Law of order of operations can be overridden by parentheses
        company_query.update(
            num_chairs=((F('num_employees') + 2) * F('num_employees'))
        )
        self.assertQuerysetEqual(
            company_query, [
                {
                    'num_chairs': 5294600,
                    'name': 'Example Inc.',
                    'num_employees': 2300
                },
                {
                    'num_chairs': 15,
                    'name': 'Foobar Ltd.',
                    'num_employees': 3
                },
                {
                    'num_chairs': 1088,
                    'name': 'Test GmbH',
                    'num_employees': 32
                }
            ],
            lambda o: o,
        )

        # The relation of a foreign key can become copied over to an other
        # foreign key.
        self.assertEqual(
            Company.objects.update(point_of_contact=F('ceo')),
            3
        )
        self.assertQuerysetEqual(
            Company.objects.all(), [
                "Joe Smith",
                "Frank Meyer",
                "Max Mustermann",
            ],
            lambda c: six.text_type(c.point_of_contact),
            ordered=False
        )

        c = Company.objects.all()[0]
        c.point_of_contact = Employee.objects.create(firstname="Guido", lastname="van Rossum")
        c.save()

        # F Expressions can also span joins
        self.assertQuerysetEqual(
            Company.objects.filter(ceo__firstname=F("point_of_contact__firstname")), [
                "Foobar Ltd.",
                "Test GmbH",
            ],
            lambda c: c.name,
            ordered=False
        )

        Company.objects.exclude(
            ceo__firstname=F("point_of_contact__firstname")
        ).update(name="foo")
        self.assertEqual(
            Company.objects.exclude(
                ceo__firstname=F('point_of_contact__firstname')
            ).get().name,
            "foo",
        )

        with transaction.atomic():
            with self.assertRaises(FieldError):
                Company.objects.exclude(
                    ceo__firstname=F('point_of_contact__firstname')
                ).update(name=F('point_of_contact__lastname'))

        # F expressions can be used to update attributes on single objects
        test_gmbh = Company.objects.get(name="Test GmbH")
        self.assertEqual(test_gmbh.num_employees, 32)
        test_gmbh.num_employees = F("num_employees") + 4
        test_gmbh.save()
        test_gmbh = Company.objects.get(pk=test_gmbh.pk)
        self.assertEqual(test_gmbh.num_employees, 36)

        # F expressions cannot be used to update attributes which are foreign
        # keys, or attributes which involve joins.
        test_gmbh.point_of_contact = None
        test_gmbh.save()
        self.assertTrue(test_gmbh.point_of_contact is None)
        def test():
            test_gmbh.point_of_contact = F("ceo")
        self.assertRaises(ValueError, test)

        test_gmbh.point_of_contact = test_gmbh.ceo
        test_gmbh.save()
        test_gmbh.name = F("ceo__last_name")
        self.assertRaises(FieldError, test_gmbh.save)

        # F expressions cannot be used to update attributes on objects which do
        # not yet exist in the database
        acme = Company(
            name="The Acme Widget Co.", num_employees=12, num_chairs=5,
            ceo=test_gmbh.ceo
        )
        acme.num_employees = F("num_employees") + 16
        self.assertRaises(TypeError, acme.save)

Example 75

Project: 1flow
Source File: url.py
View license
    def absolutize_url(self, requests_response=None, force=False, commit=True):
        """ Make the current article URL absolute.

        Eg. transform:

        http://feedproxy.google.com/~r/francaistechcrunch/~3/hEIhLwVyEEI/

        into:

        http://techcrunch.com/2013/05/18/hell-no-tumblr-users-wont-go-to-yahoo/ # NOQA
            ?utm_source=feeurner&utm_medium=feed&utm_campaign=Feed%3A+francaistechcrunch+%28TechCrunch+en+Francais%29 # NOQA

        and then remove all these F*G utm_* parameters to get a clean
        final URL for the current article.

        Returns ``True`` if the operation succeeded, ``False`` if the
        absolutization pointed out that the current article is a
        duplicate of another. In this case the caller should stop its
        processing because the current article will be marked for deletion.

        Can also return ``None`` if absolutizing is disabled globally
        in ``constance`` configuration.
        """

        # Another example: http://rss.lefigaro.fr/~r/lefigaro/laune/~3/7jgyrQ-PmBA/story01.htm # NOQA

        if self.absolutize_url_must_abort(force=force, commit=commit):
            return

        if requests_response is None:
            try:
                requests_response = requests.get(self.url)

            except requests.ConnectionError as e:
                statsd.gauge('articles.counts.url_errors', 1, delta=True)
                message = u'Connection error while absolutizing “%s”: %s'
                args = (self.url, str(e), )

                self.url_error = message % args
                # Don't waste a version just for that.
                self.save_without_historical_record()

                LOGGER.error(message, *args)
                return

        if not requests_response.ok or requests_response.status_code != 200:

            message = u'HTTP Error %s while absolutizing “%s”: %s'
            args = (
                requests_response.status_code,
                requests_response.url,
                requests_response.reason
            )

            with statsd.pipeline() as spipe:
                spipe.gauge('articles.counts.url_errors', 1, delta=True)

                if requests_response.status_code in (404, ):
                    self.is_orphaned = True

                    # This is not handled by the post_save()
                    # which acts only at article creation.
                    spipe.gauge('articles.counts.orphaned', 1, delta=True)

            self.url_error = message % args

            # Don't waste a version just for that.
            self.save_without_historical_record()

            LOGGER.error(message, *args)
            return

        #
        # NOTE: we could also get it eventually from r.headers['link'],
        #       which contains '<another_url>'. We need to strip out
        #       the '<>', and re-absolutize this link, because in the
        #       example it's another redirector. Also r.links is a good
        #       candidate but in the example I used, it contains the
        #       shortlink, which must be re-resolved too.
        #
        #       So: as we already are at the final address *now*, no need
        #       bothering re-following another which would lead us to the
        #       the same final place.
        #

        final_url = clean_url(requests_response.url)

        # LOGGER.info(u'\n\nFINAL: %s vs. ORIG: %s\n\n', final_url, self.url)

        if final_url != self.url:

            # Just for displaying purposes, see below.
            old_url = self.url

            if self.url_error:
                statsd.gauge('articles.counts.url_errors', -1, delta=True)

            # Even if we are a duplicate, we came until here and everything
            # went fine. We won't need to lookup again the absolute URL.
            statsd.gauge('articles.counts.absolutes', 1, delta=True)
            self.url_absolute = True
            self.url_error = None

            self.url = final_url

            try:
                if self.name.endswith(old_url):
                    self.name = self.name.replace(old_url, final_url)
            except:
                LOGGER.exception(u'Could not replace URL in name of %s #%s',
                                 self._meta.model.__name__, self.id)

            duplicate = False

            with transaction.atomic():
                # Without the atomic() block, saving the current article
                # (beiing a duplicate) will trigger the IntegrityError,
                # but will render the current SQL context unusable, unable
                # to register duplicate, potentially leading to massive
                # inconsistencies in the caller's context.
                try:
                    # Don't waste a version just for that.
                    self.save_without_historical_record()

                except IntegrityError:
                    duplicate = True

            if duplicate:
                params = {
                    '%s___url' % self._meta.model.__name__: final_url
                }
                original = BaseItem.objects.get(**params)

                # Just to display the right “old” one in logs.
                self.url = old_url

                LOGGER.info(u'%s #%s is a duplicate of #%s, '
                            u'registering as such.',
                            self._meta.model.__name__, self.id, original.id)

                original.register_duplicate(self)
                return False

            # Any other exception will raise. This is intentional.
            else:
                LOGGER.info(u'URL of %s (#%s) successfully absolutized '
                            u'from %s to %s.', self._meta.model.__name__,
                            self.id, old_url, final_url)

        else:
            # Don't do the job twice.
            if self.url_error:
                statsd.gauge('articles.counts.url_errors', -1, delta=True)

            statsd.gauge('articles.counts.absolutes', 1, delta=True)
            self.url_absolute = True
            self.url_error = None

            # Don't waste a version just for that.
            self.save_without_historical_record()

        return True

Example 76

Project: 1flow
Source File: article.py
View license
    @classmethod
    def create_article(cls, title, url, feeds, **kwargs):
        """ Returns ``True`` if article created, ``False`` if a pure duplicate
            (already exists in the same feed), ``None`` if exists but not in
            the same feed. If more than one feed given, only returns ``True``
            or ``False`` (mutualized state is not checked). """

        tags = kwargs.pop('tags', [])

        if url is None:
            # We have to build a reliable orphaned URL, because orphaned
            # articles are often duplicates. RSS feeds serve us many times
            # the same article, without any URL, and we keep recording it
            # as new (but orphaned) content… Seen 20141111 on Chuck Norris
            # facts, where the content is in the title, and there is no URL.
            # We have 860k+ items, out of 1k real facts… Doomed.
            url = ARTICLE_ORPHANED_BASE + generate_orphaned_hash(title, feeds)

            defaults = {
                'name': title,
                'is_orphaned': True,

                # Skip absolutization, it's useless.
                'url_absolute': True
            }

            defaults.update(kwargs)

            article, created = cls.objects.get_or_create(url=url,
                                                         defaults=defaults)

            # HEADS UP: no statsd here, it's handled by post_save().

        else:
            url = clean_url(url)

            defaults = {'name': title}
            defaults.update(kwargs)

            article, created = cls.objects.get_or_create(url=url,
                                                         defaults=defaults)

        if created:
            created_retval = True

            LOGGER.info(u'Created %sarticle %s %s.', u'orphaned '
                        if article.is_orphaned else u'', article.id,
                        u'in feed(s) {0}'.format(_format_feeds(feeds))
                        if feeds else u'without any feed')

        else:
            created_retval = False

            if article.duplicate_of_id:
                LOGGER.info(u'Swaping duplicate %s %s for master %s on '
                            u'the fly.', article._meta.verbose_name,
                            article.id, article.duplicate_of_id)

                article = article.duplicate_of

            if len(feeds) == 1 and feeds[0] not in article.feeds.all():
                # This article is already there, but has not yet been
                # fetched for this feed. It's mutualized, and as such
                # it is considered at partly new. At least, it's not
                # as bad as being a true duplicate.
                created_retval = None

                LOGGER.info(u'Mutualized article %s in feed(s) %s.',
                            article.id, _format_feeds(feeds))

                article.create_reads(feeds=feeds)

            else:
                # No statsd, because we didn't create any record in database.
                LOGGER.info(u'Duplicate article %s in feed(s) %s.',
                            article.id, _format_feeds(feeds))

            # Special case where a mutualized article arrives from RSS
            # (with date/author) while it was already here from Twitter
            # (no date/author). Post-processing of original data will
            # handle the authors, but at lest we update the date now for
            # users to have sorted articles until original data is
            # post-processed (this can take time, given the server load).
            if article.date_published is None:
                date_published = kwargs.get('date_published', None)

                if date_published is not None:
                    article.date_published = date_published
                    article.save()

        # Tags & feeds are ManyToMany, they
        # need the article to be saved before.

        if tags:
            try:
                with transaction.atomic():
                    article.tags.add(*tags)

            except IntegrityError:
                LOGGER.exception(u'Could not add tags %s to article %s',
                                 tags, article.id)

        if feeds:
            try:
                with transaction.atomic():
                    article.feeds.add(*feeds)

            except:
                LOGGER.exception(u'Could not add feeds to article %s',
                                 article.id)

        # Get a chance to catch the duplicate if workers were fast.
        # At the cost of another DB read, this will save some work
        # in repair scripts, and avoid some writes when creating reads.
        article = cls.objects.get(id=article.id)

        if article.duplicate_of_id:
            if settings.DEBUG:
                LOGGER.debug(u'Catched on-the-fly duplicate %s, returning '
                             u'master %s instead.', article.id,
                             article.duplicate_of_id)

            return article.duplicate_of, False

        return article, created_retval

Example 77

Project: 1flow
Source File: email.py
View license
    @classmethod
    def create_email(cls, email_data, feeds, **kwargs):
        """ Returns ``True`` if email created, ``False`` if a pure duplicate
            (already exists in the same feed), ``None`` if exists but not in
            the same feed. If more than one feed given, only returns ``True``
            or ``False`` (mutualized state is not checked). """

        email = email_data.get('email')

        message_body = {
            'plain': u'',
            'html': u'',
        }

        name = None
        message_id = None

        for part in email.walk():
            if settings.DEBUG:
                # Print a log of debugging ouput about the email structure.
                part_content_type = part.get_content_type()
                LOGGER.debug(u'   |> part %s (%s keys)',
                             part_content_type, len(part.keys()))

                for key, value in part.items():
                    LOGGER.debug(u'      |> %s (len: %s): %s',
                                 key, len(value),
                                 unicode(value)[:40])

            if 'subject' in part and name is None:

                # Some subjects are long text wrapped at ~80 chars.
                name = part.get('subject').replace(
                    u'\r', u'').replace(u'\n', u'')

            if 'message-id' in part and message_id is None:
                message_id = part.get('message-id').strip()

            if part.is_multipart():
                # Multipart parts are just glue,
                # skip to the interesting parts.
                continue

            if not part.get_content_type().startswith('text'):
                LOGGER.error(u'Skipped e-mail %s part (not implemented yet).',
                             part.get_content_type())
                continue

            part_payload = part.get_payload(decode=True)
            charset = part.get_charset()

            if not bool(charset):
                content_type = part.get('Content-Type',

                                        # Using this as default allows
                                        # to have only one mechanic to
                                        # extract the charset.
                                        'text/plain; charset="utf-8"')
                content_type_parts = [
                    x.strip().lower() for x in content_type.split(';')
                ]

                if len(content_type_parts) != 2:
                    LOGGER.error(u'Could not get email part charset, thus '
                                 u'skipped. E-mail body will probably be '
                                 u'incomplete!')
                    continue

                charset = content_type_parts[1].split(
                    u'=')[1].replace('"', '').replace("'", "")

            # LOGGER.info(u'payload: %s of %s, charset=%s',
            #             len(part_payload), type(part_payload), charset)

            message_body_key = (
                'plain' if content_type_parts[0] == 'text/plain' else 'html'
            )

            # Concatenate every body part to get a full body.
            if isinstance(part_payload, str):
                try:
                    message_body[message_body_key] += part_payload.decode(
                        charset)

                except LookupError:
                    message_body[message_body_key] += part_payload.decode(
                        'utf-8', errors='replace')

            else:
                message_body[message_body_key] += part_payload

        # HTML content has precedence, because data will be richer.
        # In case of newsletters / mailing-lists, HTML content will
        # allow us to follow links, while text-only content will
        # [perhaps] not, or less easily. BTW, text/plain content will
        # not contain markdown links, and we will loose the ability
        # to render them in the GUI, while HTML will be converted to
        # Markdown as usual and the user will see richer content.
        if message_body['html']:
            content_type = CONTENT_TYPES.HTML
            content = message_body['html']

        else:
            content_type = CONTENT_TYPES.MARKDOWN
            content = message_body['plain']

        defaults = {
            'name': name,
            'origin': kwargs.pop('origin', ORIGINS.EMAIL),
            'date_published': email_data.get('date'),
            'content': content,
            'content_type': content_type,
        }

        defaults.update(kwargs)

        if message_id.startswith(u'<'):
            # Remove the <> around the ID.
            message_id = message_id[1:-1]

        email, created = cls.objects.get_or_create(message_id=message_id,
                                                   defaults=defaults)

        if created:
            LOGGER.info(u'Created email #%s in feed(s) %s.', message_id,
                        u', '.join(unicode(f) for f in feeds))

            if feeds:
                try:
                    with transaction.atomic():
                        email.feeds.add(*feeds)

                except IntegrityError:
                    LOGGER.exception(u'Integrity error on created email #%s',
                                     message_id)
                    pass

            od = email.add_original_data(
                'email',
                value=email_data.get('raw_email'),

                # We will commit at next call.
                commit=False,

                # We do not launch the post-processing
                # task, it's not implemented yet anyway.
                launch_task=False
            )

            email.add_original_data(
                'matching_rule',
                value=email_data.get('meta'),
                # Use the OD returned before to commit on.
                original_data=od,
                # No post-processing for matching rules.
                launch_task=False
            )

            return email, True

        # —————————————————————————————————————————————————————— existing email

        # Get a change to catch a duplicate if workers were fast.
        if email.duplicate_of_id:
            LOGGER.info(u'Swaping duplicate email #%s with master #%s on '
                        u'the fly.', email.id, email.duplicate_of_id)

            email = email.duplicate_of

        created_retval = False

        previous_feeds_count = email.feeds.count()

        try:
            with transaction.atomic():
                email.feeds.add(*feeds)

        except IntegrityError:
            # Race condition when backfill_if_needed() is run after
            # reception of first item in a stream, and they both create
            # the same email.
            LOGGER.exception(u'Integrity error when adding feeds %s to '
                             u'email #%s', feeds, message_id)

        else:
            if email.feeds.count() > previous_feeds_count:
                # This email is already there, but has not yet been
                # fetched for this feed. It's mutualized, and as such
                # it is considered at partly new. At least, it's not
                # as bad as being a true duplicate.
                created_retval = None

                LOGGER.info(u'Mutualized email #%s #%s in feed(s) %s.',
                            message_id, email.id,
                            u', '.join(unicode(f) for f in feeds))

                email.create_reads(feeds=feeds)

            else:
                # No statsd, because we didn't create any record in database.
                LOGGER.info(u'Duplicate email “%s” #%s #%s in feed(s) %s.',
                            name, message_id, email.id,
                            u', '.join(unicode(f) for f in feeds))

        return email, created_retval

Example 78

Project: 1flow
Source File: checks.py
View license
@task(name="oneflow.core.tasks.global_duplicates_checker", queue='check')
def global_duplicates_checker(limit=None, force=False):
    """ Check that duplicate articles have no more Reads anywhere.

    Fix it if not, and update all counters accordingly.

    :param limit: integer, the maximum number of duplicates to check.
        Default: none.
    :param force: boolean, default ``False``, allows to by bypass and
        reacquire the lock.
    """

    if config.CHECK_DUPLICATES_DISABLED:
        LOGGER.warning(u'Duplicates check disabled in configuration.')
        return

    # This task runs one a day. Acquire the lock for just a
    # little more time to avoid over-parallelized runs.
    my_lock = RedisExpiringLock('check_all_duplicates', expire_time=3600 * 25)

    if not my_lock.acquire():
        if force:
            my_lock.release()
            my_lock.acquire()
            LOGGER.warning(u'Forcing duplicates check…')

        else:
            # Avoid running this task over and over again in the queue
            # if the previous instance did not yet terminate. Happens
            # when scheduled task runs too quickly.
            LOGGER.warning(u'global_subscriptions_checker() is already '
                           u'locked, aborting.')
            return

    if limit is None:
        limit = config.CHECK_DUPLICATES_LIMIT

    start_time = pytime.time()
    duplicates = BaseItem.objects.duplicate()

    total_dupes_count  = duplicates.count()
    total_reads_count  = 0
    processed_dupes    = 0
    done_dupes_count   = 0
    purged_dupes_count = 0

    purge_after_weeks_count = max(1, config.CHECK_DUPLICATES_PURGE_AFTER_WEEKS)
    purge_after_weeks_count = min(52, purge_after_weeks_count)

    purge_before_date = now() - timedelta(days=purge_after_weeks_count * 7)

    LOGGER.info(u'Done counting (took %s of pure SQL joy), starting procedure.',
                naturaldelta(pytime.time() - start_time))

    with benchmark(u"Check {0}/{1} duplicates".format(limit or u'all',
                   total_dupes_count)):

        try:
            for duplicate in duplicates.iterator():
                reads = duplicate.reads.all()

                processed_dupes += 1

                if reads.exists():
                    done_dupes_count  += 1
                    reads_count        = reads.count()
                    total_reads_count += reads_count

                    LOGGER.info(u'Duplicate %s #%s still has %s reads, fixing…',
                                duplicate._meta.model.__name__,
                                duplicate.id, reads_count)

                    duplicate.duplicate_of.register_duplicate(
                        duplicate, force=duplicate.duplicate_status
                        == DUPLICATE_STATUS.FINISHED)

                if duplicate.duplicate_status == DUPLICATE_STATUS.FINISHED:
                    #
                    # TODO: check we didn't get some race-conditions new
                    #       dependancies between the moment the duplicate
                    #       was marked duplicate and now.

                    if duplicate.date_created < purge_before_date:
                        try:
                            with transaction.atomic():
                                duplicate.delete()
                        except:
                            LOGGER.exception(u'Exception while deleting '
                                             u'duplicate %s #%s',
                                             duplicate._meta.model.__name__,
                                             duplicate.id)

                        purged_dupes_count += 1
                        LOGGER.info(u'Purged duplicate %s #%s from database.',
                                    duplicate._meta.model.__name__,
                                    duplicate.id)

                elif duplicate.duplicate_status in (
                    DUPLICATE_STATUS.NOT_REPLACED,
                        DUPLICATE_STATUS.FAILED):
                    # Something went wrong, perhaps the
                    # task was purged before beiing run.
                    duplicate.duplicate_of.register_duplicate(duplicate)
                    done_dupes_count += 1

                elif duplicate.duplicate_status is None:
                    # Something went very wrong. If the article is a known
                    # duplicate, its status field should have been set to
                    # at least NOT_REPLACED.
                    duplicate.duplicate_of.register_duplicate(duplicate)
                    done_dupes_count += 1

                    LOGGER.error(u'Corrected duplicate %s #%s found with no '
                                 u'status.', duplicate._meta.model.__name__,
                                 duplicate.id)

                if limit and processed_dupes >= limit:
                    break

        finally:
            my_lock.release()

    LOGGER.info(u'global_duplicates_checker(): %s/%s duplicates processed '
                u'(%.2f%%; limit: %s), %s corrected (%.2f%%), '
                u'%s purged (%.2f%%); %s reads altered.',

                processed_dupes, total_dupes_count,
                processed_dupes * 100.0 / total_dupes_count,

                limit or u'none',

                done_dupes_count,
                (done_dupes_count * 100.0 / processed_dupes)
                if processed_dupes else 0.0,

                purged_dupes_count,
                (purged_dupes_count * 100.0 / processed_dupes)
                if processed_dupes else 0.0,

                total_reads_count)

Example 79

Project: crowdsource-platform
Source File: task.py
View license
    def create(self, **kwargs):
        project = kwargs['project']
        skipped = False
        task_worker = {}
        with self.lock:
            with transaction.atomic():  # select_for_update(nowait=False)
                # noinspection SqlResolve
                query = '''
                    SELECT
                      t.id,
                      p.id

                    FROM crowdsourcing_task t INNER JOIN (SELECT
                                                            group_id,
                                                            max(id) id
                                                          FROM crowdsourcing_task
                                                          WHERE deleted_at IS NULL
                                                          GROUP BY group_id) t_max ON t_max.id = t.id
                      INNER JOIN crowdsourcing_project p ON p.id = t.project_id
                      INNER JOIN (
                                   SELECT
                                     t.group_id,
                                     sum(t.own)    own,
                                     sum(t.others) others
                                   FROM (
                                          SELECT
                                            t.group_id,
                                            CASE WHEN tw.worker_id = (%(worker_id)s)
                                              THEN 1
                                            ELSE 0 END own,
                                            CASE WHEN (tw.worker_id IS NOT NULL AND tw.worker_id <> (%(worker_id)s))
                                             AND tw.status NOT IN (4, 6, 7)
                                              THEN 1
                                            ELSE 0 END others
                                          FROM crowdsourcing_task t
                                            LEFT OUTER JOIN crowdsourcing_taskworker tw ON (t.id =
                                                                                            tw.task_id)
                                          WHERE exclude_at IS NULL AND t.deleted_at IS NULL) t
                                   GROUP BY t.group_id) t_count ON t_count.group_id = t.group_id
                    WHERE t_count.own = 0 AND t_count.others < p.repetition AND p.id=(%(project_id)s)
                    AND p.status = 3 LIMIT 1
                    '''

                tasks = models.Task.objects.raw(query, params={'project_id': project,
                                                               'worker_id': kwargs['worker'].id})

                if not len(list(tasks)):
                    # noinspection SqlResolve
                    tasks = models.Task.objects.raw(
                        '''
                            SELECT
                                t.id,
                                t.group_id,
                                p.id project_id
                            FROM crowdsourcing_task t INNER JOIN (SELECT
                                                                    group_id,
                                                                    max(id) id
                                                                  FROM crowdsourcing_task
                                                                  WHERE deleted_at IS NULL
                                                                  GROUP BY group_id) t_max ON t_max.id = t.id
                              INNER JOIN crowdsourcing_project p ON p.id = t.project_id
                              INNER JOIN (
                                           SELECT
                                             t.group_id,
                                             sum(t.own)    own,
                                             sum(t.others) others
                                           FROM (
                                                  SELECT
                                                    t.group_id,
                                                    CASE WHEN tw.worker_id = (%(worker_id)s) AND tw.status <> 6
                                                      THEN 1
                                                    ELSE 0 END own,
                                                    CASE WHEN (tw.worker_id IS NOT NULL
                                                    AND tw.worker_id <> (%(worker_id)s))
                                                     AND tw.status NOT IN (4, 6, 7)
                                                      THEN 1
                                                    ELSE 0 END others
                                                  FROM crowdsourcing_task t
                                                    LEFT OUTER JOIN crowdsourcing_taskworker tw ON (t.id =
                                                                                                    tw.task_id)
                                                  WHERE exclude_at IS NULL AND t.deleted_at IS NULL) t
                                           GROUP BY t.group_id) t_count ON t_count.group_id = t.group_id
                            WHERE t_count.own = 0 AND t_count.others < p.repetition AND p.id=(%(project_id)s)
                            AND p.status = 3 LIMIT 1
                        ''', params={'project_id': project, 'worker_id': kwargs['worker'].id})
                    skipped = True
                if len(list(tasks)) and not skipped:
                    task_worker = models.TaskWorker.objects.create(worker=kwargs['worker'], task=tasks[0])
                elif len(list(tasks)) and skipped:
                    task_worker = models.TaskWorker.objects.get(worker=kwargs['worker'],
                                                                task__group_id=tasks[0].group_id)
                    task_worker.status = models.TaskWorker.STATUS_IN_PROGRESS
                    task_worker.task_id = tasks[0].id
                    task_worker.save()
                else:
                    return {}, 204
                return task_worker, 200

Example 80

Project: crowdsource-platform
Source File: viewsets.py
View license
    @detail_route(methods=['post'], permission_classes=[IsValidHITAssignment], url_path='submit-results')
    def submit_results(self, request, *args, **kwargs):
        mturk_assignment = self.get_object()
        template_items = request.data.get('items', [])

        with transaction.atomic():
            task_worker_results = TaskWorkerResult.objects.filter(task_worker_id=mturk_assignment.task_worker.id)
            serializer = TaskWorkerResultSerializer(data=template_items, many=True)

            if serializer.is_valid():
                if task_worker_results.count() != 0:
                    serializer.update(task_worker_results, serializer.validated_data)
                else:
                    serializer.create(task_worker=mturk_assignment.task_worker)

                if mturk_assignment.status == TaskWorker.STATUS_SKIPPED:
                    in_progress_assignment = MTurkAssignment.objects. \
                        filter(hit=mturk_assignment.hit, assignment_id=mturk_assignment.assignment_id,
                               status=TaskWorker.STATUS_IN_PROGRESS).first()

                    if in_progress_assignment is not None and in_progress_assignment.task_worker is not None:
                        in_progress_assignment.status = TaskWorker.STATUS_SKIPPED
                        in_progress_assignment.task_worker.status = TaskWorker.STATUS_SKIPPED
                        in_progress_assignment.task_worker.save()

                        in_progress_assignment.save()

                mturk_assignment.task_worker.task_status = TaskWorker.STATUS_SUBMITTED
                mturk_assignment.task_worker.status = TaskWorker.STATUS_SUBMITTED
                mturk_assignment.task_worker.save()

                mturk_assignment.status = TaskWorker.STATUS_SUBMITTED
                mturk_assignment.save()

                task_worker = mturk_assignment.task_worker

                task_data = task_worker.task.data

                redis_publisher = RedisPublisher(facility='bot',
                                                 users=[task_worker.task.project.owner])
                task = task_worker.task
                message = {
                    "type": "REGULAR",
                    "payload": {
                        'project_id': task_worker.task.project_id,
                        'project_key': ProjectSerializer().get_hash_id(task_worker.task.project),
                        'task_id': task_worker.task_id,
                        'task_group_id': task_worker.task.group_id,
                        'taskworker_id': task_worker.id,
                        'worker_id': task_worker.worker_id
                    }
                }
                if task.project.is_review:
                    match_group = MatchGroup.objects.get(batch=task.batch)
                    if is_final_review(task.batch_id):
                        message = {
                            "type": "REVIEW",
                            "payload": {
                                "match_group_id": match_group.id,
                                'project_key': ProjectSerializer().get_hash_id(task_worker.task.project),
                                "is_done": True
                            }
                        }
                message = RedisMessage(json.dumps(message))

                redis_publisher.publish_message(message)
                update_worker_cache.delay([task_worker.worker_id], constants.TASK_SUBMITTED)

                if task.project.is_review:
                    winner_id = task_worker_results[0].result
                    update_ts_scores(task_worker, winner_id=winner_id)

                if "gold_truth" in task_data:
                    truth = dict()
                    truth["message"] = "truth"
                    truth["truth"] = task_data.get("gold_truth")
                    return Response(data=truth, status=status.HTTP_200_OK)

                return Response(data={'message': 'Success'}, status=status.HTTP_200_OK)
            else:
                return Response(serializer.errors, status.HTTP_400_BAD_REQUEST)

Example 81

Project: netbox
Source File: run_inventory.py
View license
    def handle(self, *args, **options):

        def create_modules(modules, parent=None):
            for module in modules:
                m = Module(device=device, parent=parent, name=module['name'], part_id=module['part_id'],
                           serial=module['serial'], discovered=True)
                m.save()
                create_modules(module.get('modules', []), parent=m)

        # Credentials
        if options['username']:
            self.username = options['username']
        if options['password']:
            self.password = getpass("Password: ")

        # Attempt to inventory only active devices
        device_list = Device.objects.filter(status=True)

        # --site: Include only devices belonging to specified site(s)
        if options['site']:
            sites = Site.objects.filter(slug__in=options['site'])
            if sites:
                site_names = [s.name for s in sites]
                self.stdout.write("Running inventory for these sites: {}".format(', '.join(site_names)))
            else:
                raise CommandError("One or more sites specified but none found.")
            device_list = device_list.filter(rack__site__in=sites)

        # --name: Filter devices by name matching a regex
        if options['name']:
            device_list = device_list.filter(name__iregex=options['name'])

        # --full: Gather inventory data for *all* devices
        if options['full']:
            self.stdout.write("WARNING: Running inventory for all devices! Prior data will be overwritten. (--full)")

        # --fake: Gathering data but not updating the database
        if options['fake']:
            self.stdout.write("WARNING: Inventory data will not be saved! (--fake)")

        device_count = device_list.count()
        self.stdout.write("** Found {} devices...".format(device_count))

        for i, device in enumerate(device_list, start=1):

            self.stdout.write("[{}/{}] {}: ".format(i, device_count, device.name), ending='')

            # Skip inactive devices
            if not device.status:
                self.stdout.write("Skipped (inactive)")
                continue

            # Skip devices without primary_ip set
            if not device.primary_ip:
                self.stdout.write("Skipped (no primary IP set)")
                continue

            # Skip devices which have already been inventoried if not doing a full update
            if device.serial and not options['full']:
                self.stdout.write("Skipped (Serial: {})".format(device.serial))
                continue

            RPC = device.get_rpc_client()
            if not RPC:
                self.stdout.write("Skipped (no RPC client available for platform {})".format(device.platform))
                continue

            # Connect to device and retrieve inventory info
            try:
                with RPC(device, self.username, self.password) as rpc_client:
                    inventory = rpc_client.get_inventory()
            except KeyboardInterrupt:
                raise
            except (AuthenticationError, AuthenticationException):
                self.stdout.write("Authentication error!")
                continue
            except Exception as e:
                self.stdout.write("Error: {}".format(e))
                continue

            if options['verbosity'] > 1:
                self.stdout.write("")
                self.stdout.write("\tSerial: {}".format(inventory['chassis']['serial']))
                self.stdout.write("\tDescription: {}".format(inventory['chassis']['description']))
                for module in inventory['modules']:
                    self.stdout.write("\tModule: {} / {} ({})".format(module['name'], module['part_id'],
                                                                      module['serial']))
            else:
                self.stdout.write("{} ({})".format(inventory['chassis']['description'], inventory['chassis']['serial']))

            if not options['fake']:
                with transaction.atomic():
                    # Update device serial
                    if device.serial != inventory['chassis']['serial']:
                        device.serial = inventory['chassis']['serial']
                        device.save()
                    Module.objects.filter(device=device, discovered=True).delete()
                    create_modules(inventory.get('modules', []))

        self.stdout.write("Finished!")

Example 82

Project: commcare-hq
Source File: payment_handlers.py
View license
    def process_request(self, request):
        customer = None
        amount = self.get_charge_amount(request)
        card = request.POST.get('stripeToken')
        remove_card = request.POST.get('removeCard')
        is_saved_card = request.POST.get('selectedCardType') == 'saved'
        save_card = request.POST.get('saveCard') and not is_saved_card
        autopay = request.POST.get('autopayCard')
        billing_account = BillingAccount.get_account_by_domain(self.domain)
        generic_error = {
            'error': {
                'message': _(
                    "Something went wrong while processing your payment. "
                    "We're working quickly to resolve the issue. No charges "
                    "were issued. Please try again in a few hours."
                ),
            },
        }
        try:
            with transaction.atomic():
                if remove_card:
                    self.payment_method.remove_card(card)
                    return {'success': True, 'removedCard': card, }
                if save_card:
                    card = self.payment_method.create_card(card, billing_account, self.domain, autopay=autopay)
                if save_card or is_saved_card:
                    customer = self.payment_method.customer

                payment_record = PaymentRecord.create_record(
                    self.payment_method, 'temp', amount
                )
                self.update_credits(payment_record)

                charge = self.create_charge(amount, card=card, customer=customer)

            payment_record.transaction_id = charge.id
            payment_record.save()
            self.update_payment_information(billing_account)
        except stripe.error.CardError as e:
            # card was declined
            return e.json_body
        except (
            stripe.error.AuthenticationError,
            stripe.error.InvalidRequestError,
            stripe.error.APIConnectionError,
            stripe.error.StripeError,
        ) as e:
            log_accounting_error(
                "A payment for %(cost_item)s failed due "
                "to a Stripe %(error_class)s: %(error_msg)s" % {
                    'error_class': e.__class__.__name__,
                    'cost_item': self.cost_item_name,
                    'error_msg': e.json_body['error']
                }
            )
            return generic_error
        except Exception as e:
            log_accounting_error(
                "A payment for %(cost_item)s failed due to: %(error_msg)s" % {
                    'cost_item': self.cost_item_name,
                    'error_msg': e,
                }
            )
            return generic_error

        try:
            self.send_email(payment_record)
        except Exception:
            log_accounting_error(
                "Failed to send out an email receipt for "
                "payment related to PaymentRecord No. %s. "
                "Everything else succeeded."
                % payment_record.id
            )

        return {
            'success': True,
            'card': card,
            'wasSaved': save_card,
            'changedBalance': amount,
        }

Example 83

Project: django-cms
Source File: placeholderadmin.py
View license
    @method_decorator(require_POST)
    @xframe_options_sameorigin
    @transaction.atomic
    def copy_plugins(self, request):
        """
        POST request should have the following data:

        - source_language
        - source_placeholder_id
        - source_plugin_id (optional)
        - target_language
        - target_placeholder_id
        - target_plugin_id (optional, new parent)
        """
        source_language = request.POST['source_language']
        source_placeholder_id = request.POST['source_placeholder_id']
        source_plugin_id = request.POST.get('source_plugin_id', None)
        target_language = request.POST['target_language']
        target_placeholder_id = request.POST['target_placeholder_id']
        target_plugin_id = request.POST.get('target_plugin_id', None)
        source_placeholder = get_object_or_404(Placeholder, pk=source_placeholder_id)
        target_placeholder = get_object_or_404(Placeholder, pk=target_placeholder_id)

        if not target_language or not target_language in get_language_list():
            return HttpResponseBadRequest(force_text(_("Language must be set to a supported language!")))

        copy_to_clipboard = target_placeholder.pk == request.toolbar.clipboard.pk

        if source_plugin_id:
            source_plugin = get_object_or_404(CMSPlugin, pk=source_plugin_id)
            reload_required = requires_reload(PLUGIN_COPY_ACTION, [source_plugin])
            if source_plugin.plugin_type == "PlaceholderPlugin":
                # if it is a PlaceholderReference plugin only copy the plugins it references
                inst, cls = source_plugin.get_plugin_instance(self)
                plugins = inst.placeholder_ref.get_plugins_list()
            else:
                plugins = list(
                    source_placeholder.get_plugins().filter(
                        path__startswith=source_plugin.path,
                        depth__gte=source_plugin.depth).order_by('path')
                )
        else:
            plugins = list(
                source_placeholder.get_plugins(language=source_language).order_by('path'))
            reload_required = requires_reload(PLUGIN_COPY_ACTION, plugins)

        if copy_to_clipboard:
            has_permissions = self.has_copy_plugins_permission(request, plugins)
        else:
            # Plugins are being copied from a placeholder in another language
            # using the "Copy from language" placeholder action.
            # Check if the user can copy plugins from source placeholder to
            # target placeholder.
            has_permissions = self.has_copy_from_placeholder_permission(
                request,
                source_placeholder,
                target_placeholder,
                plugins,
            )

        if not has_permissions:
            return HttpResponseForbidden(force_text(
                _('You do not have permission to copy these plugins.')))

        if copy_to_clipboard and not source_plugin_id and not target_plugin_id:
            # if we copy a whole placeholder to the clipboard create
            # PlaceholderReference plugin instead and fill it the content of the
            # source_placeholder.
            ref = PlaceholderReference()
            ref.name = source_placeholder.get_label()
            ref.plugin_type = "PlaceholderPlugin"
            ref.language = target_language
            ref.placeholder = target_placeholder
            ref.save()
            ref.copy_from(source_placeholder, source_language)
        else:
            copy_plugins.copy_plugins_to(
                plugins, target_placeholder, target_language, target_plugin_id)

        plugin_list = CMSPlugin.objects.filter(
                language=target_language,
                placeholder=target_placeholder
            ).order_by('path')
        reduced_list = []

        for plugin in plugin_list:
            reduced_list.append(
                {
                    'id': plugin.pk, 'type': plugin.plugin_type, 'parent': plugin.parent_id,
                    'position': plugin.position, 'desc': force_text(plugin.get_short_description()),
                    'language': plugin.language, 'placeholder_id': plugin.placeholder_id
                }
            )

        self.post_copy_plugins(request, source_placeholder, target_placeholder, plugins)

        # When this is executed we are in the admin class of the source placeholder
        # It can be a page or a model with a placeholder field.
        # Because of this we need to get the admin class instance of the
        # target placeholder and call post_copy_plugins() on it.
        # By doing this we make sure that both the source and target are
        # informed of the operation.
        target_placeholder_admin = self._get_attached_admin(target_placeholder)

        if (target_placeholder_admin and
                target_placeholder_admin.model != self.model):
            target_placeholder_admin.post_copy_plugins(
                request,
                source_placeholder=source_placeholder,
                target_placeholder=target_placeholder,
                plugins=plugins,
            )

        json_response = {'plugin_list': reduced_list, 'reload': reload_required}
        return HttpResponse(json.dumps(json_response), content_type='application/json')

Example 84

Project: django-cms
Source File: api.py
View license
@transaction.atomic
def create_page(title, template, language, menu_title=None, slug=None,
                apphook=None, apphook_namespace=None, redirect=None, meta_description=None,
                created_by='python-api', parent=None,
                publication_date=None, publication_end_date=None,
                in_navigation=False, soft_root=False, reverse_id=None,
                navigation_extenders=None, published=False, site=None,
                login_required=False, limit_visibility_in_menu=constants.VISIBILITY_ALL,
                position="last-child", overwrite_url=None,
                xframe_options=Page.X_FRAME_OPTIONS_INHERIT, with_revision=None):
    """
    Create a CMS Page and it's title for the given language

    See docs/extending_cms/api_reference.rst for more info
    """
    if with_revision in (True, False):
        _raise_revision_warning()

    # validate template
    if not template == TEMPLATE_INHERITANCE_MAGIC:
        assert template in [tpl[0] for tpl in get_cms_setting('TEMPLATES')]
        get_template(template)

    # validate site
    if not site:
        site = Site.objects.get_current()
    else:
        assert isinstance(site, Site)

    # validate language:
    assert language in get_language_list(site), get_cms_setting('LANGUAGES').get(site.pk)

    # set default slug:
    if not slug:
        slug = generate_valid_slug(title, parent, language)

    # validate parent
    if parent:
        assert isinstance(parent, Page)
        parent = Page.objects.get(pk=parent.pk)

    # validate publication date
    if publication_date:
        assert isinstance(publication_date, datetime.date)

    # validate publication end date
    if publication_end_date:
        assert isinstance(publication_end_date, datetime.date)

    if navigation_extenders:
        raw_menus = menu_pool.get_menus_by_attribute("cms_enabled", True)
        menus = [menu[0] for menu in raw_menus]
        assert navigation_extenders in menus

    # validate menu visibility
    accepted_limitations = (constants.VISIBILITY_ALL, constants.VISIBILITY_USERS, constants.VISIBILITY_ANONYMOUS)
    assert limit_visibility_in_menu in accepted_limitations

    # validate position
    assert position in ('last-child', 'first-child', 'left', 'right')
    if parent:
        if position in ('last-child', 'first-child'):
            parent_id = parent.pk
        else:
            parent_id = parent.parent_id
    else:
        parent_id = None
    # validate and normalize apphook
    if apphook:
        application_urls = _verify_apphook(apphook, apphook_namespace)
    else:
        application_urls = None

    # ugly permissions hack
    if created_by and isinstance(created_by, get_user_model()):
        _thread_locals.user = created_by
        created_by = getattr(created_by, get_user_model().USERNAME_FIELD)
    else:
        _thread_locals.user = None

    if reverse_id:
        if Page.objects.drafts().filter(reverse_id=reverse_id, site=site).exists():
            raise FieldError('A page with the reverse_id="%s" already exist.' % reverse_id)

    page = Page(
        created_by=created_by,
        changed_by=created_by,
        parent_id=parent_id,
        publication_date=publication_date,
        publication_end_date=publication_end_date,
        in_navigation=in_navigation,
        soft_root=soft_root,
        reverse_id=reverse_id,
        navigation_extenders=navigation_extenders,
        template=template,
        application_urls=application_urls,
        application_namespace=apphook_namespace,
        site=site,
        login_required=login_required,
        limit_visibility_in_menu=limit_visibility_in_menu,
        xframe_options=xframe_options,
    )

    # This saves the page
    page = page.add_root(instance=page)

    if parent:
        page = page.move(target=parent, pos=position)

    create_title(
        language=language,
        title=title,
        menu_title=menu_title,
        slug=slug,
        redirect=redirect,
        meta_description=meta_description,
        page=page,
        overwrite_url=overwrite_url,
    )

    if published:
        page.publish(language)

    del _thread_locals.user

    page = page.reload()

    # Avoid an extra query when accessing the site
    # for the newly created page.
    page._site_cache = site
    return page

Example 85

Project: django-cms
Source File: api.py
View license
@transaction.atomic
def add_plugin(placeholder, plugin_type, language, position='last-child',
               target=None, **data):
    """
    Add a plugin to a placeholder

    See docs/extending_cms/api_reference.rst for more info
    """
    # validate placeholder
    assert isinstance(placeholder, Placeholder)

    # validate and normalize plugin type
    plugin_model, plugin_type = _verify_plugin_type(plugin_type)
    if target:
        if position == 'last-child':
            if CMSPlugin.node_order_by:
                position = 'sorted-child'
            new_pos = CMSPlugin.objects.filter(parent=target).count()
            parent_id = target.pk
        elif position == 'first-child':
            new_pos = 0
            if CMSPlugin.node_order_by:
                position = 'sorted-child'
            parent_id = target.pk
        elif position == 'left':
            new_pos = target.position
            if CMSPlugin.node_order_by:
                position = 'sorted-sibling'
            parent_id = target.parent_id
        elif position == 'right':
            new_pos = target.position + 1
            if CMSPlugin.node_order_by:
                position = 'sorted-sibling'
            parent_id = target.parent_id
        else:
            raise Exception('position not supported: %s' % position)
        if position == 'last-child' or position == 'first-child':
            qs = CMSPlugin.objects.filter(language=language, parent=target, position__gte=new_pos,
                                          placeholder=placeholder)
        else:
            qs = CMSPlugin.objects.filter(language=language, parent=target.parent_id, position__gte=new_pos,
                                          placeholder=placeholder)
        for pl in qs:
            pl.position += 1
            pl.save()
    else:
        if position == 'last-child':
            new_pos = CMSPlugin.objects.filter(language=language, parent__isnull=True, placeholder=placeholder).count()
        else:
            new_pos = 0
            for pl in CMSPlugin.objects.filter(language=language, parent__isnull=True, position__gte=new_pos,
                                               placeholder=placeholder):
                pl.position += 1
                pl.save()
        parent_id = None
    plugin_base = CMSPlugin(
        plugin_type=plugin_type,
        placeholder=placeholder,
        position=new_pos,
        language=language,
        parent_id=parent_id,
    )

    plugin_base = plugin_base.add_root(instance=plugin_base)

    if target:
        plugin_base = plugin_base.move(target, pos=position)
    plugin = plugin_model(**data)
    plugin_base.set_base_attr(plugin)
    plugin.save()
    return plugin

Example 86

Project: django
Source File: fields.py
View license
def create_generic_related_manager(superclass, rel):
    """
    Factory function to create a manager that subclasses another manager
    (generally the default manager of a given model) and adds behaviors
    specific to generic relations.
    """

    class GenericRelatedObjectManager(superclass):
        def __init__(self, instance=None):
            super(GenericRelatedObjectManager, self).__init__()

            self.instance = instance

            self.model = rel.model

            content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
                instance, for_concrete_model=rel.field.for_concrete_model)
            self.content_type = content_type
            self.content_type_field_name = rel.field.content_type_field_name
            self.object_id_field_name = rel.field.object_id_field_name
            self.prefetch_cache_name = rel.field.attname
            self.pk_val = instance._get_pk_val()

            self.core_filters = {
                '%s__pk' % self.content_type_field_name: content_type.id,
                self.object_id_field_name: self.pk_val,
            }

        def __call__(self, **kwargs):
            # We use **kwargs rather than a kwarg argument to enforce the
            # `manager='manager_name'` syntax.
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_generic_related_manager(manager.__class__, rel)
            return manager_class(instance=self.instance)
        do_not_call_in_templates = True

        def __str__(self):
            return repr(self)

        def _apply_rel_filters(self, queryset):
            """
            Filter the queryset for the instance this manager is bound to.
            """
            db = self._db or router.db_for_read(self.model, instance=self.instance)
            return queryset.using(db).filter(**self.core_filters)

        def get_queryset(self):
            try:
                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
            except (AttributeError, KeyError):
                queryset = super(GenericRelatedObjectManager, self).get_queryset()
                return self._apply_rel_filters(queryset)

        def get_prefetch_queryset(self, instances, queryset=None):
            if queryset is None:
                queryset = super(GenericRelatedObjectManager, self).get_queryset()

            queryset._add_hints(instance=instances[0])
            queryset = queryset.using(queryset._db or self._db)

            query = {
                '%s__pk' % self.content_type_field_name: self.content_type.id,
                '%s__in' % self.object_id_field_name: set(obj._get_pk_val() for obj in instances)
            }

            # We (possibly) need to convert object IDs to the type of the
            # instances' PK in order to match up instances:
            object_id_converter = instances[0]._meta.pk.to_python
            return (queryset.filter(**query),
                    lambda relobj: object_id_converter(getattr(relobj, self.object_id_field_name)),
                    lambda obj: obj._get_pk_val(),
                    False,
                    self.prefetch_cache_name)

        def add(self, *objs, **kwargs):
            bulk = kwargs.pop('bulk', True)
            db = router.db_for_write(self.model, instance=self.instance)

            def check_and_update_obj(obj):
                if not isinstance(obj, self.model):
                    raise TypeError("'%s' instance expected, got %r" % (
                        self.model._meta.object_name, obj
                    ))
                setattr(obj, self.content_type_field_name, self.content_type)
                setattr(obj, self.object_id_field_name, self.pk_val)

            if bulk:
                pks = []
                for obj in objs:
                    if obj._state.adding or obj._state.db != db:
                        raise ValueError(
                            "%r instance isn't saved. Use bulk=False or save "
                            "the object first." % obj
                        )
                    check_and_update_obj(obj)
                    pks.append(obj.pk)

                self.model._base_manager.using(db).filter(pk__in=pks).update(**{
                    self.content_type_field_name: self.content_type,
                    self.object_id_field_name: self.pk_val,
                })
            else:
                with transaction.atomic(using=db, savepoint=False):
                    for obj in objs:
                        check_and_update_obj(obj)
                        obj.save()
        add.alters_data = True

        def remove(self, *objs, **kwargs):
            if not objs:
                return
            bulk = kwargs.pop('bulk', True)
            self._clear(self.filter(pk__in=[o.pk for o in objs]), bulk)
        remove.alters_data = True

        def clear(self, **kwargs):
            bulk = kwargs.pop('bulk', True)
            self._clear(self, bulk)
        clear.alters_data = True

        def _clear(self, queryset, bulk):
            db = router.db_for_write(self.model, instance=self.instance)
            queryset = queryset.using(db)
            if bulk:
                # `QuerySet.delete()` creates its own atomic block which
                # contains the `pre_delete` and `post_delete` signal handlers.
                queryset.delete()
            else:
                with transaction.atomic(using=db, savepoint=False):
                    for obj in queryset:
                        obj.delete()
        _clear.alters_data = True

        def set(self, objs, **kwargs):
            # Force evaluation of `objs` in case it's a queryset whose value
            # could be affected by `manager.clear()`. Refs #19816.
            objs = tuple(objs)

            bulk = kwargs.pop('bulk', True)
            clear = kwargs.pop('clear', False)

            db = router.db_for_write(self.model, instance=self.instance)
            with transaction.atomic(using=db, savepoint=False):
                if clear:
                    self.clear()
                    self.add(*objs, bulk=bulk)
                else:
                    old_objs = set(self.using(db).all())
                    new_objs = []
                    for obj in objs:
                        if obj in old_objs:
                            old_objs.remove(obj)
                        else:
                            new_objs.append(obj)

                    self.remove(*old_objs)
                    self.add(*new_objs, bulk=bulk)
        set.alters_data = True

        def create(self, **kwargs):
            kwargs[self.content_type_field_name] = self.content_type
            kwargs[self.object_id_field_name] = self.pk_val
            db = router.db_for_write(self.model, instance=self.instance)
            return super(GenericRelatedObjectManager, self).using(db).create(**kwargs)
        create.alters_data = True

        def get_or_create(self, **kwargs):
            kwargs[self.content_type_field_name] = self.content_type
            kwargs[self.object_id_field_name] = self.pk_val
            db = router.db_for_write(self.model, instance=self.instance)
            return super(GenericRelatedObjectManager, self).using(db).get_or_create(**kwargs)
        get_or_create.alters_data = True

        def update_or_create(self, **kwargs):
            kwargs[self.content_type_field_name] = self.content_type
            kwargs[self.object_id_field_name] = self.pk_val
            db = router.db_for_write(self.model, instance=self.instance)
            return super(GenericRelatedObjectManager, self).using(db).update_or_create(**kwargs)
        update_or_create.alters_data = True

    return GenericRelatedObjectManager

Example 87

Project: django
Source File: layermapping.py
View license
    def __init__(self, model, data, mapping, layer=0,
                 source_srs=None, encoding='utf-8',
                 transaction_mode='commit_on_success',
                 transform=True, unique=None, using=None):
        """
        A LayerMapping object is initialized using the given Model (not an instance),
        a DataSource (or string path to an OGR-supported data file), and a mapping
        dictionary.  See the module level docstring for more details and keyword
        argument usage.
        """
        # Getting the DataSource and the associated Layer.
        if isinstance(data, six.string_types):
            self.ds = DataSource(data, encoding=encoding)
        else:
            self.ds = data
        self.layer = self.ds[layer]

        self.using = using if using is not None else router.db_for_write(model)
        self.spatial_backend = connections[self.using].ops

        # Setting the mapping & model attributes.
        self.mapping = mapping
        self.model = model

        # Checking the layer -- initialization of the object will fail if
        # things don't check out before hand.
        self.check_layer()

        # Getting the geometry column associated with the model (an
        # exception will be raised if there is no geometry column).
        if connections[self.using].features.supports_transform:
            self.geo_field = self.geometry_field()
        else:
            transform = False

        # Checking the source spatial reference system, and getting
        # the coordinate transformation object (unless the `transform`
        # keyword is set to False)
        if transform:
            self.source_srs = self.check_srs(source_srs)
            self.transform = self.coord_transform()
        else:
            self.transform = transform

        # Setting the encoding for OFTString fields, if specified.
        if encoding:
            # Making sure the encoding exists, if not a LookupError
            # exception will be thrown.
            from codecs import lookup
            lookup(encoding)
            self.encoding = encoding
        else:
            self.encoding = None

        if unique:
            self.check_unique(unique)
            transaction_mode = 'autocommit'  # Has to be set to autocommit.
            self.unique = unique
        else:
            self.unique = None

        # Setting the transaction decorator with the function in the
        # transaction modes dictionary.
        self.transaction_mode = transaction_mode
        if transaction_mode == 'autocommit':
            self.transaction_decorator = None
        elif transaction_mode == 'commit_on_success':
            self.transaction_decorator = transaction.atomic
        else:
            raise LayerMapError('Unrecognized transaction mode: %s' % transaction_mode)

Example 88

Project: django-extensions
Source File: syncdata.py
View license
    @signalcommand
    @transaction.atomic
    def handle(self, *fixture_labels, **options):
        """ Main method of a Django command """
        from django.db.models import get_apps
        from django.core import serializers
        from django.conf import settings

        self.style = no_style()

        verbosity = int(options.get('verbosity', 1))
        show_traceback = options.get('traceback', False)

        # Keep a count of the installed objects and fixtures
        fixture_count = 0
        object_count = 0
        objects_per_fixture = []
        models = set()

        humanize = lambda dirname: dirname and "'%s'" % dirname or 'absolute path'

        # Get a cursor (even though we don't need one yet). This has
        # the side effect of initializing the test database (if
        # it isn't already initialized).
        cursor = connection.cursor()

        app_fixtures = [os.path.join(os.path.dirname(app.__file__), 'fixtures') for app in get_apps()]
        for fixture_label in fixture_labels:
            parts = fixture_label.split('.')
            if len(parts) == 1:
                fixture_name = fixture_label
                formats = serializers.get_public_serializer_formats()
            else:
                fixture_name, format = '.'.join(parts[:-1]), parts[-1]
                if format in serializers.get_public_serializer_formats():
                    formats = [format]
                else:
                    formats = []

            if formats:
                if verbosity > 1:
                    print("Loading '%s' fixtures..." % fixture_name)
            else:
                sys.stderr.write(self.style.ERROR("Problem installing fixture '%s': %s is not a known serialization format." % (fixture_name, format)))
                transaction.rollback()
                return

            if os.path.isabs(fixture_name):
                fixture_dirs = [fixture_name]
            else:
                fixture_dirs = app_fixtures + list(settings.FIXTURE_DIRS) + ['']

            for fixture_dir in fixture_dirs:
                if verbosity > 1:
                    print("Checking %s for fixtures..." % humanize(fixture_dir))

                label_found = False
                for format in formats:
                    if verbosity > 1:
                        print("Trying %s for %s fixture '%s'..." % (humanize(fixture_dir), format, fixture_name))
                    try:
                        full_path = os.path.join(fixture_dir, '.'.join([fixture_name, format]))
                        fixture = open(full_path, 'r')
                        if label_found:
                            fixture.close()
                            print(self.style.ERROR("Multiple fixtures named '%s' in %s. Aborting." % (fixture_name, humanize(fixture_dir))))
                            transaction.rollback()
                            return
                        else:
                            fixture_count += 1
                            objects_per_fixture.append(0)
                            if verbosity > 0:
                                print("Installing %s fixture '%s' from %s." % (format, fixture_name, humanize(fixture_dir)))
                            try:
                                objects_to_keep = {}
                                objects = serializers.deserialize(format, fixture)
                                for obj in objects:
                                    object_count += 1
                                    objects_per_fixture[-1] += 1

                                    class_ = obj.object.__class__
                                    if class_ not in objects_to_keep:
                                        objects_to_keep[class_] = set()
                                    objects_to_keep[class_].add(obj.object)

                                    models.add(class_)
                                    obj.save()

                                if options.get('remove'):
                                    self.remove_objects_not_in(objects_to_keep, verbosity)

                                label_found = True
                            except (SystemExit, KeyboardInterrupt):
                                raise
                            except Exception:
                                import traceback
                                fixture.close()
                                transaction.rollback()
                                if show_traceback:
                                    traceback.print_exc()
                                else:
                                    sys.stderr.write(self.style.ERROR("Problem installing fixture '%s': %s\n" % (full_path, traceback.format_exc())))
                                return
                            fixture.close()
                    except:
                        if verbosity > 1:
                            print("No %s fixture '%s' in %s." % (format, fixture_name, humanize(fixture_dir)))

        # If any of the fixtures we loaded contain 0 objects, assume that an
        # error was encountered during fixture loading.
        if 0 in objects_per_fixture:
            sys.stderr.write(
                self.style.ERROR("No fixture data found for '%s'. (File format may be invalid.)" % fixture_name))
            transaction.rollback()
            return

        # If we found even one object in a fixture, we need to reset the
        # database sequences.
        if object_count > 0:
            sequence_sql = connection.ops.sequence_reset_sql(self.style, models)
            if sequence_sql:
                if verbosity > 1:
                    print("Resetting sequences")
                for line in sequence_sql:
                    cursor.execute(line)

        transaction.commit()

        if object_count == 0:
            if verbosity > 1:
                print("No fixtures found.")
        else:
            if verbosity > 0:
                print("Installed %d object(s) from %d fixture(s)" % (object_count, fixture_count))

        # Close the DB connection. This is required as a workaround for an
        # edge case in MySQL: if the same connection is used to
        # create tables, load data, and query, the query can return
        # incorrect results. See Django #7572, MySQL #37735.
        connection.close()

Example 89

View license
    def handle_noargs(self, **options):
        try:
            from django.apps import apps
        except ImportError:
            # Don't bother migrating old south tables, first migrate to Django 1.7 please.
            raise CommandError("This is a Django 1.7+ command only")

        Entry = get_entry_model()
        CategoryM2M = Entry.categories.through
        old_fk = CategoryM2M._meta.get_field('category')
        CurrentModel = old_fk.rel.to
        self.stdout.write("Current Entry.categories model: <{0}.{1}>".format(
            CurrentModel._meta.app_label, CurrentModel._meta.object_name
        ))

        old = options['from']
        new = options['to']
        if not old or not new:
            raise CommandError("Expected --from and --to options")

        if old.lower() == 'categories.category' and 'categories' not in settings.INSTALLED_APPS:
            # Can't import it in a Django 1.8+ project.
            OldModel = DummyCategoryBase
        else:
            try:
                OldModel = apps.get_model(old)
            except LookupError as e:
                raise CommandError("Invalid --from value: {0}".format(e))

        if not issubclass(OldModel, MPTTModel):
            raise CommandError("Expected MPTT model for --from value")

        try:
            NewModel = apps.get_model(new)
        except LookupError as e:
            raise CommandError("Invalid --to value: {0}".format(e))

        if not issubclass(NewModel, MPTTModel):
            raise CommandError("Expected MPTT model for --to value")

        if NewModel.objects.all().exists():
            raise CommandError("New model already has records, it should be an empty table!")

        old_i18n = issubclass(OldModel, TranslatableModel)
        new_i18n = issubclass(NewModel, TranslatableModel)
        old_title = _detect_title_field(OldModel)
        new_title = _detect_title_field(NewModel)
        mptt_fields = "lft, rght, tree_id, level, parent_id"

        with transaction.atomic():
            if not old_i18n and not new_i18n:
                # Untranslated to untranslated
                self.stdout.write("* Copying category fields...")
                with connection.cursor() as cursor:
                    cursor.execute(
                        'INSERT INTO {new_model}(id, slug, {new_title}, {mptt_fields}})'
                        ' SELECT id, slug, {old_title}, {mptt_fields} FROM {old_model}'.format(
                            new_model=NewModel._meta.db_table,
                            new_title=new_title,
                            old_model=OldModel._meta.db_table,
                            old_title=old_title,
                            mptt_fields=mptt_fields,
                        ))
            elif not old_i18n and new_i18n:
                # Untranslated to translated
                # - base table fields
                with connection.cursor() as cursor:
                    self.stdout.write("* Copying category base fields...")
                    cursor.execute(
                        'INSERT INTO {new_model}(id, {mptt_fields})'
                        ' SELECT id, {mptt_fields} FROM {old_model}'.format(
                            new_model=NewModel._meta.db_table,
                            old_model=OldModel._meta.db_table,
                            mptt_fields=mptt_fields,
                        ))
                    # - create translations on fallback language
                    self.stdout.write("* Creating category translations...")
                    cursor.execute(
                        'INSERT INTO {new_translations}(master_id, language_code, slug, {new_title})'
                        ' SELECT id, %s, slug, {old_title} FROM {old_model}'.format(
                            new_translations=NewModel._parler_meta.root_model._meta.db_table,
                            new_title=new_title,
                            old_model=OldModel._meta.db_table,
                            old_title=old_title,
                        ), [appsettings.FLUENT_BLOGS_DEFAULT_LANGUAGE_CODE])
            elif old_i18n and not new_i18n:
                # Reverse, translated to untranslated. Take fallback only
                # Convert all fields back to the single-language table.
                self.stdout.write("* Copying category fields and fallback language fields...")
                for old_category in OldModel.objects.all():
                    translations = old_category.translations.all()
                    try:
                        # Try default translation
                        old_translation = translations.get(language_code=appsettings.FLUENT_BLOGS_DEFAULT_LANGUAGE_CODE)
                    except ObjectDoesNotExist:
                        try:
                            # Try internal fallback
                            old_translation = translations.get(language_code__in=('en-us', 'en'))
                        except ObjectDoesNotExist:
                            # Hope there is a single translation
                            old_translation = translations.get()

                    fields = dict(
                        id=old_category.id,
                        lft=old_category.lft,
                        rght=old_category.rght,
                        tree_id=old_category.tree_id,
                        level=old_category.level,
                        # parler fields
                        _language_code=old_translation.language_code,
                        slug=old_category.slug
                    )
                    fields[new_title] = getattr(old_translation, old_title)
                    NewModel.objects.create(**fields)

            elif old_i18n and new_i18n:
                # Translated to translated
                # - base table
                with connection.cursor() as cursor:
                    self.stdout.write("* Copying category base fields...")
                    cursor.execute(
                        'INSERT INTO {new_model}(id, {mptt_fields})'
                        ' SELECT id, {mptt_fields} FROM {old_model}'.format(
                            new_model=NewModel._meta.db_table,
                            old_model=OldModel._meta.db_table,
                            mptt_fields=mptt_fields,
                        ))
                    # - all translations
                    self.stdout.write("* Copying category translations...")
                    cursor.execute(
                        'INSERT INTO {new_translations}(master_id, language_code, slug, {new_title})'
                        ' SELECT id, languag_code, slug, {old_title} FROM {old_translations}'.format(
                            new_translations=NewModel._parler_meta.root_model._meta.db_table,
                            new_title=new_title,
                            old_translations=OldModel._parler_meta.root_model._meta.db_table,
                            old_title=old_title,
                        ), [appsettings.FLUENT_BLOGS_DEFAULT_LANGUAGE_CODE])
            else:
                raise NotImplementedError()  # impossible combination

            self.stdout.write("* Switching M2M foreign key constraints...")
            __, __, __, kwargs = old_fk.deconstruct()
            kwargs['to'] = NewModel
            new_fk = models.ForeignKey(**kwargs)
            new_fk.set_attributes_from_name(old_fk.name)
            with connection.schema_editor() as schema_editor:
                schema_editor.alter_field(CategoryM2M, old_fk, new_fk)

        self.stdout.write("Done.\n")
        self.stdout.write("You may now remove the old category app from your project, INSTALLED_APPS and database.\n")

Example 90

Project: django-import-export
Source File: resources.py
View license
    def import_row(self, row, instance_loader, using_transactions=True, dry_run=False, **kwargs):
        """
        Imports data from ``tablib.Dataset``. Refer to :doc:`import_workflow`
        for a more complete description of the whole import process.

        :param row: A ``dict`` of the row to import

        :param instance_loader: The instance loader to be used to load the row

        :param using_transactions: If ``using_transactions`` is set, a transaction
            is being used to wrap the import

        :param dry_run: If ``dry_run`` is set, or error occurs, transaction
            will be rolled back.
        """
        row_result = self.get_row_result_class()()
        row_result.import_type = RowResult.IMPORT_TYPE_ERROR
        try:
            self.before_import_row(row, **kwargs)
            instance, new = self.get_or_init_instance(instance_loader, row)
            self.after_import_instance(instance, new, **kwargs)
            if new:
                row_result.import_type = RowResult.IMPORT_TYPE_NEW
            else:
                row_result.import_type = RowResult.IMPORT_TYPE_UPDATE
            row_result.new_record = new
            original = deepcopy(instance)
            diff = Diff(self, original, new)
            if self.for_delete(row, instance):
                if new:
                    row_result.import_type = RowResult.IMPORT_TYPE_SKIP
                    diff.compare_with(self, None, dry_run)
                else:
                    row_result.import_type = RowResult.IMPORT_TYPE_DELETE
                    self.delete_instance(instance, using_transactions, dry_run)
                    diff.compare_with(self, None, dry_run)
            else:
                self.import_obj(instance, row, dry_run)
                if self.skip_row(instance, original):
                    row_result.import_type = RowResult.IMPORT_TYPE_SKIP
                else:
                    with transaction.atomic():
                        self.save_instance(instance, using_transactions, dry_run)
                    self.save_m2m(instance, row, using_transactions, dry_run)
                diff.compare_with(self, instance, dry_run)
            row_result.diff = diff.as_html()
            # Add object info to RowResult for LogEntry
            if row_result.import_type != RowResult.IMPORT_TYPE_SKIP:
                row_result.object_id = instance.pk
                row_result.object_repr = force_text(instance)
            self.after_import_row(row, row_result, **kwargs)
        except Exception as e:
            # There is no point logging a transaction error for each row
            # when only the original error is likely to be relevant
            if not isinstance(e, TransactionManagementError):
                logging.exception(e)
            tb_info = traceback.format_exc()
            row_result.errors.append(self.get_error_result_class()(e, tb_info, row))
        return row_result

Example 91

Project: django-nose
Source File: runner.py
View license
    def setup_databases(self):
        """Setup databases, skipping DB creation if requested and possible."""
        for alias in connections:
            connection = connections[alias]
            creation = connection.creation
            test_db_name = creation._get_test_db_name()

            # Mess with the DB name so other things operate on a test DB
            # rather than the real one. This is done in create_test_db when
            # we don't monkeypatch it away with _skip_create_test_db.
            orig_db_name = connection.settings_dict['NAME']
            connection.settings_dict['NAME'] = test_db_name

            if _should_create_database(connection):
                # We're not using _skip_create_test_db, so put the DB name
                # back:
                connection.settings_dict['NAME'] = orig_db_name

                # Since we replaced the connection with the test DB, closing
                # the connection will avoid pooling issues with SQLAlchemy. The
                # issue is trying to CREATE/DROP the test database using a
                # connection to a DB that was established with that test DB.
                # MySQLdb doesn't allow it, and SQLAlchemy attempts to reuse
                # the existing connection from its pool.
                connection.close()
            else:
                # Reset auto-increment sequences. Apparently, SUMO's tests are
                # horrid and coupled to certain numbers.
                cursor = connection.cursor()
                style = no_style()

                if uses_mysql(connection):
                    reset_statements = _mysql_reset_sequences(
                        style, connection)
                else:
                    reset_statements = connection.ops.sequence_reset_sql(
                        style, self._get_models_for_connection(connection))

                if hasattr(transaction, "atomic"):
                    with transaction.atomic(using=connection.alias):
                        for reset_statement in reset_statements:
                            cursor.execute(reset_statement)
                else:
                    # Django < 1.6
                    for reset_statement in reset_statements:
                        cursor.execute(reset_statement)
                    transaction.commit_unless_managed(using=connection.alias)

                # Each connection has its own creation object, so this affects
                # only a single connection:
                creation.create_test_db = MethodType(
                    _skip_create_test_db, creation)

        Command.handle = _foreign_key_ignoring_handle

        # With our class patch, does nothing but return some connection
        # objects:
        return super(NoseTestSuiteRunner, self).setup_databases()

Example 92

Project: site
Source File: ratings.py
View license
def rate_contest(contest):
    cursor = connection.cursor()
    cursor.execute('''
        SELECT judge_rating.user_id, judge_rating.rating, judge_rating.volatility, r.times
        FROM judge_rating INNER JOIN
             judge_contest ON (judge_contest.id = judge_rating.contest_id) INNER JOIN (
            SELECT judge_rating.user_id AS id, MAX(judge_contest.end_time) AS last_time,
                   COUNT(judge_rating.user_id) AS times
            FROM judge_contestparticipation INNER JOIN
                 judge_rating ON (judge_rating.user_id = judge_contestparticipation.user_id) INNER JOIN
                 judge_contest ON (judge_contest.id = judge_rating.contest_id)
            WHERE judge_contestparticipation.contest_id = %s AND judge_contest.end_time < %s AND
                  judge_contestparticipation.user_id NOT IN (
                      SELECT profile_id FROM judge_contest_rate_exclude WHERE contest_id = %s
                  ) AND judge_contestparticipation.virtual = 0
            GROUP BY judge_rating.user_id
            ORDER BY judge_contestparticipation.score DESC, judge_contestparticipation.cumtime ASC
        ) AS r ON (judge_rating.user_id = r.id AND judge_contest.end_time = r.last_time)
    ''', (contest.id, contest.end_time, contest.id))
    data = {user: (rating, volatility, times) for user, rating, volatility, times in cursor.fetchall()}
    cursor.close()

    users = contest.users.order_by('-score', 'cumtime').annotate(submissions=Count('submission')) \
                   .exclude(user_id__in=contest.rate_exclude.all()).filter(virtual=0)\
                   .values_list('id', 'user_id', 'score', 'cumtime')
    if not contest.rate_all:
        users = users.filter(submissions__gt=0)
    users = list(tie_ranker(users, key=itemgetter(2, 3)))
    participation_ids = [user[1][0] for user in users]
    user_ids = [user[1][1] for user in users]
    ranking = map(itemgetter(0), users)
    old_data = [data.get(user, (1200, 535, 0)) for user in user_ids]
    old_rating = map(itemgetter(0), old_data)
    old_volatility = map(itemgetter(1), old_data)
    times_ranked = map(itemgetter(2), old_data)
    rating, volatility = recalculate_ratings(old_rating, old_volatility, ranking, times_ranked)

    now = timezone.now()
    ratings = [Rating(user_id=id, contest=contest, rating=r, volatility=v, last_rated=now, participation_id=p, rank=z)
               for id, p, r, v, z in izip(user_ids, participation_ids, rating, volatility, ranking)]
    cursor = connection.cursor()
    cursor.execute('CREATE TEMPORARY TABLE _profile_rating_update(id integer, rating integer)')
    cursor.executemany('INSERT INTO _profile_rating_update VALUES (%s, %s)', zip(user_ids, rating))
    with transaction.atomic():
        Rating.objects.filter(contest=contest).delete()
        Rating.objects.bulk_create(ratings)
        cursor.execute('''
            UPDATE `%s` p INNER JOIN `_profile_rating_update` tmp ON (p.id = tmp.id)
            SET p.rating = tmp.rating
        ''' % Profile._meta.db_table)
    cursor.execute('DROP TABLE _profile_rating_update')
    cursor.close()
    return old_rating, old_volatility, ranking, times_ranked, rating, volatility

Example 93

Project: django-moderation
Source File: models.py
View license
    def _moderate(self, new_status, by, reason):
        # See register.py pre_save_handler() for the case where the model is
        # reset to its old values, and the new values are stored in the
        # ModeratedObject. In such cases, on approval, we should restore the
        # changes to the base object by saving the one attached to the
        # ModeratedObject.

        if (self.status == MODERATION_STATUS_PENDING and
                new_status == MODERATION_STATUS_APPROVED and
                not self.moderator.visible_until_rejected):
            base_object = self.changed_object
            base_object_force_save = True
        else:
            # The model in the database contains the most recent data already,
            # or we're not ready to approve the changes stored in
            # ModeratedObject.
            obj_class = self.changed_object.__class__
            pk = self.changed_object.pk
            base_object = obj_class._default_manager.get(pk=pk)
            base_object_force_save = False

        if new_status == MODERATION_STATUS_APPROVED:
            # This version is now approved, and will be reverted to if
            # future changes are rejected by a moderator.
            self.state = MODERATION_READY_STATE

        self.status = new_status
        self.on = datetime.datetime.now()
        self.by = by
        self.reason = reason
        self.save()

        if self.moderator.visibility_column:
            old_visible = getattr(base_object,
                                  self.moderator.visibility_column)

            if new_status == MODERATION_STATUS_APPROVED:
                new_visible = True
            elif new_status == MODERATION_STATUS_REJECTED:
                new_visible = False
            else:  # MODERATION_STATUS_PENDING
                new_visible = self.moderator.visible_until_rejected

            if new_visible != old_visible:
                setattr(base_object, self.moderator.visibility_column,
                        new_visible)
                base_object_force_save = True

        if base_object_force_save:
            # avoid triggering pre/post_save_handler
            with transaction.atomic(using=None, savepoint=False):
                base_object.save_base(raw=True)
                # The _save_parents call is required for models with an
                # inherited visibility_column.
                base_object._save_parents(base_object.__class__, None, None)

        if self.changed_by:
            self.moderator.inform_user(self.content_object, self.changed_by)

Example 94

Project: ecommerce
Source File: baskets.py
View license
    def create(self, request, *args, **kwargs):
        """Add products to the authenticated user's basket.

        Expects an array of product objects, 'products', each containing a SKU, in the request
        body. The SKUs are used to populate the user's basket with the corresponding products.

        The caller indicates whether checkout should occur by providing a Boolean value
        in the request body, 'checkout'. If checkout operations are requested and the
        contents of the user's basket are free, an order is placed immediately.

        If checkout operations are requested but the contents of the user's basket are not
        free, pre-payment operations are performed instead of placing an order. The caller
        indicates which payment processor to use by providing a string in the request body,
        'payment_processor_name'.

        Protected by JWT authentication. Consuming services (e.g., the LMS)
        must authenticate themselves by passing a JWT in the Authorization
        HTTP header, prepended with the string 'JWT '. The JWT payload should
        contain user details. At a minimum, these details must include a
        username; providing an email is recommended.

        Arguments:
            request (HttpRequest): With parameters 'products', 'checkout', and
                'payment_processor_name' in the body.

        Returns:
            200 if a basket was created successfully; the basket ID is included in the response body along with
                either an order number corresponding to the placed order (None if one wasn't placed) or
                payment information (None if payment isn't required).
            400 if the client provided invalid data or attempted to add an unavailable product to their basket,
                with reason for the failure in JSON format.
            401 if an unauthenticated request is denied permission to access the endpoint.
            429 if the client has made requests at a rate exceeding that allowed by the configured rate limit.
            500 if an error occurs when attempting to initiate checkout.

        Examples:
            Create a basket for the user with username 'Saul' as follows. Successful fulfillment
            requires that a user with username 'Saul' exists on the LMS, and that EDX_API_KEY be
            configured within both the LMS and the ecommerce service.

            >>> url = 'http://localhost:8002/api/v2/baskets/'
            >>> token = jwt.encode({'username': 'Saul', 'email': '[email protected]'}, 'insecure-secret-key')
            >>> headers = {
                'content-type': 'application/json',
                'Authorization': 'JWT ' + token
            }

            If checkout is not desired:

            >>> data = {'products': [{'sku': 'SOME-SEAT'}, {'sku': 'SOME-OTHER-SEAT'}], 'checkout': False}
            >>> response = requests.post(url, data=json.dumps(data), headers=headers)
            >>> response.json()
            {
                'id': 7,
                'order': None,
                'payment_data': None
            }

            If the product with SKU 'FREE-SEAT' is free and checkout is desired:

            >>> data = {'products': [{'sku': 'FREE-SEAT'}], 'checkout': True, 'payment_processor_name': 'paypal'}
            >>> response = requests.post(url, data=json.dumps(data), headers=headers)
            >>> response.json()
            {
                'id': 7,
                'order': {'number': 'OSCR-100007'},
                'payment_data': None
            }

            If the product with SKU 'PAID-SEAT' is not free and checkout is desired:

            >>> data = {'products': [{'sku': 'PAID-SEAT'}], 'checkout': True, 'payment_processor_name': 'paypal'}
            >>> response = requests.post(url, data=json.dumps(data), headers=headers)
            >>> response.json()
            {
                'id': 7,
                'order': None,
                'payment_data': {
                    'payment_processor_name': 'paypal',
                    'payment_form_data': {...},
                    'payment_page_url': 'https://www.someexternallyhostedpaymentpage.com'
                }
            }
        """
        # Explicitly delimit operations which will be rolled back if an exception occurs.
        # atomic() context managers restore atomicity at points where we are modifying data
        # (baskets, then orders) to ensure that we don't leave the system in a dirty state
        # in the event of an error.
        with transaction.atomic():
            basket = Basket.create_basket(request.site, request.user)
            basket_id = basket.id

            attribute_cookie_data(basket, request)

            requested_products = request.data.get('products')
            if requested_products:
                for requested_product in requested_products:
                    # Ensure the requested products exist
                    sku = requested_product.get('sku')
                    if sku:
                        try:
                            product = data_api.get_product(sku)
                        except api_exceptions.ProductNotFoundError as error:
                            return self._report_bad_request(
                                error.message,
                                api_exceptions.PRODUCT_NOT_FOUND_USER_MESSAGE
                            )
                    else:
                        return self._report_bad_request(
                            api_exceptions.SKU_NOT_FOUND_DEVELOPER_MESSAGE,
                            api_exceptions.SKU_NOT_FOUND_USER_MESSAGE
                        )

                    # Ensure the requested products are available for purchase before adding them to the basket
                    availability = basket.strategy.fetch_for_product(product).availability
                    if not availability.is_available_to_buy:
                        return self._report_bad_request(
                            api_exceptions.PRODUCT_UNAVAILABLE_DEVELOPER_MESSAGE.format(
                                sku=sku,
                                availability=availability.message
                            ),
                            api_exceptions.PRODUCT_UNAVAILABLE_USER_MESSAGE
                        )

                    basket.add_product(product)
                    logger.info('Added product with SKU [%s] to basket [%d]', sku, basket_id)

                    # Call signal handler to notify listeners that something has been added to the basket
                    basket_addition = get_class('basket.signals', 'basket_addition')
                    basket_addition.send(sender=basket_addition, product=product, user=request.user,
                                         request=request, basket=basket)
            else:
                # If no products were included in the request, we cannot checkout.
                return self._report_bad_request(
                    api_exceptions.PRODUCT_OBJECTS_MISSING_DEVELOPER_MESSAGE,
                    api_exceptions.PRODUCT_OBJECTS_MISSING_USER_MESSAGE
                )

        if request.data.get('checkout') is True:
            # Begin the checkout process, if requested, with the requested payment processor.
            payment_processor_name = request.data.get('payment_processor_name')
            if payment_processor_name:
                try:
                    payment_processor = get_processor_class_by_name(payment_processor_name)
                except payment_exceptions.ProcessorNotFoundError as error:
                    return self._report_bad_request(
                        error.message,
                        payment_exceptions.PROCESSOR_NOT_FOUND_USER_MESSAGE
                    )
            else:
                payment_processor = get_default_processor_class()

            try:
                response_data = self._checkout(basket, payment_processor(request.site), request)
            except Exception as ex:  # pylint: disable=broad-except
                basket.delete()
                logger.exception('Failed to initiate checkout for Basket [%d]. The basket has been deleted.', basket_id)
                return Response({'developer_message': ex.message}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
        else:
            # Return a serialized basket, if checkout was not requested.
            response_data = self._generate_basic_response(basket)

        return Response(response_data, status=status.HTTP_200_OK)

Example 95

Project: ecommerce
Source File: coupons.py
View license
    def create(self, request, *args, **kwargs):
        """Adds coupon to the user's basket.

        Expects request array to contain all the necessary data (listed out below).
        This information is then used to create a coupon product, add to a
        basket and create an order from it.

        Arguments:
            request (HttpRequest): With parameters title, client,
            stock_record_ids, start_date, end_date, code, benefit_type, benefit_value,
            voucher_type, quantity, price, category, note and invoice data in the body.

        Returns:
            200 if the order was created successfully; the basket ID is included in the response
                body along with the order ID and payment information.
            400 if a custom code is received that already exists,
                if a course mode is selected that is not supported.
            401 if an unauthenticated request is denied permission to access the endpoint.
            429 if the client has made requests at a rate exceeding that allowed by the configured rate limit.
            500 if an error occurs when attempting to create a coupon.
        """
        category_data = request.data.get('category')
        code = request.data.get('code')
        course_seat_types = request.data.get('course_seat_types')
        max_uses = request.data.get('max_uses')
        partner = request.site.siteconfiguration.partner
        stock_record_ids = request.data.get('stock_record_ids')
        voucher_type = request.data.get('voucher_type')

        with transaction.atomic():
            if code:
                try:
                    Voucher.objects.get(code=code)
                    return Response(
                        'A coupon with code {code} already exists.'.format(code=code),
                        status=status.HTTP_400_BAD_REQUEST
                    )
                except Voucher.DoesNotExist:
                    pass

            if course_seat_types:
                course_seat_types = prepare_course_seat_types(course_seat_types)

            try:
                category = Category.objects.get(name=category_data['name'])
            except Category.DoesNotExist:
                return Response(
                    'Category {category_name} not found.'.format(category_name=category_data['name']),
                    status=status.HTTP_404_NOT_FOUND
                )
            except KeyError:
                return Response('Invalid Coupon Category data.', status=status.HTTP_400_BAD_REQUEST)

            # Maximum number of uses can be set for each voucher type and disturb
            # the predefined behaviours of the different voucher types. Therefor
            # here we enforce that the max_uses variable can't be used for SINGLE_USE
            # voucher types.
            if max_uses and voucher_type != Voucher.SINGLE_USE:
                max_uses = int(max_uses)
            else:
                max_uses = None

            # When a black-listed course mode is received raise an exception.
            # Audit modes do not have a certificate type and therefore will raise
            # an AttributeError exception.
            if stock_record_ids:
                seats = Product.objects.filter(stockrecords__id__in=stock_record_ids)
                for seat in seats:
                    try:
                        if seat.attr.certificate_type in settings.BLACK_LIST_COUPON_COURSE_MODES:
                            return Response('Course mode not supported', status=status.HTTP_400_BAD_REQUEST)
                    except AttributeError:
                        return Response('Course mode not supported', status=status.HTTP_400_BAD_REQUEST)

                stock_records_string = ' '.join(str(id) for id in stock_record_ids)
                coupon_catalog, __ = get_or_create_catalog(
                    name='Catalog for stock records: {}'.format(stock_records_string),
                    partner=partner,
                    stock_record_ids=stock_record_ids
                )
            else:
                coupon_catalog = None

            coupon_product = create_coupon_product(
                benefit_type=request.data.get('benefit_type'),
                benefit_value=request.data.get('benefit_value'),
                catalog=coupon_catalog,
                catalog_query=request.data.get('catalog_query'),
                category=category,
                code=code,
                course_seat_types=course_seat_types,
                email_domains=request.data.get('email_domains'),
                end_datetime=dateutil.parser.parse(request.data.get('end_datetime')),
                max_uses=max_uses,
                note=request.data.get('note'),
                partner=partner,
                price=request.data.get('price'),
                quantity=request.data.get('quantity'),
                start_datetime=dateutil.parser.parse(request.data.get('start_datetime')),
                title=request.data.get('title'),
                voucher_type=voucher_type
            )

            basket = prepare_basket(request, coupon_product)

            # Create an order now since payment is handled out of band via an invoice.
            client, __ = BusinessClient.objects.get_or_create(name=request.data.get('client'))
            invoice_data = self.create_update_data_dict(data=request.data, fields=Invoice.UPDATEABLE_INVOICE_FIELDS)
            response_data = self.create_order_for_invoice(
                basket, coupon_id=coupon_product.id, client=client, invoice_data=invoice_data
            )

            return Response(response_data, status=status.HTTP_200_OK)

Example 96

Project: ecommerce
Source File: views.py
View license
    def post(self, request):
        """Process a CyberSource merchant notification and place an order for paid products as appropriate."""

        # Note (CCB): Orders should not be created until the payment processor has validated the response's signature.
        # This validation is performed in the handle_payment method. After that method succeeds, the response can be
        # safely assumed to have originated from CyberSource.
        cybersource_response = request.POST.dict()
        basket = None
        transaction_id = None

        try:
            transaction_id = cybersource_response.get('transaction_id')
            order_number = cybersource_response.get('req_reference_number')
            basket_id = OrderNumberGenerator().basket_id(order_number)

            logger.info(
                'Received CyberSource merchant notification for transaction [%s], associated with basket [%d].',
                transaction_id,
                basket_id
            )

            basket = self._get_basket(basket_id)

            if not basket:
                logger.error('Received payment for non-existent basket [%s].', basket_id)
                return HttpResponse(status=400)
        finally:
            # Store the response in the database regardless of its authenticity.
            ppr = self.payment_processor.record_processor_response(cybersource_response, transaction_id=transaction_id,
                                                                   basket=basket)

        try:
            # Explicitly delimit operations which will be rolled back if an exception occurs.
            with transaction.atomic():
                try:
                    self.handle_payment(cybersource_response, basket)
                except InvalidSignatureError:
                    logger.exception(
                        'Received an invalid CyberSource response. The payment response was recorded in entry [%d].',
                        ppr.id
                    )
                    return HttpResponse(status=400)
                except (UserCancelled, TransactionDeclined) as exception:
                    logger.info(
                        'CyberSource payment did not complete for basket [%d] because [%s]. '
                        'The payment response was recorded in entry [%d].',
                        basket.id,
                        exception.__class__.__name__,
                        ppr.id
                    )
                    return HttpResponse()
                except PaymentError:
                    logger.exception(
                        'CyberSource payment failed for basket [%d]. The payment response was recorded in entry [%d].',
                        basket.id,
                        ppr.id
                    )
                    return HttpResponse()
        except:  # pylint: disable=bare-except
            logger.exception('Attempts to handle payment for basket [%d] failed.', basket.id)
            return HttpResponse(status=500)

        try:
            # Note (CCB): In the future, if we do end up shipping physical products, we will need to
            # properly implement shipping methods. For more, see
            # http://django-oscar.readthedocs.org/en/latest/howto/how_to_configure_shipping.html.
            shipping_method = NoShippingRequired()
            shipping_charge = shipping_method.calculate(basket)

            # Note (CCB): This calculation assumes the payment processor has not sent a partial authorization,
            # thus we use the amounts stored in the database rather than those received from the payment processor.
            order_total = OrderTotalCalculator().calculate(basket, shipping_charge)
            billing_address = self._get_billing_address(cybersource_response)

            user = basket.owner

            self.handle_order_placement(
                order_number,
                user,
                basket,
                None,
                shipping_method,
                shipping_charge,
                billing_address,
                order_total,
                request=request
            )

            return HttpResponse()
        except:  # pylint: disable=bare-except
            logger.exception(self.order_placement_failure_msg, basket.id)
            return HttpResponse(status=500)

Example 97

Project: edx-ora2
Source File: ai.py
View license
    @classmethod
    @transaction.atomic
    def create_classifier_set(cls, classifiers_dict, rubric, algorithm_id, course_id, item_id):
        """
        Create a set of classifiers.

        Args:
            classifiers_dict (dict): Mapping of criterion names to
                JSON-serializable classifiers.
            rubric (Rubric): The rubric model.
            algorithm_id (unicode): The ID of the algorithm used to train the classifiers.
            course_id (unicode): The ID of the course that the classifier is going to be grading
            item_id (unicode): The item within the course that the classifier is trained to grade.

        Returns:
            AIClassifierSet

        Raises:
            ClassifierSerializeError
            ClassifierUploadError
            InvalidRubricSelection
            DatabaseError

        """
        # Create the classifier set
        classifier_set = cls.objects.create(
            rubric=rubric, algorithm_id=algorithm_id, item_id=item_id, course_id=course_id
        )

        # Retrieve the criteria for this rubric,
        # then organize them by criterion name
        try:
            rubric_index = rubric.index
        except DatabaseError as ex:
            msg = (
                u"An unexpected error occurred while retrieving rubric criteria with the"
                u"rubric hash {rh} and algorithm_id {aid}: {ex}"
            ).format(rh=rubric.content_hash, aid=algorithm_id, ex=ex)
            logger.exception(msg)
            raise

        # Check that we have classifiers for all criteria in the rubric
        # Ignore criteria that have no options: since these have only written feedback,
        # we can't assign them a score.
        all_criteria = set(classifiers_dict.keys())
        all_criteria |= set(
            criterion.name for criterion in 
            rubric_index.find_criteria_without_options()
        )
        missing_criteria = rubric_index.find_missing_criteria(all_criteria)
        if missing_criteria:
            raise IncompleteClassifierSet(missing_criteria)

        # Create classifiers for each criterion
        for criterion_name, classifier_data in classifiers_dict.iteritems():
            classifier = AIClassifier.objects.create(
                classifier_set=classifier_set,
                criterion=rubric_index.find_criterion(criterion_name)
            )

            # Serialize the classifier data and upload
            try:
                contents = ContentFile(json.dumps(classifier_data))
            except (TypeError, ValueError, UnicodeDecodeError) as ex:
                msg = (
                    u"Could not serialize classifier data as JSON: {ex}"
                ).format(ex=ex)
                raise ClassifierSerializeError(msg)

            filename = uuid4().hex
            try:
                classifier.classifier_data.save(filename, contents)
            except Exception as ex:
                full_filename = upload_to_path(classifier, filename)
                msg = (
                    u"Could not upload classifier data to {filename}: {ex}"
                ).format(filename=full_filename, ex=ex)
                raise ClassifierUploadError(msg)

        return classifier_set

Example 98

Project: edx-ora2
Source File: training.py
View license
@transaction.atomic
def deserialize_training_examples(examples, rubric_dict):
    """
    Deserialize training examples to Django models.

    Args:
        examples (list of dict): The serialized training examples.
        rubric_dict (dict): The serialized rubric.

    Returns:
        list of TrainingExamples

    Raises:
        InvalidRubric
        InvalidRubricSelection
        InvalidTrainingExample

    Example usage:

        >>> options = [
        >>>     {
        >>>         "order_num": 0,
        >>>         "name": "poor",
        >>>         "explanation": "Poor job!",
        >>>         "points": 0,
        >>>     },
        >>>     {
        >>>         "order_num": 1,
        >>>         "name": "good",
        >>>         "explanation": "Good job!",
        >>>         "points": 1,
        >>>     },
        >>>     {
        >>>         "order_num": 2,
        >>>         "name": "excellent",
        >>>         "explanation": "Excellent job!",
        >>>         "points": 2,
        >>>     },
        >>> ]
        >>>
        >>> rubric = {
        >>>     "prompts": [
        >>>         {"description": "Prompt 1"}
        >>>         {"description": "Prompt 2"}
        >>>         {"description": "Prompt 3"}
        >>>     ],
        >>>     "criteria": [
        >>>         {
        >>>             "order_num": 0,
        >>>             "name": "vocabulary",
        >>>             "prompt": "How varied is the vocabulary?",
        >>>             "options": options
        >>>         },
        >>>         {
        >>>             "order_num": 1,
        >>>             "name": "grammar",
        >>>             "prompt": "How correct is the grammar?",
        >>>             "options": options
        >>>         }
        >>>     ]
        >>> }
        >>>
        >>> examples = [
        >>>     {
        >>>         'answer': {
        >>>             'parts': {
        >>>                 [
        >>>                     {'text:' 'Answer part 1'},
        >>>                     {'text:' 'Answer part 2'},
        >>>                     {'text:' 'Answer part 3'}
        >>>                 ]
        >>>             }
        >>>         },
        >>>         'options_selected': {
        >>>             'vocabulary': 'good',
        >>>             'grammar': 'excellent'
        >>>         }
        >>>     },
        >>>     {
        >>>         'answer': u'Doler',
        >>>         'options_selected': {
        >>>             'vocabulary': 'good',
        >>>             'grammar': 'poor'
        >>>         }
        >>>     }
        >>> ]
        >>>
        >>> examples = deserialize_training_examples(examples, rubric)

    """
    # Parse the rubric
    # This will raise an exception if the serialized rubric is invalid.
    rubric = rubric_from_dict(rubric_dict)

    # Parse each example
    created_examples = []
    for example_dict in examples:

        # Try to retrieve the example from the cache
        cache_key, content_hash = TrainingExample.cache_key(example_dict['answer'], example_dict['options_selected'], rubric)
        example = cache.get(cache_key)

        # If we couldn't retrieve the example from the cache, create it
        if example is None:
            # Validate the training example
            is_valid, errors = validate_training_example_format(example_dict)
            if not is_valid:
                raise InvalidTrainingExample("; ".join(errors))

            # Get or create the training example
            try:
                example = TrainingExample.objects.get(content_hash=content_hash)
            except TrainingExample.DoesNotExist:
                try:
                    with transaction.atomic():
                        example = TrainingExample.create_example(
                            example_dict['answer'], example_dict['options_selected'], rubric
                        )
                except IntegrityError:
                    example = TrainingExample.objects.get(content_hash=content_hash)

            # Add the example to the cache
            cache.set(cache_key, example)

        created_examples.append(example)

    return created_examples

Example 99

Project: edx-ora2
Source File: models.py
View license
    @classmethod
    @transaction.atomic
    def start_workflow(cls, submission_uuid, step_names, on_init_params):
        """
        Start a new workflow.

        Args:
            submission_uuid (str): The UUID of the submission associated with this workflow.
            step_names (list): The names of the assessment steps in the workflow.
            on_init_params (dict): The parameters to pass to each assessment module
                on init.  Keys are the assessment step names.

        Returns:
            AssessmentWorkflow

        Raises:
            SubmissionNotFoundError
            SubmissionRequestError
            SubmissionInternalError
            DatabaseError
            Assessment-module specific errors
        """
        submission_dict = sub_api.get_submission_and_student(submission_uuid)

        staff_auto_added = False
        if 'staff' not in step_names:
            staff_auto_added = True
            new_list = ['staff']
            new_list.extend(step_names)
            step_names = new_list

        # Create the workflow and step models in the database
        # For now, set the status to waiting; we'll modify it later
        # based on the first step in the workflow.
        workflow = cls.objects.create(
            submission_uuid=submission_uuid,
            status=AssessmentWorkflow.STATUS.waiting,
            course_id=submission_dict['student_item']['course_id'],
            item_id=submission_dict['student_item']['item_id']
        )
        workflow_steps = [
            AssessmentWorkflowStep(
                workflow=workflow, name=step, order_num=i
            )
            for i, step in enumerate(step_names)
        ]
        workflow.steps.add(*workflow_steps)

        # Initialize the assessment APIs
        has_started_first_step = False
        for step in workflow_steps:
            api = step.api()

            if api is not None:
                # Initialize the assessment module
                # We do this for every assessment module
                on_init_func = getattr(api, 'on_init', lambda submission_uuid, **params: None)
                on_init_func(submission_uuid, **on_init_params.get(step.name, {}))

                # If we auto-added a staff step, it is optional and should be marked complete immediately
                if step.name == "staff" and staff_auto_added:
                    step.assessment_completed_at = now()
                    step.save()

                # For the first valid step, update the workflow status
                # and notify the assessment module that it's being started
                if not has_started_first_step:
                    # Update the workflow
                    workflow.status = step.name
                    workflow.save()

                    # Notify the assessment module that it's being started
                    on_start_func = getattr(api, 'on_start', lambda submission_uuid: None)
                    on_start_func(submission_uuid)

                    # Remember that we've already started the first step
                    has_started_first_step = True

        # Update the workflow (in case some of the assessment modules are automatically complete)
        # We do NOT pass in requirements, on the assumption that any assessment module
        # that accepts requirements would NOT automatically complete.
        workflow.update_from_assessments(None)

        # Return the newly created workflow
        return workflow

Example 100

Project: edx-platform
Source File: views.py
View license
    @method_decorator(login_required)
    @method_decorator(transaction.atomic)
    def get(self, request, course_id, error=None):
        """Displays the course mode choice page.

        Args:
            request (`Request`): The Django Request object.
            course_id (unicode): The slash-separated course key.

        Keyword Args:
            error (unicode): If provided, display this error message
                on the page.

        Returns:
            Response

        """
        course_key = CourseKey.from_string(course_id)

        # Check whether the user has access to this course
        # based on country access rules.
        embargo_redirect = embargo_api.redirect_if_blocked(
            course_key,
            user=request.user,
            ip_address=get_ip(request),
            url=request.path
        )
        if embargo_redirect:
            return redirect(embargo_redirect)

        enrollment_mode, is_active = CourseEnrollment.enrollment_mode_for_user(request.user, course_key)
        modes = CourseMode.modes_for_course_dict(course_key)
        ecommerce_service = EcommerceService()

        # We assume that, if 'professional' is one of the modes, it should be the *only* mode.
        # If there are both modes, default to non-id-professional.
        has_enrolled_professional = (CourseMode.is_professional_slug(enrollment_mode) and is_active)
        if CourseMode.has_professional_mode(modes) and not has_enrolled_professional:
            purchase_workflow = request.GET.get("purchase_workflow", "single")
            verify_url = reverse('verify_student_start_flow', kwargs={'course_id': unicode(course_key)})
            redirect_url = "{url}?purchase_workflow={workflow}".format(url=verify_url, workflow=purchase_workflow)
            if ecommerce_service.is_enabled(request.user):
                professional_mode = modes.get(CourseMode.NO_ID_PROFESSIONAL_MODE) or modes.get(CourseMode.PROFESSIONAL)
                if purchase_workflow == "single" and professional_mode.sku:
                    redirect_url = ecommerce_service.checkout_page_url(professional_mode.sku)
                if purchase_workflow == "bulk" and professional_mode.bulk_sku:
                    redirect_url = ecommerce_service.checkout_page_url(professional_mode.bulk_sku)
            return redirect(redirect_url)

        # If there isn't a verified mode available, then there's nothing
        # to do on this page.  The user has almost certainly been auto-registered
        # in the "honor" track by this point, so we send the user
        # to the dashboard.
        if not CourseMode.has_verified_mode(modes):
            return redirect(reverse('dashboard'))

        # If a user has already paid, redirect them to the dashboard.
        if is_active and (enrollment_mode in CourseMode.VERIFIED_MODES + [CourseMode.NO_ID_PROFESSIONAL_MODE]):
            return redirect(reverse('dashboard'))

        donation_for_course = request.session.get("donation_for_course", {})
        chosen_price = donation_for_course.get(unicode(course_key), None)

        course = modulestore().get_course(course_key)
        if CourseEnrollment.is_enrollment_closed(request.user, course):
            locale = to_locale(get_language())
            enrollment_end_date = format_datetime(course.enrollment_end, 'short', locale=locale)
            params = urllib.urlencode({'course_closed': enrollment_end_date})
            return redirect('{0}?{1}'.format(reverse('dashboard'), params))

        # When a credit mode is available, students will be given the option
        # to upgrade from a verified mode to a credit mode at the end of the course.
        # This allows students who have completed photo verification to be eligible
        # for univerity credit.
        # Since credit isn't one of the selectable options on the track selection page,
        # we need to check *all* available course modes in order to determine whether
        # a credit mode is available.  If so, then we show slightly different messaging
        # for the verified track.
        has_credit_upsell = any(
            CourseMode.is_credit_mode(mode) for mode
            in CourseMode.modes_for_course(course_key, only_selectable=False)
        )

        context = {
            "course_modes_choose_url": reverse(
                "course_modes_choose",
                kwargs={'course_id': course_key.to_deprecated_string()}
            ),
            "modes": modes,
            "has_credit_upsell": has_credit_upsell,
            "course_name": course.display_name_with_default_escaped,
            "course_org": course.display_org_with_default,
            "course_num": course.display_number_with_default,
            "chosen_price": chosen_price,
            "error": error,
            "responsive": True,
            "nav_hidden": True,
        }
        if "verified" in modes:
            verified_mode = modes["verified"]
            context["suggested_prices"] = [
                decimal.Decimal(x.strip())
                for x in verified_mode.suggested_prices.split(",")
                if x.strip()
            ]
            context["currency"] = verified_mode.currency.upper()
            context["min_price"] = verified_mode.min_price
            context["verified_name"] = verified_mode.name
            context["verified_description"] = verified_mode.description

            if verified_mode.sku:
                context["use_ecommerce_payment_flow"] = ecommerce_service.is_enabled(request.user)
                context["ecommerce_payment_page"] = ecommerce_service.payment_page_url()
                context["sku"] = verified_mode.sku
                context["bulk_sku"] = verified_mode.bulk_sku

        return render_to_response("course_modes/choose.html", context)