os.path.basename

Here are the examples of the python api os.path.basename taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: virt-test
Source File: bootstrap.py
View license
def create_subtests_cfg(t_type):
    root_dir = data_dir.get_root_dir()

    specific_test_list = []
    specific_file_list = []
    specific_subdirs = asset.get_test_provider_subdirs(t_type)
    provider_names_specific = asset.get_test_provider_names(t_type)

    provider_info_specific = []
    for specific_provider in provider_names_specific:
        provider_info_specific.append(asset.get_test_provider_info(specific_provider))

    for subdir in specific_subdirs:
        specific_test_list += data_dir.SubdirGlobList(subdir,
                                                      '*.py',
                                                      test_filter)
        specific_file_list += data_dir.SubdirGlobList(subdir,
                                                      '*.cfg',
                                                      config_filter)

    shared_test_list = []
    shared_file_list = []
    shared_subdirs = asset.get_test_provider_subdirs('generic')
    provider_names_shared = asset.get_test_provider_names('generic')

    provider_info_shared = []
    for shared_provider in provider_names_shared:
        provider_info_shared.append(asset.get_test_provider_info(shared_provider))

    if not t_type == 'lvsb':
        for subdir in shared_subdirs:
            shared_test_list += data_dir.SubdirGlobList(subdir,
                                                        '*.py',
                                                        test_filter)
            shared_file_list += data_dir.SubdirGlobList(subdir,
                                                        '*.cfg',
                                                        config_filter)

    all_specific_test_list = []
    for test in specific_test_list:
        for p in provider_info_specific:
            provider_base_path = p['backends'][t_type]['path']
            if provider_base_path in test:
                provider_name = p['name']
                break

        basename = os.path.basename(test)
        if basename != "__init__.py":
            all_specific_test_list.append("%s.%s" %
                                          (provider_name,
                                           basename.split(".")[0]))
    all_shared_test_list = []
    for test in shared_test_list:
        for p in provider_info_shared:
            provider_base_path = p['backends']['generic']['path']
            if provider_base_path in test:
                provider_name = p['name']
                break

        basename = os.path.basename(test)
        if basename != "__init__.py":
            all_shared_test_list.append("%s.%s" %
                                        (provider_name,
                                         basename.split(".")[0]))

    all_specific_test_list.sort()
    all_shared_test_list.sort()
    all_test_list = set(all_specific_test_list + all_shared_test_list)

    first_subtest_file = []
    last_subtest_file = []
    non_dropin_tests = []
    tmp = []

    for shared_file in shared_file_list:
        provider_name = None
        for p in provider_info_shared:
            provider_base_path = p['backends']['generic']['path']
            if provider_base_path in shared_file:
                provider_name = p['name']
                break

        shared_file_obj = open(shared_file, 'r')
        for line in shared_file_obj.readlines():
            line = line.strip()
            if line.startswith("type"):
                cartesian_parser = cartesian_config.Parser()
                cartesian_parser.parse_string(line)
                td = cartesian_parser.get_dicts().next()
                values = td['type'].split(" ")
                for value in values:
                    if t_type not in non_dropin_tests:
                        non_dropin_tests.append("%s.%s" %
                                                (provider_name, value))

        shared_file_name = os.path.basename(shared_file)
        shared_file_name = shared_file_name.split(".")[0]
        if shared_file_name in first_subtest[t_type]:
            if [provider_name, shared_file] not in first_subtest_file:
                first_subtest_file.append([provider_name, shared_file])
        elif shared_file_name in last_subtest[t_type]:
            if [provider_name, shared_file] not in last_subtest_file:
                last_subtest_file.append([provider_name, shared_file])
        else:
            if [provider_name, shared_file] not in tmp:
                tmp.append([provider_name, shared_file])
    shared_file_list = tmp

    tmp = []
    for shared_file in specific_file_list:
        provider_name = None
        for p in provider_info_specific:
            provider_base_path = p['backends'][t_type]['path']
            if provider_base_path in shared_file:
                provider_name = p['name']
                break

        shared_file_obj = open(shared_file, 'r')
        for line in shared_file_obj.readlines():
            line = line.strip()
            if line.startswith("type"):
                cartesian_parser = cartesian_config.Parser()
                cartesian_parser.parse_string(line)
                td = cartesian_parser.get_dicts().next()
                values = td['type'].split(" ")
                for value in values:
                    if value not in non_dropin_tests:
                        non_dropin_tests.append("%s.%s" %
                                                (provider_name, value))

        shared_file_name = os.path.basename(shared_file)
        shared_file_name = shared_file_name.split(".")[0]
        if shared_file_name in first_subtest[t_type]:
            if [provider_name, shared_file] not in first_subtest_file:
                first_subtest_file.append([provider_name, shared_file])
        elif shared_file_name in last_subtest[t_type]:
            if [provider_name, shared_file] not in last_subtest_file:
                last_subtest_file.append([provider_name, shared_file])
        else:
            if [provider_name, shared_file] not in tmp:
                tmp.append([provider_name, shared_file])
    specific_file_list = tmp

    non_dropin_tests.sort()
    non_dropin_tests = set(non_dropin_tests)
    dropin_tests = all_test_list - non_dropin_tests
    dropin_file_list = []
    tmp_dir = data_dir.get_tmp_dir()
    if not os.path.isdir(tmp_dir):
        os.makedirs(tmp_dir)

    for dropin_test in dropin_tests:
        provider = dropin_test.split(".")[0]
        d_type = dropin_test.split(".")[-1]
        autogen_cfg_path = os.path.join(tmp_dir,
                                        '%s.cfg' % dropin_test)
        autogen_cfg_file = open(autogen_cfg_path, 'w')
        autogen_cfg_file.write("# Drop-in test - auto generated snippet\n")
        autogen_cfg_file.write("- %s:\n" % dropin_test)
        autogen_cfg_file.write("    virt_test_type = %s\n" % t_type)
        autogen_cfg_file.write("    type = %s\n" % d_type)
        autogen_cfg_file.close()
        dropin_file_list.append([provider, autogen_cfg_path])

    dropin_file_list_2 = []
    dropin_tests = os.listdir(os.path.join(data_dir.get_root_dir(), "dropin"))
    dropin_cfg_path = os.path.join(tmp_dir, 'dropin.cfg')
    dropin_cfg_file = open(dropin_cfg_path, 'w')
    dropin_cfg_file.write("# Auto generated snippet for dropin tests\n")
    dropin_cfg_file.write("- dropin:\n")
    dropin_cfg_file.write("    variants:\n")
    for dropin_test in dropin_tests:
        if dropin_test == "README":
            continue
        dropin_cfg_file.write("        - %s:\n" % dropin_test)
        dropin_cfg_file.write("            virt_test_type = %s\n" % t_type)
        dropin_cfg_file.write("            type = dropin\n")
        dropin_cfg_file.write("            start_vm = no\n")
        dropin_cfg_file.write("            dropin_path = %s\n" % dropin_test)
    dropin_cfg_file.close()
    dropin_file_list_2.append(['io-github-autotest-qemu', dropin_cfg_path])

    subtests_cfg = os.path.join(root_dir, 'backends', t_type, 'cfg',
                                'subtests.cfg')
    subtests_file = open(subtests_cfg, 'w')
    subtests_file.write(
        "# Do not edit, auto generated file from subtests config\n")

    subtests_file.write("variants subtest:\n")
    write_subtests_files(first_subtest_file, subtests_file)
    write_subtests_files(specific_file_list, subtests_file, t_type)
    write_subtests_files(shared_file_list, subtests_file)
    write_subtests_files(dropin_file_list, subtests_file)
    write_subtests_files(dropin_file_list_2, subtests_file)
    write_subtests_files(last_subtest_file, subtests_file)

    subtests_file.close()

Example 2

Project: avocado-vt
Source File: bootstrap.py
View license
def create_subtests_cfg(t_type):
    specific_test_list = []
    specific_file_list = []
    specific_subdirs = asset.get_test_provider_subdirs(t_type)
    provider_names_specific = asset.get_test_provider_names(t_type)
    config_filter = get_config_filter()

    provider_info_specific = []
    for specific_provider in provider_names_specific:
        provider_info_specific.append(
            asset.get_test_provider_info(specific_provider))

    for subdir in specific_subdirs:
        specific_test_list += data_dir.SubdirGlobList(subdir,
                                                      '*.py',
                                                      test_filter)
        specific_file_list += data_dir.SubdirGlobList(subdir,
                                                      '*.cfg',
                                                      config_filter)

    shared_test_list = []
    shared_file_list = []
    shared_subdirs = asset.get_test_provider_subdirs('generic')
    provider_names_shared = asset.get_test_provider_names('generic')

    provider_info_shared = []
    for shared_provider in provider_names_shared:
        provider_info_shared.append(
            asset.get_test_provider_info(shared_provider))

    if not t_type == 'lvsb':
        for subdir in shared_subdirs:
            shared_test_list += data_dir.SubdirGlobList(subdir,
                                                        '*.py',
                                                        test_filter)
            shared_file_list += data_dir.SubdirGlobList(subdir,
                                                        '*.cfg',
                                                        config_filter)

    all_specific_test_list = []
    for test in specific_test_list:
        for p in provider_info_specific:
            provider_base_path = p['backends'][t_type]['path']
            if provider_base_path in test:
                provider_name = p['name']
                break

        basename = os.path.basename(test)
        if basename != "__init__.py":
            all_specific_test_list.append("%s.%s" %
                                          (provider_name,
                                           basename.split(".")[0]))
    all_shared_test_list = []
    for test in shared_test_list:
        for p in provider_info_shared:
            provider_base_path = p['backends']['generic']['path']
            if provider_base_path in test:
                provider_name = p['name']
                break

        basename = os.path.basename(test)
        if basename != "__init__.py":
            all_shared_test_list.append("%s.%s" %
                                        (provider_name,
                                         basename.split(".")[0]))

    all_specific_test_list.sort()
    all_shared_test_list.sort()

    first_subtest_file = []
    last_subtest_file = []
    non_dropin_tests = []
    tmp = []

    for shared_file in shared_file_list:
        provider_name = None
        for p in provider_info_shared:
            provider_base_path = p['backends']['generic']['path']
            if provider_base_path in shared_file:
                provider_name = p['name']
                break

        shared_file_obj = open(shared_file, 'r')
        for line in shared_file_obj.readlines():
            line = line.strip()
            if line.startswith("type"):
                cartesian_parser = cartesian_config.Parser()
                cartesian_parser.parse_string(line)
                td = cartesian_parser.get_dicts().next()
                values = td['type'].split(" ")
                for value in values:
                    if t_type not in non_dropin_tests:
                        non_dropin_tests.append("%s.%s" %
                                                (provider_name, value))

        shared_file_name = os.path.basename(shared_file)
        shared_file_name = shared_file_name.split(".")[0]
        if shared_file_name in first_subtest[t_type]:
            if [provider_name, shared_file] not in first_subtest_file:
                first_subtest_file.append([provider_name, shared_file])
        elif shared_file_name in last_subtest[t_type]:
            if [provider_name, shared_file] not in last_subtest_file:
                last_subtest_file.append([provider_name, shared_file])
        else:
            if [provider_name, shared_file] not in tmp:
                tmp.append([provider_name, shared_file])
    shared_file_list = tmp

    tmp = []
    for shared_file in specific_file_list:
        provider_name = None
        for p in provider_info_specific:
            provider_base_path = p['backends'][t_type]['path']
            if provider_base_path in shared_file:
                provider_name = p['name']
                break

        shared_file_obj = open(shared_file, 'r')
        for line in shared_file_obj.readlines():
            line = line.strip()
            if line.startswith("type"):
                cartesian_parser = cartesian_config.Parser()
                cartesian_parser.parse_string(line)
                td = cartesian_parser.get_dicts().next()
                values = td['type'].split(" ")
                for value in values:
                    if value not in non_dropin_tests:
                        non_dropin_tests.append("%s.%s" %
                                                (provider_name, value))

        shared_file_name = os.path.basename(shared_file)
        shared_file_name = shared_file_name.split(".")[0]
        if shared_file_name in first_subtest[t_type]:
            if [provider_name, shared_file] not in first_subtest_file:
                first_subtest_file.append([provider_name, shared_file])
        elif shared_file_name in last_subtest[t_type]:
            if [provider_name, shared_file] not in last_subtest_file:
                last_subtest_file.append([provider_name, shared_file])
        else:
            if [provider_name, shared_file] not in tmp:
                tmp.append([provider_name, shared_file])
    specific_file_list = tmp

    subtests_cfg = os.path.join(data_dir.get_backend_dir(t_type), 'cfg',
                                'subtests.cfg')
    subtests_file = open(subtests_cfg, 'w')
    subtests_file.write(
        "# Do not edit, auto generated file from subtests config\n")

    subtests_file.write("variants subtest:\n")
    write_subtests_files(first_subtest_file, subtests_file)
    write_subtests_files(specific_file_list, subtests_file, t_type)
    write_subtests_files(shared_file_list, subtests_file)
    write_subtests_files(last_subtest_file, subtests_file)

    subtests_file.close()
    logging.debug("Config file %s auto generated from subtest samples",
                  subtests_cfg)

Example 3

Project: radical.pilot
Source File: staging_directives.py
View license
def expand_staging_directive(staging_directive):
    """Take an abbreviated or compressed staging directive and expand it.

    """

    # Convert single entries into a list
    if not isinstance(staging_directive, list):
        staging_directive = [staging_directive]

    # Use this to collect the return value
    new_staging_directive = []

    # We loop over the list of staging directives
    for sd in staging_directive:

        if isinstance(sd, basestring):

            # We detected a string, convert into dict.  The interpretation
            # differs depending of redirection characters being present in the
            # string.

            append = False
            if '>>'  in sd:
                src, tgt = sd.split('>>', 2)
                append = True
            elif '>' in sd :
                src, tgt = sd.split('>',  2)
                append  = False
            elif '<<' in sd:
                tgt, src = sd.split('<<', 2)
                append = True
            elif '<'  in sd:
                tgt, src = sd.split('<',  2)
                append = False
            else:
                src, tgt = sd, os.path.basename(sd)
                append = False

            if append:
                logger.warn("append mode on staging not supported (ignored)")

            new_sd = {'source':   src.strip(),
                      'target':   tgt.strip(),
                      'action':   DEFAULT_ACTION,
                      'flags':    DEFAULT_FLAGS,
                      'priority': DEFAULT_PRIORITY
            }
            logger.debug("Converting string '%s' into dict '%s'" % (sd, new_sd))
            new_staging_directive.append(new_sd)

        elif isinstance(sd, dict):
            # We detected a dict, will have to distinguish between single and multiple entries

            if 'action' in sd:
                action = sd['action']
            else:
                action = DEFAULT_ACTION

            if 'flags' in sd:
                flags = sd['flags']
            else:
                flags = DEFAULT_FLAGS

            if 'priority' in sd:
                priority = sd['priority']
            else:
                priority = DEFAULT_PRIORITY

            if not 'source' in sd:
                raise Exception("Staging directive dict has no source member!")
            source = sd['source']

            if 'target' in sd:
                target = sd['target']
            else:
                # Set target to None, as inferring it depends on the type of source
                target = None

            if isinstance(source, basestring) or isinstance(source, saga.Url):

                if target:
                    # Detect asymmetry in source and target length
                    if isinstance(target, list):
                        raise Exception("Source is singular but target is a list")
                else:
                    # We had no target specified, assume the basename of source
                    if isinstance(source, basestring):
                        target = os.path.basename(source)
                    elif isinstance(source, saga.Url):
                        target = os.path.basename(source.path)
                    else:
                        raise Exception("Source %s is neither a string nor a Url (%s)!" %
                                        (source, type(source)))

                # This is a regular entry, complete and append it
                new_sd = {'source':   source,
                          'target':   target,
                          'action':   action,
                          'flags':    flags,
                          'priority': priority,
                }
                new_staging_directive.append(new_sd)
                logger.debug("Completing entry '%s'" % new_sd)

            elif isinstance(source, list):
                # We detected a list of sources, we need to expand it

                # We will break up the list entries in source into an equal length list of dicts
                new_sds = []

                if target:
                    # Target is also specified, make sure it is a list of equal length

                    if not isinstance(target, list):
                        raise Exception("Both source and target are specified, but target is not a list")

                    if len(source) != len(target):
                        raise Exception("Source (%d) and target (%d) are lists of different length" % (len(source), len(target)))

                    # Now that we have established that the list are of equal size we can combine them
                    for src_entry, tgt_entry in zip(source, target):

                        new_sd = {'source':   src_entry,
                                  'target':   tgt_entry,
                                  'action':   action,
                                  'flags':    flags,
                                  'priority': priority
                        }
                        new_sds.append(new_sd)
                else:
                    # Target is not specified, use the source for the target too.

                    # Go over all entries in the list and create an equal length list of dicts.
                    for src_entry in source:

                        if isinstance(source, basestring):
                            target = os.path.basename(src_entry),
                        elif isinstance(source, saga.Url):
                            target = os.path.basename(src_entry.path),
                        else:
                            raise Exception("Source %s is neither a string nor a Url (%s)!" %
                                             (source, type(source)))

                        new_sd = {'source':   src_entry,
                                  'target':   target,
                                  'action':   action,
                                  'flags':    flags,
                                  'priority': priority
                        }
                        new_sds.append(new_sd)

                logger.debug("Converting list '%s' into dicts '%s'" % (source, new_sds))

                # Add the content of the local list to global list
                new_staging_directive.extend(new_sds)

            else:
                raise Exception("Source %s is neither an entry nor a list (%s)!" %
                                (source, type(source)))

        else:
            raise Exception("Unknown type of staging directive: %s (%s)" % (sd, type(sd)))

    return new_staging_directive

Example 4

Project: entropy
Source File: transceivers.py
View license
    def _transceive(self, uri):

        fine = set()
        broken = set()
        fail = False
        crippled_uri = EntropyTransceiver.get_uri_name(uri)
        action = 'push'
        if self.download:
            action = 'pull'
        elif self.remove:
            action = 'remove'

        try:
            txc = EntropyTransceiver(uri)
            if const_isnumber(self.speed_limit):
                txc.set_speed_limit(self.speed_limit)
            txc.set_output_interface(self._entropy)
        except TransceiverConnectionError:
            print_traceback()
            return True, fine, broken # issues

        maxcount = len(self.myfiles)
        counter = 0

        with txc as handler:

            for mypath in self.myfiles:

                base_dir = self.txc_basedir

                if isinstance(mypath, tuple):
                    if len(mypath) < 2:
                        continue
                    base_dir, mypath = mypath

                if not handler.is_dir(base_dir):
                    handler.makedirs(base_dir)

                mypath_fn = os.path.basename(mypath)
                remote_path = os.path.join(base_dir, mypath_fn)

                syncer = handler.upload
                myargs = (mypath, remote_path)
                if self.download:
                    syncer = handler.download
                    local_path = os.path.join(self.local_basedir, mypath_fn)
                    myargs = (remote_path, local_path)
                elif self.remove:
                    syncer = handler.delete
                    myargs = (remote_path,)

                fallback_syncer, fallback_args = None, None
                # upload -> remote copy herustic support
                # if a package file might have been already uploaded
                # to remote mirror, try to look in other repositories'
                # package directories if a file, with the same md5 and name
                # is already available. In this case, use remote copy instead
                # of upload to save bandwidth.
                if self._copy_herustic and (syncer == handler.upload):
                    # copy herustic support enabled
                    # we are uploading
                    new_syncer, new_args = self._copy_herustic_support(
                        handler, mypath, base_dir, remote_path)
                    if new_syncer is not None:
                        fallback_syncer, fallback_args = syncer, myargs
                        syncer, myargs = new_syncer, new_args
                        action = "copy"

                counter += 1
                tries = 0
                done = False
                lastrc = None

                while tries < 5:
                    tries += 1
                    self._entropy.output(
                        "[%s|#%s|(%s/%s)] %s: %s" % (
                            blue(crippled_uri),
                            darkgreen(str(tries)),
                            blue(str(counter)),
                            bold(str(maxcount)),
                            blue(action),
                            red(os.path.basename(mypath)),
                        ),
                        importance = 0,
                        level = "info",
                        header = red(" @@ ")
                    )
                    rc = syncer(*myargs)
                    if (not rc) and (fallback_syncer is not None):
                        # if we have a fallback syncer, try it first
                        # before giving up.
                        rc = fallback_syncer(*myargs)

                    if rc and not (self.download or self.remove):
                        remote_md5 = handler.get_md5(remote_path)
                        rc = self.handler_verify_upload(mypath, uri,
                            counter, maxcount, tries, remote_md5 = remote_md5)
                    if rc:
                        self._entropy.output(
                            "[%s|#%s|(%s/%s)] %s %s: %s" % (
                                        blue(crippled_uri),
                                        darkgreen(str(tries)),
                                        blue(str(counter)),
                                        bold(str(maxcount)),
                                        blue(action),
                                        _("successful"),
                                        red(os.path.basename(mypath)),
                            ),
                            importance = 0,
                            level = "info",
                            header = darkgreen(" @@ ")
                        )
                        done = True
                        fine.add(uri)
                        break
                    else:
                        self._entropy.output(
                            "[%s|#%s|(%s/%s)] %s %s: %s" % (
                                        blue(crippled_uri),
                                        darkgreen(str(tries)),
                                        blue(str(counter)),
                                        bold(str(maxcount)),
                                        blue(action),
                                        brown(_("failed, retrying")),
                                        red(os.path.basename(mypath)),
                                ),
                            importance = 0,
                            level = "warning",
                            header = brown(" @@ ")
                        )
                        lastrc = rc
                        continue

                if not done:

                    self._entropy.output(
                        "[%s|(%s/%s)] %s %s: %s - %s: %s" % (
                                blue(crippled_uri),
                                blue(str(counter)),
                                bold(str(maxcount)),
                                blue(action),
                                darkred("failed, giving up"),
                                red(os.path.basename(mypath)),
                                _("error"),
                                lastrc,
                        ),
                        importance = 1,
                        level = "error",
                        header = darkred(" !!! ")
                    )

                    if mypath not in self.critical_files:
                        self._entropy.output(
                            "[%s|(%s/%s)] %s: %s, %s..." % (
                                blue(crippled_uri),
                                blue(str(counter)),
                                bold(str(maxcount)),
                                blue(_("not critical")),
                                os.path.basename(mypath),
                                blue(_("continuing")),
                            ),
                            importance = 1,
                            level = "warning",
                            header = brown(" @@ ")
                        )
                        continue

                    fail = True
                    broken.add((uri, lastrc))
                    # next mirror
                    break

        return fail, fine, broken

Example 5

Project: InterMol
Source File: convert.py
View license
def main(args=None):
    logger.info('Beginning InterMol conversion')
    if not args:
        args = vars(parse_args(args))

    if args.get('gromacs_path'):
        gmx.GMX_PATH = args['gromacs_path']
    if args.get('lammps_path'):
        lmp.LMP_PATH = args['lammps_path']
    if args.get('desmond_path'):
        des.DES_PATH = args['desmond_path']
    if args.get('amber_path'):
        amb.AMB_PATH = args['amber_path']
    if args.get('charmm_path'):
        crm.CRM_PATH = args['charmm_path']

    if args.get('verbose'):
        h.setLevel(logging.DEBUG)

    # Print warnings.
    warnings.simplefilter("always")
    if not args.get('force'):
        # Warnings will be treated as exceptions unless force flag is used.
        warnings.simplefilter("error")

    # --------------- PROCESS INPUTS ----------------- #
    if args.get('gro_in'):
        gromacs_files = args['gro_in']

        prefix = os.path.splitext(os.path.basename(gromacs_files[0]))[0]
        # Find the top file since order of inputs is not enforced.
        top_in = [x for x in gromacs_files if x.endswith('.top')]
        assert(len(top_in) == 1)
        top_in = os.path.abspath(top_in[0])

        # Find the gro file since order of inputs is not enforced.
        gro_in = [x for x in gromacs_files if x.endswith('.gro')]
        assert(len(gro_in) == 1)
        gro_in = os.path.abspath(gro_in[0])
        system = gmx.load(top_in, gro_in)

    elif args.get('des_in'):
        cms_file = args['des_in']
        prefix = os.path.splitext(os.path.basename(cms_file))[0]
        system = des.load(cms_file=cms_file)

    elif args.get('lmp_in'):
        lammps_file = args['lmp_in']
        prefix = os.path.splitext(os.path.basename(lammps_file))[0]
        system = lmp.load(in_file=lammps_file)

    elif args.get('amb_in'):
        amber_files = args['amb_in']
        prefix = os.path.splitext(os.path.basename(amber_files[0]))[0]

        # Find the prmtop file since order of inputs is not enforced.
        prmtop_in = [x for x in amber_files if x.endswith('.prmtop')]
        assert(len(prmtop_in) == 1)
        prmtop_in = os.path.abspath(prmtop_in[0])

        # Find the crd file since order of inputs is not enforced, not is suffix
        crd_in = [x for x in amber_files if (x.endswith('.rst7') or x.endswith('.crd') or x.endswith('.rst') or x.endswith('.inpcrd'))]
        assert(len(crd_in) == 1)
        crd_in = os.path.abspath(crd_in[0])

        structure = parmed.amber.AmberParm(prmtop_in,crd_in)
        #Make GROMACS topology
        parmed_system = parmed.gromacs.GromacsTopologyFile.from_structure(structure)

        # write out the files.  Should write them out in the proper directory (the one reading in)
        pathprefix = os.path.dirname(prmtop_in)
        fromamber_top_in = os.path.join(pathprefix, prefix + '_from_amber.top')
        fromamber_gro_in = os.path.join(pathprefix, prefix + '_from_amber.gro')
        parmed.gromacs.GromacsTopologyFile.write(parmed_system, fromamber_top_in)
        parmed.gromacs.GromacsGroFile.write(parmed_system, fromamber_gro_in, precision = 8)

        # now, read in using gromacs
        system = gmx.load(fromamber_top_in, fromamber_gro_in)

    elif args.get('crm_in'):

        #logger.error('CHARMM can\'t currently be used as an input file type for conversions')
        #sys.exit(1)    

        charmm_input_file = args['crm_in']
        prefix = os.path.splitext(os.path.basename(charmm_input_file))[0]
        # we need to find the parameter and structure files by reading the input file.
        with open(charmm_input_file) as cinf:
            lines = cinf.readlines()
        box = []
        rtfs = []
        prms = []
        strms = []
        topsuffixes = ['.rtf','.top']
        prmsuffixes = ['.prm','.par']

        for line in lines:
            # Find the psf file
            if '.psf' in line: 
                psffile = line.split()[-1] # append the file at the end of the line 
            if '.crd' in line:
                crdfile = line.split()[-1] # append the file at the end of the line
            if '.str' in line:     
                strms.append(line.split()[-1]) # append the file at the end of the line 
            if any(x in line for x in topsuffixes): # if the file is any of the topology suffixes: need to be read in before prms.
                rtfs.append(line.split()[-1])  # append the file at the end of the line 
            if any(x in line for x in prmsuffixes): # if the file is any of the parameter suffixes
                prms.append(line.split()[-1])  # append the file at the end of the line 
            if 'set box ' in line:   # will need to handle general variables
                boxlength_vars = line.split()
                boxval = np.float(boxlength_vars[2])
            if 'crystal define' in line:
                boxangle_vars = line.split()
                boxtype = boxangle_vars[2]
                boxvecs = boxangle_vars[3:6]
                for i, b in enumerate(boxvecs):
                    if '@box' in b:
                        boxvecs[i] = boxval
                    else:
                        boxvecs[i] = np.float(b)
                boxangles = np.array(boxangle_vars[6:9],float)
                box = np.append(np.array(boxvecs),boxangles)

        #load in the parameters
        psf = parmed.load_file(psffile)
        parameterset = parmed.charmm.CharmmParameterSet()
        for tfile in rtfs:
            parameterset.read_topology_file(tfile)
        for pfile in prms:
            parameterset.read_parameter_file(pfile)
        for sfile in strms:
            parameterset.read_stream_file(sfile)
        psf.load_parameters(parameterset)

        if len(box) == 6:
            psf.box = box
        # now load in the coordinates
        crd = parmed.load_file(crdfile)
        try:
            if len(crd.coordinates.shape) == 3:
                coords = crd.coordinates[0]
            else:
                coords = crd.coordinates
        except:
            logger.error('No coordinates in %s' % (crd))

        if coords.shape != (len(psf.atoms), 3):
            logger.error('Mismatch in number of coordinates (%d) and '
                '3*number of atoms (%d)' % (len(coords), 3*len(psf.atoms)))
            # Set the coordinates now, since creating the parm may re-order the 
            # atoms in order to maintain contiguous molecules
        psf.coordinates = coords

        # copy the box over to the .psf
        if hasattr(crd, 'box') and crd.box is not None:
            if len(crd.box.shape) == 1:
                crdbox = crd.box
            else:
                # Trajectory
                crdbox = crd.box[0]

            if len(crdbox) == 3:
                psf.box = list(crdbox) + [90.0, 90.0, 90.0]
            elif len(crdbox) == 6:
                psf.box = list(crdbox)
            else:
                logger.error('Unexpected box array shape')

        #Make GROMACS topology
        parmed_system = parmed.gromacs.GromacsTopologyFile.from_structure(psf)

        # write out the files.  Should write them out in the proper directory (the one reading in)
        pathprefix = os.path.dirname(charmm_input_file)
        fromcharmm_top_in = os.path.join(pathprefix, prefix + '_from_charmm.top')
        fromcharmm_gro_in = os.path.join(pathprefix,prefix + '_from_charmm.gro')
        parmed.gromacs.GromacsTopologyFile.write(parmed_system, fromcharmm_top_in)
        parmed.gromacs.GromacsGroFile.write(parmed_system, fromcharmm_gro_in, precision = 8)

        # now, read in using gromacs
        system = gmx.load(fromcharmm_top_in, fromcharmm_gro_in)
    
    else:
        logger.error('No input file')
        sys.exit(1)

    # --------------- WRITE OUTPUTS ----------------- #
    if not args.get('oname'):
        oname = '{0}_converted'.format(prefix)
    else:
        oname = args['oname']
    oname = os.path.abspath(os.path.join(args['odir'], oname))  # Prepend output directory to oname.

    output_status = dict()
    # TODO: factor out exception handling
    if args.get('gromacs'):
        try:
            gmx.save('{0}.top'.format(oname), '{0}.gro'.format(oname), system)
        except Exception as e:
            logger.exception(e)
            output_status['gromacs'] = e
        else:
            output_status['gromacs'] = 'Converted'

    if args.get('lammps'):
        try:
            lmp.save('{0}.input'.format(oname), system, nonbonded_style=args.get('lmp_settings'))
        except Exception as e:
            logger.exception(e)
            output_status['lammps'] = e
        else:
            output_status['lammps'] = 'Converted'

    if args.get('desmond'):
        try:
            des.save('{0}.cms'.format(oname), system)
        except Exception as e:
            logger.exception(e)
            output_status['desmond'] = e
        else:
            output_status['desmond'] = 'Converted'

    if args.get('amber'):
        # NOTE: Although in theory this should work fine, the gromacs
        # output that InterMol produces include rb_torsions, which
        # AMBER can't handle. They are EQUIVALENT to non-rb torsion
        # parameters, but rb torsions are the preferred intermediate
        # in GROMACS because it can handle as special cases all of
        # different versions of torsions used in the supported
        # programs so far.

        # first, check if the gro files exit from writing
        gro_out = oname + '.gro'
        top_out = oname + '.top'
        top = None
        e = None
        if os.path.isfile(gro_out) and os.path.isfile(top_out):
            # if so, use these files.  Load them into ParmEd
            try:
                top = parmed.load_file(top_out, xyz=gro_out)
                prmtop_out = oname + '.prmtop'
                crd_out = oname + '.rst7'
                try:
                    top.save(oname + '.prmtop', overwrite=True)
                except Exception as e:
                    output_status['amber'] = e
                try:        
                    top.save(oname + '.rst7', overwrite=True)
                except Exception as e:
                    output_status['amber'] = e
                if e == None:
                    output_status['amber'] = 'Converted'
            except Exception as e:
                output_status['amber'] = e
        else:
            logger.warn("Can't convert to AMBER unless GROMACS is also selected")

    if args.get('charmm'):
        # currently, this only works if amb_in is used. Reason is that
        # charmm does not support RB dihedrals.
        if args.get('amb_in'):
            e = None
            # if so, use the structure object.
            try:
                parmed_system = parmed.charmm.CharmmPsfFile.from_structure(structure)
                charmm_output_psf = '{0}.psf'.format(oname)
                charmm_output_rtf = '{0}.rtf'.format(oname)
                charmm_output_prm = '{0}.prm'.format(oname)
                charmm_output_crd = '{0}.crd'.format(oname)
                # we need these arrays for enery output
                prms = [charmm_output_prm]
                rtfs = [charmm_output_rtf]
                parmed.charmm.CharmmParameterSet.write(
                    parmed.charmm.CharmmParameterSet.from_structure(structure),
                    top=charmm_output_rtf,
                    par=charmm_output_prm)
                structure.save(charmm_output_psf, format='psf',overwrite=True)
                parmed.charmm.CharmmCrdFile.write(parmed_system, charmm_output_crd)
                
            except Exception as e:
                output_status['charmm'] = e
            if e is None:
                output_status['charmm'] = 'Converted'
        else: 
            logger.warn("Can't convert to CHARMM unless inputs are in AMBER")

    # --------------- ENERGY EVALUATION ----------------- #

    if args.get('energy'):
        # Run control file paths.
        tests_path = os.path.abspath(os.path.dirname(intermol.tests.__file__))

        # default locations of setting control files

        mdp_in_default = os.path.abspath(os.path.join(tests_path, 'gromacs', 'grompp.mdp'))
        cfg_in_default = os.path.abspath(os.path.join(tests_path, 'desmond', 'onepoint.cfg'))
        in_in_default = os.path.abspath(os.path.join(tests_path, 'amber', 'min.in'))

        # this hardcoding of which input energy files to use should not be here; it should be in the testing files.
        # Evaluate input energies.
        if args.get('gro_in'):
            if args.get('inefile'):
                if os.path.splitext(args.get('inefile'))[-1] != '.mdp':
                    logger.warn("GROMACS energy settings file does not end with .mdp")
                mdp_in = args['inefile']
            else:
                mdp_in = mdp_in_default
            input_type = 'gromacs'
            e_in, e_infile = gmx.energies(top_in, gro_in, mdp_in, gmx.GMX_PATH)

        elif args.get('lmp_in'):
            if args.get('inefile'):
                logger.warn("LAMMPS energy settings should not require a separate infile")
            e_in, e_infile = lmp.energies(lammps_file, lmp.LMP_PATH)

        elif args.get('des_in'):
            if args.get('inefile'):
                if os.path.splitext(args.get('inefile'))[-1] != '.cfg':
                    logger.warn("DESMOND energy settings file does not end with .cfg")
                cfg_in = args['inefile']
            else:
                cfg_in = cfg_in_default
            input_type = 'desmond'
            e_in, e_infile = des.energies(cms_file, cfg_in, des.DES_PATH)

        elif args.get('amb_in'):
            if args.get('inefile'):
                if os.path.splitext(args.get('inefile'))[-1] != '.in':
                    logger.warn("AMBER energy settings file does not end with .in")
                in_in = args['inefile']
            else:
                in_in = in_in_default
            input_type = 'amber'
            e_in, e_infile = amb.energies(prmtop_in, crd_in, in_in, amb.AMB_PATH)

        elif args.get('crm_in'):
            if args.get('inefile'):
                logger.warn("Original CHARMM input file is being used, not the supplied input file")
            input_type = 'charmm'
            # returns energy file
            e_in, e_infile = crm.energies(args.get('crm_in'), crm.CRM_PATH)
        else:
            logger.warn('No input files identified! Code should have never made it here!')

        # Evaluate output energies.
        output_type = []
        e_outfile = []
        e_out = []

        if args.get('gromacs') and output_status['gromacs'] == 'Converted':
            output_type.append('gromacs')
            if args.get('gromacs_set'):
                mdp = args.get('gromacs_set')
            else:
                mdp = mdp_in_default
            try:
                out, outfile = gmx.energies('{0}.top'.format(oname),
                                            '{0}.gro'.format(oname),
                                            mdp, gmx.GMX_PATH)
            except Exception as e:
                record_exception(logger, e_out, e_outfile, e)
                output_status['gromacs'] = e
            else:
                output_status['gromacs'] = potential_energy_diff(e_in, out)
                e_out.append(out)
                e_outfile.append(outfile)

        if args.get('lammps') and output_status['lammps'] == 'Converted':
            output_type.append('lammps')
            try:
                out, outfile = lmp.energies('{0}.input'.format(oname),
                                            lmp.LMP_PATH)
            except Exception as e:
                record_exception(logger, e_out, e_outfile, e)
                output_status['lammps'] = e
            else:
                output_status['lammps'] = potential_energy_diff(e_in, out)
                e_out.append(out)
                e_outfile.append(outfile)

        if args.get('desmond') and output_status['desmond'] == 'Converted':
            output_type.append('desmond')
            if args.get('desmond_set'):
                cfg = args.get('desmond_set')
            else:
                cfg = cfg_in_default
            try:
                out, outfile = des.energies('{0}.cms'.format(oname),
                                            cfg, des.DES_PATH)
            except Exception as e:
                record_exception(logger, e_out, e_outfile, e)
                output_status['desmond'] = e
            else:
                output_status['desmond'] = potential_energy_diff(e_in, out)
                e_out.append(out)
                e_outfile.append(outfile)

        if args.get('charmm') and output_status['charmm'] == 'Converted':
            output_type.append('charmm')
            if args.get('inefile'):
                logger.warn("CHARMM energy input file not used, information recreated from command line options")
            inpfile = os.path.join(oname,'{0}.inp'.format(oname))
            crm.write_input_file(inpfile, charmm_output_psf, rtfs, prms, [],
                                 crm.pick_crystal_type(structure.box),
                                 structure.box, charmm_output_crd,
                                 args.get('charmm_settings'))
            try:
                out, outfile = crm.energies(inpfile, crm.CRM_PATH)
            except Exception as e:
                record_exception(logger, e_out, e_outfile, e)
                output_status['charmm'] = e
            else:
                output_status['amber'] = potential_energy_diff(e_in, out)
                e_out.append(out)
                e_outfile.append(outfile)


        # Display energy comparison results.
        out = ['InterMol Conversion Energy Comparison Results', '',
               '{0} input energy file: {1}'.format(input_type, e_infile)]
        for out_type, file in zip(output_type, e_outfile):
            out.append('{0} output energy file: {1}'.format(out_type, file))
        out += summarize_energy_results(e_in, e_out, input_type, output_type, args.get('noncanonical'))
        logger.info('\n'.join(out))

    logger.info('Finished!')
    return output_status

Example 6

Project: aamporter
Source File: aamporter.py
View license
def main():
    usage = """

%prog [options] path/to/plist [path/to/more/plists..]
%prog --build-product-plist [path/to/CCP/pkg/file.ccp] [--munki-update-for BaseProductPkginfoName]

The first form will check and cache updates for the channels listed in the product plists
given as arguments.

The second form will generate a product plist containing all channel IDs contained in the
installer metadata. Accepts either a path to a .cpp file (from Creative Cloud Packager) or
a mounted ESD volume path for CS6-and-earlier installers.

See %prog --help for more options and the README for more detail."""

    o = optparse.OptionParser(usage=usage)
    o.add_option("-l", "--platform", type='choice', choices=['mac', 'win'], default='mac',
        help="Download Adobe updates for Mac or Windows. Available options are 'mac' or 'win', defaults to 'mac'.")
    o.add_option("-m", "--munkiimport", action="store_true", default=False,
        help="Process downloaded updates with munkiimport using options defined in %s." % os.path.basename(settings_plist))
    o.add_option("-r", "--include-revoked", action="store_true", default=False,
        help="Include updates that have been marked as revoked in Adobe's feed XML.")
    o.add_option("--skip-cc", action="store_true", default=False,
        help=("Skip updates for Creative Cloud updates. Useful for certain updates for "
              "CS-era applications that incorporate CC subscription updates."))
    o.add_option("-f", "--force-import", action="store_true", default=False,
        help="Run munkiimport even if it finds an identical pkginfo and installer_item_hash in the repo.")
    o.add_option("-c", "--make-catalogs", action="store_true", default=False,
        help="Automatically run makecatalogs after importing into Munki.")
    o.add_option("-p", "--product-plist", "--plist", action="append", default=[],
        help="Deprecated option for specifying product plists, kept for compatibility. Instead, pass plist paths \
as arguments.")
    o.add_option("-b", "--build-product-plist", action="store",
        help="Given a path to either a mounted Adobe product ESD installer or a .ccp file from a package built with CCP, \
save a product plist containing every Channel ID found for the product. Plist is saved to the current working directory.")
    o.add_option("-u", "--munki-update-for", action="store",
        help="To be used with the --build-product-plist option, specifies the base Munki product.")
    o.add_option("-v", "--verbose", action="count", default=0,
        help="Output verbosity. Can be specified either '-v' or '-vv'.")
    o.add_option("--no-colors", action="store_true", default=False,
        help="Disable colored ANSI output.")
    o.add_option("--no-progressbar", action="store_true", default=False,
        help="Disable the progress indicator.")

    opts, args = o.parse_args()

    # setup logging
    global L
    L = logging.getLogger('com.github.aamporter')
    log_stdout_handler = logging.StreamHandler(stream=sys.stdout)
    log_stdout_handler.setFormatter(ColorFormatter(
        use_color=not opts.no_colors))
    L.addHandler(log_stdout_handler)
    # INFO is level 30, so each verbose option count lowers level by 10
    L.setLevel(INFO - (10 * opts.verbose))

    # arg/opt processing
    if len(sys.argv) == 1:
        o.print_usage()
        sys.exit(0)

    # any args we just pass through to the "legacy" --product-plist/--plist options
    if args:
        opts.product_plist.extend(args)
    if opts.munki_update_for and not opts.build_product_plist:
        errorExit("--munki-update-for requires the --build-product-plist option!")
    if not opts.build_product_plist and not opts.product_plist:
        errorExit("One of --product-plist or --build-product-plist must be specified!")
    if opts.platform == 'win' and opts.munkiimport:
        errorExit("Cannot use the --munkiimport option with --platform win option!")

    if opts.build_product_plist:
        esd_path = opts.build_product_plist
        if esd_path.endswith('/'):
            esd_path = esd_path[0:-1]
        plist = buildProductPlist(esd_path, opts.munki_update_for)
        if not plist:
            errorExit("Couldn't build payloads from path %s." % esd_path)
        else:
            if opts.munki_update_for:
                output_plist_name = opts.munki_update_for
            else:
                output_plist_name = os.path.basename(esd_path.replace(' ', ''))
            output_plist_name += '.plist'
            output_plist_file = os.path.join(os.getcwd(), output_plist_name)
            if os.path.exists(output_plist_file):
                errorExit("A file already exists at %s, not going to overwrite." %
                    output_plist_file)
            try:
                plistlib.writePlist(plist, output_plist_file)
            except:
                errorExit("Error writing plist to %s" % output_plist_file)
            print "Product plist written to %s" % output_plist_file
            sys.exit(0)

    # munki sanity checks
    if opts.munkiimport:
        if not os.path.exists('/usr/local/munki'):
            errorExit("No Munki installation could be found. Get it at http://code.google.com/p/munki")
        sys.path.insert(0, MUNKI_DIR)
        munkiimport_prefs = os.path.expanduser('~/Library/Preferences/com.googlecode.munki.munkiimport.plist')
        if pref('munki_tool') == 'munkiimport':
            if not os.path.exists(munkiimport_prefs):
                errorExit("Your Munki repo seems to not be configured. Run munkiimport --configure first.")
            try:
                import imp
                # munkiimport doesn't end in .py, so we use imp to make it available to the import system
                imp.load_source('munkiimport', os.path.join(MUNKI_DIR, 'munkiimport'))
                import munkiimport
                munkiimport.REPO_PATH = munkiimport.pref('repo_path')
            except ImportError:
                errorExit("There was an error importing munkilib, which is needed for --munkiimport functionality.")

            # rewrite some of munkiimport's function names since they were changed to
            # snake case around 2.6.1:
            # https://github.com/munki/munki/commit/e3948104e869a6a5eb6b440559f4c57144922e71
            try:
                munkiimport.repoAvailable()
            except AttributeError:
                munkiimport.repoAvailable = munkiimport.repo_available
                munkiimport.makePkgInfo = munkiimport.make_pkginfo
                munkiimport.findMatchingPkginfo = munkiimport.find_matching_pkginfo
                munkiimport.makeCatalogs = munkiimport.make_catalogs
            if not munkiimport.repoAvailable():
                errorExit("The Munki repo cannot be located. This tool is not interactive; first ensure the repo is mounted.")

    # set up the cache path
    local_cache_path = pref('local_cache_path')
    if os.path.exists(local_cache_path) and not os.path.isdir(local_cache_path):
        errorExit("Local cache path %s was specified and exists, but it is not a directory!" %
            local_cache_path)
    elif not os.path.exists(local_cache_path):
        try:
            os.mkdir(local_cache_path)
        except OSError:
            errorExit("Local cache path %s could not be created. Verify permissions." %
                local_cache_path)
        except:
            errorExit("Unknown error creating local cache path %s." % local_cache_path)
    try:
        os.access(local_cache_path, os.W_OK)
    except:
        errorExit("Cannot write to local cache path!" % local_cache_path)

    # load our product plists
    product_plists = []
    for plist_path in opts.product_plist:
        try:
            plist = plistlib.readPlist(plist_path)
        except:
            errorExit("Couldn't read plist at %s!" % plist_path)
        if 'channels' not in plist.keys():
            errorExit("Plist at %s is missing a 'channels' array, which is required." % plist_path)
        else:
            product_plists.append(plist)

    # sanity-check the settings plist for unknown keys
    if os.path.exists(settings_plist):
        try:
            app_options = plistlib.readPlist(settings_plist)
        except:
            errorExit("There was an error loading the settings plist at %s" % settings_plist)
        for k in app_options.keys():
            if k not in supported_settings_keys:
                print "Warning: Unknown setting in %s: %s" % (os.path.basename(settings_plist), k)

    L.log(INFO, "Starting aamporter run..")
    if opts.munkiimport:
        L.log(INFO, "Will import into Munki (--munkiimport option given).")

    L.log(DEBUG, "aamporter preferences:")
    for key in supported_settings_keys:
        L.log(DEBUG, " - {0}: {1}".format(key, pref(key)))

    if (sys.version_info.minor, sys.version_info.micro) == (7, 10):
        global NONSSL_ADOBE_URL
        NONSSL_ADOBE_URL = True
        L.log(VERBOSE, ("Python 2.7.10 detected, using HTTP feed URLs to work "
                        "around SSL issues."))

    # pull feed info and populate channels
    L.log(INFO, "Retrieving feed data..")
    feed = getFeedData(opts.platform)
    parsed = parseFeedData(feed)
    channels = getChannelsFromProductPlists(product_plists)
    L.log(INFO, "Processing the following Channel IDs:")
    [ L.log(INFO, "  - %s" % channel) for channel in sorted(channels) ]

    # begin caching run and build updates dictionary with product/version info
    updates = {}
    for channelid in channels.keys():
        L.log(VERBOSE, "Getting updates for Channel ID %s.." % channelid)
        channel_updates = getUpdatesForChannel(channelid, parsed)
        if not channel_updates:
            L.log(DEBUG, "No updates for channel %s" % channelid)
            continue
        channel_updates = addUpdatesXML(channel_updates, opts.platform, skipTargetLicensingCC=opts.skip_cc)

        for update in channel_updates:
            L.log(VERBOSE, "Considering update %s, %s.." % (update.product, update.version))

            if opts.include_revoked is False:
                highest_version = getHighestVersionOfProduct(channel_updates, update.product)
                if update.version != highest_version:
                    L.log(DEBUG, "%s is not the highest version available (%s) for this update. Skipping.." % (
                        update.version, highest_version))
                    continue

                if updateIsRevoked(update.channel, update.product, update.version, parsed):
                    L.log(DEBUG, "Update is revoked. Skipping update.")
                    continue

                file_element = update.xml.find('InstallFiles/File')
                if file_element is None:
                    L.log(DEBUG, "No File XML element found. Skipping update.")
                else:
                    filename = file_element.find('Name').text
                    update_bytes = file_element.find('Size').text
                    description = update.xml.find('Description/en_US').text
                    display_name = update.xml.find('DisplayName/en_US').text

                    if not update.product in updates.keys():
                        updates[update.product] = {}
                    if not update.version in updates[update.product].keys():
                        updates[update.product][update.version] = {}
                        updates[update.product][update.version]['channel_ids'] = []
                        updates[update.product][update.version]['update_for'] = []
                    updates[update.product][update.version]['channel_ids'].append(update.channel)
                    for opt in ['munki_repo_destination_path',
                                'munki_update_for',
                                'makepkginfo_options']:
                        if opt in channels[update.channel].keys():
                            updates[update.product][update.version][opt] = channels[update.channel][opt]
                    updates[update.product][update.version]['description'] = description
                    updates[update.product][update.version]['display_name'] = display_name
                    dmg_url = urljoin(getURL('updates'), UPDATE_PATH_PREFIX + opts.platform) + \
                            '/%s/%s/%s' % (update.product, update.version, filename)
                    output_filename = os.path.join(local_cache_path, "%s-%s.%s" % (
                            update.product, update.version, 'dmg' if opts.platform == 'mac' else 'zip'))
                    updates[update.product][update.version]['local_path'] = output_filename
                    need_to_dl = True
                    if os.path.exists(output_filename):
                        we_have_bytes = os.stat(output_filename).st_size
                        if we_have_bytes == int(update_bytes):
                            L.log(INFO, "Skipping download of %s %s, it is already cached."
                                % (update.product, update.version))
                            need_to_dl = False
                        else:
                            L.log(VERBOSE, "Incomplete download (%s bytes on disk, should be %s), re-starting." % (
                                we_have_bytes, update_bytes))
                    if need_to_dl:
                        L.log(INFO, "Downloading %s %s (%s bytes) to %s" % (update.product, update.version, update_bytes, output_filename))
                        if opts.no_progressbar:
                            urllib.urlretrieve(dmg_url, output_filename)
                        else:
                            urllib.urlretrieve(dmg_url, output_filename, reporthook)

    L.log(INFO, "Done caching updates.")

    # begin munkiimport run
    if opts.munkiimport:
        L.log(INFO, "Beginning Munki imports..")
        for (update_name, update_meta) in updates.items():
            for (version_name, version_meta) in update_meta.items():
                need_to_import = True
                item_name = "%s%s" % (update_name.replace('-', '_'),
                    pref('munki_pkginfo_name_suffix'))
                # Do 'exists in repo' checks if we're not forcing imports
                if opts.force_import is False and pref("munki_tool") == "munkiimport":
                    pkginfo = munkiimport.makePkgInfo(['--name',
                                            item_name,
                                            version_meta['local_path']],
                                            False)
                    # Cribbed from munkiimport
                    L.log(VERBOSE, "Looking for a matching pkginfo for %s %s.." % (
                        item_name, version_name))
                    matchingpkginfo = munkiimport.findMatchingPkginfo(pkginfo)
                    if matchingpkginfo:
                        L.log(VERBOSE, "Got a matching pkginfo.")
                        if ('installer_item_hash' in matchingpkginfo and
                            matchingpkginfo['installer_item_hash'] ==
                            pkginfo.get('installer_item_hash')):
                            need_to_import = False
                            L.log(INFO,
                                ("We have an exact match for %s %s in the repo. Skipping.." % (
                                    item_name, version_name)))
                    else:
                        need_to_import = True

                if need_to_import:
                    munkiimport_opts = pref('munkiimport_options')[:]
                    if pref("munki_tool") == 'munkiimport':
                        if 'munki_repo_destination_path' in version_meta.keys():
                            subdir = version_meta['munki_repo_destination_path']
                        else:
                            subdir = pref('munki_repo_destination_path')
                        munkiimport_opts.append('--subdirectory')
                        munkiimport_opts.append(subdir)
                    if not version_meta['munki_update_for']:
                        L.log(WARNING,
                            "Warning: {0} does not have an 'update_for' key "
                            "specified in the product plist!".format(item_name))
                        update_catalogs = []
                    else:
                        # handle case of munki_update_for being either a list or a string
                        flatten = lambda *n: (e for a in n
                            for e in (flatten(*a) if isinstance(a, (tuple, list)) else (a,)))
                        update_catalogs = list(flatten(version_meta['munki_update_for']))
                        for base_product in update_catalogs:
                            munkiimport_opts.append('--update_for')
                            munkiimport_opts.append(base_product)
                    munkiimport_opts.extend(['--name', item_name,
                                             '--displayname', version_meta['display_name'],
                                             '--description', version_meta['description']])

                    if 'makepkginfo_options' in version_meta:
                        L.log(VERBOSE,
                            "Appending makepkginfo options: %s" %
                            " ".join(version_meta['makepkginfo_options']))
                        munkiimport_opts += version_meta['makepkginfo_options']

                    if pref('munki_tool') == 'munkiimport':
                        import_cmd = ['/usr/local/munki/munkiimport', '--nointeractive']
                    elif pref('munki_tool') == 'makepkginfo':
                        import_cmd = ['/usr/local/munki/makepkginfo']
                    else:
                        # TODO: validate this pref earlier
                        L.log(ERROR, "Not sure what tool you wanted to use; munki_tool should be 'munkiimport' " + \
                        "or 'makepkginfo' but we got '%s'.  Skipping import." % (pref('munki_tool')))
                        break
                    # Load our app munkiimport options overrides last
                    import_cmd += munkiimport_opts
                    import_cmd.append(version_meta['local_path'])

                    L.log(INFO, "Importing {0} {1} into Munki. Update for: {2}".format(
                        item_name, version_name, ', '.join(update_catalogs)))
                    L.log(VERBOSE, "Calling %s on %s version %s, file %s." % (
                        pref('munki_tool'),
                        update_name,
                        version_name,
                        version_meta['local_path']))
                    munkiprocess = subprocess.Popen(import_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                    # wait for the process to terminate
                    stdout, stderr = munkiprocess.communicate()
                    import_retcode = munkiprocess.returncode
                    if import_retcode:
                        L.log(ERROR, "munkiimport returned an error. Skipping update..")
                    else:
                        if pref('munki_tool') == 'makepkginfo':
                            plist_path = os.path.splitext(version_meta['local_path'])[0] + ".plist"
                            with open(plist_path, "w") as plist:
                                plist.write(stdout)
                                L.log(INFO, "pkginfo written to %s" % plist_path)


        L.log(INFO, "Done Munki imports.")
        if opts.make_catalogs:
            munkiimport.makeCatalogs()

Example 7

Project: ursgal
Source File: unify_csv_1_0_0.py
View license
def main(input_file=None, output_file=None, scan_rt_lookup=None,
         params=None, search_engine=None, score_colname=None,
         upeptide_mapper=None):
    '''
    Arguments:
        input_file (str): input filename of csv which should be unified
        output_file (str): output filename of csv after unifying
        scan_rt_lookup (dict): dictionary with entries of scanID to
            retention time under key 'scan_2_rt'
        force (bool): force True or False
        params (dict): params as passed by ursgal
        search_engine(str): the search engine the csv file stems from
        score_colname (str): the column names of the search engine's
            score (i.e. 'OMSSA:pvalue')

    List of fixes

    All engines
        * Retention Time (s) is correctly set using _ursgal_lookup.pkl
          During mzML conversion to mgf the retention time for every spec
          is stored in a internal lookup and used later for setting the RT.
        * All modifications are checked if they were given in
          params['modifications'], converted to the name that was given
          there and sorted according to their position.
        * Fixed modifications are added in 'Modifications', if not reported
          by the engine.
        * The monoisotopic m/z for for each line is calculated (uCalc m/z),
          since not all engines report the monoisotopic m/z
        * Mass accuracy calculation (in ppm), also taking into account that
          not always the monoisotopic peak is picked
        * All peptide Sequences are remapped to their corresponding protein,
          assuring correct start, stop, pre and post aminoacid. Thereby,
          also correct enzymatic cleavage is checked.
        * Rows describing the same PSM (i.e. when two proteins share the
          same peptide) are merged to one row.

    X!Tandem
        * 'RTINSECONDS=' is stripped from Spectrum Title if present in .mgf or
          in search result.

    Myrimatch
        * Spectrum Title is corrected
        * 15N label is not formatted correctly these modifications are
          removed for further analysis.
        * When using 15N modifications on amino acids and Carbamidomethyl
          myrimatch reports sometimes Carboxymethylation on Cystein.

    MS-GF+
        * 15N label is not formatted correctly these modifications are
          removed for further analysis.
        * 'Is decoy' column is properly set to true/false
        * Carbamidomethyl is updated and set if label is 15N

    OMSSA
        * Carbamidomethyl is updated and set
        * Selenocystein is not reported with the correct unimod modification

    MS-Amanda
        * Selenocystein is not reported with the correct unimod modification
        * multiple protein ID per peptide are splitted in two entries.
          (is done in MS-Amanda postflight)
        * short protein IDs are mapped to the full protein ID, it is checked
          which peptides map on which protein ID (is done in MS-Amanda
          postflight)

    '''
    print(
        '''
[ unifycsv ] Converting {0} of engine {1} to unified CSV format...
        '''.format(
            os.path.basename(input_file),
            search_engine,
        )
    )

    # get the rows which define a unique PSM (i.e. sequence+spec+score...)
    psm_defining_colnames = get_psm_defining_colnames(score_colname)
    joinchar              = params['translations']['protein_delimiter']
    do_not_delete         = False
    created_tmp_files     = []
    use15N                = False

    if 'label' in params.keys():
        if params['label'] == '15N':
            use15N = True
    else:
        params['label'] = '14N'
    # print(use15N)
    # exit()
    aa_exception_dict = params['translations']['aa_exception_dict']
    n_term_replacement = {
        'Ammonia-loss' : None,
        'Trimethyl'    : None,
        'Gly->Val'     : None,
    }
    fixed_mods = {}
    opt_mods   = {}
    modname2aa = {}
    cam        = False

    # mod pattern
    mod_pattern = re.compile( r''':(?P<pos>[0-9]*$)''' )

    for modification in params['translations']['modifications']:
        aa = modification.split(',')[0]
        mod_type = modification.split(',')[1]
        pos = modification.split(',')[2]
        name = modification.split(',')[3]
        if name not in modname2aa.keys():
            modname2aa[name] = []
        modname2aa[name].append(aa)
        if 'N-term' in pos:
            n_term_replacement[name] = aa
        if mod_type == 'fix':
            fixed_mods[aa] = name
        if mod_type == 'opt':
            opt_mods[aa] = name
        if 'C,fix,any,Carbamidomethyl' in modification:
            cam = True

    cc = ursgal.ChemicalComposition()
    ursgal.GlobalUnimodMapper._reparseXML()
    de_novo_engines = ['novor', 'pepnovo', 'uninovo', 'unknown_engine']
    database_search_engines = [
        'msamanda',
        'msgf',
        'myrimatch',
        'omssa',
        'xtandem'
    ]
    de_novo = False
    database_search = False
    for de_novo_engine in de_novo_engines:
        if de_novo_engine in search_engine.lower():
            de_novo = True
    for db_se in database_search_engines:
        if db_se in search_engine.lower():
            database_search = True

    if upeptide_mapper is None:
        upapa = ursgal.UPeptideMapper()
    else:
        upapa = upeptide_mapper

    if database_search is True:
        target_decoy_peps = set()
        non_enzymatic_peps = set()
        pep_map_lookup = {}
        fasta_lookup_name = upapa.build_lookup_from_file(
            params['translations']['database'],
            force=False
        )
    # print('Cached!')
    # input()
    psm_counter = Counter()
    # if a PSM with multiple rows is found (i.e. in omssa results), the psm
    # rows are merged afterwards

    output_file_object = open(output_file, 'w')
    protein_id_output = open(output_file + '_full_protein_names.txt', 'w')
    mz_buffer = {}
    csv_kwargs = {
        'extrasaction' : 'ignore'
    }
    if sys.platform == 'win32':
        csv_kwargs['lineterminator'] = '\n'
    else:
        csv_kwargs['lineterminator'] = '\r\n'
    total_lines = len(list(csv.reader(open(input_file,'r'))))
    ze_only_buffer = {}

    allowed_aa, cleavage_site, inhibitor_aa = params['translations']['enzyme'].split(';')
    allowed_aa += '-'

    with open( input_file, 'r' ) as in_file:
        csv_input  = csv.DictReader(
            in_file
        )

        output_fieldnames = list(csv_input.fieldnames)
        for remove_fieldname in [
            'proteinacc_start_stop_pre_post_;',
            'Start',
            'Stop',
            'NIST score',
            'gi',
            'Accession',
        ]:
            if remove_fieldname not in output_fieldnames:
                continue
            output_fieldnames.remove(remove_fieldname)
        new_fieldnames = [
            'uCalc m/z',
            'Accuracy (ppm)',
            'Protein ID',
            'Sequence Start',
            'Sequence Stop',
            'Sequence Pre AA',
            'Sequence Post AA',
        ]


        for new_fieldname in new_fieldnames:
            if new_fieldname not in output_fieldnames:
                output_fieldnames.insert(-5,new_fieldname)
        csv_output = csv.DictWriter(
            output_file_object,
            output_fieldnames,
            **csv_kwargs
        )
        csv_output.writeheader()
        print('''[ unify_cs ] parsing csv''')
        import time
        for line_nr, line_dict in enumerate(csv_input):
            if line_nr % 500 == 0:
                print(
                    '[ unify_cs ] Processing line number: {0}/{1} .. '.format(
                        line_nr,
                        total_lines,
                    ),
                    end='\r'
                )

            if line_dict['Spectrum Title'] != '':
                '''
                Valid for:
                    OMSSA
                    MSGF+
                    X!Tandem
                '''
                if 'RTINSECONDS=' in line_dict['Spectrum Title']:
                    line_2_split = line_dict['Spectrum Title'].split(' ')[0].strip()
                else:
                    line_2_split = line_dict['Spectrum Title']
                line_dict['Spectrum Title'] = line_2_split

                input_file_basename, spectrum_id, _spectrum_id, charge = line_2_split.split('.')
                pure_input_file_name = ''

            elif 'scan=' in line_dict['Spectrum ID']:
                pure_input_file_name                = os.path.basename(
                    line_dict['Raw data location']
                )
                input_file_basename = pure_input_file_name.split(".")[0]
                # not using os.path.splitext because we could have multiple file
                # extensions (i.e. ".mzml.gz")

                '''
                Valid for:
                    myrimatch
                '''
                spectrum_id = line_dict['Spectrum ID'].split('=')[-1]
                line_dict['Spectrum Title'] = '{0}.{1}.{1}.{2}'.format(
                    input_file_basename,
                    spectrum_id,
                    line_dict['Charge']
                )

            elif line_dict['Spectrum Title'] == '':
                '''
                Valid for:
                    Novor
                '''
                pure_input_file_name = os.path.basename(
                    line_dict['Raw data location']
                )
                input_file_basename = pure_input_file_name.split(".")[0]
                spectrum_id = line_dict['Spectrum ID']
                line_dict['Spectrum Title'] = '{0}.{1}.{1}.{2}'.format(
                    input_file_basename,
                    spectrum_id,
                    line_dict['Charge']
                )
            else:
                raise Exception( 'New csv format present for engine {0}'.format( engine ) )

            #update spectrum ID from block above
            line_dict['Spectrum ID'] = spectrum_id


            # now check for the basename in the scan rt lookup
            # possible cases:
            #   - input_file_basename
            #   - input_file_basename + prefix
            #   - input_file_basename - prefix

            input_file_basename_for_rt_lookup = None
            if input_file_basename in scan_rt_lookup.keys():
                input_file_basename_for_rt_lookup = input_file_basename
            else:
                basename_with_prefix = '{0}_{1}'.format(
                    params['prefix'],
                    input_file_basename
                )
                basename_without_prefix  = input_file_basename.replace(
                    params['prefix'],
                    ''
                )
                if basename_with_prefix in scan_rt_lookup.keys():
                    input_file_basename_for_rt_lookup = basename_with_prefix
                elif basename_without_prefix in scan_rt_lookup.keys():
                    input_file_basename_for_rt_lookup = basename_without_prefix
                else:
                    print(
                        '''
Could not find scan ID {0} in scan_rt_lookup[ {1} ]
                        '''.format(
                            spectrum_id,
                            input_file_basename
                        )
                    )

            retention_time_in_minutes = \
                scan_rt_lookup[ input_file_basename_for_rt_lookup ][ 'scan_2_rt' ]\
                    [ spectrum_id ]

            #we should check if data has minute format or second format...
            if scan_rt_lookup[ input_file_basename ]['unit'] == 'second':
                rt_corr_factor = 1
            else:
                rt_corr_factor = 60
            line_dict['Retention Time (s)'] = float( retention_time_in_minutes ) * rt_corr_factor

            #
            # now lets buffer for real !! :)
            #
            _ze_ultra_buffer_key_ = '{Sequence} || {Charge} || {Modifications} || '.format( **line_dict ) + params['label']
            if _ze_ultra_buffer_key_ not in ze_only_buffer.keys():
                line_dict_update = {}
                #
                # Modification block

                # some engines do not report fixed modifications
                # include in unified csv
                if fixed_mods != {}:
                    for pos, aminoacid in enumerate(line_dict['Sequence']):
                        if aminoacid in fixed_mods.keys():
                            name = fixed_mods[ aminoacid ]
                            tmp = '{0}:{1}'.format(
                                name,
                                pos + 1
                            )
                            if tmp in line_dict['Modifications']:
                                # everything is ok :)
                                pass
                            else:
                                tmp_mods = line_dict['Modifications'].split(';')
                                tmp_mods.append(tmp)
                                line_dict['Modifications'] = ';'.join( tmp_mods )

                # Myrimatch and msgf+ can not handle 15N that easily
                # report all AAs moded with unknown modification
                # Note: masses are checked below to avoid any mismatch
                if use15N:
                    if 'myrimatch' in search_engine.lower() or \
                            'msgfplus_v9979' in search_engine.lower():
                        for p in range(1,len(line_dict['Sequence'])+1):
                                line_dict['Modifications'] = \
                                    line_dict['Modifications'].replace(
                                        'unknown modification:{0}'.format(p),
                                        '',
                                        1,
                                    )
                    if 'myrimatch' in search_engine.lower():
                        if 'Carboxymethyl' in line_dict['Modifications'] and cam == True:
                            line_dict['Modifications'] = line_dict['Modifications'].replace(
                                'Carboxymethyl',
                                'Carbamidomethyl'
                            )
                        elif 'Delta:H(6)C(3)O(1)' in line_dict['Modifications']:
                            line_dict['Modifications'] = line_dict['Modifications'].replace(
                                'Delta:H(6)C(3)O(1)',
                                'Carbamidomethyl'
                            )

                tmp_mods = []
                for modification in line_dict['Modifications'].split(';'):
                    Nterm = False
                    Cterm = False
                    skip_mod = False
                    if modification == '':
                        continue
                    pos, mod = None, None
                    match = mod_pattern.search( modification )
                    pos = int( match.group('pos') )
                    mod = modification[ :match.start() ]
                    assert pos is not None, '''
                            The format of the modification {0}
                            is not recognized by ursgal'''.format(
                                modification
                            )
                    if pos <= 1:
                        Nterm = True
                        new_pos = 1
                    elif pos > len(line_dict['Sequence']):
                        Cterm = True
                        new_pos = len(line_dict['Sequence'])
                    else:
                        new_pos = pos
                    aa = line_dict['Sequence'][ new_pos - 1 ].upper()
                    # if aa in fixed_mods.keys():
                    #     fixed_mods[ aminoacid ]
                    #     # fixed mods are corrected/added already
                    #     continue
                    if mod in modname2aa.keys():
                        correct_mod = False
                        if aa in modname2aa[mod]:
                            # everything is ok
                            correct_mod = True
                        elif Nterm or Cterm:
                            if '*' in modname2aa[mod]:
                                correct_mod = True
                                # still is ok
                        assert correct_mod is True,'''
                                A modification was reported for an aminoacid for which it was not defined
                                unify_csv cannot deal with this, please check your parameters and engine output
                                reported modification: {0} on {1}
                                modifications in parameters: {2}
                                '''.format(
                                    mod,
                                    aa,
                                    params['translations']['modifications']
                                )
                    elif 'unknown modification' == mod:
                        modification_known = False
                        if aa in opt_mods.keys():
                            # fixed mods are corrected/added already
                            modification = '{0}:{1}'.format(opt_mods[aa],new_pos)
                            modification_known = True
                        assert modification_known == True,'''
                                unify csv does not work for the given unknown modification for
                                {0} {1} aa: {2}
                                maybe an unknown modification with terminal position was given?
                                '''.format(
                                    line_dict['Sequence'], modification, aa
                                )
                    else:
                        if aa in fixed_mods.keys() and use15N \
                            and 'msgfplus' in search_engine.lower():
                            if pos != 0:
                                mod = float(mod) - ursgal.ursgal_kb.DICT_15N_DIFF[aa]
                        try:
                            name_list = ursgal.GlobalUnimodMapper.appMass2name_list(
                                round(float(mod), 3), decimal_places = 3
                            )
                        except:
                            print('''
                                A modification was reported that was not included in the search parameters
                                unify_csv cannot deal with this, please check your parameters and engine output
                                reported modification: {0}
                                modifications in parameters: {1}
                                '''.format(mod, params['translations']['modifications'])
                            )
                            raise Exception('unify_csv failed because a '\
                                'modification was reported that was not '\
                                'given in params.'
                                '{0}'.format(modification)
                            )
                        mapped_mod = False
                        for name in name_list:
                            if name in modname2aa.keys():
                                if aa in modname2aa[name]:
                                    modification = '{0}:{1}'.format(name,new_pos)
                                    mapped_mod = True
                                elif Nterm and '*' in modname2aa[name]:
                                    modification = '{0}:{1}'.format(name,0)
                                    mapped_mod = True
                                else:
                                    continue
                            elif use15N and name in [
                                'Label:15N(1)',
                                'Label:15N(2)',
                                'Label:15N(3)',
                                'Label:15N(4)' 
                            ]:
                                mapped_mod = True
                                skip_mod = True
                                break
                        assert mapped_mod is True, '''
                                A mass was reported that does not map on any unimod or userdefined modification
                                or the modified aminoacid is not the specified one
                                unify_csv cannot deal with this, please check your parameters and engine output
                                reported mass: {0}
                                maps on: {1}
                                reported modified aminoacid: {2}
                                modifications in parameters: {3}
                                '''.format(
                                    mod,
                                    name_list,
                                    aa,
                                    params['translations']['modifications']
                                )
                    if modification in tmp_mods or skip_mod is True:
                        continue
                    tmp_mods.append(modification)
                line_dict_update['Modifications'] = ';'.join( tmp_mods )
                #
                # ^^--------- REPLACED MODIFICATIONS! ---------------^
                #
                for unimod_name in n_term_replacement.keys():
                    if '{0}:1'.format(unimod_name) in line_dict_update['Modifications'].split(';'):
                        if unimod_name in modname2aa.keys():
                            aa = modname2aa[unimod_name]
                            if aa != ['*']:
                                if line_dict['Sequence'][0] in aa:
                                    continue
                        line_dict_update['Modifications'] = line_dict_update['Modifications'].replace(
                            '{0}:1'.format( unimod_name ),
                            '{0}:0'.format( unimod_name )
                            )

                for aa_to_replace, replace_dict in aa_exception_dict.items():
                    if aa_to_replace in line_dict['Sequence']:
                        #change mods only if unimod has to be changed...
                        if 'unimod_name' in replace_dict.keys():
                            for r_pos, aa in enumerate(line_dict['Sequence']):
                                if aa == aa_to_replace:
                                    index_of_U = r_pos + 1
                                    unimod_name = replace_dict['unimod_name']
                                    if cam and replace_dict['original_aa'] == 'C':
                                        unimod_name = replace_dict['unimod_name_with_cam']
                                    new_mod = '{0}:{1}'.format(
                                        unimod_name,
                                        index_of_U
                                    )
                                    if line_dict_update['Modifications'] == '':
                                        line_dict_update['Modifications'] += new_mod
                                    else:
                                        line_dict_update['Modifications'] += ';{0}'.format(
                                            new_mod
                                        )
                        line_dict['Sequence'] = line_dict['Sequence'].replace(
                            aa_to_replace,
                            replace_dict['original_aa']
                        )

                line_dict_update['Sequence'] = line_dict['Sequence']
                #
                # ^^--------- REPLACED SEQUENCE! ---------------^
                #
                # remove the double ';''
                if line_dict_update['Modifications'] != '':
                    tmp = []
                    for e in line_dict_update['Modifications'].split(';'):
                        if e == '':
                            # that remove the doubles ....
                            continue
                        else:
                            # other way to do it...
                            # pos_of_split_point = re.search( ':\d*\Z', e )
                            # pattern = re.compile( r''':(?P<pos>[0-9]*$)''' )
                            for occ, match in enumerate( mod_pattern.finditer( e )):
                                mod = e[:match.start()]
                                mod_pos = e[match.start()+1:]
                                # mod, pos = e.split(':')
                                m = (int(mod_pos), mod)
                                if m not in tmp:
                                    tmp.append( m )
                    tmp.sort()
                    line_dict_update['Modifications'] = ';'.join(
                        [
                            '{m}:{p}'.format( m=mod, p=pos) for pos, mod in tmp
                        ]
                    )

                # calculate m/z
                cc.use(
                    '{Sequence}#{Modifications}'.format(
                        **line_dict_update
                    )
                )
                if use15N:
                    number_N = dc( cc['N'] )
                    cc['15N'] = number_N
                    del cc['N']
                    if cam:
                        c_count = line_dict_update['Sequence'].count('C')
                        cc['14N'] = c_count
                        cc['15N'] -= c_count
                    # mass = mass + ( DIFFERENCE_14N_15N * number_N )
                mass = cc._mass()
                calc_mz = ursgal.ucore.calculate_mz(
                    mass,
                    line_dict['Charge']
                )
                # mz_buffer[ buffer_key ] = calc_mz

                line_dict_update['uCalc m/z'] = calc_mz
                # if 'msamanda' in search_engine.lower():
                    # ms amanda does not return calculated mz values
                if line_dict['Calc m/z'] == '':
                    line_dict_update['Calc m/z'] = calc_mz

                line_dict_update['Accuracy (ppm)'] = \
                    (float(line_dict['Exp m/z']) - line_dict_update['uCalc m/z'])/line_dict_update['uCalc m/z'] * 1e6
                prec_m_accuracy = (params['translations']['precursor_mass_tolerance_minus'] + params['translations']['precursor_mass_tolerance_plus'])/2
                i = 0
                while abs(line_dict_update['Accuracy (ppm)']) > prec_m_accuracy:
                    i += 1
                    if i > len(params['translations']['precursor_isotope_range'].split(','))-1:
                        break
                    isotope = params['translations']['precursor_isotope_range'].split(',')[i]
                    isotope = int(isotope)
                    if isotope == 0:
                        continue
                    calc_mz = ursgal.ucore.calculate_mz(
                        mass + isotope*1.008664904,
                        line_dict['Charge']
                    )
                    line_dict_update['Accuracy (ppm)'] = \
                        (float(line_dict['Exp m/z']) - calc_mz)/calc_mz * 1e6

                # ------------
                # BUFFER END
                # -----------
                ze_only_buffer[ _ze_ultra_buffer_key_ ] = line_dict_update

            line_dict_update = ze_only_buffer[ _ze_ultra_buffer_key_ ]
            line_dict.update( line_dict_update )

            # protein block, only for database search engine
            if database_search is True:
                # remap peptides to proteins, check correct enzymatic
                # cleavage and decoy assignment
                lookup_identifier = '{0}><{1}'.format(
                    line_dict['Sequence'],
                    fasta_lookup_name
                )
                if lookup_identifier not in pep_map_lookup.keys():
                    tmp_decoy = set()
                    # tmp_protein_id = {}

                    upeptide_maps = upapa.map_peptide(
                        peptide    = line_dict['Sequence'],
                        fasta_name = fasta_lookup_name
                    )
                    '''
                    <><><><><><><><><><><><><>
                    '''
                    # assert upeptide_maps != [],'''
                    #         The peptide {0} could not be mapped to the
                    #         given database {1}

                    #         {2}

                    #         '''.format(
                    #             line_dict['Sequence'],
                    #             fasta_lookup_name,
                    #             ''
                    #         )
                    if upeptide_maps == []:
                        print('''
[ WARNING ] The peptide {0} could not be mapped to the
[ WARNING ] given database {1}
[ WARNING ] {2}
[ WARNING ] This PSM will be skipped.
                            '''.format(
                                line_dict['Sequence'],
                                fasta_lookup_name,
                                ''
                            )
                        )
                        continue

                    sorted_upeptide_maps = [ protein_dict for protein_dict in sorted( upeptide_maps, key=lambda x: x['id'] ) ]
                    # sorted(bacterial_protein_collector[race].items(),key=lambda x: x[1]['psm_count'])
                    # print()
                    # print(line_dict['Sequence'])
                    # print(sorted_upeptide_maps)
                    protein_mapping_dict = None
                    last_protein_id = None
                    for protein in sorted_upeptide_maps:
                        # print(line_dict)
                        # print(protein)
                        add_protein   = False
                        nterm_correct = False
                        cterm_correct = False
                        if params['translations']['keep_asp_pro_broken_peps'] is True:
                            if line_dict['Sequence'][-1] == 'D' and\
                                    protein['post'] == 'P':
                                cterm_correct = True
                            if line_dict['Sequence'][0] == 'P' and\
                                    protein['pre'] == 'D':
                                nterm_correct = True

                        if cleavage_site == 'C':
                            if protein['pre'] in allowed_aa\
                                    or protein['start'] in [1, 2, 3]:
                                if line_dict['Sequence'][0] not in inhibitor_aa\
                                        or protein['start'] in [1, 2, 3]:
                                    nterm_correct = True
                            if protein['post'] not in inhibitor_aa:
                                if line_dict['Sequence'][-1] in allowed_aa\
                                     or protein['post'] == '-':
                                    cterm_correct = True

                        elif cleavage_site == 'N':
                            if protein['post'] in allowed_aa:
                                if line_dict['Sequence'][-1] not in inhibitor_aa\
                                        or protein['post'] == '-':
                                    cterm_correct = True
                            if protein['pre'] not in inhibitor_aa\
                                or protein['start'] in [1, 2, 3]:
                                if line_dict['Sequence'][0] in allowed_aa\
                                    or protein['start'] in [1, 2, 3]:
                                    nterm_correct = True

                        if params['translations']['semi_enzyme'] is True:
                            if cterm_correct is True or nterm_correct is True:
                                add_protein = True
                        elif cterm_correct is True and nterm_correct is True:
                            add_protein = True

                        if add_protein is True:
                            # print(add_protein)
                            # print(cterm_correct, nterm_correct)
                            if protein_mapping_dict is None:
                                protein_mapping_dict = {
                                    'Protein ID'       : protein['id'],
                                    'Sequence Start'   : str(protein['start']),
                                    'Sequence Stop'    : str(protein['end']),
                                    'Sequence Pre AA'  : protein['pre'],
                                    'Sequence Post AA' : protein['post'],
                                }
                            else:
                                if protein['id'] == last_protein_id:
                                    tmp_join_char = ';'
                                else:
                                    tmp_join_char = joinchar

                                    protein_mapping_dict['Protein ID' ] += '{0}{1}'.format(tmp_join_char, protein['id'])

                                protein_mapping_dict['Sequence Start'   ] += '{0}{1}'.format(tmp_join_char, str(protein['start']))
                                protein_mapping_dict['Sequence Stop'    ] += '{0}{1}'.format(tmp_join_char, str(protein['end']))
                                protein_mapping_dict['Sequence Pre AA'  ] += '{0}{1}'.format(tmp_join_char, protein['pre'])
                                protein_mapping_dict['Sequence Post AA' ] += '{0}{1}'.format(tmp_join_char, protein['post'])

                            # print(protein_mapping_dict['Protein ID' ])
                            last_protein_id = protein['id']

                            # mzidentml-lib does not always set 'Is decoy' correctly
                            # (it's always 'false' for MS-GF+ results), this is fixed here:
                            if params['translations']['decoy_tag'] in protein['id']:
                                tmp_decoy.add('true')
                            else:
                                tmp_decoy.add('false')

                    if protein_mapping_dict is None:
                        non_enzymatic_peps.add(line_dict['Sequence'])
                        continue

                    if len(protein_mapping_dict['Protein ID']) >= 2000:
                        print(
                            '{0}: {1}'.format(
                                line_dict['Sequence'],
                                protein_mapping_dict['Protein ID']
                            ),
                            file = protein_id_output
                        )
                        protein_mapping_dict['Protein ID'] = protein_mapping_dict['Protein ID'][:1990] + ' ...'
                        do_not_delete = True

                    if len(tmp_decoy) >= 2:
                        target_decoy_peps.add(line_dict['Sequence'])
                        protein_mapping_dict['Is decoy'] = 'true'
                    else:
                        protein_mapping_dict['Is decoy'] = list(tmp_decoy)[0]

                    pep_map_lookup[ lookup_identifier ] = protein_mapping_dict

                buffered_protein_mapping_dict = pep_map_lookup[lookup_identifier]
                line_dict.update( buffered_protein_mapping_dict )
                # count each PSM occurence to check whether row-merging is needed:
                psm = tuple([line_dict[x] for x in psm_defining_colnames])
                psm_counter[psm] += 1

            csv_output.writerow(line_dict)
            '''
                to_be_written_csv_lines.append( line_dict )
            '''
    output_file_object.close()

    if database_search is True:
        # upapa.purge_fasta_info( fasta_lookup_name )
        if len(non_enzymatic_peps) != 0:
            print( '''
                [ WARNING ] The following peptides could not be mapped to the
                [ WARNING ] given database {0}
                [ WARNING ] with correct enzymatic cleavage sites:
                [ WARNING ] {1}
                [ WARNING ] These PSMs were skipped.'''.format(
            params['translations']['database'],
            non_enzymatic_peps
            ))
        if len(target_decoy_peps) != 0:
            print(
                '''
                [ WARNING ] The following peptides occured in a target as well as decoy protein
                [ WARNING ] {0}
                [ WARNING ] 'Is decoy' has been set to 'True' '''.format(
                    target_decoy_peps,
                )
            )

    # if there are multiple rows for a PSM, we have to merge them aka rewrite the csv...
    if psm_counter != Counter():
        if max(psm_counter.values()) > 1:
            merge_duplicate_psm_rows(output_file, psm_counter, psm_defining_colnames, params['translations']['psm_merge_delimiter'])
            '''
            to_be_written_csv_lines = merge_duplicate_psm_rows(
                to_be_written_csv_lines,
                psm_counter
            )
            '''
        '''
        do output_file magic with to_be_written_csv_lines
        '''
    if do_not_delete is False:
        created_tmp_files.append( output_file + '_full_protein_names.txt' )
    return created_tmp_files

Example 8

Project: MITMf
Source File: filepwn.py
View license
    def binaryGrinder(self, binaryFile):
        """
        Feed potential binaries into this function,
        it will return the result PatchedBinary, False, or None
        """

        with open(binaryFile, 'r+b') as f:
            binaryTMPHandle = f.read()

        binaryHeader = binaryTMPHandle[:4]
        result = None

        try:
            if binaryHeader[:2] == 'MZ':  # PE/COFF
                pe = pefile.PE(data=binaryTMPHandle, fast_load=True)
                magic = pe.OPTIONAL_HEADER.Magic
                machineType = pe.FILE_HEADER.Machine

                # update when supporting more than one arch
                if (magic == int('20B', 16) and machineType == 0x8664 and
                   self.WindowsType.lower() in ['all', 'x64']):
                    add_section = False
                    cave_jumping = False
                    if self.WindowsIntelx64['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx64['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx64['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx64['SHELL'],
                                             HOST=self.WindowsIntelx64['HOST'],
                                             PORT=int(self.WindowsIntelx64['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             RUNAS_ADMIN=self.str2bool(self.WindowsIntelx86['RUNAS_ADMIN']),
                                             PATCH_DLL=self.str2bool(self.WindowsIntelx64['PATCH_DLL']),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx64['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.str2bool(self.WindowsIntelx64['ZERO_CERT']),
                                             PATCH_METHOD=self.WindowsIntelx64['PATCH_METHOD'].lower(),
                                             SUPPLIED_BINARY=self.WindowsIntelx64['SUPPLIED_BINARY'],
                                             )

                    result = targetFile.run_this()

                elif (machineType == 0x14c and
                      self.WindowsType.lower() in ['all', 'x86']):
                    add_section = False
                    cave_jumping = False
                    # add_section wins for cave_jumping
                    # default is single for BDF
                    if self.WindowsIntelx86['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx86['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx86['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True
                        add_section = False

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx86['SHELL'],
                                             HOST=self.WindowsIntelx86['HOST'],
                                             PORT=int(self.WindowsIntelx86['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             RUNAS_ADMIN=self.str2bool(self.WindowsIntelx86['RUNAS_ADMIN']),
                                             PATCH_DLL=self.str2bool(self.WindowsIntelx86['PATCH_DLL']),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx86['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.str2bool(self.WindowsIntelx86['ZERO_CERT']),
                                             PATCH_METHOD=self.WindowsIntelx86['PATCH_METHOD'].lower(),
                                             SUPPLIED_BINARY=self.WindowsIntelx86['SUPPLIED_BINARY'],
                                             XP_MODE=self.str2bool(self.WindowsIntelx86['XP_MODE'])
                                             )

                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') == '7f454c46':  # ELF

                targetFile = elfbin.elfbin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                if targetFile.class_type == 0x1:
                    # x86CPU Type
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx86['SHELL'],
                                               HOST=self.LinuxIntelx86['HOST'],
                                               PORT=int(self.LinuxIntelx86['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx86['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType
                                               )
                    result = targetFile.run_this()
                elif targetFile.class_type == 0x2:
                    # x64
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx64['SHELL'],
                                               HOST=self.LinuxIntelx64['HOST'],
                                               PORT=int(self.LinuxIntelx64['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx64['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType
                                               )
                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') in ['cefaedfe', 'cffaedfe', 'cafebabe']:  # Macho
                targetFile = machobin.machobin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                # ONE CHIP SET MUST HAVE PRIORITY in FAT FILE

                if targetFile.FAT_FILE is True:
                    if self.FatPriority == 'x86':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx86['SHELL'],
                                                       HOST=self.MachoIntelx86['HOST'],
                                                       PORT=int(self.MachoIntelx86['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority
                                                       )
                        result = targetFile.run_this()

                    elif self.FatPriority == 'x64':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx64['SHELL'],
                                                       HOST=self.MachoIntelx64['HOST'],
                                                       PORT=int(self.MachoIntelx64['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority
                                                       )
                        result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x7':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx86['SHELL'],
                                                   HOST=self.MachoIntelx86['HOST'],
                                                   PORT=int(self.MachoIntelx86['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority
                                                   )
                    result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x1000007':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx64['SHELL'],
                                                   HOST=self.MachoIntelx64['HOST'],
                                                   PORT=int(self.MachoIntelx64['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority
                                                   )
                    result = targetFile.run_this()

            return result

        except Exception as e:
            self.log.error("Exception in binaryGrinder {0}".format(e))
            return None

Example 9

Project: bamsurgeon
Source File: addsv.py
View license
def makemut(args, bedline, alignopts):

    if args.seed is not None: random.seed(int(args.seed) + int(bedline.strip().split()[1]))

    mutid = '_'.join(map(str, bedline.strip().split()))
    try:
        bamfile = pysam.Samfile(args.bamFileName, 'rb')
        reffile = pysam.Fastafile(args.refFasta)
        logfn = '_'.join(map(os.path.basename, bedline.strip().split())) + ".log"
        logfile = open('addsv_logs_' + os.path.basename(args.outBamFile) + '/' + os.path.basename(args.outBamFile) + '_' + logfn, 'w')
        exclfile = args.tmpdir + '/' + '.'.join((mutid, 'exclude', str(uuid4()), 'txt'))
        exclude = open(exclfile, 'w')

        # optional CNV file
        cnv = None
        if (args.cnvfile):
            cnv = pysam.Tabixfile(args.cnvfile, 'r')

        # temporary file to hold mutated reads
        outbam_mutsfile = args.tmpdir + '/' + '.'.join((mutid, str(uuid4()), "muts.bam"))

        c = bedline.strip().split()
        chrom  = c[0]
        start  = int(c[1])
        end    = int(c[2])
        araw   = c[3:len(c)] # INV, DEL, INS seqfile.fa TSDlength, DUP
 
        # translocation specific
        trn_chrom = None
        trn_start = None
        trn_end   = None

        is_transloc = c[3] == 'TRN'

        if is_transloc:
            start -= 3000
            end   += 3000
            if start < 0: start = 0

            trn_chrom = c[4]
            trn_start = int(c[5]) - 3000
            trn_end   = int(c[5]) + 3000
            if trn_start < 0: trn_start = 0

        actions = map(lambda x: x.strip(),' '.join(araw).split(','))

        svfrac = float(args.svfrac) # default, can be overridden by cnv file

        if cnv: # CNV file is present
            if chrom in cnv.contigs:
                for cnregion in cnv.fetch(chrom,start,end):
                    cn = float(cnregion.strip().split()[3]) # expect chrom,start,end,CN
                    sys.stdout.write("INFO\t" + now() + "\t" + mutid + "\t" + ' '.join(("copy number in sv region:",chrom,str(start),str(end),"=",str(cn))) + "\n")
                    svfrac = 1.0/float(cn)
                    assert svfrac <= 1.0
                    sys.stdout.write("INFO\t" + now() + "\t" + mutid + "\tadjusted MAF: " + str(svfrac) + "\n")

        print "INFO\t" + now() + "\t" + mutid + "\tinterval:", c
        print "INFO\t" + now() + "\t" + mutid + "\tlength:", end-start

       # modify start and end if interval is too short
        minctglen = int(args.minctglen)

        # adjust if minctglen is too short
        if minctglen < 3*int(args.maxlibsize):
            minctglen = 3*int(args.maxlibsize)

        if end-start < minctglen:
            adj   = minctglen - (end-start)
            start = start - adj/2
            end   = end + adj/2

            print "INFO\t" + now() + "\t" + mutid + "\tnote: interval size was too short, adjusted: %s:%d-%d" % (chrom,start,end)

        dfrac = discordant_fraction(args.bamFileName, chrom, start, end)
        print "INFO\t" + now() + "\t" + mutid + "\tdiscordant fraction:", dfrac

        maxdfrac = 0.1 # FIXME make a parameter
        if dfrac > .1: 
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\tdiscordant fraction > " + str(maxdfrac) + " aborting mutation!\n")
            return None, None

        contigs = ar.asm(chrom, start, end, args.bamFileName, reffile, int(args.kmersize), args.tmpdir, mutid=mutid, debug=args.debug)

        trn_contigs = None
        if is_transloc:
            trn_contigs = ar.asm(trn_chrom, trn_start, trn_end, args.bamFileName, reffile, int(args.kmersize), args.tmpdir, mutid=mutid, debug=args.debug)

        maxcontig = sorted(contigs)[-1]

        trn_maxcontig = None
        if is_transloc: trn_maxcontig = sorted(trn_contigs)[-1]

        # be strict about contig quality
        if re.search('N', maxcontig.seq):
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\tcontig dropped due to ambiguous base (N), aborting mutation.\n")
            return None, None

        if is_transloc and re.search('N', trn_maxcontig.seq):
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\tcontig dropped due to ambiguous base (N), aborting mutation.\n")
            return None, None

        if maxcontig is None:
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\tmaxcontig has length 0, aborting mutation!\n")
            return None, None

        if is_transloc and trn_maxcontig is None:
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\ttransloc maxcontig has length 0, aborting mutation!\n")
            return None, None

        print "INFO\t" + now() + "\t" + mutid + "\tbest contig length:", sorted(contigs)[-1].len

        if is_transloc:
            print "INFO\t" + now() + "\t" + mutid + "\tbest transloc contig length:", sorted(trn_contigs)[-1].len

        # trim contig to get best ungapped aligned region to ref.
        maxcontig, refseq, alignstats, refstart, refend, qrystart, qryend, tgtstart, tgtend = trim_contig(mutid, chrom, start, end, maxcontig, reffile)

        if maxcontig is None:
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\tbest contig did not have sufficent match to reference, aborting mutation.\n")
            return None, None
    
        print "INFO\t" + now() + "\t" + mutid + "\tstart, end, tgtstart, tgtend, refstart, refend:", start, end, tgtstart, tgtend, refstart, refend

        if is_transloc:
            trn_maxcontig, trn_refseq, trn_alignstats, trn_refstart, trn_refend, trn_qrystart, trn_qryend, trn_tgtstart, trn_tgtend = trim_contig(mutid, trn_chrom, trn_start, trn_end, trn_maxcontig, reffile)
            print "INFO\t" + now() + "\t" + mutid + "\ttrn_start, trn_end, trn_tgtstart, trn_tgtend, trn_refstart, trn_refend:", trn_start, trn_end, trn_tgtstart, trn_tgtend, trn_refstart, trn_refend

        # is there anough room to make mutations?
        if maxcontig.len < 3*int(args.maxlibsize):
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\tbest contig too short to make mutation!\n")
            return None, None

        if is_transloc and trn_maxcontig.len < 3*int(args.maxlibsize):
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\tbest transloc contig too short to make mutation!\n")
            return None, None

        # make mutation in the largest contig
        mutseq = ms.MutableSeq(maxcontig.seq)

        if is_transloc: trn_mutseq = ms.MutableSeq(trn_maxcontig.seq)

        # support for multiple mutations
        for actionstr in actions:
            a = actionstr.split()
            action = a[0]

            print "INFO\t" + now() + "\t" + mutid + "\taction: ", actionstr, action

            insseqfile = None
            insseq = ''
            tsdlen = 0  # target site duplication length
            ndups = 0   # number of tandem dups
            dsize = 0.0 # deletion size fraction
            dlen = 0
            ins_motif = None

            if action == 'INS':
                assert len(a) > 1 # insertion syntax: INS <file.fa> [optional TSDlen]
                insseqfile = a[1]
                if not (os.path.exists(insseqfile) or insseqfile == 'RND'): # not a file... is it a sequence? (support indel ins.)
                    assert re.search('^[ATGCatgc]*$',insseqfile) # make sure it's a sequence
                    insseq = insseqfile.upper()
                    insseqfile = None
                if len(a) > 2: # field 5 for insertion is TSD Length
                    tsdlen = int(a[2])

                if len(a) > 3: # field 5 for insertion is motif, format = 'NNNN/NNNN where / is cut site
                    ins_motif = a[3]
                    assert '^' in ins_motif, 'insertion motif specification requires cut site defined by ^'

            if action == 'DUP':
                if len(a) > 1:
                    ndups = int(a[1])
                else:
                    ndups = 1

            if action == 'DEL':
                if len(a) > 1:
                    dsize = float(a[1])
                    if dsize > 1.0: # if DEL size is not a fraction, interpret as bp
                        # since DEL 1 is default, if DEL 1 is specified, interpret as 1 bp deletion
                        dlen = int(dsize)
                        dsize = 1.0
                else:
                    dsize = 1.0

            if action == 'TRN':
                pass


            logfile.write(">" + chrom + ":" + str(refstart) + "-" + str(refend) + " BEFORE\n" + str(mutseq) + "\n")

            if action == 'INS':
                inspoint = mutseq.length()/2
                if ins_motif is not None:
                    inspoint = mutseq.find_site(ins_motif, left_trim=int(args.maxlibsize), right_trim=int(args.maxlibsize))

                if insseqfile: # seq in file
                    if insseqfile == 'RND':
                        assert args.inslib is not None # insertion library needs to exist
                        insseqfile = random.choice(args.inslib.keys())
                        print "INFO\t" + now() + "\t" + mutid + "\tchose sequence from insertion library: " + insseqfile
                        mutseq.insertion(inspoint, args.inslib[insseqfile], tsdlen)

                    else:
                        mutseq.insertion(inspoint, singleseqfa(insseqfile, mutid=mutid), tsdlen)

                else: # seq is input
                    mutseq.insertion(inspoint, insseq, tsdlen)

                logfile.write("\t".join(('ins',chrom,str(refstart),str(refend),action,str(mutseq.length()),str(inspoint),str(insseqfile),str(tsdlen),str(svfrac))) + "\n")

            elif action == 'INV':
                invstart = int(args.maxlibsize)
                invend = mutseq.length() - invstart
                mutseq.inversion(invstart,invend)
                logfile.write("\t".join(('inv',chrom,str(refstart),str(refend),action,str(mutseq.length()),str(invstart),str(invend),str(svfrac))) + "\n")

            elif action == 'DEL':
                delstart = int(args.maxlibsize)
                delend = mutseq.length() - delstart
                if dlen == 0: # bp size not specified, delete fraction of contig
                    dlen = int((float(delend-delstart) * dsize)+0.5) 

                dadj = delend-delstart-dlen
                if dadj < 0:
                    dadj = 0
                    sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\twarning: deletion of length 0\n")

                delstart += dadj/2
                delend   -= dadj/2

                mutseq.deletion(delstart,delend)
                logfile.write("\t".join(('del',chrom,str(refstart),str(refend),action,str(mutseq.length()),str(delstart),str(delend),str(dlen),str(svfrac))) + "\n")

            elif action == 'DUP':
                dupstart = int(args.maxlibsize)
                dupend = mutseq.length() - dupstart
                mutseq.duplication(dupstart,dupend,ndups)
                logfile.write("\t".join(('dup',chrom,str(refstart),str(refend),action,str(mutseq.length()),str(dupstart),str(dupend),str(ndups),str(svfrac))) + "\n")

            elif action == 'TRN':
                mutseq.fusion(mutseq.length()/2, trn_mutseq, trn_mutseq.length()/2)
                logfile.write("\t".join(('trn',chrom,str(refstart),str(refend),action,str(mutseq.length()),trn_chrom,str(trn_refstart),str(trn_refend),str(trn_mutseq.length()),str(svfrac))) + "\n")

            else:
                raise ValueError("ERROR\t" + now() + "\t" + mutid + "\t: mutation not one of: INS,INV,DEL,DUP,TRN\n")

            logfile.write(">" + chrom + ":" + str(refstart) + "-" + str(refend) +" AFTER\n" + str(mutseq) + "\n")

        pemean, pesd = float(args.ismean), float(args.issd) 
        print "INFO\t" + now() + "\t" + mutid + "\tset paired end mean distance: " + str(args.ismean)
        print "INFO\t" + now() + "\t" + mutid + "\tset paired end distance stddev: " + str(args.issd)

        # simulate reads
        (fq1, fq2) = runwgsim(maxcontig, mutseq.seq, svfrac, actions, exclude, pemean, pesd, args.tmpdir, mutid=mutid, seed=args.seed, trn_contig=trn_maxcontig)

        outreads = aligners.remap_fastq(args.aligner, fq1, fq2, args.refFasta, outbam_mutsfile, alignopts, mutid=mutid, threads=1)

        if outreads == 0:
            sys.stderr.write("WARN\t" + now() + "\t" + mutid + "\toutbam " + outbam_mutsfile + " has no mapped reads!\n")
            return None, None

        print "INFO\t" + now() + "\t" + mutid + "\ttemporary bam: " + outbam_mutsfile

        exclude.close()
        bamfile.close()

        return outbam_mutsfile, exclfile

    except Exception, e:
        sys.stderr.write("*"*60 + "\nencountered error in mutation spikein: " + bedline + "\n")
        traceback.print_exc(file=sys.stderr)
        sys.stderr.write("*"*60 + "\n")
        return None, None

Example 10

Project: mkmov
Source File: twod.py
View license
    def lights(self,minvar=None,maxvar=None):
        """function to do some sanity checks on the files and find out where the time dim is.
        
        """
        _lg.info("Lights! Looking at your netCDF files...")
        var_timedims=[]

        #create bias files
        if self.arguments['--bias']:
            #following example in http://linux.die.net/man/1/ncdiff
            ncout='ncra '+' '.join(self.filelist)+' '+self.workingfolder+'mean.nc'
            _lg.info("Creating mean file: " + ncout)
            subprocess.call(ncout,shell=True)

            ncout='ncwa -O -a '+self.arguments['--bias']+' '+self.workingfolder+'mean.nc '+self.workingfolder+'mean_notime.nc'
            _lg.info("Removing time dimension from mean file: " + ncout)
            subprocess.call(ncout,shell=True)

            difffol=self.workingfolder+'difffiles/'
            mkdir_sub(self.workingfolder+'difffiles/')
            newfilelist=[]
            cnt=0
            for f in self.filelist:
                ncout='ncdiff '+' '+f+' '+self.workingfolder+'mean_notime.nc '+difffol+os.path.basename(f)[:-3]+'_diff_'+str(cnt).zfill(5)+'.nc'
                _lg.info("Creating anomaly file: " + ncout)
                subprocess.call(ncout,shell=True)
                newfilelist.append(difffol+os.path.basename(f)[:-3]+'_diff_'+str(cnt).zfill(5)+'.nc')
                cnt+=1

            self.filelist=newfilelist
                
        #error checks files, are all similar
        for f in self.filelist:
            if not os.path.exists(f):
                _lg.error("Input file: " + str(os.path.basename(f))  + " does not exist.")
                sys.exit("Input file: " + str(os.path.basename(f))  + " does not exist.")

            ifile=Dataset(f, 'r')

            if self.variable_name not in ifile.variables.keys():
                _lg.error("Variable: " + str(self.variable_name) + " does not exist in netcdf4 file.")
                _lg.error("Options are: " + str(ifile.variables.keys()) )
                sys.exit("Variable: " + str(self.variable_name) + " does not exist in netcdf4 file.")

            #what shape is the passed variable? Do some error checks
            self.var_len=len(ifile.variables[self.variable_name].shape)
            if self.var_len==2:
                if len(self.arguments['FILE_NAME'])==1:
                    #h'm haven't actually tried this! 
                    _lg.error("Variable: " + str(self.variable_name) + " has only two dimensions and you only fed mkmov one file so I don't know where your time dimension is.")
                    sys.exit()
                elif len(self.arguments['FILE_NAME'])>1: #have tested this on AVISO works okay
                    pass
                    
            #the 'obvious' case; one file with one time dim and two spatial dims
            if self.var_len==3: 
                pass

            #tricky, which dims are time/random_dim/spatial1/spatial2?
            if self.var_len==4:
                if self.arguments['--4dvar']:
                    _lg.debug("Variable: " + str(self.variable_name) + " has four dimensions. Following your argument, we will plot depth level: "+self.arguments['--4dvar'] )
                    self.depthlvl=int(self.arguments['--4dvar'])
                else:
                    _lg.warning("Variable: " + str(self.variable_name) + " has four dimensions. MkMov will assume the second dim is depth/height and plot the first level.")
                    self.depthlvl=0

            ifile_dim_keys=list(dict(ifile.dimensions).keys())

            #find unlimited dimension
            findunlim=[ifile.dimensions[dim].isunlimited() for dim in ifile_dim_keys]
            dim_unlim_num=[i for i, x in enumerate(findunlim) if x]
            if len(dim_unlim_num)==0:
                _lg.warning("Input file: " + str(os.path.basename(f))  + " has no unlimited dimension, which dim is time?")
                # sys.exit("Input file: " + str(os.path.basename(f))  + " has no unlimited dimension, which dim is time?")
            elif len(dim_unlim_num)>1:
                _lg.warning("Input file: " + str(os.path.basename(f))  + " has more than one unlimited dimension.")
                # sys.exit("Input file: " + str(os.path.basename(f))  + " has more than one unlimited dimension.")
            else:
                timename=ifile_dim_keys[dim_unlim_num[0]]
                var_timedim=[i for i, x in enumerate(ifile.variables[self.variable_name].dimensions) if x==timename][0]
                var_timedims.append(var_timedim)
                ifile.close()
                continue #NOTE I'm a continue!

            #okay so we didn't find time as an unlimited dimension, perhaps it has a sensible name?
            if 'time' in ifile_dim_keys:
                timename='time'
            elif 't' in ifile_dim_keys:
                timename='t'
            elif 'Time' in ifile_dim_keys:
                timename='Time'
            else:
                timename=''

            if timename!='':
                if self.var_len>2:
                    _lg.info("Good news, we think we found the time dimension it's called: " + timename )
                    var_timedim=[i for i, x in enumerate(ifile.variables[self.variable_name].dimensions) if x==timename][0]
                    var_timedims.append(var_timedim)


            # the case where there is only two dimensions assumed to vary across each file (e.g. mwf-ers2 files)
            if self.var_len==2:
                var_timedims=[-1]

            ifile.close()

        #check all time dimensions are in the same place across all files..
        if var_timedims[1:]==var_timedims[:-1]:
            self.timedim=var_timedims[0]
        else:
            _lg.error("(Unlimited) 'time' dimension was not the same across all files, fatal error.")
            sys.exit("(Unlimited) 'time' dimension was not the same across all files, fatal error.")

        #get max and min values for timeseries. This is expensive :(
        if (minvar is None) and (maxvar is None):
            mins=[]
            maxs=[]
            for f in self.filelist:
                ifile=Dataset(f, 'r')
                name_of_array=self.getdata(ifile)

                mins.append(np.min(name_of_array))
                maxs.append(np.max(name_of_array))
                ifile.close()

            self.minvar=np.min(mins)
            self.maxvar=np.max(maxs)

        if minvar or maxvar is not None:
            #user specified the range
            self.minvar=float(minvar)
            self.maxvar=float(maxvar)

        if (self.arguments['--x'] is not None) and (self.arguments['--y'] is not None):
            ifile=Dataset(self.filelist[0], 'r') #they should all be the same.
            xvar=ifile.variables[self.arguments['--x']][:]
            yvar=ifile.variables[self.arguments['--y']][:]
            self.x,self.y=np.meshgrid(xvar,yvar)
            ifile.close()
        elif (self.arguments['--x2d'] is not None) and (self.arguments['--y2d'] is not None):
            ifile=Dataset(self.filelist[0], 'r') #they should all be the same.
            self.x=ifile.variables[self.arguments['--x2d']][:]

            if self.arguments['--fixdateline']:
                #fix the dateline
                for index in np.arange(np.shape(self.x)[0]):
                    if len(np.where(np.sign(self.x[index,:])==-1)[0])==0:
                        _lg.warning("MkMov couldn't find your dateline, skipping the 'fix'.")
                        break

                    start=np.where(np.sign(self.x[index,:])==-1)[0][0]
                    self.x[index,start:]=self.x[index,start:]+360

            self.y=ifile.variables[self.arguments['--y2d']][:]
            ifile.close()
        else:
            ifile=Dataset(self.filelist[0], 'r')
            name_of_array=np.shape(ifile.variables[self.variable_name])
            if self.var_len==4:
                name_of_array=[name_of_array[0]]+[e for e in name_of_array[2:]]

            if self.timedim==-1:
                self.x,self.y=np.meshgrid(np.arange(name_of_array[1]),\
                        np.arange(name_of_array[0]))
            else:
                self.x,self.y=np.meshgrid(np.arange(name_of_array[self.timedim+2]),\
                        np.arange(name_of_array[self.timedim+1]))

            ifile.close()

        return

Example 11

Project: autospec
Source File: tarball.py
View license
def download_tarball(url_argument, name_argument, archives, target_dir):
    global name
    global rawname
    global version
    global url
    global path
    global tarball_prefix
    global gcov_file
    # go naming
    global golibpath
    global go_pkgname

    url = url_argument
    tarfile = os.path.basename(url)
    pattern_options = [
        r"(.*?)[\-_](v*[0-9]+[alpha\+_spbfourcesigedsvstableP0-9\.\-\~]*)\.src\.(tgz|tar|zip)",
        r"(.*?)[\-_](v*[0-9]+[alpha\+_sbpfourcesigedsvstableP0-9\.\-\~]*)\.(tgz|tar|zip)",
        r"(.*?)[\-_](v*[0-9]+[a-zalpha\+_spbfourcesigedsvstableP0-9\.\-\~]*)\.orig\.tar",
        r"(.*?)[\-_](v*[0-9]+[\+_spbfourcesigedsvstableP0-9\.\~]*)(-.*?)?\.tar",
    ]
    for pattern in pattern_options:
        p = re.compile(pattern)
        m = p.search(tarfile)
        if m:
            name = m.group(1).strip()
            version = m.group(2).strip()
            b = version.find("-")
            if b >= 0:
                version = version[:b]
            break

    rawname = name
    # R package
    if url_argument.find("cran.r-project.org") > 0 or url_argument.find("cran.rstudio.com") > 0:
        buildpattern.set_build_pattern("R", 10)
        files.want_dev_split = 0
        buildreq.add_buildreq("clr-R-helpers")
        p = re.compile(r"([A-Za-z0-9]+)_(v*[0-9]+[\+_spbfourcesigedsvstableP0-9\.\~\-]*)\.tar\.gz")
        m = p.search(tarfile)
        if m:
            name = "R-" + m.group(1).strip()
            rawname = m.group(1).strip()
            version = m.group(2).strip()
            b = version.find("-")
            if b >= 0:
                version = version[:b]

    if url_argument.find("pypi.python.org") > 0:
        buildpattern.set_build_pattern("distutils", 10)
        url_argument = "http://pypi.debian.net/" + name + "/" + tarfile
    if url_argument.find("pypi.debian.net") > 0:
        buildpattern.set_build_pattern("distutils", 10)

    if url_argument.find(".cpan.org/CPAN/") > 0:
        buildpattern.set_build_pattern("cpan", 10)
        if name:
            name = "perl-" + name
    if url_argument.find(".metacpan.org/") > 0:
        buildpattern.set_build_pattern("cpan", 10)
        if name:
            name = "perl-" + name

    if "github.com" in url_argument:
        # golibpath = golang_libpath(url_argument)
        # go_pkgname = golang_name(url_argument)
        # define regex accepted for valid packages
        github_patterns = [r"https://github.com/.*/(.*?)/archive/(.*)-final.tar",
                           r"https://github.com/.*/.*/archive/[0-9a-fA-F]{1,40}\/(.*)\-(.*).tar",
                           r"https://github.com/.*/(.*?)/archive/(.*).zip",
                           r"https://github.com/.*/(.*?)/archive/v?(.*).tar"]

        for pattern in github_patterns:
            p = re.compile(pattern)
            m = p.search(url_argument)
            if m:
                name = m.group(1).strip()
                version = m.group(2).strip()
                b = version.find("-")
                if b > 0:
                    version = version[:b]
                break

    if url_argument.find("bitbucket.org") > 0:
        p = re.compile(r"https://bitbucket.org/.*/(.*?)/get/[a-zA-Z_-]*([0-9][0-9_.]*).tar")
        m = p.search(url_argument)
        if m:
            name = m.group(1).strip()
            version = m.group(2).strip().replace('_', '.')
        else:
            version = "1"

    # ruby
    if url_argument.find("rubygems.org/") > 0:
        buildpattern.set_build_pattern("ruby", 10)
        p = re.compile(r"(.*?)[\-_](v*[0-9]+[alpha\+_spbfourcesigedsvstableP0-9\.\-\~]*)\.gem")
        m = p.search(tarfile)
        if m:
            buildreq.add_buildreq("ruby")
            buildreq.add_buildreq("rubygem-rdoc")
            name = "rubygem-" + m.group(1).strip()
            rawname = m.group(1).strip()
            version = m.group(2).strip()
            b = version.find("-")
            if b >= 0:
                version = version[:b]

    # override from commandline
    if name_argument and name_argument[0] != name:
        pattern = name_argument[0] + r"[\-]*(.*)\.(tgz|tar|zip)"
        p = re.compile(pattern)
        m = p.search(tarfile)
        if m:
            name = name_argument[0]
            rawname = name
            version = m.group(1).strip()
            b = version.find("-")
            if b >= 0 and version.find("-beta") < 0:
                version = version[:b]
            if version.startswith('.'):
                version = version[1:]
        else:
            name = name_argument[0]

    if not name:
        split = url_argument.split('/')
        if len(split) > 3 and split[-2] in ('archive', 'tarball'):
            name = split[-3]
            version = split[-1]
            if version.startswith('v'):
                version = version[1:]
            # remove extension
            version = '.'.join(version.split('.')[:-1])
            if version.endswith('.tar'):
                version = '.'.join(version.split('.')[:-1])

    b = version.find("-")
    if b >= 0 and version.find("-beta") < 0:
        b = b + 1
        version = version[b:]

    if len(version) > 0 and version[0] in ['v', 'r']:
        version = version[1:]

    assert name != ""

    if not target_dir:
        build.download_path = os.getcwd() + "/" + name
    else:
        build.download_path = target_dir
    call("mkdir -p %s" % build.download_path)

    gcov_path = build.download_path + "/" + name + ".gcov"
    if os.path.isfile(gcov_path):
        gcov_file = name + ".gcov"

    tarball_path = check_or_get_file(url, tarfile)
    sha1 = get_sha1sum(tarball_path)
    with open(build.download_path + "/upstream", "w") as file:
        file.write(sha1 + "/" + tarfile + "\n")

    tarball_prefix = name + "-" + version
    if tarfile.lower().endswith('.zip'):
        tarball_contents = subprocess.check_output(
            ["unzip", "-l", tarball_path], universal_newlines=True)
        if tarball_contents and len(tarball_contents.splitlines()) > 3:
            tarball_prefix = tarball_contents.splitlines()[3].rsplit("/")[0].split()[-1]
        extract_cmd = "unzip -d {0} {1}".format(build.base_path, tarball_path)

    elif tarfile.lower().endswith('.gem'):
        tarball_contents = subprocess.check_output(
            ["gem", "unpack", "--verbose", tarball_path], universal_newlines=True)
        extract_cmd = "gem unpack --target={0} {1}".format(build.base_path, tarball_path)
        if tarball_contents:
            tarball_prefix = tarball_contents.splitlines()[-1].rsplit("/")[-1]
            if tarball_prefix.endswith("'"):
                tarball_prefix = tarball_prefix[:-1]
    else:
        extract_cmd, tarball_prefix = build_untar(tarball_path)

    if version == "":
        version = "1"

    print("\n")

    print("Processing", url_argument)
    print(
        "=============================================================================================")
    print("Name        :", name)
    print("Version     :", version)
    print("Prefix      :", tarball_prefix)

    with open(build.download_path + "/Makefile", "w") as file:
        file.write("PKG_NAME := " + name + "\n")
        file.write("URL := " + url_argument + "\n")
        file.write("ARCHIVES :=")
        for archive in archives:
            file.write(" {}".format(archive))
        file.write("\n")
        file.write("\n")
        file.write("include ../common/Makefile.common\n")

    shutil.rmtree("{}".format(build.base_path), ignore_errors=True)
    os.makedirs("{}".format(build.output_path))
    call("mkdir -p %s" % build.download_path)
    call(extract_cmd)

    path = build.base_path + tarball_prefix

    for archive, destination in zip(archives[::2], archives[1::2]):
        source_tarball_path = check_or_get_file(archive, os.path.basename(archive))
        if source_tarball_path.lower().endswith('.zip'):
            tarball_contents = subprocess.check_output(
                ["unzip", "-l", source_tarball_path], universal_newlines=True)
            if tarball_contents and len(tarball_contents.splitlines()) > 3:
                source_tarball_prefix = tarball_contents.splitlines()[3].rsplit("/")[0].split()[-1]
            extract_cmd = "unzip -d {0} {1}".format(build.base_path, source_tarball_path)
        else:
            extract_cmd, source_tarball_prefix = build_untar(source_tarball_path)
        buildpattern.archive_details[archive + "prefix"] = source_tarball_prefix
        call(extract_cmd)
        tar_files = glob.glob("{0}{1}/*".format(build.base_path, source_tarball_prefix))
        move_cmd = "mv "
        for tar_file in tar_files:
            move_cmd += tar_file + " "
        move_cmd += '{0}/{1}'.format(path, destination)

        mkdir_cmd = "mkdir -p "
        mkdir_cmd += '{0}/{1}'.format(path, destination)

        print("mkdir " + mkdir_cmd)
        call(mkdir_cmd)
        call(move_cmd)

        sha1 = get_sha1sum(source_tarball_path)
        with open(build.download_path + "/upstream", "a") as file:
            file.write(sha1 + "/" + os.path.basename(archive) + "\n")

Example 12

Project: nzbToMedia
Source File: nzbToMedia.py
View license
def main(args, section=None):
    # Initialize the config
    core.initialize(section)

    logger.info("#########################################################")
    logger.info("## ..::[{0}]::.. ##".format(os.path.basename(__file__)))
    logger.info("#########################################################")

    # debug command line options
    logger.debug("Options passed into nzbToMedia: {0}".format(args))

    # Post-Processing Result
    result = [0, ""]
    status = 0

    # NZBGet
    if 'NZBOP_SCRIPTDIR' in os.environ:
        # Check if the script is called from nzbget 11.0 or later
        if os.environ['NZBOP_VERSION'][0:5] < '11.0':
            logger.error("NZBGet Version {0} is not supported. Please update NZBGet.".format(os.environ['NZBOP_VERSION']))
            sys.exit(core.NZBGET_POSTPROCESS_ERROR)

        logger.info("Script triggered from NZBGet Version {0}.".format(os.environ['NZBOP_VERSION']))

        # Check if the script is called from nzbget 13.0 or later
        if 'NZBPP_TOTALSTATUS' in os.environ:
            if not os.environ['NZBPP_TOTALSTATUS'] == 'SUCCESS':
                logger.info("Download failed with status {0}.".format(os.environ['NZBPP_STATUS']))
                status = 1

        else:
            # Check par status
            if os.environ['NZBPP_PARSTATUS'] == '1' or os.environ['NZBPP_PARSTATUS'] == '4':
                logger.warning("Par-repair failed, setting status \"failed\"")
                status = 1

            # Check unpack status
            if os.environ['NZBPP_UNPACKSTATUS'] == '1':
                logger.warning("Unpack failed, setting status \"failed\"")
                status = 1

            if os.environ['NZBPP_UNPACKSTATUS'] == '0' and os.environ['NZBPP_PARSTATUS'] == '0':
                # Unpack was skipped due to nzb-file properties or due to errors during par-check

                if os.environ['NZBPP_HEALTH'] < 1000:
                    logger.warning(
                        "Download health is compromised and Par-check/repair disabled or no .par2 files found. Setting status \"failed\"")
                    logger.info("Please check your Par-check/repair settings for future downloads.")
                    status = 1

                else:
                    logger.info(
                        "Par-check/repair disabled or no .par2 files found, and Unpack not required. Health is ok so handle as though download successful")
                    logger.info("Please check your Par-check/repair settings for future downloads.")

        # Check for download_id to pass to CouchPotato
        download_id = ""
        failureLink = None
        if 'NZBPR_COUCHPOTATO' in os.environ:
            download_id = os.environ['NZBPR_COUCHPOTATO']
        elif 'NZBPR_DRONE' in os.environ:
            download_id = os.environ['NZBPR_DRONE']
        elif 'NZBPR_SONARR' in os.environ:
            download_id = os.environ['NZBPR_SONARR']
        if 'NZBPR__DNZB_FAILURE' in os.environ:
            failureLink = os.environ['NZBPR__DNZB_FAILURE']

        # All checks done, now launching the script.
        clientAgent = 'nzbget'
        result = process(os.environ['NZBPP_DIRECTORY'], inputName=os.environ['NZBPP_NZBNAME'], status=status,
                         clientAgent=clientAgent, download_id=download_id, inputCategory=os.environ['NZBPP_CATEGORY'],
                         failureLink=failureLink)
    # SABnzbd Pre 0.7.17
    elif len(args) == core.SABNZB_NO_OF_ARGUMENTS:
        # SABnzbd argv:
        # 1 The final directory of the job (full path)
        # 2 The original name of the NZB file
        # 3 Clean version of the job name (no path info and ".nzb" removed)
        # 4 Indexer's report number (if supported)
        # 5 User-defined category
        # 6 Group that the NZB was posted in e.g. alt.binaries.x
        # 7 Status of post processing. 0 = OK, 1=failed verification, 2=failed unpack, 3=1+2
        clientAgent = 'sabnzbd'
        logger.info("Script triggered from SABnzbd")
        result = process(args[1], inputName=args[2], status=args[7], inputCategory=args[5], clientAgent=clientAgent,
                         download_id='')
    # SABnzbd 0.7.17+
    elif len(args) >= core.SABNZB_0717_NO_OF_ARGUMENTS:
        # SABnzbd argv:
        # 1 The final directory of the job (full path)
        # 2 The original name of the NZB file
        # 3 Clean version of the job name (no path info and ".nzb" removed)
        # 4 Indexer's report number (if supported)
        # 5 User-defined category
        # 6 Group that the NZB was posted in e.g. alt.binaries.x
        # 7 Status of post processing. 0 = OK, 1=failed verification, 2=failed unpack, 3=1+2
        # 8 Failure URL
        clientAgent = 'sabnzbd'
        logger.info("Script triggered from SABnzbd 0.7.17+")
        result = process(args[1], inputName=args[2], status=args[7], inputCategory=args[5], clientAgent=clientAgent,
                         download_id='', failureLink=''.join(args[8:]))
    # Generic program
    elif len(args) > 5 and args[5] == 'generic':
        logger.info("Script triggered from generic program")
        result = process(args[1], inputName=args[2], inputCategory=args[3], download_id=args[4])
    else:
        # Perform Manual Post-Processing
        logger.warning("Invalid number of arguments received from client, Switching to manual run mode ...")

        for section, subsections in core.SECTIONS.items():
            for subsection in subsections:
                if not core.CFG[section][subsection].isenabled():
                    continue
                for dirName in getDirs(section, subsection, link='move'):
                    logger.info("Starting manual run for {0}:{1} - Folder: {2}".format(section, subsection, dirName))
                    logger.info("Checking database for download info for {0} ...".format(os.path.basename(dirName)))

                    core.DOWNLOADINFO = get_downloadInfo(os.path.basename(dirName), 0)
                    if core.DOWNLOADINFO:
                        logger.info("Found download info for {0}, "
                                    "setting variables now ...".format
                                    (os.path.basename(dirName)))
                        clientAgent = text_type(core.DOWNLOADINFO[0].get('client_agent', 'manual'))
                        download_id = text_type(core.DOWNLOADINFO[0].get('input_id', ''))
                    else:
                        logger.info('Unable to locate download info for {0}, '
                                    'continuing to try and process this release ...'.format
                                    (os.path.basename(dirName)))
                        clientAgent = 'manual'
                        download_id = ''

                    if clientAgent and clientAgent.lower() not in core.NZB_CLIENTS:
                        continue

                    try:
                        dirName = dirName.encode(core.SYS_ENCODING)
                    except UnicodeError:
                        pass
                    inputName = os.path.basename(dirName)
                    try:
                        inputName = inputName.encode(core.SYS_ENCODING)
                    except UnicodeError:
                        pass

                    results = process(dirName, inputName, 0, clientAgent=clientAgent,
                                      download_id=download_id or None, inputCategory=subsection)
                    if results[0] != 0:
                        logger.error("A problem was reported when trying to perform a manual run for {0}:{1}.".format
                                     (section, subsection))
                        result = results

    if result[0] == 0:
        logger.info("The {0} script completed successfully.".format(args[0]))
        if result[1]:
            print(result[1] + "!")
        if 'NZBOP_SCRIPTDIR' in os.environ:  # return code for nzbget v11
            del core.MYAPP
            return core.NZBGET_POSTPROCESS_SUCCESS
    else:
        logger.error("A problem was reported in the {0} script.".format(args[0]))
        if result[1]:
            print(result[1] + "!")
        if 'NZBOP_SCRIPTDIR' in os.environ:  # return code for nzbget v11
            del core.MYAPP
            return core.NZBGET_POSTPROCESS_ERROR
    del core.MYAPP
    return result[0]

Example 13

Project: ck-env
Source File: customize.py
View license
def setup(i):
    """
    Input:  {
              cfg              - meta of this soft entry
              self_cfg         - meta of module soft
              ck_kernel        - import CK kernel module (to reuse functions)

              host_os_uoa      - host OS UOA
              host_os_uid      - host OS UID
              host_os_dict     - host OS meta

              target_os_uoa    - target OS UOA
              target_os_uid    - target OS UID
              target_os_dict   - target OS meta

              target_device_id - target device ID (if via ADB)

              tags             - list of tags used to search this entry

              env              - updated environment vars from meta
              customize        - updated customize vars from meta

              deps             - resolved dependencies for this soft

              interactive      - if 'yes', can ask questions, otherwise quiet
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0

              bat          - prepared string for bat file
            }

    """

    import os

    # Get variables
    ck=i['ck_kernel']
    s=''

    iv=i.get('interactive','')

    env=i.get('env',{})
    cfg=i.get('cfg',{})
    deps=i.get('deps',{})
    tags=i.get('tags',[])
    cus=i.get('customize',{})

    target_d=i.get('target_os_dict',{})
    win=target_d.get('windows_base','')
    mic=target_d.get('intel_mic','')
    remote=target_d.get('remote','')
    mingw=target_d.get('mingw','')
    tbits=target_d.get('bits','')

    ep=cus.get('env_prefix','')
    pi=cus.get('path_install','')

    host_d=i.get('host_os_dict',{})
    tosd=i.get('target_os_dict',{})
    sdirs=host_d.get('dir_sep','')
    tsep=tosd.get('dir_sep','')

    fp=cus.get('full_path','')
    if win=='yes':
       f1='tbb.lib'
       f2='tbbmalloc.lib'
       f3='tbbproxy.lib'
       f1d='tbb.dll'
       f2d='tbbmalloc.dll'
       f3d='tbbproxy.dll'
    else:
       f1=''
       f2=''
       f3=''
       f1d='libtbb.so'
       f2d='libtbbmalloc.so'
       f3d='libproxy.so'

    if fp.find('lib_release')>0:
       lib=os.path.basename(fp)

       p1=os.path.dirname(fp)
       pi=os.path.dirname(p1)

       cus['path_lib']=p1
       cus['path_include']=pi+sdirs+'include'

       cus['path_dynamic_lib']=p1

       cus['dynamic_lib']=f1d
       cus['extra_dynamic_libs']={'libtbbmalloc':f2d,
                                  'libtbbproxy':f3d}

    else:
       lib=os.path.basename(fp)
       pl=os.path.dirname(fp)

       pi=pl
       found=False
       sd=[]
       while True:
          if os.path.isdir(os.path.join(pi,'include')):
             found=True

             break
          pix=os.path.dirname(pi)
          if pix==pi:
             break
          sd.append(os.path.basename(pi))
          pi=pix

       if not found:
          return {'return':1, 'error':'can\'t find root dir of the TBB installation'}

       if win=='yes':
          px=os.path.join(pi,'bin')
          if len(sd)>1:
             px=os.path.join(px,sd[1],sd[0])
          cus['path_bin']=px
       else:
          cus['path_bin']=pl

       cus['path_lib']=pl
       cus['path_include']=pi+tsep+'include'

       cus['path_dynamic_lib']=pl
       cus['dynamic_lib']=f1d
       cus['extra_dynamic_libs']={'libtbbmalloc':f2d,
                                  'libtbbproxy':f3d}

       cus['static_lib']=f1
       cus['extra_static_libs']={'libtbbmalloc':f2,
                                 'libtbbproxy':f3}



       env[ep+'_STATIC_NAME']=f1
       env[ep+'_DYNAMIC_NAME']=f1d

       env[ep+'_MALLOC_STATIC_NAME']=f2
       env[ep+'_MALLOC_DYNAMIC_NAME']=f2d

       env[ep+'_PROXY_STATIC_NAME']=f3
       env[ep+'_PROXY_DYNAMIC_NAME']=f3d

    return {'return':0, 'bat':s}

Example 14

Project: letsencrypt-nosudo
Source File: sign_csr.py
View license
def sign_csr(pubkey, csr, email=None, file_based=False):
    """Use the ACME protocol to get an ssl certificate signed by a
    certificate authority.

    :param string pubkey: Path to the user account public key.
    :param string csr: Path to the certificate signing request.
    :param string email: An optional user account contact email
                         (defaults to [email protected]<shortest_domain>)
    :param bool file_based: An optional flag indicating that the
                            hosting should be file-based rather
                            than providing a simple python HTTP
                            server.

    :returns: Signed Certificate (PEM format)
    :rtype: string

    """
    #CA = "https://acme-staging.api.letsencrypt.org"
    CA = "https://acme-v01.api.letsencrypt.org"
    TERMS = "https://letsencrypt.org/documents/LE-SA-v1.1.1-August-1-2016.pdf"
    nonce_req = urllib2.Request("{0}/directory".format(CA))
    nonce_req.get_method = lambda : 'HEAD'

    def _b64(b):
        "Shortcut function to go from bytes to jwt base64 string"
        return base64.urlsafe_b64encode(b).replace("=", "")

    # Step 1: Get account public key
    sys.stderr.write("Reading pubkey file...\n")
    proc = subprocess.Popen(["openssl", "rsa", "-pubin", "-in", pubkey, "-noout", "-text"],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = proc.communicate()
    if proc.returncode != 0:
        raise IOError("Error loading {0}".format(pubkey))
    pub_hex, pub_exp = re.search(
        "Modulus(?: \((?:2048|4096) bit\)|)\:\s+00:([a-f0-9\:\s]+?)Exponent\: ([0-9]+)",
        out, re.MULTILINE|re.DOTALL).groups()
    pub_mod = binascii.unhexlify(re.sub("(\s|:)", "", pub_hex))
    pub_mod64 = _b64(pub_mod)
    pub_exp = int(pub_exp)
    pub_exp = "{0:x}".format(pub_exp)
    pub_exp = "0{0}".format(pub_exp) if len(pub_exp) % 2 else pub_exp
    pub_exp = binascii.unhexlify(pub_exp)
    pub_exp64 = _b64(pub_exp)
    header = {
        "alg": "RS256",
        "jwk": {
            "e": pub_exp64,
            "kty": "RSA",
            "n": pub_mod64,
        },
    }
    accountkey_json = json.dumps(header['jwk'], sort_keys=True, separators=(',', ':'))
    thumbprint = _b64(hashlib.sha256(accountkey_json).digest())
    sys.stderr.write("Found public key!\n")

    # Step 2: Get the domain names to be certified
    sys.stderr.write("Reading csr file...\n")
    proc = subprocess.Popen(["openssl", "req", "-in", csr, "-noout", "-text"],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = proc.communicate()
    if proc.returncode != 0:
        raise IOError("Error loading {0}".format(csr))
    domains = set([])
    common_name = re.search("Subject:.*? CN=([^\s,;/]+)", out)
    if common_name is not None:
        domains.add(common_name.group(1))
    subject_alt_names = re.search("X509v3 Subject Alternative Name: \n +([^\n]+)\n", out, re.MULTILINE|re.DOTALL)
    if subject_alt_names is not None:
        for san in subject_alt_names.group(1).split(", "):
            if san.startswith("DNS:"):
                domains.add(san[4:])
    sys.stderr.write("Found domains {0}\n".format(", ".join(domains)))

    # Step 3: Ask user for contact email
    if not email:
        default_email = "[email protected]{0}".format(min(domains, key=len))
        stdout = sys.stdout
        sys.stdout = sys.stderr
        input_email = raw_input("STEP 1: What is your contact email? ({0}) ".format(default_email))
        email = input_email if input_email else default_email
        sys.stdout = stdout

    # Step 4: Generate the payloads that need to be signed
    # registration
    sys.stderr.write("Building request payloads...\n")
    reg_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
    reg_raw = json.dumps({
        "resource": "new-reg",
        "contact": ["mailto:{0}".format(email)],
        "agreement": TERMS,
    }, sort_keys=True, indent=4)
    reg_b64 = _b64(reg_raw)
    reg_protected = copy.deepcopy(header)
    reg_protected.update({"nonce": reg_nonce})
    reg_protected64 = _b64(json.dumps(reg_protected, sort_keys=True, indent=4))
    reg_file = tempfile.NamedTemporaryFile(dir=".", prefix="register_", suffix=".json")
    reg_file.write("{0}.{1}".format(reg_protected64, reg_b64))
    reg_file.flush()
    reg_file_name = os.path.basename(reg_file.name)
    reg_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="register_", suffix=".sig")
    reg_file_sig_name = os.path.basename(reg_file_sig.name)

    # need signature for each domain identifiers
    ids = []
    for domain in domains:
        sys.stderr.write("Building request for {0}...\n".format(domain))
        id_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
        id_raw = json.dumps({
            "resource": "new-authz",
            "identifier": {
                "type": "dns",
                "value": domain,
            },
        }, sort_keys=True)
        id_b64 = _b64(id_raw)
        id_protected = copy.deepcopy(header)
        id_protected.update({"nonce": id_nonce})
        id_protected64 = _b64(json.dumps(id_protected, sort_keys=True, indent=4))
        id_file = tempfile.NamedTemporaryFile(dir=".", prefix="domain_", suffix=".json")
        id_file.write("{0}.{1}".format(id_protected64, id_b64))
        id_file.flush()
        id_file_name = os.path.basename(id_file.name)
        id_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="domain_", suffix=".sig")
        id_file_sig_name = os.path.basename(id_file_sig.name)
        ids.append({
            "domain": domain,
            "protected64": id_protected64,
            "data64": id_b64,
            "file": id_file,
            "file_name": id_file_name,
            "sig": id_file_sig,
            "sig_name": id_file_sig_name,
        })

    # need signature for the final certificate issuance
    sys.stderr.write("Building request for CSR...\n")
    proc = subprocess.Popen(["openssl", "req", "-in", csr, "-outform", "DER"],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    csr_der, err = proc.communicate()
    csr_der64 = _b64(csr_der)
    csr_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
    csr_raw = json.dumps({
        "resource": "new-cert",
        "csr": csr_der64,
    }, sort_keys=True, indent=4)
    csr_b64 = _b64(csr_raw)
    csr_protected = copy.deepcopy(header)
    csr_protected.update({"nonce": csr_nonce})
    csr_protected64 = _b64(json.dumps(csr_protected, sort_keys=True, indent=4))
    csr_file = tempfile.NamedTemporaryFile(dir=".", prefix="cert_", suffix=".json")
    csr_file.write("{0}.{1}".format(csr_protected64, csr_b64))
    csr_file.flush()
    csr_file_name = os.path.basename(csr_file.name)
    csr_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="cert_", suffix=".sig")
    csr_file_sig_name = os.path.basename(csr_file_sig.name)

    # Step 5: Ask the user to sign the registration and requests
    sys.stderr.write("""\
STEP 2: You need to sign some files (replace 'user.key' with your user private key).

openssl dgst -sha256 -sign user.key -out {0} {1}
{2}
openssl dgst -sha256 -sign user.key -out {3} {4}

""".format(
    reg_file_sig_name, reg_file_name,
    "\n".join("openssl dgst -sha256 -sign user.key -out {0} {1}".format(i['sig_name'], i['file_name']) for i in ids),
    csr_file_sig_name, csr_file_name))

    stdout = sys.stdout
    sys.stdout = sys.stderr
    raw_input("Press Enter when you've run the above commands in a new terminal window...")
    sys.stdout = stdout

    # Step 6: Load the signatures
    reg_file_sig.seek(0)
    reg_sig64 = _b64(reg_file_sig.read())
    for n, i in enumerate(ids):
        i['sig'].seek(0)
        i['sig64'] = _b64(i['sig'].read())

    # Step 7: Register the user
    sys.stderr.write("Registering {0}...\n".format(email))
    reg_data = json.dumps({
        "header": header,
        "protected": reg_protected64,
        "payload": reg_b64,
        "signature": reg_sig64,
    }, sort_keys=True, indent=4)
    reg_url = "{0}/acme/new-reg".format(CA)
    try:
        resp = urllib2.urlopen(reg_url, reg_data)
        result = json.loads(resp.read())
    except urllib2.HTTPError as e:
        err = e.read()
        # skip already registered accounts
        if "Registration key is already in use" in err:
            sys.stderr.write("Already registered. Skipping...\n")
        else:
            sys.stderr.write("Error: reg_data:\n")
            sys.stderr.write("POST {0}\n".format(reg_url))
            sys.stderr.write(reg_data)
            sys.stderr.write("\n")
            sys.stderr.write(err)
            sys.stderr.write("\n")
            raise

    # Step 8: Request challenges for each domain
    responses = []
    tests = []
    for n, i in enumerate(ids):
        sys.stderr.write("Requesting challenges for {0}...\n".format(i['domain']))
        id_data = json.dumps({
            "header": header,
            "protected": i['protected64'],
            "payload": i['data64'],
            "signature": i['sig64'],
        }, sort_keys=True, indent=4)
        id_url = "{0}/acme/new-authz".format(CA)
        try:
            resp = urllib2.urlopen(id_url, id_data)
            result = json.loads(resp.read())
        except urllib2.HTTPError as e:
            sys.stderr.write("Error: id_data:\n")
            sys.stderr.write("POST {0}\n".format(id_url))
            sys.stderr.write(id_data)
            sys.stderr.write("\n")
            sys.stderr.write(e.read())
            sys.stderr.write("\n")
            raise
        challenge = [c for c in result['challenges'] if c['type'] == "http-01"][0]
        keyauthorization = "{0}.{1}".format(challenge['token'], thumbprint)

        # challenge request
        sys.stderr.write("Building challenge responses for {0}...\n".format(i['domain']))
        test_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
        test_raw = json.dumps({
            "resource": "challenge",
            "keyAuthorization": keyauthorization,
        }, sort_keys=True, indent=4)
        test_b64 = _b64(test_raw)
        test_protected = copy.deepcopy(header)
        test_protected.update({"nonce": test_nonce})
        test_protected64 = _b64(json.dumps(test_protected, sort_keys=True, indent=4))
        test_file = tempfile.NamedTemporaryFile(dir=".", prefix="challenge_", suffix=".json")
        test_file.write("{0}.{1}".format(test_protected64, test_b64))
        test_file.flush()
        test_file_name = os.path.basename(test_file.name)
        test_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="challenge_", suffix=".sig")
        test_file_sig_name = os.path.basename(test_file_sig.name)
        tests.append({
            "uri": challenge['uri'],
            "protected64": test_protected64,
            "data64": test_b64,
            "file": test_file,
            "file_name": test_file_name,
            "sig": test_file_sig,
            "sig_name": test_file_sig_name,
        })

        # challenge response for server
        responses.append({
            "uri": ".well-known/acme-challenge/{0}".format(challenge['token']),
            "data": keyauthorization,
        })

    # Step 9: Ask the user to sign the challenge responses
    sys.stderr.write("""\
STEP 3: You need to sign some more files (replace 'user.key' with your user private key).

{0}

""".format(
    "\n".join("openssl dgst -sha256 -sign user.key -out {0} {1}".format(
        i['sig_name'], i['file_name']) for i in tests)))

    stdout = sys.stdout
    sys.stdout = sys.stderr
    raw_input("Press Enter when you've run the above commands in a new terminal window...")
    sys.stdout = stdout

    # Step 10: Load the response signatures
    for n, i in enumerate(ids):
        tests[n]['sig'].seek(0)
        tests[n]['sig64'] = _b64(tests[n]['sig'].read())

    # Step 11: Ask the user to host the token on their server
    for n, i in enumerate(ids):
        if file_based:
            sys.stderr.write("""\
STEP {0}: Please update your server to serve the following file at this URL:

--------------
URL: http://{1}/{2}
File contents: \"{3}\"
--------------

Notes:
- Do not include the quotes in the file.
- The file should be one line without any spaces.

""".format(n + 4, i['domain'], responses[n]['uri'], responses[n]['data']))

            stdout = sys.stdout
            sys.stdout = sys.stderr
            raw_input("Press Enter when you've got the file hosted on your server...")
            sys.stdout = stdout
        else:
            sys.stderr.write("""\
STEP {0}: You need to run this command on {1} (don't stop the python command until the next step).

sudo python -c "import BaseHTTPServer; \\
    h = BaseHTTPServer.BaseHTTPRequestHandler; \\
    h.do_GET = lambda r: r.send_response(200) or r.end_headers() or r.wfile.write('{2}'); \\
    s = BaseHTTPServer.HTTPServer(('0.0.0.0', 80), h); \\
    s.serve_forever()"

""".format(n + 4, i['domain'], responses[n]['data']))

            stdout = sys.stdout
            sys.stdout = sys.stderr
            raw_input("Press Enter when you've got the python command running on your server...")
            sys.stdout = stdout

        # Step 12: Let the CA know you're ready for the challenge
        sys.stderr.write("Requesting verification for {0}...\n".format(i['domain']))
        test_data = json.dumps({
            "header": header,
            "protected": tests[n]['protected64'],
            "payload": tests[n]['data64'],
            "signature": tests[n]['sig64'],
        }, sort_keys=True, indent=4)
        test_url = tests[n]['uri']
        try:
            resp = urllib2.urlopen(test_url, test_data)
            test_result = json.loads(resp.read())
        except urllib2.HTTPError as e:
            sys.stderr.write("Error: test_data:\n")
            sys.stderr.write("POST {0}\n".format(test_url))
            sys.stderr.write(test_data)
            sys.stderr.write("\n")
            sys.stderr.write(e.read())
            sys.stderr.write("\n")
            raise

        # Step 13: Wait for CA to mark test as valid
        sys.stderr.write("Waiting for {0} challenge to pass...\n".format(i['domain']))
        while True:
            try:
                resp = urllib2.urlopen(test_url)
                challenge_status = json.loads(resp.read())
            except urllib2.HTTPError as e:
                sys.stderr.write("Error: test_data:\n")
                sys.stderr.write("GET {0}\n".format(test_url))
                sys.stderr.write(test_data)
                sys.stderr.write("\n")
                sys.stderr.write(e.read())
                sys.stderr.write("\n")
                raise
            if challenge_status['status'] == "pending":
                time.sleep(2)
            elif challenge_status['status'] == "valid":
                sys.stderr.write("Passed {0} challenge!\n".format(i['domain']))
                break
            else:
                raise KeyError("'{0}' challenge did not pass: {1}".format(i['domain'], challenge_status))

    # Step 14: Get the certificate signed
    sys.stderr.write("Requesting signature...\n")
    csr_file_sig.seek(0)
    csr_sig64 = _b64(csr_file_sig.read())
    csr_data = json.dumps({
        "header": header,
        "protected": csr_protected64,
        "payload": csr_b64,
        "signature": csr_sig64,
    }, sort_keys=True, indent=4)
    csr_url = "{0}/acme/new-cert".format(CA)
    try:
        resp = urllib2.urlopen(csr_url, csr_data)
        signed_der = resp.read()
    except urllib2.HTTPError as e:
        sys.stderr.write("Error: csr_data:\n")
        sys.stderr.write("POST {0}\n".format(csr_url))
        sys.stderr.write(csr_data)
        sys.stderr.write("\n")
        sys.stderr.write(e.read())
        sys.stderr.write("\n")
        raise

    # Step 15: Convert the signed cert from DER to PEM
    sys.stderr.write("Certificate signed!\n")

    if file_based:
        sys.stderr.write("You can remove the acme-challenge file from your webserver now.\n")
    else:
        sys.stderr.write("You can stop running the python command on your server (Ctrl+C works).\n")

    signed_der64 = base64.b64encode(signed_der)
    signed_pem = """\
-----BEGIN CERTIFICATE-----
{0}
-----END CERTIFICATE-----
""".format("\n".join(textwrap.wrap(signed_der64, 64)))

    return signed_pem

Example 15

View license
def main():
    ''' main '''
    #UPDATE
    # Create a feature layer from the input point features if it is not one already
    #df = arcpy.mapping.ListDataFrames(mxd)[0]
    pointFeatureName = os.path.basename(pointFeatures)
    layerExists = False

    try:
        # Check that area to number is a polygon
        descArea = arcpy.Describe(areaToNumber)
        areaGeom = descArea.shapeType
        arcpy.AddMessage("Shape type: " + str(areaGeom))
        if (descArea.shapeType != "Polygon"):
            raise Exception("ERROR: The area to number must be a polygon.")

        gisVersion = arcpy.GetInstallInfo()["Version"]
        global appEnvironment
        appEnvironment = Utilities.GetApplication()
        if DEBUG == True: arcpy.AddMessage("App environment: " + appEnvironment)

        global mxd
        global df
        global aprx
        global mp
        global mapList
        # mxd, df, aprx, mp = None, None, None, None
        #if gisVersion == "1.0": #Pro:
        if appEnvironment == "ARCGIS_PRO":
            from arcpy import mp
            aprx = arcpy.mp.ArcGISProject("CURRENT")
            mapList = aprx.listMaps()[0]
            for lyr in mapList.listLayers():
                if lyr.name == pointFeatureName:
                    layerExists = True
        #else:
        if appEnvironment == "ARCMAP":
            from arcpy import mapping
            mxd = arcpy.mapping.MapDocument('CURRENT')
            df = arcpy.mapping.ListDataFrames(mxd)[0]
            for lyr in arcpy.mapping.ListLayers(mxd):
                if lyr.name == pointFeatureName:
                    layerExists = True

        if layerExists == False:
            arcpy.MakeFeatureLayer_management(pointFeatures, pointFeatureName)
        else:
            pointFeatureName = pointFeatures

        # Select all the points that are inside of area
        arcpy.AddMessage("Selecting points from (" + str(os.path.basename(pointFeatureName)) +\
                         ") inside of the area (" + str(os.path.basename(areaToNumber)) + ")")
        selectionLayer = arcpy.SelectLayerByLocation_management(pointFeatureName, "INTERSECT",
                                                                areaToNumber, "#", "NEW_SELECTION")
        if DEBUG == True:
            arcpy.AddMessage("Selected " + str(arcpy.GetCount_management(pointFeatureName).getOutput(0)) + " points")

        # If no output FC is specified, then set it a temporary one, as this will be copied to the input and then deleted.
        # Sort layer by upper right across and then down spatially,
        overwriteFC = False
        global outputFeatureClass
        if outputFeatureClass == "":
            outputFeatureClass = "tempSortedPoints"
            overwriteFC = True;
        arcpy.AddMessage("Sorting the selected points geographically, right to left, top to bottom")
        arcpy.Sort_management(selectionLayer, outputFeatureClass, [["Shape", "ASCENDING"]])


        # Number the fields
        arcpy.AddMessage("Numbering the fields")
        i = 1
        cursor = arcpy.UpdateCursor(outputFeatureClass)
        for row in cursor:
            row.setValue(numberingField, i)
            cursor.updateRow(row)
            i += 1


        # Clear the selection
        arcpy.AddMessage("Clearing the selection")
        arcpy.SelectLayerByAttribute_management(pointFeatureName, "CLEAR_SELECTION")


        # Overwrite the Input Point Features, and then delete the temporary output feature class
        targetLayerName = ""
        if (overwriteFC):
            arcpy.AddMessage("Copying the features to the input, and then deleting the temporary feature class")
            desc = arcpy.Describe(pointFeatureName)
            overwriteFC = os.path.join(os.path.sep, desc.path, pointFeatureName)
            fields = (numberingField, "[email protected]")
            overwriteCursor = arcpy.da.UpdateCursor(overwriteFC, fields)
            for overwriteRow in overwriteCursor:
                sortedPointsCursor = arcpy.da.SearchCursor(outputFeatureClass, fields)
                for sortedRow in sortedPointsCursor:
                    if sortedRow[1].equals(overwriteRow[1]):
                        overwriteRow[0] = sortedRow[0]
                overwriteCursor.updateRow(overwriteRow)
            arcpy.Delete_management(outputFeatureClass)

            #UPDATE
            #if layerExists == False:
                #layerToAdd = arcpy.mapping.Layer(pointFeatureName)
                #arcpy.mapping.AddLayer(df, layerToAdd, "AUTO_ARRANGE")
            targetLayerName = pointFeatureName
        else:
            #UPDATE
            #layerToAdd = arcpy.mapping.Layer(outputFeatureClass)
            #arcpy.mapping.AddLayer(df, layerToAdd, "AUTO_ARRANGE")
            targetLayerName = os.path.basename(outputFeatureClass)


        # Get and label the output feature
        if appEnvironment == "ARCGIS_PRO":
            results = arcpy.MakeFeatureLayer_management(outputFeatureClass, targetLayerName).getOutput(0)
            mapList.addLayer(results, "AUTO_ARRANGE")
            layer = findLayerByName(targetLayerName)
            if(layer):
                labelFeatures(layer, numberingField)
        elif appEnvironment == "ARCMAP":
            arcpy.AddMessage("Adding features to map (" + str(targetLayerName) + ")...")
            arcpy.MakeFeatureLayer_management(outputFeatureClass, targetLayerName)
            layer = arcpy.mapping.Layer(targetLayerName)
            arcpy.mapping.AddLayer(df, layer, "AUTO_ARRANGE")
            arcpy.AddMessage("Labeling output features (" + str(targetLayerName) + ")...")
            layer = findLayerByName(targetLayerName)
            if (layer):
                labelFeatures(layer, numberingField)
        else:
            arcpy.AddMessage("Non-map application, skipping labeling...")


        arcpy.SetParameter(3, outputFeatureClass)


    except arcpy.ExecuteError:
        # Get the tool error messages
        msgs = arcpy.GetMessages()
        arcpy.AddError(msgs)
        print(msgs)

    except:
        # Get the traceback object
        tb = sys.exc_info()[2]
        tbinfo = traceback.format_tb(tb)[0]

        # Concatenate information together concerning the error into a message string
        pymsg = "PYTHON ERRORS:\nTraceback info:\n" + tbinfo + "\nError Info:\n" + str(sys.exc_info()[1])
        msgs = "ArcPy ERRORS:\n" + arcpy.GetMessages() + "\n"

        # Return python error messages for use in script tool or Python Window
        arcpy.AddError(pymsg)
        arcpy.AddError(msgs)

        # Print Python error messages for use in Python / Python Window
        print(pymsg + "\n")
        print(msgs)

Example 16

Project: C-PAC
Source File: extract_data.py
View license
def extract_data(c, param_map):
    """
    Method to generate a CPAC input subject list
    python file. The method extracts anatomical
    and functional data for each site( if multiple site)
    and/or scan parameters for each site and put it into
    a data structure read by python

    Example:
    subjects_list =[
       {
        'subject_id' : '0050386',
        'unique_id' : 'session_1',
        'anat': '/Users/home/data/NYU/0050386/session_1/anat_1/anat.nii.gz',
        'rest':{
            'rest_1_rest' : '/Users/home/data/NYU/0050386/session_1/rest_1/rest.nii.gz',
            'rest_2_rest' : '/Users/home/data/NYU/0050386/session_1/rest_2/rest.nii.gz',
            }
        'scan_parameters':{
            'tr': '2',
            'acquisition': 'alt+z2',
            'reference': '17',
            'first_tr': '',
            'last_tr': '',
            }
        },
    ]

    or

    subjects_list =[
       {
        'subject_id' : '0050386',
        'unique_id' : 'session_1',
        'anat': '/Users/home/data/NYU/0050386/session_1/anat_1/anat.nii.gz',
        'rest':{
            'rest_1_rest' : '/Users/home/data/NYU/0050386/session_1/rest_1/rest.nii.gz',
            'rest_2_rest' : '/Users/home/data/NYU/0050386/session_1/rest_2/rest.nii.gz',
            }
          },
    ]

    """

    #method to read each line of the file into list
    #returns list
    def get_list(arg):
        if isinstance(arg, list):
            ret_list = arg
        else:
            ret_list = [fline.rstrip('\r\n') for fline in open(arg, 'r').readlines()]

        return ret_list

    exclusion_list = []
    if c.exclusionSubjectList is not None:
        exclusion_list = get_list(c.exclusionSubjectList)

    subject_list = []
    if c.subjectList is not None:
        subject_list = get_list(c.subjectList)

    #check if Template is correct
    def checkTemplate(template):

        if template.count('%s') != 2:
            msg = "Please provide '%s' in the template" \
                  "where your site and subjects are present"\
                  "Please see examples"
            logging.exception(msg)
            raise Exception(msg)

        filename, ext = os.path.splitext(os.path.basename(template))
        ext = os.path.splitext(filename)[1] + ext

        if ext not in [".nii", ".nii.gz"]:
            msg = "Invalid file name", os.path.basename(template)
            logging.exception(msg)
            raise Exception(msg)

    def get_site_list(path):
        base, relative = path.split('%s')
        sites = os.listdir(base)
        return sites
    
    def check_length(scan_name, file_name):
        
        if len(file_name) > 30:
            msg = "filename- %s is too long."\
                   "It should not be more than 30 characters."%(file_name)
            logging.exception(msg)
            raise Exception(msg)
        
        if len(scan_name) - len(os.path.splitext(os.path.splitext(file_name)[0])[0])>= 40:
            msg = "scan name %s is too long."\
                  "It should not be more than 20 characters"\
                  %(scan_name.replace("_"+os.path.splitext(os.path.splitext(file_name)[0])[0], ''))
            logging.exception(msg)
            raise Exception(msg)

    def create_site_subject_mapping(base, relative):

        #mapping between site and subject
        site_subject_map = {}
        base_path_list = []

        if c.siteList is not None:
            site_list = get_list(c.siteList)
        else:
            site_list = get_site_list(base)

        for site in site_list:
            paths = glob.glob(string.replace(base, '%s', site))
            base_path_list.extend(paths)
            for path in paths:
                for sub in os.listdir(path):
                    #check if subject is present in subject_list
                    if subject_list:
                        if sub in subject_list and sub not in exclusion_list:
                            site_subject_map[sub] = site
                    elif sub not in exclusion_list:
                        if sub not in '.DS_Store':
                            site_subject_map[sub] = site

        return base_path_list, site_subject_map

    #method to split the input template path
    #into base, path before subject directory
    #and relative, path after subject directory
    def getPath(template):

        checkTemplate(template)
        base, relative = template.rsplit("%s", 1)
        base, subject_map = create_site_subject_mapping(base, relative)
        base.sort()
        relative = relative.lstrip("/")
        return base, relative, subject_map

    #get anatomical base path and anatomical relative path
    anat_base, anat_relative = getPath(c.anatomicalTemplate)[:2]

    #get functional base path, functional relative path and site-subject map
    func_base, func_relative, subject_map = getPath(c.functionalTemplate)

    if not anat_base:
        msg = "Anatomical Data template incorrect. No such file or directory %s", anat_base
        logging.exception(msg)
        raise Exception(msg)

    if not func_base:
        msg = "Functional Data template incorrect. No such file or directory %s, func_base"
        logging.exception(msg)
        raise Exception(msg)
        
    if len(anat_base) != len(func_base):
        msg1 = "Some sites are missing, Please check your template"\
              , anat_base, "!=", func_base
        logging.exception(msg1)
        
        msg2 = " Base length Unequal. Some sites are missing."\
               "extract_data doesn't script support this.Please" \
               "Provide your own subjects_list file"
        logging.exception(msg2)
        raise Exception(msg2)

    #calculate the length of relative paths(path after subject directory)
    func_relative_len = len(func_relative.split('/'))
    anat_relative_len = len(anat_relative.split('/'))

    def check_for_sessions(relative_path, path_length):
        """
        Method to check if there are sessions present

        """
        #default
        session_present = False
        session_path = 'session_1'

        #session present if path_length is equal to 3
        if path_length == 3:
            relative_path_list = relative_path.split('/')
            session_path = relative_path_list[0]
            relative_path = string.join(relative_path_list[1:], "/")
            session_present = True
        elif path_length > 3:
            msg = "extract_data script currently doesn't support this directory structure."\
                  "Please provide the subjects_list file to run CPAC."\
                  "For more information refer to manual"
            logging.exception(msg)
            raise Exception(msg)
        return session_present, session_path, relative_path

    func_session_present, func_session_path, func_relative = \
        check_for_sessions(func_relative, func_relative_len)

    anat_session_present, anat_session_path, anat_relative = \
        check_for_sessions(anat_relative, anat_relative_len)

    f = open(os.path.join(c.outputSubjectListLocation, "CPAC_subject_list_%s.yml" % c.subjectListName[0]), 'wb')



    def fetch_path(i, anat_sub, func_sub, session_id):
        """
        Method to extract anatomical and functional
        path for a session and print to file

        Parameters
        ----------
        i : int
            index of site
        anat_sub : string
            string containing subject/ concatenated
            subject-session path for anatomical file
        func_sub: string
            string containing subject/ concatenated
            subject-session path for functional file
        session_id: string
            session

        Raises
        ------
        Exception
        """

        try:

            def print_begin_of_file(sub, session_id):
                print >> f, "-"
                print >> f, "    subject_id: '" + sub + "'"
                print >> f, "    unique_id: '" + session_id + "'" 

            def print_end_of_file(sub):
                if param_map is not None:
                    try:
                        logging.debug("site for sub %s -> %s" %(sub, subject_map.get(sub)))
                        logging.debug("scan parameters for the above site %s"%param_map.get(subject_map.get(sub)))
                        print >> f, "    scan_parameters:"
                        print >> f, "        tr: '" + param_map.get(subject_map.get(sub))[4] + "'" 
                        print >> f, "        acquisition: '" + param_map.get(subject_map.get(sub))[0] + "'" 
                        print >> f, "        reference: '" + param_map.get(subject_map.get(sub))[3] + "'"
                        print >> f, "        first_tr: '" + param_map.get(subject_map.get(sub))[1] + "'"
                        print >> f, "        last_tr: '" + param_map.get(subject_map.get(sub))[2] + "'"
                    except:
                        msg = " No Parameter values for the %s site is defined in the scan"\
                              " parameters csv file" %subject_map.get(sub)
                        raise ValueError(msg)

            #get anatomical file
            anat_base_path = os.path.join(anat_base[i], anat_sub)
            func_base_path = os.path.join(func_base[i], func_sub)

            anat = None
            func = None

            anat = glob.glob(os.path.join(anat_base_path, anat_relative))
            func = glob.glob(os.path.join(func_base_path, func_relative))

            if anat and func:
                print_begin_of_file(anat_sub.split("/")[0], session_id)
                print >> f, "    anat: '" + os.path.realpath(anat[0]) + "'"
                print >> f, "    rest: "

                #iterate for each rest session
                for iter in func:
                    #get scan_id
                    iterable = os.path.splitext(os.path.splitext(iter.replace(func_base_path, '').lstrip("/"))[0])[0]
                    iterable = iterable.replace("/", "_")
                    check_length(iterable, os.path.basename(os.path.realpath(iter)))
                    print>>f, "      " + iterable + ": '" + os.path.realpath(iter) + "'"
                
                print_end_of_file(anat_sub.split("/")[0])
                
            else:
                logging.debug("skipping subject %s"%anat_sub.split("/")[0])
        
        except ValueError:

            logging.exception(ValueError.message)
            raise

        except Exception, e:

            err_msg = 'Exception while felching anatomical and functional ' \
                      'paths: \n' + str(e)

            logging.exception(err_msg)
            raise Exception(err_msg)



    def walk(index, sub):
        """
        Method which walks across each subject
        path in the data site path

        Parameters
        ----------
        index : int
            index of site
        sub : string
            subject_id

        Raises
        ------
        Exception
        """
        try:

            if func_session_present:
                #if there are sessions
                if "*" in func_session_path:
                    session_list = glob.glob(os.path.join(func_base[index], os.path.join(sub, func_session_path)))
                else:
                    session_list = [func_session_path]

                if session_list:
                    for session in session_list:
                        session_id = os.path.basename(session)
                        if anat_session_present:
                            if func_session_path == anat_session_path:
                                fetch_path(index, os.path.join(sub, session_id), os.path.join(sub, session_id), session_id)
                            else:
                                fetch_path(index, os.path.join(sub, anat_session_path), os.path.join(sub, session_id), session_id)
                        else:
                            fetch_path(index, sub, os.path.join(sub, session_id), session_id)
                else:
                    logging.debug("Skipping subject %s", sub)

            else:
                logging.debug("No sessions")
                session_id = ''
                fetch_path(index, sub, sub, session_id)

        except Exception:

            logging.exception(Exception.message)
            raise

        except:

            err_msg = 'Please make sessions are consistent across all ' \
                      'subjects.\n\n'

            logging.exception(err_msg)
            raise Exception(err_msg)


    try:
        for i in range(len(anat_base)):
            for sub in os.listdir(anat_base[i]):
                #check if subject is present in subject_list
                if subject_list:
                    if sub in subject_list and sub not in exclusion_list:
                        logging.debug("extracting data for subject: %s", sub)
                        walk(i, sub)
                #check that subject is not in exclusion list
                elif sub not in exclusion_list and sub not in '.DS_Store':
                    logging.debug("extracting data for subject: %s", sub)
                    walk(i, sub)

        
        name = os.path.join(c.outputSubjectListLocation, 'CPAC_subject_list.yml')
        print "Extraction Successfully Completed...Input Subjects_list for CPAC - %s" % name

    except Exception:

        logging.exception(Exception.message)
        raise

    finally:

        f.close()

Example 17

Project: C-PAC
Source File: extract_data_multiscan.py
View license
def extract_data(c, param_map):
    """
    Method to generate a CPAC input subject list
    python file. The method extracts anatomical
    functional data and scan parameters for each 
    site( if multiple site) and for each scan 
    and put it into a data structure read by python

    Note:
    -----
    Use this tool only if the scan parameters are different 
    for each scan as shown in the example below.
    
    Example:
    --------
    subjects_list = [
        {
            'subject_id': '0021001',
            'unique_id': 'session2',
            'anat': '/home/data/multiband_data/NKITRT/0021001/anat/mprage.nii.gz',
            'rest':{
              'RfMRI_mx_1400_rest': '/home/data/multiband_data/NKITRT/0021001/session2/RfMRI_mx_1400/rest.nii.gz',
              'RfMRI_mx_645_rest': '/home/data/multiband_data/NKITRT/0021001/session2/RfMRI_mx_645/rest.nii.gz',
              'RfMRI_std_2500_rest': '/home/data/multiband_data/NKITRT/0021001/session2/RfMRI_std_2500/rest.nii.gz',
              },
            'scan_parameters':{
                'TR':{
                    'RfMRI_mx_1400_rest': '1.4',
                    'RfMRI_mx_645_rest': '1.4',
                    'RfMRI_std_2500_rest': '2.5',
                    },
                'Acquisition':{
                    'RfMRI_mx_1400_rest': '/home/data/1400.txt',
                    'RfMRI_mx_645_rest': '/home/data/645.txt',
                    'RfMRI_std_2500_rest': '/home/data/2500.txt',
                    },
                'Reference':{
                    'RfMRI_mx_1400_rest': '32',
                    'RfMRI_mx_645_rest': '20',
                    'RfMRI_std_2500_rest': '19',
                    },
                'FirstTR':{
                    'RfMRI_mx_1400_rest': '7',
                    'RfMRI_mx_645_rest': '15',
                    'RfMRI_std_2500_rest': '4',
                    },
                'LastTR':{
                    'RfMRI_mx_1400_rest': '440',
                    'RfMRI_mx_645_rest': '898',
                    'RfMRI_std_2500_rest': 'None',
                    },
                }
        },

    ]
    """

    #method to read each line of the file into list
    #returns list
    def get_list(arg):
        if isinstance(arg, list):
            ret_list = arg
        else:
            ret_list = [fline.rstrip('\r\n') for fline in open(arg, 'r').readlines()]

        return ret_list

    exclusion_list = []
    if c.exclusionSubjectList is not None:
        exclusion_list = get_list(c.exclusionSubjectList)

    subject_list = []
    if c.subjectList is not None:
        subject_list = get_list(c.subjectList)

    #check if Template is correct
    def checkTemplate(template):

        if template.count('%s') != 2:
            raise Exception("Please provide '%s' in the template" \
                            "where your site and subjects are present"\
                            "Please see examples")

        filename, ext = os.path.splitext(os.path.basename(template))
        ext = os.path.splitext(filename)[1] + ext

        if ext not in [".nii", ".nii.gz"]:
            raise Exception("Invalid file name", os.path.basename(template))

    def get_site_list(path):
        base = path.split('%s')[0]
        sites = os.listdir(base)
        return sites

    def check_length(scan_name, file_name):
               
        if len(file_name) > 30:
            msg = "filename- %s is too long."\
                   "It should not be more than 30 characters."%(file_name)
            raise Exception(msg)
        
        if len(scan_name) - len(os.path.splitext(os.path.splitext(file_name)[0])[0])>= 20:
            msg = "scan name %s is too long."\
                  "It should not be more than 20 characters"\
                  %(scan_name.replace("_"+os.path.splitext(os.path.splitext(file_name)[0])[0], ''))
            raise Exception(msg)
        


    def create_site_subject_mapping(base, relative):

        #mapping between site and subject
        site_subject_map = {}
        base_path_list = []

        if c.siteList is not None:
            site_list = get_list(c.siteList)
        else:
            site_list = get_site_list(base)

        for site in site_list:
            paths = glob.glob(string.replace(base, '%s', site))
            base_path_list.extend(paths)
            for path in paths:
                for sub in os.listdir(path):
                    #check if subject is present in subject_list
                    if subject_list:
                        if sub in subject_list and sub not in exclusion_list:
                            site_subject_map[sub] = site
                    elif sub not in exclusion_list:
                        if sub not in '.DS_Store':
                            site_subject_map[sub] = site

        return base_path_list, site_subject_map

    #method to split the input template path
    #into base, path before subject directory
    #and relative, path after subject directory
    def getPath(template):

        checkTemplate(template)
        base, relative = template.rsplit("%s", 1)
        base, subject_map = create_site_subject_mapping(base, relative)
        base.sort()
        relative = relative.lstrip("/")
        return base, relative, subject_map

    #get anatomical base path and anatomical relative path
    anat_base, anat_relative = getPath(c.anatomicalTemplate)[:2]

    #get functional base path, functional relative path and site-subject map
    func_base, func_relative, subject_map = getPath(c.functionalTemplate)

    if not anat_base:
        print "No such file or directory ", anat_base
        raise Exception("Anatomical Data template incorrect")

    if not func_base:
        print "No such file or directory", func_base
        raise Exception("Functional Data template incorrect")

    if len(anat_base) != len(func_base):
        print "Some sites are missing, Please check your"\
              "template", anat_base, "!=", func_base
        raise Exception(" Base length Unequal. Some sites are missing."\
                           "extract_data doesn't script support this.Please" \
                           "Provide your own subjects_list file")

    #calculate the length of relative paths(path after subject directory)
    func_relative_len = len(func_relative.split('/'))
    anat_relative_len = len(anat_relative.split('/'))

    def check_for_sessions(relative_path, path_length):
        """
        Method to check if there are sessions present

        """
        #default
        session_present = False
        session_path = 'session_1'

        #session present if path_length is equal to 3
        if path_length == 3:
            relative_path_list = relative_path.split('/')
            session_path = relative_path_list[0]
            relative_path = string.join(relative_path_list[1:], "/")
            session_present = True
        elif path_length > 3:
            raise Exception("extract_data script currently doesn't support"\
                             "this directory structure.Please provide the"\
                             "subjects_list file to run CPAC." \
                             "For more information refer to manual")

        return session_present, session_path, relative_path

#    if func_relative_len!= anat_relative_len:
#        raise Exception(" extract_data script currently doesn't"\
#                          "support different relative paths for"\
#                          "Anatomical and functional files")

    func_session_present, func_session_path, func_relative = \
        check_for_sessions(func_relative, func_relative_len)

    anat_session_present, anat_session_path, anat_relative = \
        check_for_sessions(anat_relative, anat_relative_len)

    f = open(os.path.join(c.outputSubjectListLocation, "CPAC_subject_list.yml"), 'wb')

    def fetch_path(i, anat_sub, func_sub, session_id):
        """
        Method to extract anatomical and functional
        path for a session and print to file

        Parameters
        ----------
        i : int
            index of site
        anat_sub : string
            string containing subject/ concatenated
            subject-session path for anatomical file
        func_sub: string
            string containing subject/ concatenated
            subject-session path for functional file
        session_id: string
            session

        Raises
        ------
        Exception
        """

        try:

            def print_begin_of_file(sub, session_id):
                print >> f, "-"
                print >> f, "    subject_id: '" + sub + "'"
                print >> f, "    unique_id: '" + session_id + "'"

            def print_end_of_file(sub, scan_list):
                if param_map is not None:
                    def print_scan_param(index):
                        try:
                            for scan in scan_list:
                                print>>f,  "            " + scan[1] + ": '" + \
                                param_map.get((subject_map.get(sub), scan[0]))[index] + "'"
                        
                        except:
                            raise Exception(" No Parameter values for the %s site and %s scan is defined in the scan"\
                                            " parameters csv file" % (subject_map.get(sub), scan[0]))

                    print "site for sub", sub, "->", subject_map.get(sub)
                    print >>f, "    scan_parameters: "
                    print >> f, "        tr:" 
                    print_scan_param(4) 
                    print >> f, "        acquisition:" 
                    print_scan_param(0) 
                    print >> f, "        reference:" 
                    print_scan_param(3) 
                    print >> f, "        first_tr:" 
                    print_scan_param(1) 
                    print >> f, "        last_tr:" 
                    print_scan_param(2) 

 
            #get anatomical file
            anat_base_path = os.path.join(anat_base[i], anat_sub)
            func_base_path = os.path.join(func_base[i], func_sub)

            anat = None
            func = None

            anat = glob.glob(os.path.join(anat_base_path, anat_relative))
            func = glob.glob(os.path.join(func_base_path, func_relative))
            scan_list = []
            if anat and func:
                print_begin_of_file(anat_sub.split("/")[0], session_id)
                print >> f, "    anat: '" + anat[0] + "'" 
                print >>f, "    rest: "

                #iterate for each rest session
                for iter in func:
                    #get scan_id
                    iterable = os.path.splitext(os.path.splitext(iter.replace(func_base_path,'').lstrip("/"))[0])[0]
                    scan_name = iterable.replace("/", "_")
                    scan_list.append((os.path.dirname(iterable), scan_name))
                    check_length(scan_name, os.path.basename(iter))
                    print>>f,  "      " + scan_name + ": '" + iter +  "'"
                print_end_of_file(anat_sub.split("/")[0], scan_list)

        except Exception:
            raise

    def walk(index, sub):
        """
        Method which walks across each subject
        path in the data site path

        Parameters
        ----------
        index : int
            index of site
        sub : string
            subject_id

        Raises
        ------
        Exception
        """
        try:

            if func_session_present:
                #if there are sessions
                if "*" in func_session_path:
                    session_list = glob.glob(os.path.join(func_base[index], os.path.join(sub, func_session_path)))
                else:
                    session_list = [func_session_path]

                for session in session_list:
                    session_id = os.path.basename(session)
                    if anat_session_present:
                        if func_session_path == anat_session_path:
                            fetch_path(index, os.path.join(sub, session_id), os.path.join(sub, session_id), session_id)
                        else:
                            fetch_path(index, os.path.join(sub, anat_session_path), os.path.join(sub, session_id), session_id)
                    else:
                        fetch_path(index, sub, os.path.join(sub, session_id), session_id)
            else:
                print "No sessions"
                session_id = ''
                fetch_path(index, sub, sub, session_id)

        except Exception:
            raise
        except:
            print "Please make sessions are consistent across all subjects"
            raise

    try:
        for i in range(len(anat_base)):
            for sub in os.listdir(anat_base[i]):
                #check if subject is present in subject_list
                if subject_list:
                    if sub in subject_list and sub not in exclusion_list:
                        print "extracting data for subject: ", sub
                        walk(i, sub)
                #check that subject is not in exclusion list
                elif sub not in exclusion_list and sub not in '.DS_Store':
                    print "extracting data for subject: ", sub
                    walk(i, sub)

        
        name = os.path.join(c.outputSubjectListLocation, 'CPAC_subject_list.yml')
        print "Extraction Complete...Input Subjects_list for CPAC - %s" % name
    except Exception:
        raise
    finally:
        f.close()

Example 18

Project: ganga
Source File: DiracFile.py
View license
    def put(self, lfn='', force=False, uploadSE="", replicate=False):
        """
        Try to upload file sequentially to storage elements defined in configDirac['allDiracSE'].
        File will be uploaded to the first SE that the upload command succeeds for.

        The file is uploaded to the SE described by the DiracFile.defaultSE attribute

        Alternatively, the user can specify an uploadSE which contains an SE
        which the file is to be uploaded to.

        If the user wants to replicate this file(s) across all SE then they should state replicate = True.

        Return value will be either the stdout from the dirac upload command if not
        using the wildcard characters '*?[]' in the namePattern.
        If the wildcard characters are used then the return value will be a list containing
        newly created DiracFile objects which were the result of glob-ing the wildcards.

        The objects in this list will have been uploaded or had their failureReason attribute populated if the
        upload failed.
        """

        if self.lfn != "" and force == False and lfn == '':
            logger.warning("Warning you're about to 'put' this DiracFile: %s on the grid as it already has an lfn: %s" % (self.namePattern, self.lfn))
            decision = raw_input('y / [n]:')
            while not (decision.lower() in ['y', 'n'] or decision.lower() == ''):
                decision = raw_input('y / [n]:')

            if decision.lower() == 'y':
                pass
            else:
                return

        if (lfn != '' and self.lfn != '') and force == False:
            logger.warning("Warning you're attempting to put this DiracFile: %s" % self.namePattern)
            logger.warning("It currently has an LFN associated with it: %s" % self.lfn)
            logger.warning("Do you want to continue and attempt to upload to: %s" % lfn)
            decision = raw_input('y / [n]:')
            while not (decision.lower() in ['y', 'n', '']):
                decision = raw_input('y / [n]:')

            if decision.lower() == 'y':
                pass
            else:
                return

        if lfn and os.path.basename(lfn) != self.namePattern:
            logger.warning("Changing namePattern from: '%s' to '%s' during put operation" % (self.namePattern, os.path.basename(lfn)))

        if lfn:
            self.lfn = lfn

        # looks like will only need this for the interactive uploading of jobs.
        # Also if any backend need dirac upload on client then when downloaded
        # this will upload then delete the file.

        if self.namePattern == "":
            if self.lfn != '':
                logger.warning("'Put'-ing a file with ONLY an existing LFN makes no sense!")
            raise GangaException('Can\'t upload a file without a local file name.')

        sourceDir = self.localDir
        if self.localDir is None:
            sourceDir = os.getcwd()
            # attached to a job, use the joboutputdir
            if self._parent != None and os.path.isdir(self.getJobObject().outputdir):
                sourceDir = self.getJobObject().outputdir

        if not os.path.isdir(sourceDir):
            raise GangaException('localDir attribute is not a valid dir, don\'t know from which dir to take the file')

        if regex.search(self.namePattern) is not None:
            if self.lfn != "":
                logger.warning("Cannot specify a single lfn for a wildcard namePattern")
                logger.warning("LFN will be generated automatically")
                self.lfn = ""

        if not self.remoteDir:
            try:
                job = self.getJobObject()
                lfn_folder = os.path.join("GangaUploadedFiles", "GangaJob_%s" % job.getFQID('.'))
            except AssertionError:
                t = datetime.datetime.now()
                this_date = t.strftime("%H.%M_%A_%d_%B_%Y")
                lfn_folder = os.path.join("GangaUploadedFiles", 'GangaFiles_%s' % this_date)
            self.lfn = os.path.join(DiracFile.diracLFNBase(), lfn_folder, self.namePattern)

        if self.remoteDir[:4] == 'LFN:':
            lfn_base = self.remoteDir[4:]
        else:
            lfn_base = self.remoteDir

        if uploadSE == "":
            if self.defaultSE != "":
                storage_elements = [self.defaultSE]
            else:
                if configDirac['allDiracSE']:
                    storage_elements = [random.choice(configDirac['allDiracSE'])]
                else:
                    raise GangaException("Can't upload a file without a valid defaultSE or storageSE, please provide one")
        elif isinstance(uploadSE, list):
            storage_elements = uploadSE
        else:
            storage_elements = [uploadSE]

        outputFiles = GangaList()
        for this_file in glob.glob(os.path.join(sourceDir, self.namePattern)):
            name = this_file

            if not os.path.exists(name):
                if not self.compressed:
                    raise GangaException('Cannot upload file. File "%s" must exist!' % name)
                name += '.gz'
                if not os.path.exists(name):
                    raise GangaException('File "%s" must exist!' % name)
            else:
                if self.compressed:
                    os.system('gzip -c %s > %s.gz' % (name, name))
                    name += '.gz'
                    if not os.path.exists(name):
                        raise GangaException('File "%s" must exist!' % name)

            if lfn == "":
                lfn = os.path.join(lfn_base, os.path.basename(name))

            #lfn = os.path.join(os.path.dirname(self.lfn), this_file)

            d = DiracFile()
            d.namePattern = os.path.basename(name)
            d.compressed = self.compressed
            d.localDir = sourceDir
            stderr = ''
            stdout = ''
            logger.debug('Uploading file \'%s\' to \'%s\' as \'%s\'' % (name, storage_elements[0], lfn))
            logger.debug('execute: uploadFile("%s", "%s", %s)' % (lfn, name, str([storage_elements[0]])))
            stdout = execute('uploadFile("%s", "%s", %s)' % (lfn, name, str([storage_elements[0]])))
            if type(stdout) == str:
                logger.warning("Couldn't upload file '%s': \'%s\'" % (os.path.basename(name), stdout))
                continue
            if stdout.get('OK', False) and lfn in stdout.get('Value', {'Successful': {}})['Successful']:
                # when doing the two step upload delete the temp file
                if self.compressed or self._parent != None:
                    os.remove(name)
                # need another eval as datetime needs to be included.
                guid = stdout['Value']['Successful'][lfn].get('GUID', '')
                if regex.search(self.namePattern) is not None:
                    d.lfn = lfn
                    d.remoteDir = os.path.dirname(lfn)
                    d.locations = stdout['Value']['Successful'][lfn].get('allDiracSE', '')
                    d.guid = guid
                    outputFiles.append(GPIProxyObjectFactory(d))
                    continue
                else:
                    self.lfn = lfn
                    self.remoteDir = os.path.dirname(lfn)
                    self.locations = stdout['Value']['Successful'][lfn].get('allDiracSE', '')
                    self.guid = guid
                # return ## WHY?
            else:
                failureReason = "Error in uploading file %s : %s" % (os.path.basename(name), str(stdout))
                logger.error(failureReason)
                if regex.search(self.namePattern) is not None:
                    d.failureReason = failureReason
                    outputFiles.append(GPIProxyObjectFactory(d))
                    continue
                self.failureReason = failureReason
                return str(stdout)

        if replicate == True:

            if len(outputFiles) == 1 or len(outputFiles) == 0:
                storage_elements.pop(0)
                for se in storage_elements:
                    self.replicate(se)
            else:
                storage_elements.pop(0)
                for this_file in outputFiles:
                    for se in storage_elements:
                        this_file.replicate(se)

        if len(outputFiles) > 0:
            return outputFiles
        else:
            outputFiles.append(self)
            return outputFiles

Example 19

Project: ganga
Source File: DiracFile.py
View license
    def put(self, lfn='', force=False, uploadSE="", replicate=False):
        """
        Try to upload file sequentially to storage elements defined in configDirac['allDiracSE'].
        File will be uploaded to the first SE that the upload command succeeds for.

        The file is uploaded to the SE described by the DiracFile.defaultSE attribute

        Alternatively, the user can specify an uploadSE which contains an SE
        which the file is to be uploaded to.

        If the user wants to replicate this file(s) across all SE then they should state replicate = True.

        Return value will be either the stdout from the dirac upload command if not
        using the wildcard characters '*?[]' in the namePattern.
        If the wildcard characters are used then the return value will be a list containing
        newly created DiracFile objects which were the result of glob-ing the wildcards.

        The objects in this list will have been uploaded or had their failureReason attribute populated if the
        upload failed.
        """

        if self.lfn != "" and force == False and lfn == '':
            logger.warning("Warning you're about to 'put' this DiracFile: %s on the grid as it already has an lfn: %s" % (self.namePattern, self.lfn))
            decision = raw_input('y / [n]:')
            while not (decision.lower() in ['y', 'n'] or decision.lower() == ''):
                decision = raw_input('y / [n]:')

            if decision.lower() == 'y':
                pass
            else:
                return

        if (lfn != '' and self.lfn != '') and force == False:
            logger.warning("Warning you're attempting to put this DiracFile: %s" % self.namePattern)
            logger.warning("It currently has an LFN associated with it: %s" % self.lfn)
            logger.warning("Do you want to continue and attempt to upload to: %s" % lfn)
            decision = raw_input('y / [n]:')
            while not (decision.lower() in ['y', 'n', '']):
                decision = raw_input('y / [n]:')

            if decision.lower() == 'y':
                pass
            else:
                return

        if lfn and os.path.basename(lfn) != self.namePattern:
            logger.warning("Changing namePattern from: '%s' to '%s' during put operation" % (self.namePattern, os.path.basename(lfn)))

        if lfn:
            self.lfn = lfn

        # looks like will only need this for the interactive uploading of jobs.
        # Also if any backend need dirac upload on client then when downloaded
        # this will upload then delete the file.

        if self.namePattern == "":
            if self.lfn != '':
                logger.warning("'Put'-ing a file with ONLY an existing LFN makes no sense!")
            raise GangaException('Can\'t upload a file without a local file name.')

        sourceDir = self.localDir
        if self.localDir is None:
            sourceDir = os.getcwd()
            # attached to a job, use the joboutputdir
            if self._parent != None and os.path.isdir(self.getJobObject().outputdir):
                sourceDir = self.getJobObject().outputdir

        if not os.path.isdir(sourceDir):
            raise GangaException('localDir attribute is not a valid dir, don\'t know from which dir to take the file')

        if regex.search(self.namePattern) is not None:
            if self.lfn != "":
                logger.warning("Cannot specify a single lfn for a wildcard namePattern")
                logger.warning("LFN will be generated automatically")
                self.lfn = ""

        if not self.remoteDir:
            try:
                job = self.getJobObject()
                lfn_folder = os.path.join("GangaUploadedFiles", "GangaJob_%s" % job.getFQID('.'))
            except AssertionError:
                t = datetime.datetime.now()
                this_date = t.strftime("%H.%M_%A_%d_%B_%Y")
                lfn_folder = os.path.join("GangaUploadedFiles", 'GangaFiles_%s' % this_date)
            self.lfn = os.path.join(DiracFile.diracLFNBase(), lfn_folder, self.namePattern)

        if self.remoteDir[:4] == 'LFN:':
            lfn_base = self.remoteDir[4:]
        else:
            lfn_base = self.remoteDir

        if uploadSE == "":
            if self.defaultSE != "":
                storage_elements = [self.defaultSE]
            else:
                if configDirac['allDiracSE']:
                    storage_elements = [random.choice(configDirac['allDiracSE'])]
                else:
                    raise GangaException("Can't upload a file without a valid defaultSE or storageSE, please provide one")
        elif isinstance(uploadSE, list):
            storage_elements = uploadSE
        else:
            storage_elements = [uploadSE]

        outputFiles = GangaList()
        for this_file in glob.glob(os.path.join(sourceDir, self.namePattern)):
            name = this_file

            if not os.path.exists(name):
                if not self.compressed:
                    raise GangaException('Cannot upload file. File "%s" must exist!' % name)
                name += '.gz'
                if not os.path.exists(name):
                    raise GangaException('File "%s" must exist!' % name)
            else:
                if self.compressed:
                    os.system('gzip -c %s > %s.gz' % (name, name))
                    name += '.gz'
                    if not os.path.exists(name):
                        raise GangaException('File "%s" must exist!' % name)

            if lfn == "":
                lfn = os.path.join(lfn_base, os.path.basename(name))

            #lfn = os.path.join(os.path.dirname(self.lfn), this_file)

            d = DiracFile()
            d.namePattern = os.path.basename(name)
            d.compressed = self.compressed
            d.localDir = sourceDir
            stderr = ''
            stdout = ''
            logger.debug('Uploading file \'%s\' to \'%s\' as \'%s\'' % (name, storage_elements[0], lfn))
            logger.debug('execute: uploadFile("%s", "%s", %s)' % (lfn, name, str([storage_elements[0]])))
            stdout = execute('uploadFile("%s", "%s", %s)' % (lfn, name, str([storage_elements[0]])))
            if type(stdout) == str:
                logger.warning("Couldn't upload file '%s': \'%s\'" % (os.path.basename(name), stdout))
                continue
            if stdout.get('OK', False) and lfn in stdout.get('Value', {'Successful': {}})['Successful']:
                # when doing the two step upload delete the temp file
                if self.compressed or self._parent != None:
                    os.remove(name)
                # need another eval as datetime needs to be included.
                guid = stdout['Value']['Successful'][lfn].get('GUID', '')
                if regex.search(self.namePattern) is not None:
                    d.lfn = lfn
                    d.remoteDir = os.path.dirname(lfn)
                    d.locations = stdout['Value']['Successful'][lfn].get('allDiracSE', '')
                    d.guid = guid
                    outputFiles.append(GPIProxyObjectFactory(d))
                    continue
                else:
                    self.lfn = lfn
                    self.remoteDir = os.path.dirname(lfn)
                    self.locations = stdout['Value']['Successful'][lfn].get('allDiracSE', '')
                    self.guid = guid
                # return ## WHY?
            else:
                failureReason = "Error in uploading file %s : %s" % (os.path.basename(name), str(stdout))
                logger.error(failureReason)
                if regex.search(self.namePattern) is not None:
                    d.failureReason = failureReason
                    outputFiles.append(GPIProxyObjectFactory(d))
                    continue
                self.failureReason = failureReason
                return str(stdout)

        if replicate == True:

            if len(outputFiles) == 1 or len(outputFiles) == 0:
                storage_elements.pop(0)
                for se in storage_elements:
                    self.replicate(se)
            else:
                storage_elements.pop(0)
                for this_file in outputFiles:
                    for se in storage_elements:
                        this_file.replicate(se)

        if len(outputFiles) > 0:
            return outputFiles
        else:
            outputFiles.append(self)
            return outputFiles

Example 20

Project: ganga
Source File: AthenaJediRTHandler.py
View license
    def master_prepare(self,app,appconfig):
        '''Prepare the master job'''

        from pandatools import Client
        from pandatools import MiscUtils
        from pandatools import AthenaUtils
        from pandatools import PsubUtils

        # create a random number for this submission to allow multiple use of containers
        self.rndSubNum = random.randint(1111,9999)

        job = app._getParent()
        logger.debug('AthenaJediRTHandler master_prepare called for %s', job.getFQID('.')) 

        if app.useRootCoreNoBuild:
            logger.info('Athena.useRootCoreNoBuild is True, setting Panda.nobuild=True.')
            job.backend.nobuild = True

        if job.backend.bexec and job.backend.nobuild:
            raise ApplicationConfigurationError(None,"Contradicting options: job.backend.bexec and job.backend.nobuild are both enabled.")

        if job.backend.requirements.rootver != '' and job.backend.nobuild:
            raise ApplicationConfigurationError(None,"Contradicting options: job.backend.requirements.rootver given and job.backend.nobuild are enabled.")
        
        # Switch on compilation flag if bexec is set or libds is empty
        if job.backend.bexec != '' or not job.backend.nobuild:
            app.athena_compile = True
            for sj in job.subjobs:
                sj.application.athena_compile = True
            logger.info('"job.backend.nobuild=False" or "job.backend.bexec" is set - Panda build job is enabled.')

        if job.backend.nobuild:
            app.athena_compile = False
            for sj in job.subjobs:
                sj.application.athena_compile = False
            logger.info('"job.backend.nobuild=True" or "--nobuild" chosen - Panda build job is switched off.')

        # check for auto datri
        if job.outputdata.location != '':
            if not PsubUtils.checkDestSE(job.outputdata.location,job.outputdata.datasetname,False):
                raise ApplicationConfigurationError(None,"Problems with outputdata.location setting '%s'" % job.outputdata.location)

        # validate application
        if not app.atlas_release and not job.backend.requirements.rootver and not app.atlas_exetype in [ 'EXE' ]:
            raise ApplicationConfigurationError(None,"application.atlas_release is not set. Did you run application.prepare()")

        self.dbrelease = app.atlas_dbrelease
        if self.dbrelease != '' and self.dbrelease != 'LATEST' and self.dbrelease.find(':') == -1:
            raise ApplicationConfigurationError(None,"ERROR : invalid argument for DB Release. Must be 'LATEST' or 'DatasetName:FileName'")

        self.runConfig = AthenaUtils.ConfigAttr(app.atlas_run_config)
        for k in self.runConfig.keys():
            self.runConfig[k]=AthenaUtils.ConfigAttr(self.runConfig[k])
        if not app.atlas_run_dir:
            raise ApplicationConfigurationError(None,"application.atlas_run_dir is not set. Did you run application.prepare()")
 
        self.rundirectory = app.atlas_run_dir
        self.cacheVer = ''
        if app.atlas_project and app.atlas_production:
            self.cacheVer = "-" + app.atlas_project + "_" + app.atlas_production

        # handle different atlas_exetypes
        self.job_options = ''
        if app.atlas_exetype == 'TRF':
            self.job_options += ' '.join([os.path.basename(fopt.name) for fopt in app.option_file])

            #if not job.outputdata.outputdata:
            #    raise ApplicationConfigurationError(None,"job.outputdata.outputdata is required for atlas_exetype in ['PYARA','ARES','TRF','ROOT','EXE' ] and Panda backend")
            #raise ApplicationConfigurationError(None,"Sorry TRF on Panda backend not yet supported")

            if app.options:
                self.job_options += ' %s ' % app.options
                
        elif app.atlas_exetype == 'ATHENA':
            
            if len(app.atlas_environment) > 0 and app.atlas_environment[0].find('DBRELEASE_OVERRIDE')==-1:
                logger.warning("Passing of environment variables to Athena using Panda not supported. Ignoring atlas_environment setting.")
                
            if job.outputdata.outputdata:
                raise ApplicationConfigurationError(None,"job.outputdata.outputdata must be empty if atlas_exetype='ATHENA' and Panda backend is used (outputs are auto-detected)")
            if app.options:
                if app.options.startswith('-c'):
                    self.job_options += ' %s ' % app.options
                else:
                    self.job_options += ' -c %s ' % app.options

                logger.warning('The value of j.application.options has been prepended with " -c " ')
                logger.warning('Please make sure to use proper quotes for the values of j.application.options !')

            self.job_options += ' '.join([os.path.basename(fopt.name) for fopt in app.option_file])

            # check for TAG compression
            if 'subcoll.tar.gz' in app.append_to_user_area:
                self.job_options = ' uncompress.py ' + self.job_options
                
        elif app.atlas_exetype in ['PYARA','ARES','ROOT','EXE']:

            #if not job.outputdata.outputdata:
            #    raise ApplicationConfigurationError(None,"job.outputdata.outputdata is required for atlas_exetype in ['PYARA','ARES','TRF','ROOT','EXE' ] and Panda backend")
            self.job_options += ' '.join([os.path.basename(fopt.name) for fopt in app.option_file])

            # sort out environment variables
            env_str = ""
            if len(app.atlas_environment) > 0:
                for env_var in app.atlas_environment:
                    env_str += "export %s ; " % env_var
            else: 
                env_str = ""

            # below fixes issue with runGen -- job_options are executed by os.system when dbrelease is used, and by the shell otherwise
            ## - REMOVED FIX DUE TO CHANGE IN PILOT - MWS 8/11/11
            if job.backend.requirements.usecommainputtxt:
                input_str = '/bin/echo %IN > input.txt; cat input.txt; '
            else:
                input_str = '/bin/echo %IN | sed \'s/,/\\\n/g\' > input.txt; cat input.txt; '
            if app.atlas_exetype == 'PYARA':
                self.job_options = env_str + input_str + ' python ' + self.job_options
            elif app.atlas_exetype == 'ARES':
                self.job_options = env_str + input_str + ' athena.py ' + self.job_options
            elif app.atlas_exetype == 'ROOT':
                self.job_options = env_str + input_str + ' root -b -q ' + self.job_options
            elif app.atlas_exetype == 'EXE':
                self.job_options = env_str + input_str + self.job_options

            if app.options:
                self.job_options += ' %s ' % app.options

        if self.job_options == '':
            raise ApplicationConfigurationError(None,"No Job Options found!")
        logger.info('Running job options: %s'%self.job_options)

        # validate dbrelease
        if self.dbrelease != "LATEST":
            self.dbrFiles,self.dbrDsList = getDBDatasets(self.job_options,'',self.dbrelease)

        # handle the output dataset
        if job.outputdata:
            if job.outputdata._name != 'DQ2OutputDataset':
                raise ApplicationConfigurationError(None,'Panda backend supports only DQ2OutputDataset')
        else:
            logger.info('Adding missing DQ2OutputDataset')
            job.outputdata = DQ2OutputDataset()

        # validate the output dataset name (and make it a container)
        job.outputdata.datasetname,outlfn = dq2outputdatasetname(job.outputdata.datasetname, job.id, job.outputdata.isGroupDS, job.outputdata.groupname)
        if not job.outputdata.datasetname.endswith('/'):
            job.outputdata.datasetname+='/'

        # add extOutFiles
        self.extOutFile = []
        for tmpName in job.outputdata.outputdata:
            if tmpName != '':
                self.extOutFile.append(tmpName)
        for tmpName in job.backend.extOutFile:
            if tmpName != '':
                self.extOutFile.append(tmpName)

        # use the shared area if possible
        tmp_user_area_name = app.user_area.name
        if app.is_prepared is not True:
            from Ganga.Utility.files import expandfilename
            shared_path = os.path.join(expandfilename(getConfig('Configuration')['gangadir']),'shared',getConfig('Configuration')['user'])
            tmp_user_area_name = os.path.join(os.path.join(shared_path,app.is_prepared.name),os.path.basename(app.user_area.name))


        # Add inputsandbox to user_area
        if job.inputsandbox:
            logger.warning("Submitting Panda job with inputsandbox. This may slow the submission slightly.")

            if tmp_user_area_name:
                inpw = os.path.dirname(tmp_user_area_name)
                self.inputsandbox = os.path.join(inpw, 'sources.%s.tar' % commands.getoutput('uuidgen 2> /dev/null'))
            else:
                inpw = job.getInputWorkspace()
                self.inputsandbox = inpw.getPath('sources.%s.tar' % commands.getoutput('uuidgen 2> /dev/null'))

            if tmp_user_area_name:
                rc, output = commands.getstatusoutput('cp %s %s.gz' % (tmp_user_area_name, self.inputsandbox))
                if rc:
                    logger.error('Copying user_area failed with status %d',rc)
                    logger.error(output)
                    raise ApplicationConfigurationError(None,'Packing inputsandbox failed.')
                rc, output = commands.getstatusoutput('gunzip %s.gz' % (self.inputsandbox))
                if rc:
                    logger.error('Unzipping user_area failed with status %d',rc)
                    logger.error(output)
                    raise ApplicationConfigurationError(None,'Packing inputsandbox failed.')

            for fname in [os.path.abspath(f.name) for f in job.inputsandbox]:
                fname.rstrip(os.sep)
                path = os.path.dirname(fname)
                fn = os.path.basename(fname)

                #app.atlas_run_dir
                # get Athena versions
                rc, out = AthenaUtils.getAthenaVer()
                # failed
                if not rc:
                    #raise ApplicationConfigurationError(None, 'CMT could not parse correct environment ! \n Did you start/setup ganga in the run/ or cmt/ subdirectory of your athena analysis package ?')
                    logger.warning("CMT could not parse correct environment for inputsandbox - will use the atlas_run_dir as default")
                    
                    # as we don't have to be in the run dir now, create a copy of the run_dir directory structure and use that
                    input_dir = os.path.dirname(self.inputsandbox)
                    run_path = "%s/sbx_tree/%s" % (input_dir, app.atlas_run_dir)
                    rc, output = commands.getstatusoutput("mkdir -p %s" % run_path)
                    if not rc:
                        # copy this sandbox file
                        rc, output = commands.getstatusoutput("cp %s %s" % (fname, run_path))
                        if not rc:
                            path = os.path.join(input_dir, 'sbx_tree')
                            fn = os.path.join(app.atlas_run_dir, fn)
                        else:
                            raise ApplicationConfigurationError(None, "Couldn't copy file %s to recreate run_dir for input sandbox" % fname)
                    else:
                        raise ApplicationConfigurationError(None, "Couldn't create directory structure to match run_dir %s for input sandbox" % run_path)

                else:
                    userarea = out['workArea']

                    # strip the path from the filename if present in the userarea
                    ua = os.path.abspath(userarea)
                    if ua in path:
                        fn = fname[len(ua)+1:]
                        path = ua

                rc, output = commands.getstatusoutput('tar -h -r -f %s -C %s %s' % (self.inputsandbox, path, fn))
                if rc:
                    logger.error('Packing inputsandbox failed with status %d',rc)
                    logger.error(output)
                    raise ApplicationConfigurationError(None,'Packing inputsandbox failed.')

            # remove sandbox tree if created
            if "sbx_tree" in os.listdir(os.path.dirname(self.inputsandbox)):                
                rc, output = commands.getstatusoutput("rm -r %s/sbx_tree" % os.path.dirname(self.inputsandbox))
                if rc:
                    raise ApplicationConfigurationError(None, "Couldn't remove directory structure used for input sandbox")
                
            rc, output = commands.getstatusoutput('gzip %s' % (self.inputsandbox))
            if rc:
                logger.error('Packing inputsandbox failed with status %d',rc)
                logger.error(output)
                raise ApplicationConfigurationError(None,'Packing inputsandbox failed.')
            self.inputsandbox += ".gz"
        else:
            self.inputsandbox = tmp_user_area_name

        # job name
        jobName = 'ganga.%s' % MiscUtils.wrappedUuidGen()

        # make task
        taskParamMap = {}
        # Enforce that outputdataset name ends with / for container
        if not job.outputdata.datasetname.endswith('/'):
            job.outputdata.datasetname = job.outputdata.datasetname + '/'

        taskParamMap['taskName'] = job.outputdata.datasetname

        taskParamMap['uniqueTaskName'] = True
        taskParamMap['vo'] = 'atlas'
        taskParamMap['architecture'] = AthenaUtils.getCmtConfig(athenaVer=app.atlas_release, cmtConfig=app.atlas_cmtconfig)
        if app.atlas_release:
            taskParamMap['transUses'] = 'Atlas-%s' % app.atlas_release
        else:
            taskParamMap['transUses'] = ''
        taskParamMap['transHome'] = 'AnalysisTransforms'+self.cacheVer#+nightVer

        configSys = getConfig('System')
        gangaver = configSys['GANGA_VERSION'].lower()
        if not gangaver:
            gangaver = "ganga"

        if app.atlas_exetype in ["ATHENA", "TRF"]:
            taskParamMap['processingType'] = '{0}-jedi-athena'.format(gangaver)
        else:
            taskParamMap['processingType'] = '{0}-jedi-run'.format(gangaver)

        #if options.eventPickEvtList != '':
        #    taskParamMap['processingType'] += '-evp'
        taskParamMap['prodSourceLabel'] = 'user'
        if job.backend.site != 'AUTO':
            taskParamMap['cloud'] = Client.PandaSites[job.backend.site]['cloud']
            taskParamMap['site'] = job.backend.site
        elif job.backend.requirements.cloud != None and not job.backend.requirements.anyCloud:
            taskParamMap['cloud'] = job.backend.requirements.cloud
        if job.backend.requirements.excluded_sites != []:
            taskParamMap['excludedSite'] = expandExcludedSiteList( job )

        # if only a single site specifed, don't set includedSite
        #if job.backend.site != 'AUTO':
        #    taskParamMap['includedSite'] = job.backend.site
        #taskParamMap['cliParams'] = fullExecString
        if job.backend.requirements.noEmail:
            taskParamMap['noEmail'] = True
        if job.backend.requirements.skipScout:
            taskParamMap['skipScout'] = True
        if not app.atlas_exetype in ["ATHENA", "TRF"]: 
            taskParamMap['nMaxFilesPerJob'] = job.backend.requirements.maxNFilesPerJob
        if job.backend.requirements.disableAutoRetry:
            taskParamMap['disableAutoRetry'] = 1
        # source URL
        matchURL = re.search("(http.*://[^/]+)/",Client.baseURLCSRVSSL)
        if matchURL != None:
            taskParamMap['sourceURL'] = matchURL.group(1)

        # dataset names
        outDatasetName = job.outputdata.datasetname
        logDatasetName = re.sub('/$','.log/',job.outputdata.datasetname)
        # log
        taskParamMap['log'] = {'dataset': logDatasetName,
                               'container': logDatasetName,
                               'type':'template',
                               'param_type':'log',
                               'value':'{0}.${{SN}}.log.tgz'.format(logDatasetName[:-1])
                               }
        # job parameters
        if app.atlas_exetype in ["ATHENA", "TRF"]:
            taskParamMap['jobParameters'] = [
                {'type':'constant',
                 'value': ' --sourceURL ${SURL}',
                 },
                ]
        else:
            taskParamMap['jobParameters'] = [
                {'type':'constant',
                 'value': '-j "" --sourceURL ${SURL}',
                 },
                ]

        taskParamMap['jobParameters'] += [
            {'type':'constant',
             'value': '-r {0}'.format(self.rundirectory),
             },
            ]

        # Add the --trf option to jobParameters if required
        if app.atlas_exetype == "TRF":
            taskParamMap['jobParameters'] += [{'type': 'constant', 'value': '--trf'}]

        # output
        # output files
        outMap = {}
        if app.atlas_exetype in ["ATHENA", "TRF"]:
            outMap, tmpParamList = AthenaUtils.convertConfToOutput(self.runConfig, self.extOutFile, job.outputdata.datasetname, destination=job.outputdata.location)
            taskParamMap['jobParameters'] += [
                {'type':'constant',
                 'value': '-o "%s" ' % outMap
                 },
                ]
            taskParamMap['jobParameters'] += tmpParamList 

        else: 
            if job.outputdata.outputdata:
                for tmpLFN in job.outputdata.outputdata:
                    if len(job.outputdata.datasetname.split('.')) > 2:
                        lfn = '{0}.{1}'.format(*job.outputdata.datasetname.split('.')[:2])
                    else:
                        lfn = job.outputdata.datasetname[:-1]
                    lfn += '.$JOBSETID._${{SN/P}}.{0}'.format(tmpLFN)
                    dataset = '{0}_{1}/'.format(job.outputdata.datasetname[:-1],tmpLFN)
                    taskParamMap['jobParameters'] += MiscUtils.makeJediJobParam(lfn,dataset,'output',hidden=True, destination=job.outputdata.location)
                    outMap[tmpLFN] = lfn

                taskParamMap['jobParameters'] += [ 
                    {'type':'constant',
                     'value': '-o "{0}"'.format(str(outMap)),
                     },
                    ]

        if app.atlas_exetype in ["ATHENA"]:
            # jobO parameter
            tmpJobO = self.job_options
            # replace full-path jobOs
            for tmpFullName,tmpLocalName in AthenaUtils.fullPathJobOs.iteritems():
                tmpJobO = re.sub(tmpFullName,tmpLocalName,tmpJobO)
            # modify one-liner for G4 random seeds
            if self.runConfig.other.G4RandomSeeds > 0:
                if app.options != '':
                    tmpJobO = re.sub('-c "%s" ' % app.options,
                                     '-c "%s;from G4AtlasApps.SimFlags import SimFlags;SimFlags.SeedsG4=${RNDMSEED}" ' \
                                         % app.options,tmpJobO)
                else:
                    tmpJobO = '-c "from G4AtlasApps.SimFlags import SimFlags;SimFlags.SeedsG4=${RNDMSEED}" '
                dictItem = {'type':'template',
                            'param_type':'number',
                            'value':'${RNDMSEED}',
                            'hidden':True,
                            'offset':self.runConfig.other.G4RandomSeeds,
                            }
                taskParamMap['jobParameters'] += [dictItem]
        elif app.atlas_exetype in ["TRF"]:
            # replace parameters for TRF
            tmpJobO = self.job_options
            # output : basenames are in outMap['IROOT'] trough extOutFile
            tmpOutMap = []
            for tmpName,tmpLFN in outMap['IROOT']:
                tmpJobO = tmpJobO.replace('%OUT.' + tmpName,tmpName)
            # replace DBR
            tmpJobO = re.sub('%DB=[^ \'\";]+','${DBR}',tmpJobO)

        if app.atlas_exetype in ["TRF"]:
            taskParamMap['useLocalIO'] = 1

        # build
        if job.backend.nobuild:
            taskParamMap['jobParameters'] += [
                {'type':'constant',
                 'value': '-a {0}'.format(os.path.basename(self.inputsandbox)),
                 },
                ]
        else:
            taskParamMap['jobParameters'] += [
                {'type':'constant',
                 'value': '-l ${LIB}',
                 },
                ]

        #
        # input
        if job.inputdata and job.inputdata._name == 'DQ2Dataset':
            if job.backend.requirements.nFilesPerJob > 0 and job.inputdata.number_of_files == 0 and job.backend.requirements.split > 0:
                job.inputdata.number_of_files = job.backend.requirements.nFilesPerJob * job.backend.requirements.split

        if job.inputdata and job.inputdata._name == 'DQ2Dataset' and job.inputdata.number_of_files != 0:
            taskParamMap['nFiles'] = job.inputdata.number_of_files
        elif job.backend.requirements.nFilesPerJob > 0 and job.backend.requirements.split > 0:
            # pathena does this for some reason even if there is no input files
            taskParamMap['nFiles'] = job.backend.requirements.nFilesPerJob * job.backend.requirements.split
        if job.backend.requirements.nFilesPerJob > 0:    
            taskParamMap['nFilesPerJob'] = job.backend.requirements.nFilesPerJob
            
        if job.backend.requirements.nEventsPerFile > 0:    
            taskParamMap['nEventsPerFile'] = job.backend.requirements.nEventsPerFile

        if not job.backend.requirements.nGBPerJob in [ 0,'MAX']:
            try:
                if job.backend.requirements.nGBPerJob != 'MAX':
                    job.backend.requirments.nGBPerJob = int(job.backend.requirements.nGBPerJob)
            except:
                logger.error("nGBPerJob must be an integer or MAX")
            # check negative                                                                                                                                                         
            if job.backend.requirements.nGBPerJob <= 0:
                logger.error("nGBPerJob must be positive")

            # don't set MAX since it is the defalt on the server side
            if not job.backend.requirements.nGBPerJob in [-1,'MAX']: 
                taskParamMap['nGBPerJob'] = job.backend.requirements.nGBPerJob

        if app.atlas_exetype in ["ATHENA", "TRF"]:
            inputMap = {}
            if job.inputdata and job.inputdata._name == 'DQ2Dataset':
                tmpDict = {'type':'template',
                           'param_type':'input',
                           'value':'-i "${IN/T}"',
                           'dataset': ','.join(job.inputdata.dataset),
                           'expand':True,
                           'exclude':'\.log\.tgz(\.\d+)*$',
                           }
                #if options.inputType != '':
                #    tmpDict['include'] = options.inputType
                taskParamMap['jobParameters'].append(tmpDict)
                taskParamMap['dsForIN'] = ','.join(job.inputdata.dataset)
                inputMap['IN'] = ','.join(job.inputdata.dataset)
            else:
                # no input
                taskParamMap['noInput'] = True
                if job.backend.requirements.split > 0:
                    taskParamMap['nEvents'] = job.backend.requirements.split
                else:
                    taskParamMap['nEvents'] = 1
                taskParamMap['nEventsPerJob'] = 1
                taskParamMap['jobParameters'] += [
                    {'type':'constant',
                     'value': '-i "[]"',
                     },
                    ]
        else:
            if job.inputdata and job.inputdata._name == 'DQ2Dataset':
                tmpDict = {'type':'template',
                           'param_type':'input',
                           'value':'-i "${IN/T}"',
                           'dataset': ','.join(job.inputdata.dataset),
                           'expand':True,
                           'exclude':'\.log\.tgz(\.\d+)*$',
                           }
               #if options.nSkipFiles != 0:
               #    tmpDict['offset'] = options.nSkipFiles
                taskParamMap['jobParameters'].append(tmpDict)
                taskParamMap['dsForIN'] = ','.join(job.inputdata.dataset)
            else:
                # no input
                taskParamMap['noInput'] = True
                if job.backend.requirements.split > 0:
                    taskParamMap['nEvents'] = job.backend.requirements.split
                else:
                    taskParamMap['nEvents'] = 1
                taskParamMap['nEventsPerJob'] = 1

        # param for DBR     
        if self.dbrelease != '':
            dbrDS = self.dbrelease.split(':')[0]
            # change LATEST to DBR_LATEST
            if dbrDS == 'LATEST':
                dbrDS = 'DBR_LATEST'
            dictItem = {'type':'template',
                        'param_type':'input',
                        'value':'--dbrFile=${DBR}',
                        'dataset':dbrDS,
                            }
            taskParamMap['jobParameters'] += [dictItem]
        # no expansion
        #if options.notExpandDBR:
        #dictItem = {'type':'constant',
        #            'value':'--noExpandDBR',
        #            }
        #taskParamMap['jobParameters'] += [dictItem]

        # secondary FIXME disabled
        self.secondaryDSs = {}
        if self.secondaryDSs != {}:
            inMap = {}
            streamNames = []
            for tmpDsName,tmpMap in self.secondaryDSs.iteritems():
                # make template item
                streamName = tmpMap['streamName']
                dictItem = MiscUtils.makeJediJobParam('${'+streamName+'}',tmpDsName,'input',hidden=True,
                                                      expand=True,include=tmpMap['pattern'],offset=tmpMap['nSkip'],
                                                      nFilesPerJob=tmpMap['nFiles'])
                taskParamMap['jobParameters'] += dictItem
                inMap[streamName] = 'tmp_'+streamName 
                streamNames.append(streamName)
            # make constant item
            strInMap = str(inMap)
            # set placeholders
            for streamName in streamNames:
                strInMap = strInMap.replace("'tmp_"+streamName+"'",'${'+streamName+'/T}')
            dictItem = {'type':'constant',
                        'value':'--inMap "%s"' % strInMap,
                        }
            taskParamMap['jobParameters'] += [dictItem]

        # misc
        jobParameters = ''
        # use Athena packages
        if app.atlas_exetype == 'ARES' or (app.atlas_exetype in ['PYARA','ROOT','EXE'] and app.useAthenaPackages):
            jobParameters += "--useAthenaPackages "
            
        # use RootCore
        if app.useRootCore or app.useRootCoreNoBuild:
            jobParameters += "--useRootCore "
            
        # use mana
        if app.useMana:
            jobParameters += "--useMana "
            if app.atlas_release != "":
                jobParameters += "--manaVer %s " % app.atlas_release
        # root
        if app.atlas_exetype in ['PYARA','ROOT','EXE'] and job.backend.requirements.rootver != '':
            rootver = re.sub('/','.', job.backend.requirements.rootver)
            jobParameters += "--rootVer %s " % rootver

        # write input to txt
        #if options.writeInputToTxt != '':
        #    jobParameters += "--writeInputToTxt %s " % options.writeInputToTxt
        # debug parameters
        #if options.queueData != '':
        #    jobParameters += "--overwriteQueuedata=%s " % options.queueData
        # JEM
        #if options.enableJEM:
        #    jobParameters += "--enable-jem "
        #    if options.configJEM != '':
        #        jobParameters += "--jem-config %s " % options.configJEM

        # set task param
        if jobParameters != '':
            taskParamMap['jobParameters'] += [ 
                {'type':'constant',
                 'value': jobParameters,
                 },
                ]

        # force stage-in
        if job.backend.accessmode == "LocalIO":
            taskParamMap['useLocalIO'] = 1

        # set jobO parameter
        if app.atlas_exetype in ["ATHENA", "TRF"]:
            taskParamMap['jobParameters'] += [
                {'type':'constant',
                 'value': '-j "',
                 'padding':False,
                 },
                ]
            taskParamMap['jobParameters'] += PsubUtils.convertParamStrToJediParam(tmpJobO,inputMap,job.outputdata.datasetname[:-1], True,False)
            taskParamMap['jobParameters'] += [
                {'type':'constant',
                 'value': '"',
                 },
                ]

        else:
            taskParamMap['jobParameters'] += [ {'type':'constant',
                                                'value': '-p "{0}"'.format(urllib.quote(self.job_options)),
                                                },
                                               ]

        # build step
        if not job.backend.nobuild:
            jobParameters = '-i ${IN} -o ${OUT} --sourceURL ${SURL} '

            if job.backend.bexec != '':
                jobParameters += ' --bexec "%s" ' % urllib.quote(job.backend.bexec)

            if app.atlas_exetype == 'ARES' or (app.atlas_exetype in ['PYARA','ROOT','EXE'] and app.useAthenaPackages):
                # use Athena packages
                jobParameters += "--useAthenaPackages "
            # use RootCore
            if app.useRootCore or app.useRootCoreNoBuild:
                jobParameters += "--useRootCore "

            # run directory
            if app.atlas_exetype in ['PYARA','ARES','ROOT','EXE']:
                jobParameters += '-r {0} '.format(self.rundirectory)
                
            # no compile
            #if options.noCompile:
            #    jobParameters += "--noCompile "
            # use mana
            if app.useMana:
                jobParameters += "--useMana "
                if app.atlas_release != "":
                    jobParameters += "--manaVer %s " % app.atlas_release

            # root
            if app.atlas_exetype in ['PYARA','ROOT','EXE'] and job.backend.requirements.rootver != '':
                rootver = re.sub('/','.', job.backend.requirements.rootver)
                jobParameters += "--rootVer %s " % rootver
                
            # cmt config
            if app.atlas_exetype in ['PYARA','ARES','ROOT','EXE']:
                if not app.atlas_cmtconfig in ['','NULL',None]:
                    jobParameters += " --cmtConfig %s " % app.atlas_cmtconfig
                                            
            
            #cmtConfig         = AthenaUtils.getCmtConfig(athenaVer=app.atlas_release, cmtConfig=app.atlas_cmtconfig)
            #if cmtConfig:
            #    jobParameters += "--cmtConfig %s " % cmtConfig
            # debug parameters
            #if options.queueData != '':
            #    jobParameters += "--overwriteQueuedata=%s " % options.queueData
            # set task param
            taskParamMap['buildSpec'] = {
                'prodSourceLabel':'panda',
                'archiveName':os.path.basename(self.inputsandbox),
                'jobParameters':jobParameters,
                }


        # enable merging
        if job.backend.requirements.enableMerge:
            jobParameters = '-r {0} '.format(self.rundirectory)
            if 'exec' in job.backend.requirements.configMerge and job.backend.requirements.configMerge['exec'] != '':
                jobParameters += '-j "{0}" '.format(job.backend.requirements.configMerge['exec'])
            if not job.backend.nobuild:
                jobParameters += '-l ${LIB} '
            else:
                jobParameters += '-a {0} '.format(os.path.basename(self.inputsandbox))
                jobParameters += "--sourceURL ${SURL} "
            jobParameters += '${TRN_OUTPUT:OUTPUT} ${TRN_LOG:LOG}'
            taskParamMap['mergeSpec'] = {}
            taskParamMap['mergeSpec']['useLocalIO'] = 1
            taskParamMap['mergeSpec']['jobParameters'] = jobParameters
            taskParamMap['mergeOutput'] = True    
            
        # Selected by Jedi
        #if not app.atlas_exetype in ['PYARA','ROOT','EXE']:
        #    taskParamMap['transPath'] = 'http://atlpan.web.cern.ch/atlpan/runAthena-00-00-12'

        logger.debug(taskParamMap)

        # upload sources
        if self.inputsandbox and not job.backend.libds:
            uploadSources(os.path.dirname(self.inputsandbox),os.path.basename(self.inputsandbox))

            if not self.inputsandbox == tmp_user_area_name:
                logger.info('Removing source tarball %s ...' % self.inputsandbox )
                os.remove(self.inputsandbox)

        return taskParamMap

Example 21

Project: agdc
Source File: dem_tiler.py
View license
    def create_tiles(self, filename=None, level_name=None, tile_type_id=None):
        # Set default values to instance values
        filename = filename or self.filename
        level_name = level_name or self.level_name
        tile_type_id = tile_type_id or self.default_tile_type_id
        nodata_value = None
        
        tile_type_info = self.tile_type_dict[tile_type_id]
        
        dem_band_info = self.bands[tile_type_id].get(('DERIVED', level_name))
        assert dem_band_info, 'No band level information defined for level %s' % level_name
        
        def find_tiles(x_index = None, y_index = None):
            """Find any tile records for current dataset
            returns dict of tile information keyed by tile_id
            """
            db_cursor2 = self.db_connection.cursor()
            
            sql = """-- Check for any existing tiles
select
tile_id,
x_index,
y_index,
tile_type_id,
tile_pathname,
dataset_id,
tile_class_id,
tile_size
from tile_footprint
inner join tile using(x_index, y_index, tile_type_id)
inner join dataset using(dataset_id)
inner join processing_level using(level_id)
where tile_type_id = %(tile_type_id)s
and (%(x_index)s is null or x_index = %(x_index)s)
and (%(y_index)s is null or y_index = %(y_index)s)
and level_name = %(level_name)s
and ctime is not null
;
"""
            params = {'x_index': x_index,
                  'y_index': y_index,
                  'tile_type_id': tile_type_info['tile_type_id'],
                  'level_name': level_name}
                          
            log_multiline(logger.debug, db_cursor2.mogrify(sql, params), 'SQL', '\t')
            db_cursor2.execute(sql, params)
            tile_info = {}
            for record in db_cursor2:
                tile_info_dict = {
                    'x_index': record[1],
                    'y_index': record[2],
                    'tile_type_id': record[3],
                    'tile_pathname': record[4],
                    'dataset_id': record[5],
                    'tile_class_id': record[6],
                    'tile_size': record[7]
                    }
                tile_info[record[0]] = tile_info_dict # Keyed by tile_id
                
            log_multiline(logger.debug, tile_info, 'tile_info', '\t')
            return tile_info

        
        # Function create_tiles starts here
        db_cursor = self.db_connection.cursor()
        
        dataset = gdal.Open(filename)
        assert dataset, 'Unable to open dataset %s' % filename
        spatial_reference = osr.SpatialReference()
        spatial_reference.ImportFromWkt(dataset.GetProjection())
        geotransform = dataset.GetGeoTransform()
        logger.debug('geotransform = %s', geotransform)
            
        latlong_spatial_reference = spatial_reference.CloneGeogCS()
        coord_transform_to_latlong = osr.CoordinateTransformation(spatial_reference, latlong_spatial_reference)
        
        tile_spatial_reference = osr.SpatialReference()
        s = re.match('EPSG:(\d+)', tile_type_info['crs'])
        if s:
            epsg_code = int(s.group(1))
            logger.debug('epsg_code = %d', epsg_code)
            assert tile_spatial_reference.ImportFromEPSG(epsg_code) == 0, 'Invalid EPSG code for tile projection'
        else:
            assert tile_spatial_reference.ImportFromWkt(tile_type_info['crs']), 'Invalid WKT for tile projection'
        
        logger.debug('Tile WKT = %s', tile_spatial_reference.ExportToWkt())
            
        coord_transform_to_tile = osr.CoordinateTransformation(spatial_reference, tile_spatial_reference)
        
        # Need to keep tile and lat/long references separate even though they may be equivalent
        # Upper Left
        ul_x, ul_y = geotransform[0], geotransform[3]
        ul_lon, ul_lat, _z = coord_transform_to_latlong.TransformPoint(ul_x, ul_y, 0)
        tile_ul_x, tile_ul_y, _z = coord_transform_to_tile.TransformPoint(ul_x, ul_y, 0)
        # Upper Right
        ur_x, ur_y = geotransform[0] + geotransform[1] * dataset.RasterXSize, geotransform[3]
        ur_lon, ur_lat, _z = coord_transform_to_latlong.TransformPoint(ur_x, ur_y, 0)
        tile_ur_x, tile_ur_y, _z = coord_transform_to_tile.TransformPoint(ur_x, ur_y, 0)
        # Lower Right
        lr_x, lr_y = geotransform[0] + geotransform[1] * dataset.RasterXSize, geotransform[3] + geotransform[5] * dataset.RasterYSize
        lr_lon, lr_lat, _z = coord_transform_to_latlong.TransformPoint(lr_x, lr_y, 0)
        tile_lr_x, tile_lr_y, _z = coord_transform_to_tile.TransformPoint(lr_x, lr_y, 0)
        # Lower Left
        ll_x, ll_y = geotransform[0], geotransform[3] + geotransform[5] * dataset.RasterYSize
        ll_lon, ll_lat, _z = coord_transform_to_latlong.TransformPoint(ll_x, ll_y, 0)
        tile_ll_x, tile_ll_y, _z = coord_transform_to_tile.TransformPoint(ll_x, ll_y, 0)
        
        tile_min_x = min(tile_ul_x, tile_ll_x)
        tile_max_x = max(tile_ur_x, tile_lr_x)
        tile_min_y = min(tile_ll_y, tile_lr_y)
        tile_max_y = max(tile_ul_y, tile_ur_y)
        
        tile_index_range = (int(floor((tile_min_x - tile_type_info['x_origin']) / tile_type_info['x_size'])), 
                    int(floor((tile_min_y - tile_type_info['y_origin']) / tile_type_info['y_size'])), 
                    int(ceil((tile_max_x - tile_type_info['x_origin']) / tile_type_info['x_size'])), 
                    int(ceil((tile_max_y - tile_type_info['y_origin']) / tile_type_info['y_size'])))
        
        sql = """-- Find dataset_id for given path
select dataset_id
from dataset 
where dataset_path like '%%' || %(basename)s
"""
        params = {'basename': os.path.basename(filename)}
        log_multiline(logger.debug, db_cursor.mogrify(sql, params), 'SQL', '\t')
        db_cursor.execute(sql, params)
        result = db_cursor.fetchone()
        if result: # Record already exists
            dataset_id = result[0]
            if self.refresh:
                logger.info('Updating existing record for %s', filename)
                
                sql = """
update dataset 
  set level_id = (select level_id from processing_level where upper(level_name) = upper(%(processing_level)s)),
  datetime_processed = %(datetime_processed)s,
  dataset_size = %(dataset_size)s,
  crs = %(crs)s,
  ll_x = %(ll_x)s,
  ll_y = %(ll_y)s,
  lr_x = %(lr_x)s,
  lr_y = %(lr_y)s,
  ul_x = %(ul_x)s,
  ul_y = %(ul_y)s,
  ur_x = %(ur_x)s,
  ur_y = %(ur_y)s,
  x_pixels = %(x_pixels)s,
  y_pixels = %(y_pixels)s
where dataset_id = %(dataset_id)s;

select %(dataset_id)s
"""
            else:
                logger.info('Skipping existing record for %s', filename)
                return
        else: # Record doesn't already exist
            logger.info('Creating new record for %s', filename)
            dataset_id = None       
                    
            sql = """-- Create new dataset record
insert into dataset(
  dataset_id, 
  acquisition_id, 
  dataset_path, 
  level_id,
  datetime_processed,
  dataset_size,
  crs,
  ll_x,
  ll_y,
  lr_x,
  lr_y,
  ul_x,
  ul_y,
  ur_x,
  ur_y,
  x_pixels,
  y_pixels
  )
select
  nextval('dataset_id_seq') as dataset_id,
  null as acquisition_id,
  %(dataset_path)s,
  (select level_id from processing_level where upper(level_name) = upper(%(processing_level)s)),
  %(datetime_processed)s,
  %(dataset_size)s,
  %(crs)s,
  %(ll_x)s,
  %(ll_y)s,
  %(lr_x)s,
  %(lr_y)s,
  %(ul_x)s,
  %(ul_y)s,
  %(ur_x)s,
  %(ur_y)s,
  %(x_pixels)s,
  %(y_pixels)s
where not exists
  (select dataset_id
  from dataset
  where dataset_path = %(dataset_path)s
  );

select dataset_id 
from dataset
where dataset_path = %(dataset_path)s
;
"""
        dataset_size = self.getFileSizekB(filename) # Need size in kB to match other datasets 
        
        # same params for insert or update
        params = {'dataset_id': dataset_id,
            'dataset_path': filename,
            'processing_level': level_name,
            'datetime_processed': None,
            'dataset_size': dataset_size,
            'll_lon': ll_lon,
            'll_lat': ll_lat,
            'lr_lon': lr_lon,
            'lr_lat': lr_lat,
            'ul_lon': ul_lon,
            'ul_lat': ul_lat,
            'ur_lon': ur_lon,
            'ur_lat': ur_lat,
            'crs': dataset.GetProjection(),
            'll_x': ll_x,
            'll_y': ll_y,
            'lr_x': lr_x,
            'lr_y': lr_y,
            'ul_x': ul_x,
            'ul_y': ul_y,
            'ur_x': ur_x,
            'ur_y': ur_y,
            'x_pixels': dataset.RasterXSize,
            'y_pixels': dataset.RasterYSize,
            'gcp_count': None,
            'mtl_text': None,
            'cloud_cover': None
            }
        
        log_multiline(logger.debug, db_cursor.mogrify(sql, params), 'SQL', '\t')    
        db_cursor.execute(sql, params)
        result = db_cursor.fetchone() # Retrieve new dataset_id if required
        dataset_id = dataset_id or result[0]

        tile_output_root = os.path.join(self.tile_root, 
                                        tile_type_info['tile_directory'],
                                        level_name, 
                                        os.path.basename(filename)
                                        )
        logger.debug('tile_output_root = %s', tile_output_root)
        self.create_directory(tile_output_root)

        work_directory = os.path.join(self.temp_dir,
                                      os.path.basename(filename)
                                      )
        logger.debug('work_directory = %s', work_directory)
        self.create_directory(work_directory)
                
        for x_index in range(tile_index_range[0], tile_index_range[2]):
            for y_index in range(tile_index_range[1], tile_index_range[3]): 
                
                tile_info = find_tiles(x_index, y_index)
                
                if tile_info:
                    logger.info('Skipping existing tile (%d, %d)', x_index, y_index)
                    continue

                tile_basename = '_'.join([level_name,
                                          re.sub('\+', '', '%+04d_%+04d' % (x_index, y_index))]) + tile_type_info['file_extension']
                
                tile_output_path = os.path.join(tile_output_root, tile_basename)
                                                                   
                # Check whether this tile has already been processed
                if not self.lock_object(tile_output_path):
                    logger.warning('Tile  %s already being processed - skipping.', tile_output_path)
                    continue
                
                try:                 
                    self.remove(tile_output_path)
                    
                    temp_tile_path = os.path.join(self.temp_dir, tile_basename)
                                                                       
                    tile_extents = (tile_type_info['x_origin'] + x_index * tile_type_info['x_size'], 
                                tile_type_info['y_origin'] + y_index * tile_type_info['y_size'], 
                                tile_type_info['x_origin'] + (x_index + 1) * tile_type_info['x_size'], 
                                tile_type_info['y_origin'] + (y_index + 1) * tile_type_info['y_size']) 
                    logger.debug('tile_extents = %s', tile_extents)
                    
                    command_string = 'gdalwarp'
                    if not self.debug:
                        command_string += ' -q'
                    command_string += ' -t_srs %s -te %f %f %f %f -tr %f %f -tap -tap -r %s' % (
                        tile_type_info['crs'],
                        tile_extents[0], tile_extents[1], tile_extents[2], tile_extents[3], 
                        tile_type_info['x_pixel_size'], tile_type_info['y_pixel_size'],
                        dem_band_info[10]['resampling_method']
                        )
                    
                    if nodata_value is not None:
                        command_string += ' -srcnodata %d -dstnodata %d' % (nodata_value, nodata_value)
                                                                          
                    command_string += ' -of %s' % tile_type_info['file_format']
                    
                    if tile_type_info['format_options']:
                        for format_option in tile_type_info['format_options'].split(','):
                            command_string += ' -co %s' % format_option
                        
                    command_string += ' -overwrite %s %s' % (
                        filename,
                        temp_tile_path
                        )
                    
                    logger.debug('command_string = %s', command_string)
                    
                    result = execute(command_string=command_string)
                    
                    if result['stdout']:
                        log_multiline(logger.info, result['stdout'], 'stdout from ' + command_string, '\t') 
                    
                    if result['returncode']:
                        log_multiline(logger.error, result['stderr'], 'stderr from ' + command_string, '\t')
                        raise Exception('%s failed', command_string) 
                    
                    temp_dataset = gdal.Open(temp_tile_path)
                                    
                    gdal_driver = gdal.GetDriverByName(tile_type_info['file_format'])
                    #output_dataset = gdal_driver.Create(output_tile_path, 
                    #                                    nbar_dataset.RasterXSize, nbar_dataset.RasterYSize,
                    #                                    1, nbar_dataset.GetRasterBand(1).DataType,
                    #                                    tile_type_info['format_options'].split(','))
                    output_dataset = gdal_driver.Create(tile_output_path, 
                                                        temp_dataset.RasterXSize, temp_dataset.RasterYSize,
                                                        len(dem_band_info), 
                                                        temp_dataset.GetRasterBand(1).DataType,
                                                        tile_type_info['format_options'].split(','))
                    assert output_dataset, 'Unable to open output dataset %s'% output_dataset  
                    output_geotransform = temp_dataset.GetGeoTransform()                                 
                    output_dataset.SetGeoTransform(output_geotransform)
                    output_dataset.SetProjection(temp_dataset.GetProjection()) 
                    
                    elevation_array = temp_dataset.GetRasterBand(1).ReadAsArray()
                    del temp_dataset
                    self.remove(temp_tile_path)
        
                    pixel_x_size = abs(output_geotransform[1])
                    pixel_y_size = abs(output_geotransform[5])
                    x_m_array, y_m_array = self.get_pixel_size_grids(output_dataset)
                    
                    dzdx_array = ndimage.sobel(elevation_array, axis=1)/(8. * abs(output_geotransform[1]))
                    dzdx_array = numexpr.evaluate("dzdx_array * pixel_x_size / x_m_array")
                    del x_m_array
                    
                    dzdy_array = ndimage.sobel(elevation_array, axis=0)/(8. * abs(output_geotransform[5]))
                    dzdy_array = numexpr.evaluate("dzdy_array * pixel_y_size / y_m_array")
                    del y_m_array

                    for band_file_number in sorted(dem_band_info.keys()):
                        output_band_number = dem_band_info[band_file_number]['tile_layer']
                        output_band = output_dataset.GetRasterBand(output_band_number)
                        
                        if band_file_number == 10: # Elevation
                            output_band.WriteArray(elevation_array)
                            del elevation_array
                            
                        elif band_file_number == 20: # Slope    
                            hypotenuse_array = numpy.hypot(dzdx_array, dzdy_array)
                            slope_array = numexpr.evaluate("arctan(hypotenuse_array) / RADIANS_PER_DEGREE")
                            del hypotenuse_array
                            output_band.WriteArray(slope_array)
                            del slope_array
                            
                        elif band_file_number == 30: # Aspect
                            # Convert angles from conventional radians to compass heading 0-360
                            aspect_array = numexpr.evaluate("(450 - arctan2(dzdy_array, -dzdx_array) / RADIANS_PER_DEGREE) % 360")
                            output_band.WriteArray(aspect_array)
                            del aspect_array

                        if nodata_value is not None:
                            output_band.SetNoDataValue(nodata_value)
                        output_band.FlushCache()
                    
                    #===========================================================
                    # # This is not strictly necessary - copy metadata to output dataset
                    # output_dataset_metadata = temp_dataset.GetMetadata()
                    # if output_dataset_metadata:
                    #    output_dataset.SetMetadata(output_dataset_metadata) 
                    #    log_multiline(logger.debug, output_dataset_metadata, 'output_dataset_metadata', '\t')    
                    #===========================================================
                    
                    output_dataset.FlushCache()
                    del output_dataset
                    logger.info('Finished writing dataset %s', tile_output_path)
                    
                    tile_size = self.getFileSizeMB(tile_output_path)
    
                    sql = """-- Insert new tile_footprint record if necessary
    insert into tile_footprint (
      x_index, 
      y_index, 
      tile_type_id, 
      x_min, 
      y_min, 
      x_max, 
      y_max
      )
    select
      %(x_index)s, 
      %(y_index)s, 
      %(tile_type_id)s, 
      %(x_min)s, 
      %(y_min)s, 
      %(x_max)s, 
      %(y_max)s
    where not exists
      (select 
        x_index, 
        y_index, 
        tile_type_id
      from tile_footprint
      where x_index = %(x_index)s 
        and y_index = %(y_index)s 
        and tile_type_id = %(tile_type_id)s);
    
    -- Update any existing tile record
    update tile
    set 
      tile_pathname = %(tile_pathname)s,
      tile_class_id = %(tile_class_id)s,
      tile_size = %(tile_size)s,
      ctime = now()
    where 
      x_index = %(x_index)s
      and y_index = %(y_index)s
      and tile_type_id = %(tile_type_id)s
      and dataset_id = %(dataset_id)s;
    
    -- Insert new tile record if necessary
    insert into tile (
      tile_id,
      x_index,
      y_index,
      tile_type_id,
      dataset_id,
      tile_pathname,
      tile_class_id,
      tile_size,
      ctime
      )  
    select
      nextval('tile_id_seq'::regclass),
      %(x_index)s,
      %(y_index)s,
      %(tile_type_id)s,
      %(dataset_id)s,
      %(tile_pathname)s,
      %(tile_class_id)s,
      %(tile_size)s,
      now()
    where not exists
      (select tile_id
      from tile
      where 
        x_index = %(x_index)s
        and y_index = %(y_index)s
        and tile_type_id = %(tile_type_id)s
        and dataset_id = %(dataset_id)s
      );
    """  
                    params = {'x_index': x_index,
                              'y_index': y_index,
                              'tile_type_id': tile_type_info['tile_type_id'],
                              'x_min': tile_extents[0], 
                              'y_min': tile_extents[1], 
                              'x_max': tile_extents[2], 
                              'y_max': tile_extents[3],
                              'dataset_id': dataset_id, 
                              'tile_pathname': tile_output_path,
                              'tile_class_id': 1,
                              'tile_size': tile_size
                              }
                    
                    log_multiline(logger.debug, db_cursor.mogrify(sql, params), 'SQL', '\t')
                    db_cursor.execute(sql, params)
                            
                    self.db_connection.commit()  
                finally:
                    self.unlock_object(tile_output_path)
                
        logger.info('Finished creating all tiles')

Example 22

Project: calibre-kobo-driver
Source File: common.py
View license
def modify_epub(container, filename, metadata=None, opts={}):
    debug_print('modify_epub opts: ' + str(opts))
    # Search for the ePub cover
    found_cover = False
    opf = container.opf
    cover_meta_node = opf.xpath('./opf:metadata/opf:meta[@name="cover"]',
                                namespaces=OPF_NAMESPACES)
    if len(cover_meta_node) > 0:
        cover_meta_node = cover_meta_node[0]
        cover_id = cover_meta_node.attrib[
            "content"] if "content" in cover_meta_node.attrib else None
        if cover_id is not None:
            debug_print(
                "KoboTouchExtended:common:modify_epub:Found cover image ID "
                "'{0}'".format(cover_id))
            cover_node = opf.xpath(
                './opf:manifest/opf:item[@id="{0}"]'.format(cover_id),
                namespaces=OPF_NAMESPACES)
            if len(cover_node) > 0:
                cover_node = cover_node[0]
                if "properties" not in cover_node.attrib or cover_node.attrib[
                        "properties"] != "cover-image":
                    debug_print(
                        "KoboTouchExtended:common:modify_epub:Setting cover-image property")
                    cover_node.set("properties", "cover-image")
                    container.dirty(container.opf_name)
                    found_cover = True
    # It's possible that the cover image can't be detected this way. Try looking for the cover image ID in the OPF manifest.
    if not found_cover:
        debug_print(
            "KoboTouchExtended:common:modify_epub:Looking for cover image in OPF manifest")
        node_list = opf.xpath(
            './opf:manifest/opf:item[(translate(@id, \'ABCDEFGHIJKLMNOPQRSTUVWXYZ\', \'abcdefghijklmnopqrstuvwxyz\')="cover" or starts-with(translate(@id, \'ABCDEFGHIJKLMNOPQRSTUVWXYZ\', \'abcdefghijklmnopqrstuvwxyz\'), "cover")) and starts-with(@media-type, "image")]',
            namespaces=OPF_NAMESPACES)
        if len(node_list) > 0:
            node = node_list[0]
            if "properties" not in node.attrib or node.attrib[
                    "properties"] != 'cover-image':
                debug_print(
                    "KoboTouchExtended:common:modify_epub:Setting cover-image")
                node.set("properties", "cover-image")
                container.dirty(container.opf_name)
                found_cover = True

    # Because of the changes made to the markup here, cleanup needs to be done before any other content file processing
    container.forced_cleanup()
    if 'clean_markup' in opts and opts['clean_markup'] is True:
        container.clean_markup()

    # Hyphenate files?
    if 'no-hyphens' in opts and opts['no-hyphens'] is True:
        nohyphen_css = PersistentTemporaryFile(suffix="_nohyphen",
                                               prefix="kepub_")
        nohyphen_css.write(get_resources("css/no-hyphens.css"))
        nohyphen_css.close()
        css_path = os.path.basename(container.copy_file_to_container(
            nohyphen_css.name, name='kte-css/no-hyphens.css'))
        container.add_content_file_reference("kte-css/{0}".format(css_path))
    elif 'hyphenate' in opts and opts['hyphenate'] is True:
        if ('replace_lang' not in opts or
                opts['replace_lang'] is not True) or (
                    metadata is not None and
                    metadata.language == NULL_VALUES['language']):
            debug_print(
                "KoboTouchExtended:common:modify_epub:WARNING - Hyphenation is enabled but not overriding content file language. Hyphenation may use the wrong dictionary.")
        hyphenation_css = PersistentTemporaryFile(suffix='_hyphenate',
                                                  prefix='kepub_')
        hyphenation_css.write(get_resources('css/hyphenation.css'))
        hyphenation_css.close()
        css_path = os.path.basename(
            container.copy_file_to_container(hyphenation_css.name,
                                             name='kte-css/hyphenation.css'))
        container.add_content_file_reference("kte-css/{0}".format(css_path))

    # Override content file language
    if 'replace_lang' in opts and opts['replace_lang'] is True and (
            metadata is not None and
            metadata.language != NULL_VALUES["language"]):
        # First override for the OPF file
        lang_node = container.opf_xpath('//opf:metadata/dc:language')
        if len(lang_node) > 0:
            debug_print(
                "KoboTouchExtended:common:modify_epub:Overriding OPF language")
            lang_node = lang_node[0]
            lang_node.text = metadata.language
        else:
            debug_print(
                "KoboTouchExtended:common:modify_epub:Setting OPF language")
            metadata_node = container.opf_xpath('//opf:metadata')[0]
            lang_node = metadata_node.makeelement("{%s}language" %
                                                  OPF_NAMESPACES['dc'])
            lang_node.text = metadata.language
            container.insert_into_xml(metadata_node, lang_node)
        container.dirty(container.opf_name)

        # Now override for content files
        for name in container.get_html_names():
            debug_print(
                "KoboTouchExtended:common:modify_epub:Overriding content file language :: {0}".format(
                    name))
            root = container.parsed(name)
            root.attrib["{%s}lang" % XML_NAMESPACE] = metadata.language
            root.attrib["lang"] = metadata.language

    # Now smarten punctuation
    if 'smarten_punctuation' in opts and opts['smarten_punctuation'] is True:
        container.smarten_punctuation()

    if 'extended_kepub_features' in opts and opts[
            'extended_kepub_features'] is True:
        if metadata is not None:
            debug_print(
                "KoboTouchExtended:common:modify_epub:Adding extended Kobo features to {0} by {1}".format(
                    metadata.title, ' and '.join(metadata.authors)))
        # Add the Kobo span tags
        container.add_kobo_spans()
        # Add the Kobo style hacks div tags
        container.add_kobo_divs()

        skip_js = False
        # Check to see if there's already a kobo*.js in the ePub
        for name in container.name_path_map:
            if kobo_js_re.match(name):
                skip_js = True
                break
        if not skip_js:
            if os.path.isfile(reference_kepub):
                reference_container = EpubContainer(reference_kepub,
                                                    default_log)
                for name in reference_container.name_path_map:
                    if kobo_js_re.match(name):
                        jsname = container.copy_file_to_container(
                            os.path.join(reference_container.root, name),
                            name='kobo.js')
                        container.add_content_file_reference(jsname)
                        break

        # Add the Kobo style hacks
        stylehacks_css = PersistentTemporaryFile(suffix='_stylehacks',
                                                 prefix='kepub_')
        stylehacks_css.write(get_resources('css/style-hacks.css'))
        stylehacks_css.close()
        css_path = os.path.basename(
            container.copy_file_to_container(stylehacks_css.name,
                                             name='kte-css/stylehacks.css'))
        container.add_content_file_reference("kte-css/{0}".format(css_path))
    os.unlink(filename)
    container.commit(filename)

Example 23

Project: LASIF
Source File: validator.py
View license
    def _validate_event_files(self):
        """
        Validates all event files in the currently active project.

        The following tasks are performed:
            * Validate against QuakeML 1.2 scheme.
            * Check for duplicate ids amongst all QuakeML files.
            * Make sure they contain at least one origin, magnitude and focal
              mechanism object.
            * Some simply sanity checks so that the event depth is reasonable
              and the moment tensor values as well. This is rather fragile and
              mainly intended to detect values specified in wrong units.
            * Events that are too close in time. Events that are less then one
              hour apart can in general not be used for adjoint tomography.
              This will naturally also detect duplicate events.
        """
        import collections
        import itertools
        import math
        from obspy import read_events
        from obspy.io.quakeml.core import _validate as validate_quakeml
        from lxml import etree

        print "Validating %i event files ..." % self.comm.events.count()

        # Start with the schema validation.
        print "\tValidating against QuakeML 1.2 schema ",
        all_valid = True
        for event in self.comm.events.get_all_events().values():
            filename = event["filename"]
            self._flush_point()
            if validate_quakeml(filename) is not True:
                all_valid = False
                msg = (
                    "ERROR: "
                    "The QuakeML file '{basename}' did not validate against "
                    "the QuakeML 1.2 schema. Unfortunately the error messages "
                    "delivered by lxml are not useful at all. To get useful "
                    "error messages make sure jing is installed "
                    "('brew install jing' (OSX) or "
                    "'sudo apt-get install jing' (Debian/Ubuntu)) and "
                    "execute the following command:\n\n"
                    "\tjing http://quake.ethz.ch/schema/rng/QuakeML-1.2.rng "
                    "{filename}\n\n"
                    "Alternatively you could also use the "
                    "'lasif add_spud_event' command to redownload the event "
                    "if it is in the GCMT "
                    "catalog.\n\n").format(
                    basename=os.path.basename(filename),
                    filename=os.path.relpath(filename))
                self._add_report(msg)
        if all_valid is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

        # Now check for duplicate public IDs.
        print "\tChecking for duplicate public IDs ",
        ids = collections.defaultdict(list)
        for event in self.comm.events.get_all_events().values():
            filename = event["filename"]
            self._flush_point()
            # Now walk all files and collect all public ids. Each should be
            # unique!
            with open(filename, "rt") as fh:
                for event, elem in etree.iterparse(fh, events=("start",)):
                    if "publicID" not in elem.keys() or \
                            elem.tag.endswith("eventParameters"):
                        continue
                    ids[elem.get("publicID")].append(filename)
        ids = {key: list(set(value)) for (key, value) in ids.iteritems()
               if len(value) > 1}
        if not ids:
            self._print_ok_message()
        else:
            self._print_fail_message()
            self._add_report(
                "Found the following duplicate publicIDs:\n" +
                "\n".join(["\t%s in files: %s" % (
                    id_string,
                    ", ".join([os.path.basename(i) for i in faulty_files]))
                    for id_string, faulty_files in ids.iteritems()]),
                error_count=len(ids))

        def print_warning(filename, message):
            self._add_report("WARNING: File '{event_name}' "
                             "contains {msg}.\n".format(
                                 event_name=os.path.basename(filename),
                                 msg=message))

        # Performing simple sanity checks.
        print "\tPerforming some basic sanity checks ",
        all_good = True
        for event in self.comm.events.get_all_events().values():
            filename = event["filename"]
            self._flush_point()
            cat = read_events(filename)
            filename = os.path.basename(filename)
            # Check that all files contain exactly one event!
            if len(cat) != 1:
                all_good = False
                print_warning(filename, "%i events instead of only one." %
                              len(cat))
            event = cat[0]

            # Sanity checks related to the origin.
            if not event.origins:
                all_good = False
                print_warning(filename, "no origin")
                continue
            origin = event.preferred_origin() or event.origins[0]
            if (origin.depth % 100.0):
                all_good = False
                print_warning(
                    filename, "a depth of %.1f meters. This kind of accuracy "
                              "seems unrealistic. The depth in the QuakeML "
                              "file has to be specified in meters. Checking "
                              "all other QuakeML files for the correct units "
                              "might be a good idea"
                    % origin.depth)
            if (origin.depth > (800.0 * 1000.0)):
                all_good = False
                print_warning(filename, "a depth of more than 800 km. This is"
                                        " likely wrong.")

            # Sanity checks related to the magnitude.
            if not event.magnitudes:
                all_good = False
                print_warning(filename, "no magnitude")
                continue

            # Sanity checks related to the focal mechanism.
            if not event.focal_mechanisms:
                all_good = False
                print_warning(filename, "no focal mechanism")
                continue

            focmec = event.preferred_focal_mechanism() or \
                event.focal_mechanisms[0]
            if not hasattr(focmec, "moment_tensor") or \
                    not focmec.moment_tensor:
                all_good = False
                print_warning(filename, "no moment tensor")
                continue

            mt = focmec.moment_tensor
            if not hasattr(mt, "tensor") or \
                    not mt.tensor:
                all_good = False
                print_warning(filename, "no actual moment tensor")
                continue
            tensor = mt.tensor

            # Convert the moment tensor to a magnitude and see if it is
            # reasonable.
            mag_in_file = event.preferred_magnitude() or event.magnitudes[0]
            mag_in_file = mag_in_file.mag
            M_0 = 1.0 / math.sqrt(2.0) * math.sqrt(
                tensor.m_rr ** 2 + tensor.m_tt ** 2 + tensor.m_pp ** 2)
            magnitude = 2.0 / 3.0 * math.log10(M_0) - 6.0
            # Use some buffer to account for different magnitudes.
            if not (mag_in_file - 1.0) < magnitude < (mag_in_file + 1.0):
                all_good = False
                print_warning(
                    filename, "a moment tensor that would result in a moment "
                              "magnitude of %.2f. The magnitude specified in "
                              "the file is %.2f. Please check that all "
                              "components of the tensor are in Newton * meter"
                    % (magnitude, mag_in_file))

        if all_good is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

        # Collect event times
        event_infos = self.comm.events.get_all_events().values()

        # Now check the time distribution of events.
        print "\tChecking for duplicates and events too close in time %s" % \
              (self.comm.events.count() * "."),
        all_good = True
        # Sort the events by time.
        event_infos = sorted(event_infos, key=lambda x: x["origin_time"])
        # Loop over adjacent indices.
        a, b = itertools.tee(event_infos)
        next(b, None)
        for event_1, event_2 in itertools.izip(a, b):
            time_diff = abs(event_2["origin_time"] - event_1["origin_time"])
            # If time difference is under one hour, it could be either a
            # duplicate event or interfering events.
            if time_diff <= 3600.0:
                all_good = False
                self._add_report(
                    "WARNING: "
                    "The time difference between events '{file_1}' and "
                    "'{file_2}' is only {diff:.1f} minutes. This could "
                    "be either due to a duplicate event or events that have "
                    "interfering waveforms.\n".format(
                        file_1=event_1["filename"],
                        file_2=event_2["filename"],
                        diff=time_diff / 60.0))
        if all_good is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

        # Check that all events fall within the chosen boundaries.
        print "\tAssure all events are in chosen domain %s" % \
              (self.comm.events.count() * "."),
        all_good = True
        domain = self.comm.project.domain
        for event in event_infos:
            if domain.point_in_domain(latitude=event["latitude"],
                                      longitude=event["longitude"]):
                continue
            all_good = False
            self._add_report(
                "\nWARNING: "
                "Event '{filename}' is out of bounds of the chosen domain."
                "\n".format(filename=event["filename"]))
        if all_good is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

Example 24

Project: LASIF
Source File: validator.py
View license
    def _validate_event_files(self):
        """
        Validates all event files in the currently active project.

        The following tasks are performed:
            * Validate against QuakeML 1.2 scheme.
            * Check for duplicate ids amongst all QuakeML files.
            * Make sure they contain at least one origin, magnitude and focal
              mechanism object.
            * Some simply sanity checks so that the event depth is reasonable
              and the moment tensor values as well. This is rather fragile and
              mainly intended to detect values specified in wrong units.
            * Events that are too close in time. Events that are less then one
              hour apart can in general not be used for adjoint tomography.
              This will naturally also detect duplicate events.
        """
        import collections
        import itertools
        import math
        from obspy import read_events
        from obspy.io.quakeml.core import _validate as validate_quakeml
        from lxml import etree

        print "Validating %i event files ..." % self.comm.events.count()

        # Start with the schema validation.
        print "\tValidating against QuakeML 1.2 schema ",
        all_valid = True
        for event in self.comm.events.get_all_events().values():
            filename = event["filename"]
            self._flush_point()
            if validate_quakeml(filename) is not True:
                all_valid = False
                msg = (
                    "ERROR: "
                    "The QuakeML file '{basename}' did not validate against "
                    "the QuakeML 1.2 schema. Unfortunately the error messages "
                    "delivered by lxml are not useful at all. To get useful "
                    "error messages make sure jing is installed "
                    "('brew install jing' (OSX) or "
                    "'sudo apt-get install jing' (Debian/Ubuntu)) and "
                    "execute the following command:\n\n"
                    "\tjing http://quake.ethz.ch/schema/rng/QuakeML-1.2.rng "
                    "{filename}\n\n"
                    "Alternatively you could also use the "
                    "'lasif add_spud_event' command to redownload the event "
                    "if it is in the GCMT "
                    "catalog.\n\n").format(
                    basename=os.path.basename(filename),
                    filename=os.path.relpath(filename))
                self._add_report(msg)
        if all_valid is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

        # Now check for duplicate public IDs.
        print "\tChecking for duplicate public IDs ",
        ids = collections.defaultdict(list)
        for event in self.comm.events.get_all_events().values():
            filename = event["filename"]
            self._flush_point()
            # Now walk all files and collect all public ids. Each should be
            # unique!
            with open(filename, "rt") as fh:
                for event, elem in etree.iterparse(fh, events=("start",)):
                    if "publicID" not in elem.keys() or \
                            elem.tag.endswith("eventParameters"):
                        continue
                    ids[elem.get("publicID")].append(filename)
        ids = {key: list(set(value)) for (key, value) in ids.iteritems()
               if len(value) > 1}
        if not ids:
            self._print_ok_message()
        else:
            self._print_fail_message()
            self._add_report(
                "Found the following duplicate publicIDs:\n" +
                "\n".join(["\t%s in files: %s" % (
                    id_string,
                    ", ".join([os.path.basename(i) for i in faulty_files]))
                    for id_string, faulty_files in ids.iteritems()]),
                error_count=len(ids))

        def print_warning(filename, message):
            self._add_report("WARNING: File '{event_name}' "
                             "contains {msg}.\n".format(
                                 event_name=os.path.basename(filename),
                                 msg=message))

        # Performing simple sanity checks.
        print "\tPerforming some basic sanity checks ",
        all_good = True
        for event in self.comm.events.get_all_events().values():
            filename = event["filename"]
            self._flush_point()
            cat = read_events(filename)
            filename = os.path.basename(filename)
            # Check that all files contain exactly one event!
            if len(cat) != 1:
                all_good = False
                print_warning(filename, "%i events instead of only one." %
                              len(cat))
            event = cat[0]

            # Sanity checks related to the origin.
            if not event.origins:
                all_good = False
                print_warning(filename, "no origin")
                continue
            origin = event.preferred_origin() or event.origins[0]
            if (origin.depth % 100.0):
                all_good = False
                print_warning(
                    filename, "a depth of %.1f meters. This kind of accuracy "
                              "seems unrealistic. The depth in the QuakeML "
                              "file has to be specified in meters. Checking "
                              "all other QuakeML files for the correct units "
                              "might be a good idea"
                    % origin.depth)
            if (origin.depth > (800.0 * 1000.0)):
                all_good = False
                print_warning(filename, "a depth of more than 800 km. This is"
                                        " likely wrong.")

            # Sanity checks related to the magnitude.
            if not event.magnitudes:
                all_good = False
                print_warning(filename, "no magnitude")
                continue

            # Sanity checks related to the focal mechanism.
            if not event.focal_mechanisms:
                all_good = False
                print_warning(filename, "no focal mechanism")
                continue

            focmec = event.preferred_focal_mechanism() or \
                event.focal_mechanisms[0]
            if not hasattr(focmec, "moment_tensor") or \
                    not focmec.moment_tensor:
                all_good = False
                print_warning(filename, "no moment tensor")
                continue

            mt = focmec.moment_tensor
            if not hasattr(mt, "tensor") or \
                    not mt.tensor:
                all_good = False
                print_warning(filename, "no actual moment tensor")
                continue
            tensor = mt.tensor

            # Convert the moment tensor to a magnitude and see if it is
            # reasonable.
            mag_in_file = event.preferred_magnitude() or event.magnitudes[0]
            mag_in_file = mag_in_file.mag
            M_0 = 1.0 / math.sqrt(2.0) * math.sqrt(
                tensor.m_rr ** 2 + tensor.m_tt ** 2 + tensor.m_pp ** 2)
            magnitude = 2.0 / 3.0 * math.log10(M_0) - 6.0
            # Use some buffer to account for different magnitudes.
            if not (mag_in_file - 1.0) < magnitude < (mag_in_file + 1.0):
                all_good = False
                print_warning(
                    filename, "a moment tensor that would result in a moment "
                              "magnitude of %.2f. The magnitude specified in "
                              "the file is %.2f. Please check that all "
                              "components of the tensor are in Newton * meter"
                    % (magnitude, mag_in_file))

        if all_good is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

        # Collect event times
        event_infos = self.comm.events.get_all_events().values()

        # Now check the time distribution of events.
        print "\tChecking for duplicates and events too close in time %s" % \
              (self.comm.events.count() * "."),
        all_good = True
        # Sort the events by time.
        event_infos = sorted(event_infos, key=lambda x: x["origin_time"])
        # Loop over adjacent indices.
        a, b = itertools.tee(event_infos)
        next(b, None)
        for event_1, event_2 in itertools.izip(a, b):
            time_diff = abs(event_2["origin_time"] - event_1["origin_time"])
            # If time difference is under one hour, it could be either a
            # duplicate event or interfering events.
            if time_diff <= 3600.0:
                all_good = False
                self._add_report(
                    "WARNING: "
                    "The time difference between events '{file_1}' and "
                    "'{file_2}' is only {diff:.1f} minutes. This could "
                    "be either due to a duplicate event or events that have "
                    "interfering waveforms.\n".format(
                        file_1=event_1["filename"],
                        file_2=event_2["filename"],
                        diff=time_diff / 60.0))
        if all_good is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

        # Check that all events fall within the chosen boundaries.
        print "\tAssure all events are in chosen domain %s" % \
              (self.comm.events.count() * "."),
        all_good = True
        domain = self.comm.project.domain
        for event in event_infos:
            if domain.point_in_domain(latitude=event["latitude"],
                                      longitude=event["longitude"]):
                continue
            all_good = False
            self._add_report(
                "\nWARNING: "
                "Event '{filename}' is out of bounds of the chosen domain."
                "\n".format(filename=event["filename"]))
        if all_good is True:
            self._print_ok_message()
        else:
            self._print_fail_message()

Example 25

Project: nsec3map
Source File: map.py
View license
def parse_arguments(argv):
    long_opts = [
            'aggressive=',
            'auto',
            'binary',
            'continue=',
            'end=',
            'help',
            'ignore-overlapping',
            'input=',
            'label-counter=',
            'ldh',
            'limit-rate=',
            'max-retries=',
            'mixed',
            'nsec',
            'nsec3',
            'omit-soa-check',
            'output=',
            'predict',
            'processes=',
            'query-mode=',
            'queue-element-size=',
            'quiet',
            'start=',
            'timeout=',
            'no-openssl',
            'verbose',
            'version'
    ]
    options = default_options()
    opts = '3AMNabc:e:f:hi:lm:no:pqs:v'
    try:
        opts, args = getopt.gnu_getopt(argv[1:], opts, long_opts)
    except getopt.GetoptError, err:
        log.fatal_exit(2, err, "\n", "Try `",
                str(os.path.basename(argv[0])), 
                " --help' for more information.")

    for opt, arg in opts:
        if opt in ('-h', '--help'):
            usage(os.path.basename(argv[0]))
            sys.exit(0)

        elif opt in ('-a' '--auto'):
            options['zone_type'] = 'auto'

        elif opt in ('-n', '--nsec'):
            options['zone_type'] = 'nsec'

        elif opt in ('-3', '--nsec3'):
            options['zone_type'] = 'nsec3'

        elif opt in ('-c', '--continue'):
            options['input'] = options['output'] = arg

        elif opt in ('-i', '--input'):
            options['input'] = arg

        elif opt in ('-o', '--output'):
            options['output'] = arg

        elif opt in ('--label-counter',):
            try:
                options['label_counter'] = long(arg, 0)
            except ValueError:
                invalid_argument(opt, arg)
            if options['label_counter']  < 0:
                invalid_argument(opt, arg)

        elif opt in ('--ignore-overlapping',):
            options['ignore_overlapping'] = True

        elif opt in ('-m', '--query-mode'):
            if arg not in ('mixed', 'NSEC', 'A'):
                invalid_argument(opt, arg)
            options['query_mode'] = arg

        elif opt in ('-M', '--mixed'):
            options['query_mode']  = 'mixed'

        elif opt in ('-A',):
            options['query_mode']  = 'A'

        elif opt in ('-N',):
            options['query_mode']  = 'NSEC'

        elif opt in ('-l', '--ldh'):
            options['query_chars'] = 'ldh'

        elif opt in ('-b', '--binary'):
            options['query_chars'] = 'binary'

        elif opt in ('-e', '--end'):
            options['end'] = arg

        elif opt in ('--limit-rate',):
            try:
                options['query_interval'] = _query_interval(arg)
            except ValueError:
                invalid_argument(opt, arg)

        elif opt in ('--max-retries',):
            try:
                options['max_retries'] = int(arg)
            except ValueError:
                invalid_argument(opt, arg)
            if options['max_retries'] < -1:
                invalid_argument(opt, arg)


        elif opt in ('--omit-soa-check',):
            options['soa_check'] = False

        elif opt in ('-f', '--aggressive',):
            try:
                options['aggressive'] = int(arg)
            except ValueError:
                invalid_argument(opt, arg)
            if options['aggressive'] < 1:
                invalid_argument(opt, arg)

        elif opt in ('-p', '--predict',):
            options['predict'] = True

        elif opt in ('--processes',):
            try:
                options['processes'] = int(arg)
            except ValueError:
                invalid_argument(opt, arg)
            if options['processes'] < 1:
                invalid_argument(opt, arg)


        elif opt in ('--queue-element-size',):
            try:
                options['queue_element_size'] = int(arg)
            except ValueError:
                invalid_argument(opt, arg)
            if options['queue_element_size']  < 1:
                invalid_argument(opt, arg)

        elif opt in ('-q', '--quiet'):
            options['progress'] = False

        elif opt in ('-s', '--start'):
            options['start'] = arg
            
        elif opt in ('--timeout',):
            try:
                options['timeout'] = int(arg)
            except ValueError:
                invalid_argument(opt, arg)
            if options['timeout']  < 1:
                invalid_argument(opt, arg)

        elif opt in ('--no-openssl',):
            options['use_openssl'] = False

        elif opt in ('-v', '--verbose'):
            log.logger.loglevel += 1

        elif opt in ('--version'):
            version()
            sys.exit(0)

        else:
            invalid_argument(opt, "")

    if len(args) < 1:
        log.fatal_exit(2, 'missing arguments', "\n", "Try `",
                str(os.path.basename(argv[0])), 
                " --help' for more information.")
    else:
        zone = n3map.name.fqdn_from_text(args[-1])
        if len(args) >= 2:
            ns_names = args[:-1]
            nslist = queryprovider.nameserver_from_text(*ns_names)
        else:
            ns_names = query_ns_records(zone)
            nslist = queryprovider.nameserver_from_text(*ns_names)
            for ns in nslist:
                log.info("using nameserver: ", str(ns))

    return (options, nslist, zone)

Example 26

Project: baidu-fuse
Source File: baidufuse.py
View license
    def _add_file_to_buffer(self, path,file_info):
        foo = File()
        foo['st_ctime'] = file_info['local_ctime']
        foo['st_mtime'] = file_info['local_mtime']
        foo['st_mode'] = (stat.S_IFDIR | 0777) if file_info['isdir'] \
            else (stat.S_IFREG | 0777)
        foo['st_nlink'] = 2 if file_info['isdir'] else 1
        foo['st_size'] = file_info['size']
        self.buffer[path] = foo

    def _del_file_from_buffer(self,path):
        self.buffer.pop(path)

    def getattr(self, path, fh=None):
        #print 'getattr *',path
        # 先看缓存中是否存在该文件

        if not self.buffer.has_key(path):
            print path,'未命中'
            #print self.buffer
            #print self.traversed_folder
            jdata = json.loads(self.disk.meta([path]).content)
            try:
                if 'info' not in jdata:
                    raise FuseOSError(errno.ENOENT)
                if jdata['errno'] != 0:
                    raise FuseOSError(errno.ENOENT)
                file_info = jdata['info'][0]
                self._add_file_to_buffer(path,file_info)
                st = self.buffer[path].getDict()
                return st
            except:
                raise FuseOSError(errno.ENOENT)
        else:
            #print path,'命中'
            return self.buffer[path].getDict()



    def readdir(self, path, offset):
        self.uploadLock.acquire()
        while True:
            try:
                foo = json.loads(self.disk.list_files(path).text)
                break
            except:
                print 'error'


        files = ['.', '..']
        abs_files = [] # 该文件夹下文件的绝对路径
        for file in foo['list']:
            files.append(file['server_filename'])
            abs_files.append(file['path'])
        # 缓存文件夹下文件信息,批量查询meta info

        # Update:解决meta接口一次不能查询超过100条记录
        # 分成 ceil(file_num / 100.0) 组,利用商群
        if not self.traversed_folder.has_key(path) or self.traversed_folder[path] == False:
            print '正在对',path,'缓存中'
            file_num = len(abs_files)
            group = int(math.ceil(file_num / 100.0))
            for i in range(group):
                obj = [f for n,f in enumerate(abs_files) if n % group == i] #一组数据
                while 1:
                    try:
                        ret = json.loads(self.disk.meta(obj).text)
                        break
                    except:
                        print 'error'

                for file_info in ret['info']:
                    if not self.buffer.has_key(file_info['path']):
                        self._add_file_to_buffer(file_info['path'],file_info)
            #print self.buffer
            print '对',path,'的缓存完成'
            self.traversed_folder[path] = True
        for r in files:
            yield r
        self.uploadLock.release()

    def _update_file_manual(self,path):
        while 1:
            try:
                jdata = json.loads(self.disk.meta([path]).content)
                break
            except:
                print 'error'

        if 'info' not in jdata:
            raise FuseOSError(errno.ENOENT)
        if jdata['errno'] != 0:
            raise FuseOSError(errno.ENOENT)
        file_info = jdata['info'][0]
        self._add_file_to_buffer(path,file_info)


    def rename(self, old, new):
        #logging.debug('* rename',old,os.path.basename(new))
        print '*'*10,'RENAME CALLED',old,os.path.basename(new),type(old),type(new)
        while True:
            try:
                ret = self.disk.rename([(old,os.path.basename(new))]).content
                jdata = json.loads(ret)
                break
            except:
                print 'error'

        if jdata['errno'] != 0:
            # 文件名已存在,删除原文件
            print self.disk.delete([new]).content
            print self.disk.rename([(old,os.path.basename(new))])
        self._update_file_manual(new)
        self.buffer.pop(old)


    def open(self, path, flags):
        self.readLock.acquire()
        print '*'*10,'OPEN CALLED',path,flags
        #print '[****]',path
        """
        Permission denied

        accmode = os.O_RDONLY | os.O_WRONLY | os.O_RDWR
        if (flags & accmode) != os.O_RDONLY:
            raise FuseOSError(errno.EACCES)
        """
        self.fd += 1
        self.readLock.release()
        
        return self.fd

    def create(self, path, mode,fh=None):
        # 创建文件
        # 中文路径有问题
        print '*'*10,'CREATE CALLED',path,mode,type(path)
        #if 'outputstream' not in path:
        tmp_file = tempfile.TemporaryFile('r+w+b')
        foo = self.disk.upload(os.path.dirname(path),tmp_file,os.path.basename(path)).content
        ret = json.loads(foo)
        print ret
        print 'create-not-outputstream',ret
        if ret['path'] != path:
            # 文件已存在
            print '文件已存在'
            raise FuseOSError(errno.EEXIST)
        '''
        else:
            print 'create:',path
            foo = File()
            foo['st_ctime'] = int(time.time())
            foo['st_mtime'] = int(time.time())
            foo['st_mode'] = (stat.S_IFREG | 0777)
            foo['st_nlink'] = 1
            foo['st_size'] = 0
            self.buffer[path] = foo
        '''


        '''
        dict(st_mode=(stat.S_IFREG | mode), st_nlink=1,
                                st_size=0, st_ctime=time.time(), st_mtime=time.time(),
                                st_atime=time.time())
        '''
        self.fd += 1
        return 0

    def write(self, path, data, offset, fp):
        # 上传文件时会调用
        # 4kb ( 4096 bytes ) 每块,data中是块中的数据
        # 最后一块的判断:len(data) < 4096
        # 文件大小 = 最后一块的offset + len(data)

        # 4kb传太慢了,合计成2M传一次

        #print '*'*10,path,offset, len(data)

        def _block_size(stream):
            stream.seek(0,2)
            return stream.tell()

        _BLOCK_SIZE = 16 * 2 ** 20
        # 第一块的任务
        if offset == 0:
            #self.uploadLock.acquire()
            #self.readLock.acquire()
            # 初始化块md5列表
            self.upload_blocks[path] = {'tmp':None,
                                        'blocks':[]}
            # 创建缓冲区临时文件
            tmp_file = tempfile.TemporaryFile('r+w+b')
            self.upload_blocks[path]['tmp'] = tmp_file

        # 向临时文件写入数据,检查是否>= _BLOCK_SIZE 是则上传该块并将临时文件清空
        try:
            tmp = self.upload_blocks[path]['tmp']
        except KeyError:
            return 0
        tmp.write(data)

        if _block_size(tmp) > _BLOCK_SIZE:
            print path,'发生上传'
            tmp.seek(0)
            try:
                foo = self.disk.upload_tmpfile(tmp,callback=ProgressBar()).content
                foofoo = json.loads(foo)
                block_md5 = foofoo['md5']
            except:
                 print foo



            # 在 upload_blocks 中插入本块的 md5
            self.upload_blocks[path]['blocks'].append(block_md5)
            # 创建缓冲区临时文件
            self.upload_blocks[path]['tmp'].close()
            tmp_file = tempfile.TemporaryFile('r+w+b')
            self.upload_blocks[path]['tmp'] = tmp_file
            print '创建临时文件',tmp_file.name

        # 最后一块的任务
        if len(data) < 4096:
            # 检查是否有重名,有重名则删除它
            while True:
                try:
                    foo = self.disk.meta([path]).content
                    foofoo = json.loads(foo)
                    break
                except:
                    print 'error'


            if foofoo['errno'] == 0:
                logging.debug('Deleted the file which has same name.')
                self.disk.delete([path])
            # 看看是否需要上传
            if _block_size(tmp) != 0:
                # 此时临时文件有数据,需要上传
                print path,'发生上传,块末尾,文件大小',_block_size(tmp)
                tmp.seek(0)
                while True:
                    try:
                        foo = self.disk.upload_tmpfile(tmp,callback=ProgressBar()).content
                        foofoo = json.loads(foo)
                        break
                    except:
                        print 'exception, retry.'

                block_md5 = foofoo['md5']
                # 在 upload_blocks 中插入本块的 md5
                self.upload_blocks[path]['blocks'].append(block_md5)

            # 调用 upload_superfile 以合并块文件
            print '合并文件',path,type(path)
            self.disk.upload_superfile(path,self.upload_blocks[path]['blocks'])
            # 删除upload_blocks中数据
            self.upload_blocks.pop(path)
            # 更新本地文件列表缓存
            self._update_file_manual(path)
            #self.readLock.release()
            #self.uploadLock.release()
        return len(data)


    def mkdir(self, path, mode):
        logger.debug("mkdir is:" + path)
        self.disk.mkdir(path)

    def rmdir(self, path):
        logger.debug("rmdir is:" + path)
        self.disk.delete([path])

    def read(self, path, size, offset, fh):
        #print '*'*10,'READ CALLED',path,size,offset
        #logger.debug("read is: " + path)
        paras = {'Range': 'bytes=%s-%s' % (offset, offset + size - 1)}
        while True:
            try:
                foo = self.disk.download(path, headers=paras).content
                return foo
            except:
                pass

    access = None
    statfs = None

Example 27

Project: baidu-fuse
Source File: baidufuse.py
View license
    def _add_file_to_buffer(self, path,file_info):
        foo = File()
        foo['st_ctime'] = file_info['local_ctime']
        foo['st_mtime'] = file_info['local_mtime']
        foo['st_mode'] = (stat.S_IFDIR | 0777) if file_info['isdir'] \
            else (stat.S_IFREG | 0777)
        foo['st_nlink'] = 2 if file_info['isdir'] else 1
        foo['st_size'] = file_info['size']
        self.buffer[path] = foo

    def _del_file_from_buffer(self,path):
        self.buffer.pop(path)

    def getattr(self, path, fh=None):
        #print 'getattr *',path
        # 先看缓存中是否存在该文件

        if not self.buffer.has_key(path):
            print path,'未命中'
            #print self.buffer
            #print self.traversed_folder
            jdata = json.loads(self.disk.meta([path]).content)
            try:
                if 'info' not in jdata:
                    raise FuseOSError(errno.ENOENT)
                if jdata['errno'] != 0:
                    raise FuseOSError(errno.ENOENT)
                file_info = jdata['info'][0]
                self._add_file_to_buffer(path,file_info)
                st = self.buffer[path].getDict()
                return st
            except:
                raise FuseOSError(errno.ENOENT)
        else:
            #print path,'命中'
            return self.buffer[path].getDict()



    def readdir(self, path, offset):
        self.uploadLock.acquire()
        while True:
            try:
                foo = json.loads(self.disk.list_files(path).text)
                break
            except:
                print 'error'


        files = ['.', '..']
        abs_files = [] # 该文件夹下文件的绝对路径
        for file in foo['list']:
            files.append(file['server_filename'])
            abs_files.append(file['path'])
        # 缓存文件夹下文件信息,批量查询meta info

        # Update:解决meta接口一次不能查询超过100条记录
        # 分成 ceil(file_num / 100.0) 组,利用商群
        if not self.traversed_folder.has_key(path) or self.traversed_folder[path] == False:
            print '正在对',path,'缓存中'
            file_num = len(abs_files)
            group = int(math.ceil(file_num / 100.0))
            for i in range(group):
                obj = [f for n,f in enumerate(abs_files) if n % group == i] #一组数据
                while 1:
                    try:
                        ret = json.loads(self.disk.meta(obj).text)
                        break
                    except:
                        print 'error'

                for file_info in ret['info']:
                    if not self.buffer.has_key(file_info['path']):
                        self._add_file_to_buffer(file_info['path'],file_info)
            #print self.buffer
            print '对',path,'的缓存完成'
            self.traversed_folder[path] = True
        for r in files:
            yield r
        self.uploadLock.release()

    def _update_file_manual(self,path):
        while 1:
            try:
                jdata = json.loads(self.disk.meta([path]).content)
                break
            except:
                print 'error'

        if 'info' not in jdata:
            raise FuseOSError(errno.ENOENT)
        if jdata['errno'] != 0:
            raise FuseOSError(errno.ENOENT)
        file_info = jdata['info'][0]
        self._add_file_to_buffer(path,file_info)


    def rename(self, old, new):
        #logging.debug('* rename',old,os.path.basename(new))
        print '*'*10,'RENAME CALLED',old,os.path.basename(new),type(old),type(new)
        while True:
            try:
                ret = self.disk.rename([(old,os.path.basename(new))]).content
                jdata = json.loads(ret)
                break
            except:
                print 'error'

        if jdata['errno'] != 0:
            # 文件名已存在,删除原文件
            print self.disk.delete([new]).content
            print self.disk.rename([(old,os.path.basename(new))])
        self._update_file_manual(new)
        self.buffer.pop(old)


    def open(self, path, flags):
        self.readLock.acquire()
        print '*'*10,'OPEN CALLED',path,flags
        #print '[****]',path
        """
        Permission denied

        accmode = os.O_RDONLY | os.O_WRONLY | os.O_RDWR
        if (flags & accmode) != os.O_RDONLY:
            raise FuseOSError(errno.EACCES)
        """
        self.fd += 1
        self.readLock.release()
        
        return self.fd

    def create(self, path, mode,fh=None):
        # 创建文件
        # 中文路径有问题
        print '*'*10,'CREATE CALLED',path,mode,type(path)
        #if 'outputstream' not in path:
        tmp_file = tempfile.TemporaryFile('r+w+b')
        foo = self.disk.upload(os.path.dirname(path),tmp_file,os.path.basename(path)).content
        ret = json.loads(foo)
        print ret
        print 'create-not-outputstream',ret
        if ret['path'] != path:
            # 文件已存在
            print '文件已存在'
            raise FuseOSError(errno.EEXIST)
        '''
        else:
            print 'create:',path
            foo = File()
            foo['st_ctime'] = int(time.time())
            foo['st_mtime'] = int(time.time())
            foo['st_mode'] = (stat.S_IFREG | 0777)
            foo['st_nlink'] = 1
            foo['st_size'] = 0
            self.buffer[path] = foo
        '''


        '''
        dict(st_mode=(stat.S_IFREG | mode), st_nlink=1,
                                st_size=0, st_ctime=time.time(), st_mtime=time.time(),
                                st_atime=time.time())
        '''
        self.fd += 1
        return 0

    def write(self, path, data, offset, fp):
        # 上传文件时会调用
        # 4kb ( 4096 bytes ) 每块,data中是块中的数据
        # 最后一块的判断:len(data) < 4096
        # 文件大小 = 最后一块的offset + len(data)

        # 4kb传太慢了,合计成2M传一次

        #print '*'*10,path,offset, len(data)

        def _block_size(stream):
            stream.seek(0,2)
            return stream.tell()

        _BLOCK_SIZE = 16 * 2 ** 20
        # 第一块的任务
        if offset == 0:
            #self.uploadLock.acquire()
            #self.readLock.acquire()
            # 初始化块md5列表
            self.upload_blocks[path] = {'tmp':None,
                                        'blocks':[]}
            # 创建缓冲区临时文件
            tmp_file = tempfile.TemporaryFile('r+w+b')
            self.upload_blocks[path]['tmp'] = tmp_file

        # 向临时文件写入数据,检查是否>= _BLOCK_SIZE 是则上传该块并将临时文件清空
        try:
            tmp = self.upload_blocks[path]['tmp']
        except KeyError:
            return 0
        tmp.write(data)

        if _block_size(tmp) > _BLOCK_SIZE:
            print path,'发生上传'
            tmp.seek(0)
            try:
                foo = self.disk.upload_tmpfile(tmp,callback=ProgressBar()).content
                foofoo = json.loads(foo)
                block_md5 = foofoo['md5']
            except:
                 print foo



            # 在 upload_blocks 中插入本块的 md5
            self.upload_blocks[path]['blocks'].append(block_md5)
            # 创建缓冲区临时文件
            self.upload_blocks[path]['tmp'].close()
            tmp_file = tempfile.TemporaryFile('r+w+b')
            self.upload_blocks[path]['tmp'] = tmp_file
            print '创建临时文件',tmp_file.name

        # 最后一块的任务
        if len(data) < 4096:
            # 检查是否有重名,有重名则删除它
            while True:
                try:
                    foo = self.disk.meta([path]).content
                    foofoo = json.loads(foo)
                    break
                except:
                    print 'error'


            if foofoo['errno'] == 0:
                logging.debug('Deleted the file which has same name.')
                self.disk.delete([path])
            # 看看是否需要上传
            if _block_size(tmp) != 0:
                # 此时临时文件有数据,需要上传
                print path,'发生上传,块末尾,文件大小',_block_size(tmp)
                tmp.seek(0)
                while True:
                    try:
                        foo = self.disk.upload_tmpfile(tmp,callback=ProgressBar()).content
                        foofoo = json.loads(foo)
                        break
                    except:
                        print 'exception, retry.'

                block_md5 = foofoo['md5']
                # 在 upload_blocks 中插入本块的 md5
                self.upload_blocks[path]['blocks'].append(block_md5)

            # 调用 upload_superfile 以合并块文件
            print '合并文件',path,type(path)
            self.disk.upload_superfile(path,self.upload_blocks[path]['blocks'])
            # 删除upload_blocks中数据
            self.upload_blocks.pop(path)
            # 更新本地文件列表缓存
            self._update_file_manual(path)
            #self.readLock.release()
            #self.uploadLock.release()
        return len(data)


    def mkdir(self, path, mode):
        logger.debug("mkdir is:" + path)
        self.disk.mkdir(path)

    def rmdir(self, path):
        logger.debug("rmdir is:" + path)
        self.disk.delete([path])

    def read(self, path, size, offset, fh):
        #print '*'*10,'READ CALLED',path,size,offset
        #logger.debug("read is: " + path)
        paras = {'Range': 'bytes=%s-%s' % (offset, offset + size - 1)}
        while True:
            try:
                foo = self.disk.download(path, headers=paras).content
                return foo
            except:
                pass

    access = None
    statfs = None

Example 28

Project: baidu-fuse
Source File: baidufuse2.py
View license
    def _add_file_to_buffer(self, path,file_info):
        foo = File()
        foo['st_ctime'] = file_info['local_ctime']
        foo['st_mtime'] = file_info['local_mtime']
        foo['st_mode'] = (stat.S_IFDIR | 0777) if file_info['isdir'] \
            else (stat.S_IFREG | 0777)
        foo['st_nlink'] = 2 if file_info['isdir'] else 1
        foo['st_size'] = file_info['size']
        self.buffer[path] = foo

    def _del_file_from_buffer(self,path):
        self.buffer.pop(path)

    def getattr(self, path, fh=None):
        #print 'getattr *',path
        # 先看缓存中是否存在该文件

        if not self.buffer.has_key(path):
            print path,'未命中'
            #print self.buffer
            #print self.traversed_folder
            jdata = json.loads(self.disk.meta([path]).content)
            try:
                if 'info' not in jdata:
                    raise FuseOSError(errno.ENOENT)
                if jdata['errno'] != 0:
                    raise FuseOSError(errno.ENOENT)
                file_info = jdata['info'][0]
                self._add_file_to_buffer(path,file_info)
                st = self.buffer[path].getDict()
                return st
            except:
                raise FuseOSError(errno.ENOENT)
        else:
            #print path,'命中'
            return self.buffer[path].getDict()



    def readdir(self, path, offset):
        
        while True:
            try:
                logger.debug(u'读取目录' + path)
                foo = json.loads(self.disk.list_files(path).text)
                break
            except Exception as s:
                logger.error('error',str(s))



        files = ['.', '..']
        abs_files = [] # 该文件夹下文件的绝对路径
        for file in foo['list']:
            files.append(file['server_filename'])
            abs_files.append(file['path'])
        # 缓存文件夹下文件信息,批量查询meta info

        # Update:解决meta接口一次不能查询超过100条记录
        # 分成 ceil(file_num / 100.0) 组,利用商群
        if not self.traversed_folder.has_key(path) or self.traversed_folder[path] == False:
            logger.debug(u'正在对'+path+u'缓存中')
            file_num = len(abs_files)
            group = int(math.ceil(file_num / 100.0))
            for i in range(group):
                obj = [f for n,f in enumerate(abs_files) if n % group == i] #一组数据
                while 1:
                    try:
                        ret = json.loads(self.disk.meta(obj).text)
                        break
                    except:
                        print 'error'

                for file_info in ret['info']:
                    if not self.buffer.has_key(file_info['path']):
                        self._add_file_to_buffer(file_info['path'],file_info)
            #print self.buffer
            print '对',path,'的缓存完成'
            self.traversed_folder[path] = True
        for r in files:
            yield r

    def _update_file_manual(self,path):
        while 1:
            try:
                jdata = json.loads(self.disk.meta([path]).content)
                break
            except:
                print 'error'

        if 'info' not in jdata:
            raise FuseOSError(errno.ENOENT)
        if jdata['errno'] != 0:
            raise FuseOSError(errno.ENOENT)
        file_info = jdata['info'][0]
        self._add_file_to_buffer(path,file_info)


    def rename(self, old, new):
        #logging.debug('* rename',old,os.path.basename(new))
        print '*'*10,'RENAME CALLED',old,os.path.basename(new),type(old),type(new)
        while True:
            try:
                ret = self.disk.rename([(old,os.path.basename(new))]).content
                jdata = json.loads(ret)
                break
            except:
                print 'error'

        if jdata['errno'] != 0:
            # 文件名已存在,删除原文件
            print self.disk.delete([new]).content
            print self.disk.rename([(old,os.path.basename(new))])
        self._update_file_manual(new)
        self.buffer.pop(old)


    def open(self, path, flags):
        self.readLock.acquire()
        print '*'*10,'OPEN CALLED',path,flags
        #print '[****]',path
        """
        Permission denied

        accmode = os.O_RDONLY | os.O_WRONLY | os.O_RDWR
        if (flags & accmode) != os.O_RDONLY:
            raise FuseOSError(errno.EACCES)
        """
        self.fd += 1
        self.readLock.release()
        
        return self.fd

    def create(self, path, mode,fh=None):
        # 创建文件
        # 中文路径有问题
        print '*'*10,'CREATE CALLED',path,mode,type(path)
        #if 'outputstream' not in path:
        tmp_file = tempfile.TemporaryFile('r+w+b')
        foo = self.disk.upload(os.path.dirname(path),tmp_file,os.path.basename(path)).content
        ret = json.loads(foo)
        print ret
        print 'create-not-outputstream',ret
        if ret['path'] != path:
            # 文件已存在
            print '文件已存在'
            raise FuseOSError(errno.EEXIST)
        '''
        else:
            print 'create:',path
            foo = File()
            foo['st_ctime'] = int(time.time())
            foo['st_mtime'] = int(time.time())
            foo['st_mode'] = (stat.S_IFREG | 0777)
            foo['st_nlink'] = 1
            foo['st_size'] = 0
            self.buffer[path] = foo
        '''


        '''
        dict(st_mode=(stat.S_IFREG | mode), st_nlink=1,
                                st_size=0, st_ctime=time.time(), st_mtime=time.time(),
                                st_atime=time.time())
        '''
        self.fd += 1
        return 0

    def write(self, path, data, offset, fp):
        # 上传文件时会调用
        # 4kb ( 4096 bytes ) 每块,data中是块中的数据
        # 最后一块的判断:len(data) < 4096
        # 文件大小 = 最后一块的offset + len(data)

        # 4kb传太慢了,合计成2M传一次

        #print '*'*10,path,offset, len(data)

        def _block_size(stream):
            stream.seek(0,2)
            return stream.tell()

        _BLOCK_SIZE = 16 * 2 ** 20
        # 第一块的任务
        if offset == 0:
            #self.uploadLock.acquire()
            #self.readLock.acquire()
            # 初始化块md5列表
            self.upload_blocks[path] = {'tmp':None,
                                        'blocks':[]}
            # 创建缓冲区临时文件
            tmp_file = tempfile.TemporaryFile('r+w+b')
            self.upload_blocks[path]['tmp'] = tmp_file

        # 向临时文件写入数据,检查是否>= _BLOCK_SIZE 是则上传该块并将临时文件清空
        try:
            tmp = self.upload_blocks[path]['tmp']
        except KeyError:
            return 0
        tmp.write(data)

        if _block_size(tmp) > _BLOCK_SIZE:
            print path,'发生上传'
            tmp.seek(0)
            try:
                foo = self.disk.upload_tmpfile(tmp,callback=ProgressBar()).content
                foofoo = json.loads(foo)
                block_md5 = foofoo['md5']
            except:
                 print foo



            # 在 upload_blocks 中插入本块的 md5
            self.upload_blocks[path]['blocks'].append(block_md5)
            # 创建缓冲区临时文件
            self.upload_blocks[path]['tmp'].close()
            tmp_file = tempfile.TemporaryFile('r+w+b')
            self.upload_blocks[path]['tmp'] = tmp_file
            print '创建临时文件',tmp_file.name

        # 最后一块的任务
        if len(data) < 4096:
            # 检查是否有重名,有重名则删除它
            while True:
                try:
                    foo = self.disk.meta([path]).content
                    foofoo = json.loads(foo)
                    break
                except:
                    print 'error'


            if foofoo['errno'] == 0:
                logger.debug('Deleted the file which has same name.')
                self.disk.delete([path])
            # 看看是否需要上传
            if _block_size(tmp) != 0:
                # 此时临时文件有数据,需要上传
                print path,'发生上传,块末尾,文件大小',_block_size(tmp)
                tmp.seek(0)
                while True:
                    try:
                        foo = self.disk.upload_tmpfile(tmp,callback=ProgressBar()).content
                        foofoo = json.loads(foo)
                        break
                    except:
                        print 'exception, retry.'

                block_md5 = foofoo['md5']
                # 在 upload_blocks 中插入本块的 md5
                self.upload_blocks[path]['blocks'].append(block_md5)

            # 调用 upload_superfile 以合并块文件
            print '合并文件',path,type(path)
            self.disk.upload_superfile(path,self.upload_blocks[path]['blocks'])
            # 删除upload_blocks中数据
            self.upload_blocks.pop(path)
            # 更新本地文件列表缓存
            self._update_file_manual(path)
            #self.readLock.release()
            #self.uploadLock.release()
        return len(data)


    def mkdir(self, path, mode):
        logger.debug("mkdir is:" + path)
        self.disk.mkdir(path)

    def rmdir(self, path):
        logger.debug("rmdir is:" + path)
        self.disk.delete([path])

    def read(self, path, size, offset, fh):
        print '*'*10,'READ CALLED',path,size,offset
        #logger.debug("read is: " + path)
        # 改为由第三方工具下载并每次判断下载的临时文件大小

        if offset == 0:
            tmp = tempfile.mktemp()
            url = self.disk.download_url([path])[0]
            logger.debug('%s started downloader' % url)
            """
            thread = threading.Thread(target=self.downlaoder, args=(url, tmp))
            thread.start()
            while thread.isAlive():
                pass
            """
            number = 5

            cookies = ';'.join(['%s=%s' % (k, v)
                        for k, v in self.disk.session.cookies.items()])
            cmd = 'axel --alternate -n{0} -H "Cookies:{1}" "{2}" -o "{3}"'.format(number, cookies, url, tmp)
            logger.debug('now start axel on %s' % path)
            os.system(cmd)
            logger.debug('axel on %s done.' % path)

            # self.downloader(url, tmp)
            logger.debug('%s downloaded' % url)
            self.downloading_files[path] = (tmp, open(tmp,'rb'))

        file_handler = self.downloading_files[path][1]
        return file_handler.read(size)


        """
        paras = {'Range': 'bytes=%s-%s' % (offset, offset + size - 1)}
        while True:
            try:
                foo = self.disk.download(path, headers=paras).content
                return foo
            except:
                pass
        """

    def downloader(self, url, path):
        number = 5

        cookies = ';'.join(['%s=%s' % (k, v)
                    for k, v in self.disk.session.cookies.items()])
        cmd = 'axel --alternate -n{0} -H "Cookies:{1}" {2} -o {3}'.format(number, cookies, url, path)
        logger.debug('now start axel on %s' % path)
        os.system(cmd)
        logger.debug('axel on %s done.' % path)
        return

    access = None
    statfs = None

Example 29

Project: PytheM
Source File: bdf_proxy.py
View license
    def binaryGrinder(self, binaryFile):
        """
        Feed potential binaries into this function,
        it will return the result PatchedBinary, False, or None
        """
        with open(binaryFile, 'r+b') as f:
            binaryTMPHandle = f.read()

        binaryHeader = binaryTMPHandle[:4]
        result = None

        try:
            if binaryHeader[:2] == 'MZ':  # PE/COFF
                pe = pefile.PE(data=binaryTMPHandle, fast_load=True)
                magic = pe.OPTIONAL_HEADER.Magic
                machineType = pe.FILE_HEADER.Machine

                # update when supporting more than one arch
                if (magic == int('20B', 16) and machineType == 0x8664 and
                   self.WindowsType.lower() in ['all', 'x64']):
                    add_section = False
                    cave_jumping = False
                    if self.WindowsIntelx64['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx64['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx64['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx64['SHELL'],
                                             HOST=self.WindowsIntelx64['HOST'],
                                             PORT=int(self.WindowsIntelx64['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             RUNAS_ADMIN=self.str2bool(self.WindowsIntelx86['RUNAS_ADMIN']),
                                             PATCH_DLL=self.str2bool(self.WindowsIntelx64['PATCH_DLL']),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx64['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.str2bool(self.WindowsIntelx64['ZERO_CERT']),
                                             PATCH_METHOD=self.WindowsIntelx64['PATCH_METHOD'].lower(),
                                             SUPPLIED_BINARY=self.WindowsIntelx64['SUPPLIED_BINARY'],
                                             IDT_IN_CAVE=self.str2bool(self.WindowsIntelx64['IDT_IN_CAVE']),
                                             CODE_SIGN=self.str2bool(self.WindowsIntelx64['CODE_SIGN']),
                                             PREPROCESS=self.str2bool(self.WindowsIntelx64['PREPROCESS']),
                                             )

                    result = targetFile.run_this()

                elif (machineType == 0x14c and
                      self.WindowsType.lower() in ['all', 'x86']):
                    add_section = False
                    cave_jumping = False
                    # add_section wins for cave_jumping
                    # default is single for BDF
                    if self.WindowsIntelx86['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx86['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx86['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True
                        add_section = False

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx86['SHELL'],
                                             HOST=self.WindowsIntelx86['HOST'],
                                             PORT=int(self.WindowsIntelx86['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             RUNAS_ADMIN=self.str2bool(self.WindowsIntelx86['RUNAS_ADMIN']),
                                             PATCH_DLL=self.str2bool(self.WindowsIntelx86['PATCH_DLL']),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx86['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.str2bool(self.WindowsIntelx86['ZERO_CERT']),
                                             PATCH_METHOD=self.WindowsIntelx86['PATCH_METHOD'].lower(),
                                             SUPPLIED_BINARY=self.WindowsIntelx86['SUPPLIED_BINARY'],
                                             XP_MODE=self.str2bool(self.WindowsIntelx86['XP_MODE']),
                                             IDT_IN_CAVE=self.str2bool(self.WindowsIntelx86['IDT_IN_CAVE']),
                                             CODE_SIGN=self.str2bool(self.WindowsIntelx86['CODE_SIGN']),
                                             PREPROCESS=self.str2bool(self.WindowsIntelx86['PREPROCESS']),
                                             )

                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') == '7f454c46':  # ELF

                targetFile = elfbin.elfbin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                if targetFile.class_type == 0x1:
                    # x86CPU Type
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx86['SHELL'],
                                               HOST=self.LinuxIntelx86['HOST'],
                                               PORT=int(self.LinuxIntelx86['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx86['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType,
                                               PREPROCESS=self.str2bool(self.LinuxIntelx86['PREPROCESS']),
                                               )
                    result = targetFile.run_this()
                elif targetFile.class_type == 0x2:
                    # x64
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx64['SHELL'],
                                               HOST=self.LinuxIntelx64['HOST'],
                                               PORT=int(self.LinuxIntelx64['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx64['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType,
                                               PREPROCESS=self.str2bool(self.LinuxIntelx64['PREPROCESS']),
                                               )
                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') in ['cefaedfe', 'cffaedfe', 'cafebabe']:  # Macho
                targetFile = machobin.machobin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                # ONE CHIP SET MUST HAVE PRIORITY in FAT FILE

                if targetFile.FAT_FILE is True:
                    if self.FatPriority == 'x86':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx86['SHELL'],
                                                       HOST=self.MachoIntelx86['HOST'],
                                                       PORT=int(self.MachoIntelx86['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority,
                                                       PREPROCESS=self.str2bool(self.MachoIntelx86['PREPROCESS']),
                                                       )
                        result = targetFile.run_this()

                    elif self.FatPriority == 'x64':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx64['SHELL'],
                                                       HOST=self.MachoIntelx64['HOST'],
                                                       PORT=int(self.MachoIntelx64['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority,
                                                       PREPROCESS=self.str2bool(self.MachoIntelx64['PREPROCESS']),
                                                       )
                        result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x7':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx86['SHELL'],
                                                   HOST=self.MachoIntelx86['HOST'],
                                                   PORT=int(self.MachoIntelx86['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority,
                                                   PREPROCESS=self.str2bool(self.MachoIntelx86['PREPROCESS']),
                                                   )
                    result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x1000007':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx64['SHELL'],
                                                   HOST=self.MachoIntelx64['HOST'],
                                                   PORT=int(self.MachoIntelx64['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority,
                                                   PREPROCESS=self.str2bool(self.MachoIntelx64['PREPROCESS']),
                                                   )
                    result = targetFile.run_this()

            return result

        except Exception as e:
            EnhancedOutput.print_error('binaryGrinder: {0}'.format(e))
            EnhancedOutput.logging_warning("Exception in binaryGrinder {0}".format(e))
            return None

Example 30

Project: PytheM
Source File: bdf_proxy.py
View license
    def binaryGrinder(self, binaryFile):
        """
        Feed potential binaries into this function,
        it will return the result PatchedBinary, False, or None
        """
        with open(binaryFile, 'r+b') as f:
            binaryTMPHandle = f.read()

        binaryHeader = binaryTMPHandle[:4]
        result = None

        try:
            if binaryHeader[:2] == 'MZ':  # PE/COFF
                pe = pefile.PE(data=binaryTMPHandle, fast_load=True)
                magic = pe.OPTIONAL_HEADER.Magic
                machineType = pe.FILE_HEADER.Machine

                # update when supporting more than one arch
                if (magic == int('20B', 16) and machineType == 0x8664 and
                   self.WindowsType.lower() in ['all', 'x64']):
                    add_section = False
                    cave_jumping = False
                    if self.WindowsIntelx64['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx64['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx64['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx64['SHELL'],
                                             HOST=self.WindowsIntelx64['HOST'],
                                             PORT=int(self.WindowsIntelx64['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             RUNAS_ADMIN=self.str2bool(self.WindowsIntelx86['RUNAS_ADMIN']),
                                             PATCH_DLL=self.str2bool(self.WindowsIntelx64['PATCH_DLL']),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx64['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.str2bool(self.WindowsIntelx64['ZERO_CERT']),
                                             PATCH_METHOD=self.WindowsIntelx64['PATCH_METHOD'].lower(),
                                             SUPPLIED_BINARY=self.WindowsIntelx64['SUPPLIED_BINARY'],
                                             IDT_IN_CAVE=self.str2bool(self.WindowsIntelx64['IDT_IN_CAVE']),
                                             CODE_SIGN=self.str2bool(self.WindowsIntelx64['CODE_SIGN']),
                                             PREPROCESS=self.str2bool(self.WindowsIntelx64['PREPROCESS']),
                                             )

                    result = targetFile.run_this()

                elif (machineType == 0x14c and
                      self.WindowsType.lower() in ['all', 'x86']):
                    add_section = False
                    cave_jumping = False
                    # add_section wins for cave_jumping
                    # default is single for BDF
                    if self.WindowsIntelx86['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx86['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx86['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True
                        add_section = False

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx86['SHELL'],
                                             HOST=self.WindowsIntelx86['HOST'],
                                             PORT=int(self.WindowsIntelx86['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             RUNAS_ADMIN=self.str2bool(self.WindowsIntelx86['RUNAS_ADMIN']),
                                             PATCH_DLL=self.str2bool(self.WindowsIntelx86['PATCH_DLL']),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx86['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.str2bool(self.WindowsIntelx86['ZERO_CERT']),
                                             PATCH_METHOD=self.WindowsIntelx86['PATCH_METHOD'].lower(),
                                             SUPPLIED_BINARY=self.WindowsIntelx86['SUPPLIED_BINARY'],
                                             XP_MODE=self.str2bool(self.WindowsIntelx86['XP_MODE']),
                                             IDT_IN_CAVE=self.str2bool(self.WindowsIntelx86['IDT_IN_CAVE']),
                                             CODE_SIGN=self.str2bool(self.WindowsIntelx86['CODE_SIGN']),
                                             PREPROCESS=self.str2bool(self.WindowsIntelx86['PREPROCESS']),
                                             )

                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') == '7f454c46':  # ELF

                targetFile = elfbin.elfbin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                if targetFile.class_type == 0x1:
                    # x86CPU Type
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx86['SHELL'],
                                               HOST=self.LinuxIntelx86['HOST'],
                                               PORT=int(self.LinuxIntelx86['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx86['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType,
                                               PREPROCESS=self.str2bool(self.LinuxIntelx86['PREPROCESS']),
                                               )
                    result = targetFile.run_this()
                elif targetFile.class_type == 0x2:
                    # x64
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx64['SHELL'],
                                               HOST=self.LinuxIntelx64['HOST'],
                                               PORT=int(self.LinuxIntelx64['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx64['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType,
                                               PREPROCESS=self.str2bool(self.LinuxIntelx64['PREPROCESS']),
                                               )
                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') in ['cefaedfe', 'cffaedfe', 'cafebabe']:  # Macho
                targetFile = machobin.machobin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                # ONE CHIP SET MUST HAVE PRIORITY in FAT FILE

                if targetFile.FAT_FILE is True:
                    if self.FatPriority == 'x86':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx86['SHELL'],
                                                       HOST=self.MachoIntelx86['HOST'],
                                                       PORT=int(self.MachoIntelx86['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority,
                                                       PREPROCESS=self.str2bool(self.MachoIntelx86['PREPROCESS']),
                                                       )
                        result = targetFile.run_this()

                    elif self.FatPriority == 'x64':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx64['SHELL'],
                                                       HOST=self.MachoIntelx64['HOST'],
                                                       PORT=int(self.MachoIntelx64['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority,
                                                       PREPROCESS=self.str2bool(self.MachoIntelx64['PREPROCESS']),
                                                       )
                        result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x7':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx86['SHELL'],
                                                   HOST=self.MachoIntelx86['HOST'],
                                                   PORT=int(self.MachoIntelx86['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority,
                                                   PREPROCESS=self.str2bool(self.MachoIntelx86['PREPROCESS']),
                                                   )
                    result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x1000007':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx64['SHELL'],
                                                   HOST=self.MachoIntelx64['HOST'],
                                                   PORT=int(self.MachoIntelx64['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority,
                                                   PREPROCESS=self.str2bool(self.MachoIntelx64['PREPROCESS']),
                                                   )
                    result = targetFile.run_this()

            return result

        except Exception as e:
            EnhancedOutput.print_error('binaryGrinder: {0}'.format(e))
            EnhancedOutput.logging_warning("Exception in binaryGrinder {0}".format(e))
            return None

Example 31

Project: disk_perf_test_tool
Source File: fio.py
View license
    def do_run(self, node, barrier, fio_cfg, pos, nolog=False):
        if self.use_sudo:
            sudo = "sudo "
        else:
            sudo = ""

        bash_file = """
#!/bin/bash

function get_dev() {{
    if [ -b "$1" ] ; then
        echo $1
    else
        echo $(df "$1" | tail -1 | awk '{{print $1}}')
    fi
}}

function log_io_activiti(){{
    local dest="$1"
    local dev=$(get_dev "$2")
    local sleep_time="$3"
    dev=$(basename "$dev")

    echo $dev

    for (( ; ; )) ; do
        grep -E "\\b$dev\\b" /proc/diskstats >> "$dest"
        sleep $sleep_time
    done
}}

sync
cd {exec_folder}

log_io_activiti {io_log_file} {test_file} 1 &
local pid="$!"

{fio_path}fio --output-format=json --output={out_file} --alloc-size=262144 {job_file} >{err_out_file} 2>&1
echo $? >{res_code_file}
kill -9 $pid

"""

        exec_folder = self.config.remote_dir

        if self.use_system_fio:
            fio_path = ""
        else:
            if not exec_folder.endswith("/"):
                fio_path = exec_folder + "/"
            else:
                fio_path = exec_folder

        bash_file = bash_file.format(out_file=self.results_file,
                                     job_file=self.task_file,
                                     err_out_file=self.err_out_file,
                                     res_code_file=self.exit_code_file,
                                     exec_folder=exec_folder,
                                     fio_path=fio_path,
                                     test_file=self.config_params['FILENAME'],
                                     io_log_file=self.io_log_file).strip()

        with node.connection.open_sftp() as sftp:
            save_to_remote(sftp, self.task_file, str(fio_cfg))
            save_to_remote(sftp, self.sh_file, bash_file)

        exec_time = execution_time(fio_cfg)

        timeout = int(exec_time + max(300, exec_time))
        soft_tout = exec_time

        begin = time.time()

        fnames_before = run_on_node(node)("ls -1 " + exec_folder, nolog=True)

        barrier.wait()

        task = BGSSHTask(node, self.use_sudo)
        task.start(sudo + "bash " + self.sh_file)

        while True:
            try:
                task.wait(soft_tout, timeout)
                break
            except paramiko.SSHException:
                pass

            try:
                node.connection.close()
            except:
                pass

            reconnect(node.connection, node.conn_url)

        end = time.time()
        rossh = run_on_node(node)
        fnames_after = rossh("ls -1 " + exec_folder, nolog=True)

        conn_id = node.get_conn_id().replace(":", "_")
        if not nolog:
            logger.debug("Test on node {0} is finished".format(conn_id))

        log_files_pref = []
        if 'write_lat_log' in fio_cfg.vals:
            fname = fio_cfg.vals['write_lat_log']
            log_files_pref.append(fname + '_clat')
            log_files_pref.append(fname + '_lat')
            log_files_pref.append(fname + '_slat')

        if 'write_iops_log' in fio_cfg.vals:
            fname = fio_cfg.vals['write_iops_log']
            log_files_pref.append(fname + '_iops')

        if 'write_bw_log' in fio_cfg.vals:
            fname = fio_cfg.vals['write_bw_log']
            log_files_pref.append(fname + '_bw')

        files = collections.defaultdict(lambda: [])
        all_files = [os.path.basename(self.results_file)]
        new_files = set(fnames_after.split()) - set(fnames_before.split())

        for fname in new_files:
            if fname.endswith('.log') and fname.split('.')[0] in log_files_pref:
                name, _ = os.path.splitext(fname)
                if fname.count('.') == 1:
                    tp = name.split("_")[-1]
                    cnt = 0
                else:
                    tp_cnt = name.split("_")[-1]
                    tp, cnt = tp_cnt.split('.')
                files[tp].append((int(cnt), fname))
                all_files.append(fname)
            elif fname == os.path.basename(self.io_log_file):
                files['iops'].append(('sys', fname))
                all_files.append(fname)

        arch_name = self.join_remote('wally_result.tar.gz')
        tmp_dir = os.path.join(self.config.log_directory, 'tmp_' + conn_id)

        if os.path.exists(tmp_dir):
            shutil.rmtree(tmp_dir)

        os.mkdir(tmp_dir)
        loc_arch_name = os.path.join(tmp_dir, 'wally_result.{0}.tar.gz'.format(conn_id))
        file_full_names = " ".join(all_files)

        try:
            os.unlink(loc_arch_name)
        except:
            pass

        with node.connection.open_sftp() as sftp:
            try:
                exit_code = read_from_remote(sftp, self.exit_code_file)
            except IOError:
                logger.error("No exit code file found on %s. Looks like process failed to start",
                             conn_id)
                return None

            err_out = read_from_remote(sftp, self.err_out_file)
            exit_code = exit_code.strip()

            if exit_code != '0':
                msg = "fio exit with code {0}: {1}".format(exit_code, err_out)
                logger.critical(msg.strip())
                raise StopTestError("fio failed")

            rossh("rm -f {0}".format(arch_name), nolog=True)
            pack_files_cmd = "cd {0} ; tar zcvf {1} {2}".format(exec_folder, arch_name, file_full_names)
            rossh(pack_files_cmd, nolog=True)
            sftp.get(arch_name, loc_arch_name)

        unpack_files_cmd = "cd {0} ; tar xvzf {1} >/dev/null".format(tmp_dir, loc_arch_name)
        subprocess.check_call(unpack_files_cmd, shell=True)
        os.unlink(loc_arch_name)

        for ftype, fls in files.items():
            for idx, fname in fls:
                cname = os.path.join(tmp_dir, fname)
                loc_fname = "{0}_{1}_{2}.{3}.log".format(pos, conn_id, ftype, idx)
                loc_path = os.path.join(self.config.log_directory, loc_fname)
                os.rename(cname, loc_path)

        cname = os.path.join(tmp_dir,
                             os.path.basename(self.results_file))
        loc_fname = "{0}_{1}_rawres.json".format(pos, conn_id)
        loc_path = os.path.join(self.config.log_directory, loc_fname)
        os.rename(cname, loc_path)
        os.rmdir(tmp_dir)

        remove_remote_res_files_cmd = "cd {0} ; rm -f {1} {2}".format(exec_folder,
                                                                      arch_name,
                                                                      file_full_names)
        rossh(remove_remote_res_files_cmd, nolog=True)
        return begin, end

Example 32

Project: disk_perf_test_tool
Source File: fio.py
View license
    def do_run(self, node, barrier, fio_cfg, pos, nolog=False):
        if self.use_sudo:
            sudo = "sudo "
        else:
            sudo = ""

        bash_file = """
#!/bin/bash

function get_dev() {{
    if [ -b "$1" ] ; then
        echo $1
    else
        echo $(df "$1" | tail -1 | awk '{{print $1}}')
    fi
}}

function log_io_activiti(){{
    local dest="$1"
    local dev=$(get_dev "$2")
    local sleep_time="$3"
    dev=$(basename "$dev")

    echo $dev

    for (( ; ; )) ; do
        grep -E "\\b$dev\\b" /proc/diskstats >> "$dest"
        sleep $sleep_time
    done
}}

sync
cd {exec_folder}

log_io_activiti {io_log_file} {test_file} 1 &
local pid="$!"

{fio_path}fio --output-format=json --output={out_file} --alloc-size=262144 {job_file} >{err_out_file} 2>&1
echo $? >{res_code_file}
kill -9 $pid

"""

        exec_folder = self.config.remote_dir

        if self.use_system_fio:
            fio_path = ""
        else:
            if not exec_folder.endswith("/"):
                fio_path = exec_folder + "/"
            else:
                fio_path = exec_folder

        bash_file = bash_file.format(out_file=self.results_file,
                                     job_file=self.task_file,
                                     err_out_file=self.err_out_file,
                                     res_code_file=self.exit_code_file,
                                     exec_folder=exec_folder,
                                     fio_path=fio_path,
                                     test_file=self.config_params['FILENAME'],
                                     io_log_file=self.io_log_file).strip()

        with node.connection.open_sftp() as sftp:
            save_to_remote(sftp, self.task_file, str(fio_cfg))
            save_to_remote(sftp, self.sh_file, bash_file)

        exec_time = execution_time(fio_cfg)

        timeout = int(exec_time + max(300, exec_time))
        soft_tout = exec_time

        begin = time.time()

        fnames_before = run_on_node(node)("ls -1 " + exec_folder, nolog=True)

        barrier.wait()

        task = BGSSHTask(node, self.use_sudo)
        task.start(sudo + "bash " + self.sh_file)

        while True:
            try:
                task.wait(soft_tout, timeout)
                break
            except paramiko.SSHException:
                pass

            try:
                node.connection.close()
            except:
                pass

            reconnect(node.connection, node.conn_url)

        end = time.time()
        rossh = run_on_node(node)
        fnames_after = rossh("ls -1 " + exec_folder, nolog=True)

        conn_id = node.get_conn_id().replace(":", "_")
        if not nolog:
            logger.debug("Test on node {0} is finished".format(conn_id))

        log_files_pref = []
        if 'write_lat_log' in fio_cfg.vals:
            fname = fio_cfg.vals['write_lat_log']
            log_files_pref.append(fname + '_clat')
            log_files_pref.append(fname + '_lat')
            log_files_pref.append(fname + '_slat')

        if 'write_iops_log' in fio_cfg.vals:
            fname = fio_cfg.vals['write_iops_log']
            log_files_pref.append(fname + '_iops')

        if 'write_bw_log' in fio_cfg.vals:
            fname = fio_cfg.vals['write_bw_log']
            log_files_pref.append(fname + '_bw')

        files = collections.defaultdict(lambda: [])
        all_files = [os.path.basename(self.results_file)]
        new_files = set(fnames_after.split()) - set(fnames_before.split())

        for fname in new_files:
            if fname.endswith('.log') and fname.split('.')[0] in log_files_pref:
                name, _ = os.path.splitext(fname)
                if fname.count('.') == 1:
                    tp = name.split("_")[-1]
                    cnt = 0
                else:
                    tp_cnt = name.split("_")[-1]
                    tp, cnt = tp_cnt.split('.')
                files[tp].append((int(cnt), fname))
                all_files.append(fname)
            elif fname == os.path.basename(self.io_log_file):
                files['iops'].append(('sys', fname))
                all_files.append(fname)

        arch_name = self.join_remote('wally_result.tar.gz')
        tmp_dir = os.path.join(self.config.log_directory, 'tmp_' + conn_id)

        if os.path.exists(tmp_dir):
            shutil.rmtree(tmp_dir)

        os.mkdir(tmp_dir)
        loc_arch_name = os.path.join(tmp_dir, 'wally_result.{0}.tar.gz'.format(conn_id))
        file_full_names = " ".join(all_files)

        try:
            os.unlink(loc_arch_name)
        except:
            pass

        with node.connection.open_sftp() as sftp:
            try:
                exit_code = read_from_remote(sftp, self.exit_code_file)
            except IOError:
                logger.error("No exit code file found on %s. Looks like process failed to start",
                             conn_id)
                return None

            err_out = read_from_remote(sftp, self.err_out_file)
            exit_code = exit_code.strip()

            if exit_code != '0':
                msg = "fio exit with code {0}: {1}".format(exit_code, err_out)
                logger.critical(msg.strip())
                raise StopTestError("fio failed")

            rossh("rm -f {0}".format(arch_name), nolog=True)
            pack_files_cmd = "cd {0} ; tar zcvf {1} {2}".format(exec_folder, arch_name, file_full_names)
            rossh(pack_files_cmd, nolog=True)
            sftp.get(arch_name, loc_arch_name)

        unpack_files_cmd = "cd {0} ; tar xvzf {1} >/dev/null".format(tmp_dir, loc_arch_name)
        subprocess.check_call(unpack_files_cmd, shell=True)
        os.unlink(loc_arch_name)

        for ftype, fls in files.items():
            for idx, fname in fls:
                cname = os.path.join(tmp_dir, fname)
                loc_fname = "{0}_{1}_{2}.{3}.log".format(pos, conn_id, ftype, idx)
                loc_path = os.path.join(self.config.log_directory, loc_fname)
                os.rename(cname, loc_path)

        cname = os.path.join(tmp_dir,
                             os.path.basename(self.results_file))
        loc_fname = "{0}_{1}_rawres.json".format(pos, conn_id)
        loc_path = os.path.join(self.config.log_directory, loc_fname)
        os.rename(cname, loc_path)
        os.rmdir(tmp_dir)

        remove_remote_res_files_cmd = "cd {0} ; rm -f {1} {2}".format(exec_folder,
                                                                      arch_name,
                                                                      file_full_names)
        rossh(remove_remote_res_files_cmd, nolog=True)
        return begin, end

Example 33

Project: fedora-software
Source File: importcomponents.py
View license
    def handle(self, *args, **options):

        # check arguments
        if len(args) == 1:
            xml_file = args[0]
        elif len(args) == 0:
            try:
                # try to find the file in rpm database
                import rpm
                header = rpm.TransactionSet().dbMatch('name', 'appstream-data').next()
                xml_file = filter(lambda f: f[0].endswith('.xml.gz'), header.fiFromHeader())[0][0]
            except:
                if settings.DEBUG:
                    raise
                raise CommandError(
                    'Failed to find xml file provided by appstream-data package. '\
                    'Specify path to the file as an argument.\nType {} help importcomponents'.format(
                        os.path.basename(sys.argv[0])
                    ))
        elif len(args) > 1:
            raise CommandError('Invalid number of arguments.\nType {} help importcomponents'.format(
                os.path.basename(sys.argv[0])))

        logger.info('Reading %s' % xml_file)

        try:
            tree = ElementTree.fromstring(gzip.open(xml_file,'rb').read())
        except Exception as e:
            if settings.DEBUG:
                raise
            raise CommandError('Failed to read content of {xml_file}: {e}\nType {manage} help importcomponents'.format(
                xml_file = xml_file, e = e, manage = os.path.basename(sys.argv[0])
                ))

        component_nodes_count = len(tree)
        logger.info('Parsed {} component nodes'.format(component_nodes_count))

        errors = 0
        component_ids = []
        for c_node in tree:
            c_type      = 'unknown'
            c_type_id   = 'unknown'

            try:
                with transaction.atomic():
                    c_type      = c_node.attrib['type']
                    c_type_id   = c_node.find('id').text
                    c_pkgname   = c_node.find('pkgname').text
                    try:
                        c_project_license = c_node.find('project_license').text
                    except:
                        c_project_license = None

                    logger.info('Importing component {}/{} ({}/{})'.format(
                        c_type, c_type_id, len(component_ids)+1, component_nodes_count,
                    ))

                    # create component
                    c = Component.objects.get_or_create(
                        type            = c_type,
                        type_id         = c_type_id,
                        pkgname         = c_pkgname,
                        project_license = c_project_license,
                    )[0]

                    lang_attr = '{http://www.w3.org/XML/1998/namespace}lang'

                    # create names
                    c.names.all().delete()
                    for name_node in c_node.findall('name'):
                        c.names.add(ComponentName(
                            lang = name_node.attrib.get(lang_attr),
                            name = name_node.text,
                        ))

                    # create summaries
                    c.summaries.all().delete()
                    for summary_node in c_node.findall('summary'):
                        c.summaries.add(ComponentSummary(
                            lang = summary_node.attrib.get(lang_attr),
                            summary = summary_node.text,
                        ))

                    # create descriptions
                    c.descriptions.all().delete()
                    for description_node in c_node.findall('description'):
                        c.descriptions.add(ComponentDescription(
                            lang = description_node.attrib.get(lang_attr),
                            description = ElementTree.tostring(description_node, method="html"),
                        ))

                    # create icons
                    c.icons.all().delete()
                    for icon_node in c_node.findall('icon'):
                        c.icons.add(ComponentIcon(
                            icon    = icon_node.text,
                            type    = icon_node.attrib.get('type'),
                            height  = icon_node.attrib.get('height'),
                            width   = icon_node.attrib.get('width'),
                        ))

                    # create categories
                    c.categories.all().delete()
                    categories_node = c_node.find('categories')
                    if categories_node is not None:
                        for category_node in categories_node.findall('category'):
                            c.categories.add(Category.objects.get_or_create(
                                slug        = slugify(category_node.text),
                                category    = category_node.text,
                            )[0])

                    # create keywords
                    c.keywords.all().delete()
                    keywords_node = c_node.find('keywords')
                    if keywords_node is not None:
                        for keyword_node in keywords_node.findall('keyword'):
                            c.keywords.add(Keyword.objects.get_or_create(
                                lang    = keyword_node.attrib.get(lang_attr),
                                keyword = keyword_node.text,
                            )[0])

                    # create urls
                    c.urls.all().delete()
                    for url_node in c_node.findall('url'):
                        if url_node.text is not None:
                            c.urls.add(ComponentUrl(
                                url     = url_node.text,
                                type    = url_node.attrib.get('type'),
                            ))

                    # create screenshots
                    c.screenshots.all().delete()
                    screenshots_node = c_node.find('screenshots')
                    if screenshots_node is not None:
                        for screenshot_node in screenshots_node.findall('screenshot'):
                            screenshot = ComponentScreenshot(
                                type = screenshot_node.attrib.get('type'),
                            )
                            c.screenshots.add(screenshot)
                            for image_node in screenshot_node.findall('image'):
                                screenshot.images.add(ComponentScreenshotImage(
                                    image   = image_node.text,
                                    type    = image_node.attrib.get('type'),
                                    height  = image_node.attrib.get('height'),
                                    width   = image_node.attrib.get('width'),
                                ))

                    # create releases
                    c.releases.all().delete()
                    releases_node = c_node.find('releases')
                    if releases_node is not None:
                        for release_node in releases_node.findall('release'):
                            c.releases.add(ComponentRelease(
                                version     = release_node.attrib.get('version'),
                                timestamp   = datetime.utcfromtimestamp(
                                    int(release_node.attrib.get('timestamp'))
                                ).replace(tzinfo=utc)
                            ))

                    # create languages
                    c.languages.all().delete()
                    languages_node = c_node.find('languages')
                    if languages_node is not None:
                        for lang_node in languages_node.findall('lang'):
                            c.languages.add(ComponentLanguage(
                                percentage  = lang_node.attrib.get('percentage'),
                                lang        = lang_node.text,
                            ))

                    # create metadata
                    c.metadata.all().delete()
                    metadata_node = c_node.find('metadata')
                    if metadata_node is not None:
                        for value_node in metadata_node.findall('value'):
                            c.metadata.add(ComponentMetadata(
                                key     = value_node.attrib.get('key'),
                                value   = value_node.text,
                            ))

            except Exception as e:
                logger.error('Failed to import node {}/{}: {}'.format(c_type, c_type_id, e))
                if settings.DEBUG:
                    raise
                errors += 1
            else:
                component_ids.append(c.id)

        # check errors
        if errors > 0:
            raise CommandError('Failed to import components: {} error(s)'.format(errors))
        else:
            logger.info('Successfully imported {} components'.format(len(component_ids)))

        # delete stale components
        deleted_components_count = 0
        for c in Component.objects.all():
            if c.id not in component_ids:
                logger.info('Deleting stale component {}/{}'.format(c.type, c.type_id))
                c.delete()
                deleted_components_count += 1

        if deleted_components_count > 0:
            logger.info('Successfully deleted {} stale components'.format(deleted_components_count))

Example 34

Project: Arelle
Source File: saveLoadableExcel.py
View license
def saveLoadableExcel(dts, excelFile):
    from arelle import ModelDocument, XmlUtil
    from openpyxl import Workbook, cell
    from openpyxl.styles import Font, PatternFill, Border, Alignment, Color, fills, Side
    from openpyxl.worksheet.dimensions import ColumnDimension
    
    workbook = Workbook(encoding="utf-8")
    # remove pre-existing worksheets
    while len(workbook.worksheets)>0:
        workbook.remove_sheet(workbook.worksheets[0])
    conceptsWs = workbook.create_sheet(title="Concepts")
    dtsWs = workbook.create_sheet(title="DTS")
    
    # identify type of taxonomy
    conceptsWsHeaders = None
    cellFontArgs = None
    for doc in dts.urlDocs.values():
        if doc.type == ModelDocument.Type.SCHEMA and doc.inDTS:
            for i in range(len(headersStyles)):
                if re.match(headersStyles[i][0], doc.targetNamespace):
                    cellFontArgs = headersStyles[i][1] # use as arguments to Font()
                    conceptsWsHeaders = headersStyles[i][2]
                    break
    if conceptsWsHeaders is None:
        dts.info("error:saveLoadableExcel",
         _("Referenced taxonomy style not identified, assuming general pattern."),
         modelObject=dts)
        cellFontArgs = headersStyles[-1][1] # use as arguments to Font()
        conceptsWsHeaders = headersStyles[-1][2]
        
        
    hdrCellFont = Font(**cellFontArgs)
    hdrCellFill = PatternFill(patternType=fills.FILL_SOLID,
                              fgColor=Color("00FFBF5F")) # Excel's light orange fill color = 00FF990
    cellFont = Font(**cellFontArgs)

    def writeCell(ws,row,col,value,fontBold=False,borders=True,indent=0,hAlign=None,vAlign=None,hdr=False):
        cell = ws.cell(row=row,column=col)
        cell.value = value
        if hdr:
            cell.font = hdrCellFont
            cell.fill = hdrCellFill
            if not hAlign: hAlign = "center"
            if not vAlign: vAlign = "center"
        else:
            cell.font = cellFont
            if not hAlign: hAlign = "left"
            if not vAlign: vAlign = "top"
        if borders:
            cell.border = Border(top=Side(border_style="thin"),
                                 left=Side(border_style="thin"),
                                 right=Side(border_style="thin"),
                                 bottom=Side(border_style="thin"))
        cell.alignment = Alignment(horizontal=hAlign, vertical=vAlign, wrap_text=True, indent=indent)
            
    # sheet 1 col widths
    for i, hdr in enumerate(conceptsWsHeaders):
        colLetter = cell.get_column_letter(i+1)
        conceptsWs.column_dimensions[colLetter] = ColumnDimension(conceptsWs, customWidth=True)
        conceptsWs.column_dimensions[colLetter].width = headerWidths.get(hdr[1], 40)                                   
        
    # sheet 2 headers
    for i, hdr in enumerate(dtsWsHeaders):
        colLetter = cell.get_column_letter(i+1)
        dtsWs.column_dimensions[colLetter] = ColumnDimension(conceptsWs, customWidth=True)
        dtsWs.column_dimensions[colLetter].width = hdr[1]
        writeCell(dtsWs, 1, i+1, hdr[0], hdr=True)
        
    # referenced taxonomies
    conceptsRow = 1
    dtsRow = 3
    # identify extension schema
    extensionSchemaDoc = None
    if dts.modelDocument.type == ModelDocument.Type.SCHEMA:
        extensionSchemaDoc = dts.modelDocument
    elif dts.modelDocument.type == ModelDocument.Type.INSTANCE:
        for doc, docReference in dts.modelDocument.referencesDocument.items():
            if docReference.referenceType == "href":
                extensionSchemaDoc = doc
                break
    if extensionSchemaDoc is None:
        dts.info("error:saveLoadableExcel",
         _("Unable to identify extension taxonomy."),
         modelObject=dts)
        return
            
    for doc, docReference in extensionSchemaDoc.referencesDocument.items():
        if docReference.referenceType == "import" and doc.targetNamespace != XbrlConst.xbrli:
            writeCell(dtsWs, dtsRow, 1, "import") 
            writeCell(dtsWs, dtsRow, 2, "schema") 
            writeCell(dtsWs, dtsRow, 3, XmlUtil.xmlnsprefix(doc.xmlRootElement, doc.targetNamespace)) 
            writeCell(dtsWs, dtsRow, 4, doc.uri) 
            writeCell(dtsWs, dtsRow, 5, doc.targetNamespace) 
            dtsRow += 1
                
    dtsRow += 1
    
    doc = extensionSchemaDoc
    writeCell(dtsWs, dtsRow, 1, "extension") 
    writeCell(dtsWs, dtsRow, 2, "schema") 
    writeCell(dtsWs, dtsRow, 3, XmlUtil.xmlnsprefix(doc.xmlRootElement, doc.targetNamespace)) 
    writeCell(dtsWs, dtsRow, 4, os.path.basename(doc.uri)) 
    writeCell(dtsWs, dtsRow, 5, doc.targetNamespace) 
    dtsRow += 1

    for doc, docReference in extensionSchemaDoc.referencesDocument.items():
        if docReference.referenceType == "href" and doc.type == ModelDocument.Type.LINKBASE:
            linkbaseType = ""
            role = docReference.referringModelObject.get("{http://www.w3.org/1999/xlink}role") or ""
            if role.startswith("http://www.xbrl.org/2003/role/") and role.endswith("LinkbaseRef"):
                linkbaseType = os.path.basename(role)[0:-11]
            writeCell(dtsWs, dtsRow, 1, "extension") 
            writeCell(dtsWs, dtsRow, 2, "linkbase") 
            writeCell(dtsWs, dtsRow, 3, linkbaseType) 
            writeCell(dtsWs, dtsRow, 4, os.path.basename(doc.uri)) 
            writeCell(dtsWs, dtsRow, 5, "") 
            dtsRow += 1
            
    dtsRow += 1

    # extended link roles defined in this document
    for roleURI, roleTypes in sorted(dts.roleTypes.items(), 
                                     # sort on definition if any else URI
                                     key=lambda item: (item[1][0].definition or item[0])):
        for roleType in roleTypes:
            if roleType.modelDocument == extensionSchemaDoc:
                writeCell(dtsWs, dtsRow, 1, "extension") 
                writeCell(dtsWs, dtsRow, 2, "role") 
                writeCell(dtsWs, dtsRow, 3, "") 
                writeCell(dtsWs, dtsRow, 4, roleType.definition) 
                writeCell(dtsWs, dtsRow, 5, roleURI) 
                dtsRow += 1
                
    # tree walk recursive function
    def treeWalk(row, depth, concept, preferredLabel, arcrole, preRelSet, visited):
        if concept is not None:
            # calc parents
            calcRelSet = dts.relationshipSet(XbrlConst.summationItem, preRelSet.linkrole)
            calcRel = None
            for modelRel in calcRelSet.toModelObject(concept):
                calcRel = modelRel
                break
            for i, hdr in enumerate(conceptsWsHeaders):
                colType = hdr[1]
                value = ""
                if colType == "name":
                    value = str(concept.name)
                elif colType == "prefix" and concept.qname is not None:
                    value = concept.qname.prefix
                elif colType == "type" and concept.type is not None:
                    value = str(concept.type.qname)
                elif colType == "substitutionGroup":
                    value = str(concept.substitutionGroupQname)
                elif colType == "abstract":
                    value = "true" if concept.isAbstract else "false"
                elif colType == "nillable":
                    if concept.isNillable:
                        value = "true"
                elif colType == "periodType":
                    value = concept.periodType
                elif colType == "balance":
                    value = concept.balance
                elif colType == "label":
                    role = hdr[2]
                    lang = hdr[3]
                    if role == XbrlConst.standardLabel:
                        if "indented" in hdr:
                            roleUri = preferredLabel
                        elif "overridePreferred" in hdr:
                            if preferredLabel and preferredLabel != XbrlConst.standardLabel:
                                roleUri = role
                            else:
                                roleUri = "**no value**" # skip putting a value in this column
                        else:
                            roleUri = role
                    else:
                        roleUri = role
                    if roleUri != "**no value**":
                        value = concept.label(roleUri,
                                              linkroleHint=preRelSet.linkrole,
                                              lang=lang,
                                              fallbackToQname=(role == XbrlConst.standardLabel))
                elif colType == "preferredLabel" and preferredLabel:
                    if preferredLabel.startswith("http://www.xbrl.org/2003/role/"):
                        value = os.path.basename(preferredLabel)
                    else:
                        value = preferredLabel
                elif colType == "calculationParent" and calcRel is not None:
                    calcParent = calcRel.fromModelObject
                    if calcParent is not None:
                        value = str(calcParent.qname)
                elif colType == "calculationWeight" and calcRel is not None:
                    value = calcRel.weight
                elif colType == "depth":
                    value = depth
                if "indented" in hdr:
                    indent = min(depth, MAXINDENT)
                else:
                    indent = 0
                writeCell(conceptsWs, row, i+1, value, indent=indent)
            row += 1
            if concept not in visited:
                visited.add(concept)
                for modelRel in preRelSet.fromModelObject(concept):
                    if modelRel.toModelObject is not None:
                        row = treeWalk(row, depth + 1, modelRel.toModelObject, modelRel.preferredLabel, arcrole, preRelSet, visited)
                visited.remove(concept)
        return row
    
    # use presentation relationships for conceptsWs
    arcrole = XbrlConst.parentChild
    # sort URIs by definition
    linkroleUris = []
    relationshipSet = dts.relationshipSet(arcrole)
    if relationshipSet:
        for linkroleUri in relationshipSet.linkRoleUris:
            modelRoleTypes = dts.roleTypes.get(linkroleUri)
            if modelRoleTypes:
                roledefinition = (modelRoleTypes[0].genLabel(strip=True) or modelRoleTypes[0].definition or linkroleUri)                    
            else:
                roledefinition = linkroleUri
            linkroleUris.append((roledefinition, linkroleUri))
        linkroleUris.sort()
    
        # for each URI in definition order
        for roledefinition, linkroleUri in linkroleUris:
            # write linkrole
            writeCell(conceptsWs, conceptsRow, 1, (roledefinition or linkroleUri), borders=False)  # ELR has no boarders, just font specified
            conceptsRow += 1
            # write header row
            for i, hdr in enumerate(conceptsWsHeaders):
                writeCell(conceptsWs, conceptsRow, i+1, hdr[0], hdr=True)
            conceptsRow += 1
            # elr relationships for tree walk
            linkRelationshipSet = dts.relationshipSet(arcrole, linkroleUri)
            for rootConcept in linkRelationshipSet.rootConcepts:
                conceptsRow = treeWalk(conceptsRow, 0, rootConcept, None, arcrole, linkRelationshipSet, set())
            conceptsRow += 1 # double space rows between tables
    else:
        # write header row
        for i, hdr in enumerate(conceptsWsHeaders):
            writeCell(conceptsWs, conceptsRow, i, hdr[0], hdr=True)
        conceptsRow += 1
        # get lang
        lang = None
        for i, hdr in enumerate(conceptsWsHeaders):
            colType = hdr[1]
            if colType == "label":
                lang = hdr[3]
                if colType == "label":
                    role = hdr[2]
                    lang = hdr[3]
        lbls = defaultdict(list)        
        for concept in set(dts.qnameConcepts.values()): # may be twice if unqualified, with and without namespace
            lbls[concept.label(role,lang=lang)].append(concept.objectId())
        srtLbls = sorted(lbls.keys())
        excludedNamespaces = XbrlConst.ixbrlAll.union(
            (XbrlConst.xbrli, XbrlConst.link, XbrlConst.xlink, XbrlConst.xl,
             XbrlConst.xbrldt,
             XbrlConst.xhtml))
        for label in srtLbls:
            for objectId in lbls[label]:
                concept = dts.modelObject(objectId)
                if concept.modelDocument.targetNamespace not in excludedNamespaces:
                    for i, hdr in enumerate(conceptsWsHeaders):
                        colType = hdr[1]
                        value = ""
                        if colType == "name":
                            value = str(concept.qname.localName)
                        elif colType == "prefix":
                            value = concept.qname.prefix
                        elif colType == "type":
                            value = str(concept.type.qname)
                        elif colType == "substitutionGroup":
                            value = str(concept.substitutionGroupQname)
                        elif colType == "abstract":
                            value = "true" if concept.isAbstract else "false"
                        elif colType == "periodType":
                            value = concept.periodType
                        elif colType == "balance":
                            value = concept.balance
                        elif colType == "label":
                            role = hdr[2]
                            lang = hdr[3]
                            value = concept.label(role, lang=lang)
                        elif colType == "depth":
                            value = 0
                        if "indented" in hdr:
                            indent = min(0, MAXINDENT)
                        else:
                            indent = 0
                        writeCell(conceptsWs, conceptsRow, i, value, indent=indent) 
                    conceptsRow += 1
    
    try: 
        workbook.save(excelFile)
        dts.info("info:saveLoadableExcel",
            _("Saved Excel file: %(excelFile)s"), 
            excelFile=os.path.basename(excelFile),
            modelXbrl=dts)
    except Exception as ex:
        dts.error("exception:saveLoadableExcel",
            _("File saving exception: %(error)s"), error=ex,
            modelXbrl=dts)

Example 35

Project: Arelle
Source File: saveLoadableExcel.py
View license
def saveLoadableExcel(dts, excelFile):
    from arelle import ModelDocument, XmlUtil
    from openpyxl import Workbook, cell
    from openpyxl.styles import Font, PatternFill, Border, Alignment, Color, fills, Side
    from openpyxl.worksheet.dimensions import ColumnDimension
    
    workbook = Workbook(encoding="utf-8")
    # remove pre-existing worksheets
    while len(workbook.worksheets)>0:
        workbook.remove_sheet(workbook.worksheets[0])
    conceptsWs = workbook.create_sheet(title="Concepts")
    dtsWs = workbook.create_sheet(title="DTS")
    
    # identify type of taxonomy
    conceptsWsHeaders = None
    cellFontArgs = None
    for doc in dts.urlDocs.values():
        if doc.type == ModelDocument.Type.SCHEMA and doc.inDTS:
            for i in range(len(headersStyles)):
                if re.match(headersStyles[i][0], doc.targetNamespace):
                    cellFontArgs = headersStyles[i][1] # use as arguments to Font()
                    conceptsWsHeaders = headersStyles[i][2]
                    break
    if conceptsWsHeaders is None:
        dts.info("error:saveLoadableExcel",
         _("Referenced taxonomy style not identified, assuming general pattern."),
         modelObject=dts)
        cellFontArgs = headersStyles[-1][1] # use as arguments to Font()
        conceptsWsHeaders = headersStyles[-1][2]
        
        
    hdrCellFont = Font(**cellFontArgs)
    hdrCellFill = PatternFill(patternType=fills.FILL_SOLID,
                              fgColor=Color("00FFBF5F")) # Excel's light orange fill color = 00FF990
    cellFont = Font(**cellFontArgs)

    def writeCell(ws,row,col,value,fontBold=False,borders=True,indent=0,hAlign=None,vAlign=None,hdr=False):
        cell = ws.cell(row=row,column=col)
        cell.value = value
        if hdr:
            cell.font = hdrCellFont
            cell.fill = hdrCellFill
            if not hAlign: hAlign = "center"
            if not vAlign: vAlign = "center"
        else:
            cell.font = cellFont
            if not hAlign: hAlign = "left"
            if not vAlign: vAlign = "top"
        if borders:
            cell.border = Border(top=Side(border_style="thin"),
                                 left=Side(border_style="thin"),
                                 right=Side(border_style="thin"),
                                 bottom=Side(border_style="thin"))
        cell.alignment = Alignment(horizontal=hAlign, vertical=vAlign, wrap_text=True, indent=indent)
            
    # sheet 1 col widths
    for i, hdr in enumerate(conceptsWsHeaders):
        colLetter = cell.get_column_letter(i+1)
        conceptsWs.column_dimensions[colLetter] = ColumnDimension(conceptsWs, customWidth=True)
        conceptsWs.column_dimensions[colLetter].width = headerWidths.get(hdr[1], 40)                                   
        
    # sheet 2 headers
    for i, hdr in enumerate(dtsWsHeaders):
        colLetter = cell.get_column_letter(i+1)
        dtsWs.column_dimensions[colLetter] = ColumnDimension(conceptsWs, customWidth=True)
        dtsWs.column_dimensions[colLetter].width = hdr[1]
        writeCell(dtsWs, 1, i+1, hdr[0], hdr=True)
        
    # referenced taxonomies
    conceptsRow = 1
    dtsRow = 3
    # identify extension schema
    extensionSchemaDoc = None
    if dts.modelDocument.type == ModelDocument.Type.SCHEMA:
        extensionSchemaDoc = dts.modelDocument
    elif dts.modelDocument.type == ModelDocument.Type.INSTANCE:
        for doc, docReference in dts.modelDocument.referencesDocument.items():
            if docReference.referenceType == "href":
                extensionSchemaDoc = doc
                break
    if extensionSchemaDoc is None:
        dts.info("error:saveLoadableExcel",
         _("Unable to identify extension taxonomy."),
         modelObject=dts)
        return
            
    for doc, docReference in extensionSchemaDoc.referencesDocument.items():
        if docReference.referenceType == "import" and doc.targetNamespace != XbrlConst.xbrli:
            writeCell(dtsWs, dtsRow, 1, "import") 
            writeCell(dtsWs, dtsRow, 2, "schema") 
            writeCell(dtsWs, dtsRow, 3, XmlUtil.xmlnsprefix(doc.xmlRootElement, doc.targetNamespace)) 
            writeCell(dtsWs, dtsRow, 4, doc.uri) 
            writeCell(dtsWs, dtsRow, 5, doc.targetNamespace) 
            dtsRow += 1
                
    dtsRow += 1
    
    doc = extensionSchemaDoc
    writeCell(dtsWs, dtsRow, 1, "extension") 
    writeCell(dtsWs, dtsRow, 2, "schema") 
    writeCell(dtsWs, dtsRow, 3, XmlUtil.xmlnsprefix(doc.xmlRootElement, doc.targetNamespace)) 
    writeCell(dtsWs, dtsRow, 4, os.path.basename(doc.uri)) 
    writeCell(dtsWs, dtsRow, 5, doc.targetNamespace) 
    dtsRow += 1

    for doc, docReference in extensionSchemaDoc.referencesDocument.items():
        if docReference.referenceType == "href" and doc.type == ModelDocument.Type.LINKBASE:
            linkbaseType = ""
            role = docReference.referringModelObject.get("{http://www.w3.org/1999/xlink}role") or ""
            if role.startswith("http://www.xbrl.org/2003/role/") and role.endswith("LinkbaseRef"):
                linkbaseType = os.path.basename(role)[0:-11]
            writeCell(dtsWs, dtsRow, 1, "extension") 
            writeCell(dtsWs, dtsRow, 2, "linkbase") 
            writeCell(dtsWs, dtsRow, 3, linkbaseType) 
            writeCell(dtsWs, dtsRow, 4, os.path.basename(doc.uri)) 
            writeCell(dtsWs, dtsRow, 5, "") 
            dtsRow += 1
            
    dtsRow += 1

    # extended link roles defined in this document
    for roleURI, roleTypes in sorted(dts.roleTypes.items(), 
                                     # sort on definition if any else URI
                                     key=lambda item: (item[1][0].definition or item[0])):
        for roleType in roleTypes:
            if roleType.modelDocument == extensionSchemaDoc:
                writeCell(dtsWs, dtsRow, 1, "extension") 
                writeCell(dtsWs, dtsRow, 2, "role") 
                writeCell(dtsWs, dtsRow, 3, "") 
                writeCell(dtsWs, dtsRow, 4, roleType.definition) 
                writeCell(dtsWs, dtsRow, 5, roleURI) 
                dtsRow += 1
                
    # tree walk recursive function
    def treeWalk(row, depth, concept, preferredLabel, arcrole, preRelSet, visited):
        if concept is not None:
            # calc parents
            calcRelSet = dts.relationshipSet(XbrlConst.summationItem, preRelSet.linkrole)
            calcRel = None
            for modelRel in calcRelSet.toModelObject(concept):
                calcRel = modelRel
                break
            for i, hdr in enumerate(conceptsWsHeaders):
                colType = hdr[1]
                value = ""
                if colType == "name":
                    value = str(concept.name)
                elif colType == "prefix" and concept.qname is not None:
                    value = concept.qname.prefix
                elif colType == "type" and concept.type is not None:
                    value = str(concept.type.qname)
                elif colType == "substitutionGroup":
                    value = str(concept.substitutionGroupQname)
                elif colType == "abstract":
                    value = "true" if concept.isAbstract else "false"
                elif colType == "nillable":
                    if concept.isNillable:
                        value = "true"
                elif colType == "periodType":
                    value = concept.periodType
                elif colType == "balance":
                    value = concept.balance
                elif colType == "label":
                    role = hdr[2]
                    lang = hdr[3]
                    if role == XbrlConst.standardLabel:
                        if "indented" in hdr:
                            roleUri = preferredLabel
                        elif "overridePreferred" in hdr:
                            if preferredLabel and preferredLabel != XbrlConst.standardLabel:
                                roleUri = role
                            else:
                                roleUri = "**no value**" # skip putting a value in this column
                        else:
                            roleUri = role
                    else:
                        roleUri = role
                    if roleUri != "**no value**":
                        value = concept.label(roleUri,
                                              linkroleHint=preRelSet.linkrole,
                                              lang=lang,
                                              fallbackToQname=(role == XbrlConst.standardLabel))
                elif colType == "preferredLabel" and preferredLabel:
                    if preferredLabel.startswith("http://www.xbrl.org/2003/role/"):
                        value = os.path.basename(preferredLabel)
                    else:
                        value = preferredLabel
                elif colType == "calculationParent" and calcRel is not None:
                    calcParent = calcRel.fromModelObject
                    if calcParent is not None:
                        value = str(calcParent.qname)
                elif colType == "calculationWeight" and calcRel is not None:
                    value = calcRel.weight
                elif colType == "depth":
                    value = depth
                if "indented" in hdr:
                    indent = min(depth, MAXINDENT)
                else:
                    indent = 0
                writeCell(conceptsWs, row, i+1, value, indent=indent)
            row += 1
            if concept not in visited:
                visited.add(concept)
                for modelRel in preRelSet.fromModelObject(concept):
                    if modelRel.toModelObject is not None:
                        row = treeWalk(row, depth + 1, modelRel.toModelObject, modelRel.preferredLabel, arcrole, preRelSet, visited)
                visited.remove(concept)
        return row
    
    # use presentation relationships for conceptsWs
    arcrole = XbrlConst.parentChild
    # sort URIs by definition
    linkroleUris = []
    relationshipSet = dts.relationshipSet(arcrole)
    if relationshipSet:
        for linkroleUri in relationshipSet.linkRoleUris:
            modelRoleTypes = dts.roleTypes.get(linkroleUri)
            if modelRoleTypes:
                roledefinition = (modelRoleTypes[0].genLabel(strip=True) or modelRoleTypes[0].definition or linkroleUri)                    
            else:
                roledefinition = linkroleUri
            linkroleUris.append((roledefinition, linkroleUri))
        linkroleUris.sort()
    
        # for each URI in definition order
        for roledefinition, linkroleUri in linkroleUris:
            # write linkrole
            writeCell(conceptsWs, conceptsRow, 1, (roledefinition or linkroleUri), borders=False)  # ELR has no boarders, just font specified
            conceptsRow += 1
            # write header row
            for i, hdr in enumerate(conceptsWsHeaders):
                writeCell(conceptsWs, conceptsRow, i+1, hdr[0], hdr=True)
            conceptsRow += 1
            # elr relationships for tree walk
            linkRelationshipSet = dts.relationshipSet(arcrole, linkroleUri)
            for rootConcept in linkRelationshipSet.rootConcepts:
                conceptsRow = treeWalk(conceptsRow, 0, rootConcept, None, arcrole, linkRelationshipSet, set())
            conceptsRow += 1 # double space rows between tables
    else:
        # write header row
        for i, hdr in enumerate(conceptsWsHeaders):
            writeCell(conceptsWs, conceptsRow, i, hdr[0], hdr=True)
        conceptsRow += 1
        # get lang
        lang = None
        for i, hdr in enumerate(conceptsWsHeaders):
            colType = hdr[1]
            if colType == "label":
                lang = hdr[3]
                if colType == "label":
                    role = hdr[2]
                    lang = hdr[3]
        lbls = defaultdict(list)        
        for concept in set(dts.qnameConcepts.values()): # may be twice if unqualified, with and without namespace
            lbls[concept.label(role,lang=lang)].append(concept.objectId())
        srtLbls = sorted(lbls.keys())
        excludedNamespaces = XbrlConst.ixbrlAll.union(
            (XbrlConst.xbrli, XbrlConst.link, XbrlConst.xlink, XbrlConst.xl,
             XbrlConst.xbrldt,
             XbrlConst.xhtml))
        for label in srtLbls:
            for objectId in lbls[label]:
                concept = dts.modelObject(objectId)
                if concept.modelDocument.targetNamespace not in excludedNamespaces:
                    for i, hdr in enumerate(conceptsWsHeaders):
                        colType = hdr[1]
                        value = ""
                        if colType == "name":
                            value = str(concept.qname.localName)
                        elif colType == "prefix":
                            value = concept.qname.prefix
                        elif colType == "type":
                            value = str(concept.type.qname)
                        elif colType == "substitutionGroup":
                            value = str(concept.substitutionGroupQname)
                        elif colType == "abstract":
                            value = "true" if concept.isAbstract else "false"
                        elif colType == "periodType":
                            value = concept.periodType
                        elif colType == "balance":
                            value = concept.balance
                        elif colType == "label":
                            role = hdr[2]
                            lang = hdr[3]
                            value = concept.label(role, lang=lang)
                        elif colType == "depth":
                            value = 0
                        if "indented" in hdr:
                            indent = min(0, MAXINDENT)
                        else:
                            indent = 0
                        writeCell(conceptsWs, conceptsRow, i, value, indent=indent) 
                    conceptsRow += 1
    
    try: 
        workbook.save(excelFile)
        dts.info("info:saveLoadableExcel",
            _("Saved Excel file: %(excelFile)s"), 
            excelFile=os.path.basename(excelFile),
            modelXbrl=dts)
    except Exception as ex:
        dts.error("exception:saveLoadableExcel",
            _("File saving exception: %(error)s"), error=ex,
            modelXbrl=dts)

Example 36

Project: Arelle
Source File: DTS.py
View license
def checkFilingDTS(val, modelDocument, visited):
    global targetNamespaceDatePattern, efmFilenamePattern, roleTypePattern, arcroleTypePattern, \
            arcroleDefinitionPattern, namePattern, linkroleDefinitionBalanceIncomeSheet, \
            namespacesConflictPattern
    if targetNamespaceDatePattern is None:
        targetNamespaceDatePattern = re.compile(r"/([12][0-9]{3})-([01][0-9])-([0-3][0-9])|"
                                            r"/([12][0-9]{3})([01][0-9])([0-3][0-9])|")
        efmFilenamePattern = re.compile(r"^[a-z0-9][a-zA-Z0-9_\.\-]*(\.xsd|\.xml)$")
        roleTypePattern = re.compile(r"^.*/role/[^/\s]+$")
        arcroleTypePattern = re.compile(r"^.*/arcrole/[^/\s]+$")
        arcroleDefinitionPattern = re.compile(r"^.*[^\\s]+.*$")  # at least one non-whitespace character
        namePattern = re.compile("[][()*+?\\\\/^{}|@#%^=~`\"';:,<>&$\u00a3\u20ac]") # u20ac=Euro, u00a3=pound sterling 
        linkroleDefinitionBalanceIncomeSheet = re.compile(r"[^-]+-\s+Statement\s+-\s+.*(income|balance|financial\W+position)",
                                                          re.IGNORECASE)
        namespacesConflictPattern = re.compile(r"http://(xbrl\.us|fasb\.org|xbrl\.sec\.gov)/(dei|us-types|us-roles|rr)/([0-9]{4}-[0-9]{2}-[0-9]{2})$")
        
    visited.append(modelDocument)
    for referencedDocument, modelDocumentReference in modelDocument.referencesDocument.items():
        #6.07.01 no includes
        if modelDocumentReference.referenceType == "include":
            val.modelXbrl.error("SBR.NL.2.2.0.18",
                _("Taxonomy schema %(schema)s includes %(include)s, only import is allowed"),
                modelObject=modelDocumentReference.referringModelObject,
                    schema=os.path.basename(modelDocument.uri), 
                    include=os.path.basename(referencedDocument.uri))
        if referencedDocument not in visited:
            checkFilingDTS(val, referencedDocument, visited)
            
    if val.disclosureSystem.standardTaxonomiesDict is None:
        pass

    if (modelDocument.type == ModelDocument.Type.SCHEMA and 
        modelDocument.targetNamespace not in val.disclosureSystem.baseTaxonomyNamespaces and
        modelDocument.uri.startswith(val.modelXbrl.uriDir)):
        
        # check schema contents types
        definesLinkroles = False
        definesArcroles = False
        definesLinkParts = False
        definesAbstractItems = False
        definesNonabstractItems = False
        definesConcepts = False
        definesTuples = False
        definesPresentationTuples = False
        definesSpecificationTuples = False
        definesTypes = False
        definesEnumerations = False
        definesDimensions = False
        definesDomains = False
        definesHypercubes = False
                
        genrlSpeclRelSet = val.modelXbrl.relationshipSet(XbrlConst.generalSpecial)
        for modelConcept in modelDocument.xmlRootElement.iterdescendants(tag="{http://www.w3.org/2001/XMLSchema}element"):
            if isinstance(modelConcept,ModelConcept):
                # 6.7.16 name not duplicated in standard taxonomies
                name = modelConcept.get("name")
                if name is None: 
                    name = ""
                    if modelConcept.get("ref") is not None:
                        continue    # don't validate ref's here
                for c in val.modelXbrl.nameConcepts.get(name, []):
                    if c.modelDocument != modelDocument:
                        if not (genrlSpeclRelSet.isRelated(modelConcept, "child", c) or genrlSpeclRelSet.isRelated(c, "child", modelConcept)):
                            val.modelXbrl.error("SBR.NL.2.2.2.02",
                                _("Concept %(concept)s is also defined in standard taxonomy schema %(standardSchema)s without a general-special relationship"),
                                modelObject=c, concept=modelConcept.qname, standardSchema=os.path.basename(c.modelDocument.uri))
                ''' removed RH 2011-12-23 corresponding set up of table in ValidateFiling
                if val.validateSBRNL and name in val.nameWordsTable:
                    if not any( any( genrlSpeclRelSet.isRelated(c, "child", modelConcept)
                                     for c in val.modelXbrl.nameConcepts.get(partialWordName, []))
                                for partialWordName in val.nameWordsTable[name]):
                        val.modelXbrl.error("SBR.NL.2.3.2.01",
                            _("Concept %(specialName)s is appears to be missing a general-special relationship to %(generalNames)s"),
                            modelObject=c, specialName=modelConcept.qname, generalNames=', or to '.join(val.nameWordsTable[name]))
                '''

                if modelConcept.isTuple:
                    if modelConcept.substitutionGroupQname.localName == "presentationTuple" and modelConcept.substitutionGroupQname.namespaceURI.endswith("/basis/sbr/xbrl/xbrl-syntax-extension"): # namespace may change each year
                        definesPresentationTuples = True
                    elif modelConcept.substitutionGroupQname.localName == "specificationTuple" and modelConcept.substitutionGroupQname.namespaceURI.endswith("/basis/sbr/xbrl/xbrl-syntax-extension"): # namespace may change each year
                        definesSpecificationTuples = True
                    else:
                        definesTuples = True
                    definesConcepts = True
                    if modelConcept.isAbstract:
                        val.modelXbrl.error("SBR.NL.2.2.2.03",
                            _("Concept %(concept)s is an abstract tuple"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if tupleCycle(val,modelConcept):
                        val.modelXbrl.error("SBR.NL.2.2.2.07",
                            _("Tuple %(concept)s has a tuple cycle"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if modelConcept.get("nillable") != "false" and modelConcept.isRoot:
                        val.modelXbrl.error("SBR.NL.2.2.2.17", #don't want default, just what was really there
                            _("Tuple %(concept)s must have nillable='false'"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                elif modelConcept.isItem:
                    definesConcepts = True
                if modelConcept.abstract == "true":
                    if modelConcept.isRoot:
                        if modelConcept.get("nillable") != "false": #don't want default, just what was really there
                            val.modelXbrl.error("SBR.NL.2.2.2.16",
                                _("Abstract root concept %(concept)s must have nillable='false'"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                        if modelConcept.typeQname != XbrlConst.qnXbrliStringItemType:
                            val.modelXbrl.error("SBR.NL.2.2.2.21",
                                _("Abstract root concept %(concept)s must have type='xbrli:stringItemType'"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if modelConcept.balance:
                        val.modelXbrl.error("SBR.NL.2.2.2.22",
                            _("Abstract concept %(concept)s must not have a balance attribute"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if modelConcept.isHypercubeItem:
                        definesHypercubes = True
                    elif modelConcept.isDimensionItem:
                        definesDimensions = True
                    elif modelConcept.substitutionGroupQname and modelConcept.substitutionGroupQname.localName in ("domainItem","domainMemberItem"):
                        definesDomains = True
                    elif modelConcept.isItem:
                        definesAbstractItems = True
                else:   # not abstract
                    if modelConcept.isItem:
                        definesNonabstractItems = True
                        if not (modelConcept.label(preferredLabel=XbrlConst.documentationLabel,fallbackToQname=False,lang="nl") or
                                val.modelXbrl.relationshipSet(XbrlConst.conceptReference).fromModelObject(c) or
                                modelConcept.genLabel(role=XbrlConst.genDocumentationLabel,lang="nl") or
                                val.modelXbrl.relationshipSet(XbrlConst.elementReference).fromModelObject(c)):
                            val.modelXbrl.error("SBR.NL.2.2.2.28",
                                _("Concept %(concept)s must have a documentation label or reference"),
                                modelObject=modelConcept, concept=modelConcept.qname)
                if modelConcept.balance and not modelConcept.instanceOfType(XbrlConst.qnXbrliMonetaryItemType):
                    val.modelXbrl.error("SBR.NL.2.2.2.24",
                        _("Non-monetary concept %(concept)s must not have a balance attribute"),
                        modelObject=modelConcept, concept=modelConcept.qname)
                if modelConcept.isLinkPart:
                    definesLinkParts = True
                    val.modelXbrl.error("SBR.NL.2.2.5.01",
                        _("Link:part concept %(concept)s is not allowed"),
                        modelObject=modelConcept, concept=modelConcept.qname)
                    if not modelConcept.genLabel(fallbackToQname=False,lang="nl"):
                        val.modelXbrl.error("SBR.NL.2.2.5.02",
                            _("Link part definition %(concept)s must have a generic label in language 'nl'"),
                            modelObject=modelConcept, concept=modelConcept.qname)

        # 6.7.9 role types authority
        for e in modelDocument.xmlRootElement.iterdescendants(tag="{http://www.xbrl.org/2003/linkbase}roleType"):
            if isinstance(e,ModelObject):
                roleURI = e.get("roleURI")
                # 6.7.10 only one role type declaration in DTS
                modelRoleTypes = val.modelXbrl.roleTypes.get(roleURI)
                if modelRoleTypes is not None:
                    modelRoleType = modelRoleTypes[0]
                    definition = modelRoleType.definitionNotStripped
                    usedOns = modelRoleType.usedOns
                    if usedOns & XbrlConst.standardExtLinkQnames or XbrlConst.qnGenLink in usedOns:
                        definesLinkroles = True
                        if not e.genLabel():
                            val.modelXbrl.error("SBR.NL.2.2.3.03",
                                _("Link RoleType %(roleType)s missing a generic standard label"),
                                modelObject=e, roleType=roleURI)
                        nlLabel = e.genLabel(lang="nl")
                        if definition != nlLabel:
                            val.modelXbrl.error("SBR.NL.2.2.3.04",
                                _("Link RoleType %(roleType)s definition does not match NL standard generic label, \ndefinition: %(definition)s \nNL label: %(label)s"),
                                modelObject=e, roleType=roleURI, definition=definition, label=nlLabel)
                    if definition and (definition[0].isspace() or definition[-1].isspace()):
                        val.modelXbrl.error("SBR.NL.2.2.3.07",
                            _('Link RoleType %(roleType)s definition has leading or trailing spaces: "%(definition)s"'),
                            modelObject=e, roleType=roleURI, definition=definition)

        # 6.7.13 arcrole types authority
        for e in modelDocument.xmlRootElement.iterdescendants(tag="{http://www.xbrl.org/2003/linkbase}arcroleType"):
            if isinstance(e,ModelObject):
                arcroleURI = e.get("arcroleURI")
                definesArcroles = True
                val.modelXbrl.error("SBR.NL.2.2.4.01",
                    _("Arcrole type definition is not allowed: %(arcroleURI)s"),
                    modelObject=e, arcroleURI=arcroleURI)
                    
        for appinfoElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}appinfo"):
            for nonLinkElt in appinfoElt.iterdescendants():
                if isinstance(nonLinkElt, ModelObject) and nonLinkElt.namespaceURI != XbrlConst.link:
                    val.modelXbrl.error("SBR.NL.2.2.11.05",
                        _("Appinfo contains disallowed non-link element %(element)s"),
                        modelObject=nonLinkElt, element=nonLinkElt.qname)

        for cplxTypeElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}complexType"):
            choiceElt = cplxTypeElt.find("{http://www.w3.org/2001/XMLSchema}choice")
            if choiceElt is not None:
                val.modelXbrl.error("SBR.NL.2.2.11.09",
                    _("ComplexType contains disallowed xs:choice element"),
                    modelObject=choiceElt)
                
        for cplxContentElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}complexContent"):
            if XmlUtil.descendantAttr(cplxContentElt, "http://www.w3.org/2001/XMLSchema", ("extension","restriction"), "base") != "sbr:placeholder":
                val.modelXbrl.error("SBR.NL.2.2.11.10",
                    _("ComplexContent is disallowed"),
                    modelObject=cplxContentElt)

        for typeEltTag in ("{http://www.w3.org/2001/XMLSchema}complexType",
                            "{http://www.w3.org/2001/XMLSchema}simpleType"):
            for typeElt in modelDocument.xmlRootElement.iter(tag=typeEltTag):
                definesTypes = True
                name = typeElt.get("name")
                if name:
                    if not name[0].islower() or not name.isalnum():
                        val.modelXbrl.error("SBR.NL.3.2.8.09",
                            _("Type name attribute must be lower camelcase: %(name)s."),
                            modelObject=typeElt, name=name)
        
        for enumElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}enumeration"):
            definesEnumerations = True
            if any(not valueElt.genLabel(lang="nl")
                   for valueElt in enumElt.iter(tag="{http://www.w3.org/2001/XMLSchema}value")):
                val.modelXbrl.error("SBR.NL.2.2.7.05",
                    _("Enumeration element has value(s) without generic label."),
                    modelObject=enumElt)

        if (definesLinkroles + definesArcroles + definesLinkParts +
            definesAbstractItems + definesNonabstractItems + 
            definesTuples + definesPresentationTuples + definesSpecificationTuples + definesTypes +
            definesEnumerations + definesDimensions + definesDomains + 
            definesHypercubes) != 1:
            schemaContents = []
            if definesLinkroles: schemaContents.append(_("linkroles"))
            if definesArcroles: schemaContents.append(_("arcroles"))
            if definesLinkParts: schemaContents.append(_("link parts"))
            if definesAbstractItems: schemaContents.append(_("abstract items"))
            if definesNonabstractItems: schemaContents.append(_("nonabstract items"))
            if definesTuples: schemaContents.append(_("tuples"))
            if definesPresentationTuples: schemaContents.append(_("sbrPresentationTuples"))
            if definesSpecificationTuples: schemaContents.append(_("sbrSpecificationTuples"))
            if definesTypes: schemaContents.append(_("types"))
            if definesEnumerations: schemaContents.append(_("enumerations"))
            if definesDimensions: schemaContents.append(_("dimensions"))
            if definesDomains: schemaContents.append(_("domains"))
            if definesHypercubes: schemaContents.append(_("hypercubes"))
            if schemaContents:
                if not ((definesTuples or definesPresentationTuples or definesSpecificationTuples) and
                        not (definesLinkroles or definesArcroles or definesLinkParts or definesAbstractItems or
                             definesTypes or definesDimensions or definesDomains or definesHypercubes)):
                    val.modelXbrl.error("SBR.NL.2.2.1.01",
                        _("Taxonomy schema may only define one of these: %(contents)s"),
                        modelObject=modelDocument, contents=', '.join(schemaContents))
            elif not any(refDoc.inDTS and refDoc.targetNamespace not in val.disclosureSystem.baseTaxonomyNamespaces
                         for refDoc in modelDocument.referencesDocument.keys()): # no linkbase ref or includes
                val.modelXbrl.error("SBR.NL.2.2.1.01",
                    _("Taxonomy schema must be a DTS entrypoint OR define linkroles OR arcroles OR link:parts OR context fragments OR abstract items OR tuples OR non-abstract elements OR types OR enumerations OR dimensions OR domains OR hypercubes"),
                    modelObject=modelDocument)
        if definesConcepts ^ any(  # xor so either concepts and no label LB or no concepts and has label LB
                   (refDoc.type == ModelDocument.Type.LINKBASE and
                    XmlUtil.descendant(refDoc.xmlRootElement, XbrlConst.link, "labelLink") is not None)
                   for refDoc in modelDocument.referencesDocument.keys()): # no label linkbase
            val.modelXbrl.error("SBR.NL.2.2.1.02",
                _("A schema that defines concepts MUST have a linked 2.1 label linkbase"),
                modelObject=modelDocument)
        if (definesNonabstractItems or definesTuples) and not any(  # was xor but changed to and not per RH 1/11/12
                   (refDoc.type == ModelDocument.Type.LINKBASE and
                   (XmlUtil.descendant(refDoc.xmlRootElement, XbrlConst.link, "referenceLink") is not None or
                    XmlUtil.descendant(refDoc.xmlRootElement, XbrlConst.link, "label", "{http://www.w3.org/1999/xlink}role", "http://www.xbrl.org/2003/role/documentation" ) is not None))
                    for refDoc in modelDocument.referencesDocument.keys()):
            val.modelXbrl.error("SBR.NL.2.2.1.03",
                _("A schema that defines non-abstract items MUST have a linked (2.1) reference linkbase AND/OR a label linkbase with @xlink:role=documentation"),
                modelObject=modelDocument)

    elif modelDocument.type == ModelDocument.Type.LINKBASE:
        pass
    visited.remove(modelDocument)

Example 37

Project: Arelle
Source File: DTS.py
View license
def checkFilingDTS(val, modelDocument, visited):
    global targetNamespaceDatePattern, efmFilenamePattern, roleTypePattern, arcroleTypePattern, \
            arcroleDefinitionPattern, namePattern, linkroleDefinitionBalanceIncomeSheet, \
            namespacesConflictPattern
    if targetNamespaceDatePattern is None:
        targetNamespaceDatePattern = re.compile(r"/([12][0-9]{3})-([01][0-9])-([0-3][0-9])|"
                                            r"/([12][0-9]{3})([01][0-9])([0-3][0-9])|")
        efmFilenamePattern = re.compile(r"^[a-z0-9][a-zA-Z0-9_\.\-]*(\.xsd|\.xml)$")
        roleTypePattern = re.compile(r"^.*/role/[^/\s]+$")
        arcroleTypePattern = re.compile(r"^.*/arcrole/[^/\s]+$")
        arcroleDefinitionPattern = re.compile(r"^.*[^\\s]+.*$")  # at least one non-whitespace character
        namePattern = re.compile("[][()*+?\\\\/^{}|@#%^=~`\"';:,<>&$\u00a3\u20ac]") # u20ac=Euro, u00a3=pound sterling 
        linkroleDefinitionBalanceIncomeSheet = re.compile(r"[^-]+-\s+Statement\s+-\s+.*(income|balance|financial\W+position)",
                                                          re.IGNORECASE)
        namespacesConflictPattern = re.compile(r"http://(xbrl\.us|fasb\.org|xbrl\.sec\.gov)/(dei|us-types|us-roles|rr)/([0-9]{4}-[0-9]{2}-[0-9]{2})$")
        
    visited.append(modelDocument)
    for referencedDocument, modelDocumentReference in modelDocument.referencesDocument.items():
        #6.07.01 no includes
        if modelDocumentReference.referenceType == "include":
            val.modelXbrl.error("SBR.NL.2.2.0.18",
                _("Taxonomy schema %(schema)s includes %(include)s, only import is allowed"),
                modelObject=modelDocumentReference.referringModelObject,
                    schema=os.path.basename(modelDocument.uri), 
                    include=os.path.basename(referencedDocument.uri))
        if referencedDocument not in visited:
            checkFilingDTS(val, referencedDocument, visited)
            
    if val.disclosureSystem.standardTaxonomiesDict is None:
        pass

    if (modelDocument.type == ModelDocument.Type.SCHEMA and 
        modelDocument.targetNamespace not in val.disclosureSystem.baseTaxonomyNamespaces and
        modelDocument.uri.startswith(val.modelXbrl.uriDir)):
        
        # check schema contents types
        definesLinkroles = False
        definesArcroles = False
        definesLinkParts = False
        definesAbstractItems = False
        definesNonabstractItems = False
        definesConcepts = False
        definesTuples = False
        definesPresentationTuples = False
        definesSpecificationTuples = False
        definesTypes = False
        definesEnumerations = False
        definesDimensions = False
        definesDomains = False
        definesHypercubes = False
                
        genrlSpeclRelSet = val.modelXbrl.relationshipSet(XbrlConst.generalSpecial)
        for modelConcept in modelDocument.xmlRootElement.iterdescendants(tag="{http://www.w3.org/2001/XMLSchema}element"):
            if isinstance(modelConcept,ModelConcept):
                # 6.7.16 name not duplicated in standard taxonomies
                name = modelConcept.get("name")
                if name is None: 
                    name = ""
                    if modelConcept.get("ref") is not None:
                        continue    # don't validate ref's here
                for c in val.modelXbrl.nameConcepts.get(name, []):
                    if c.modelDocument != modelDocument:
                        if not (genrlSpeclRelSet.isRelated(modelConcept, "child", c) or genrlSpeclRelSet.isRelated(c, "child", modelConcept)):
                            val.modelXbrl.error("SBR.NL.2.2.2.02",
                                _("Concept %(concept)s is also defined in standard taxonomy schema %(standardSchema)s without a general-special relationship"),
                                modelObject=c, concept=modelConcept.qname, standardSchema=os.path.basename(c.modelDocument.uri))
                ''' removed RH 2011-12-23 corresponding set up of table in ValidateFiling
                if val.validateSBRNL and name in val.nameWordsTable:
                    if not any( any( genrlSpeclRelSet.isRelated(c, "child", modelConcept)
                                     for c in val.modelXbrl.nameConcepts.get(partialWordName, []))
                                for partialWordName in val.nameWordsTable[name]):
                        val.modelXbrl.error("SBR.NL.2.3.2.01",
                            _("Concept %(specialName)s is appears to be missing a general-special relationship to %(generalNames)s"),
                            modelObject=c, specialName=modelConcept.qname, generalNames=', or to '.join(val.nameWordsTable[name]))
                '''

                if modelConcept.isTuple:
                    if modelConcept.substitutionGroupQname.localName == "presentationTuple" and modelConcept.substitutionGroupQname.namespaceURI.endswith("/basis/sbr/xbrl/xbrl-syntax-extension"): # namespace may change each year
                        definesPresentationTuples = True
                    elif modelConcept.substitutionGroupQname.localName == "specificationTuple" and modelConcept.substitutionGroupQname.namespaceURI.endswith("/basis/sbr/xbrl/xbrl-syntax-extension"): # namespace may change each year
                        definesSpecificationTuples = True
                    else:
                        definesTuples = True
                    definesConcepts = True
                    if modelConcept.isAbstract:
                        val.modelXbrl.error("SBR.NL.2.2.2.03",
                            _("Concept %(concept)s is an abstract tuple"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if tupleCycle(val,modelConcept):
                        val.modelXbrl.error("SBR.NL.2.2.2.07",
                            _("Tuple %(concept)s has a tuple cycle"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if modelConcept.get("nillable") != "false" and modelConcept.isRoot:
                        val.modelXbrl.error("SBR.NL.2.2.2.17", #don't want default, just what was really there
                            _("Tuple %(concept)s must have nillable='false'"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                elif modelConcept.isItem:
                    definesConcepts = True
                if modelConcept.abstract == "true":
                    if modelConcept.isRoot:
                        if modelConcept.get("nillable") != "false": #don't want default, just what was really there
                            val.modelXbrl.error("SBR.NL.2.2.2.16",
                                _("Abstract root concept %(concept)s must have nillable='false'"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                        if modelConcept.typeQname != XbrlConst.qnXbrliStringItemType:
                            val.modelXbrl.error("SBR.NL.2.2.2.21",
                                _("Abstract root concept %(concept)s must have type='xbrli:stringItemType'"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if modelConcept.balance:
                        val.modelXbrl.error("SBR.NL.2.2.2.22",
                            _("Abstract concept %(concept)s must not have a balance attribute"),
                            modelObject=modelConcept, concept=modelConcept.qname)
                    if modelConcept.isHypercubeItem:
                        definesHypercubes = True
                    elif modelConcept.isDimensionItem:
                        definesDimensions = True
                    elif modelConcept.substitutionGroupQname and modelConcept.substitutionGroupQname.localName in ("domainItem","domainMemberItem"):
                        definesDomains = True
                    elif modelConcept.isItem:
                        definesAbstractItems = True
                else:   # not abstract
                    if modelConcept.isItem:
                        definesNonabstractItems = True
                        if not (modelConcept.label(preferredLabel=XbrlConst.documentationLabel,fallbackToQname=False,lang="nl") or
                                val.modelXbrl.relationshipSet(XbrlConst.conceptReference).fromModelObject(c) or
                                modelConcept.genLabel(role=XbrlConst.genDocumentationLabel,lang="nl") or
                                val.modelXbrl.relationshipSet(XbrlConst.elementReference).fromModelObject(c)):
                            val.modelXbrl.error("SBR.NL.2.2.2.28",
                                _("Concept %(concept)s must have a documentation label or reference"),
                                modelObject=modelConcept, concept=modelConcept.qname)
                if modelConcept.balance and not modelConcept.instanceOfType(XbrlConst.qnXbrliMonetaryItemType):
                    val.modelXbrl.error("SBR.NL.2.2.2.24",
                        _("Non-monetary concept %(concept)s must not have a balance attribute"),
                        modelObject=modelConcept, concept=modelConcept.qname)
                if modelConcept.isLinkPart:
                    definesLinkParts = True
                    val.modelXbrl.error("SBR.NL.2.2.5.01",
                        _("Link:part concept %(concept)s is not allowed"),
                        modelObject=modelConcept, concept=modelConcept.qname)
                    if not modelConcept.genLabel(fallbackToQname=False,lang="nl"):
                        val.modelXbrl.error("SBR.NL.2.2.5.02",
                            _("Link part definition %(concept)s must have a generic label in language 'nl'"),
                            modelObject=modelConcept, concept=modelConcept.qname)

        # 6.7.9 role types authority
        for e in modelDocument.xmlRootElement.iterdescendants(tag="{http://www.xbrl.org/2003/linkbase}roleType"):
            if isinstance(e,ModelObject):
                roleURI = e.get("roleURI")
                # 6.7.10 only one role type declaration in DTS
                modelRoleTypes = val.modelXbrl.roleTypes.get(roleURI)
                if modelRoleTypes is not None:
                    modelRoleType = modelRoleTypes[0]
                    definition = modelRoleType.definitionNotStripped
                    usedOns = modelRoleType.usedOns
                    if usedOns & XbrlConst.standardExtLinkQnames or XbrlConst.qnGenLink in usedOns:
                        definesLinkroles = True
                        if not e.genLabel():
                            val.modelXbrl.error("SBR.NL.2.2.3.03",
                                _("Link RoleType %(roleType)s missing a generic standard label"),
                                modelObject=e, roleType=roleURI)
                        nlLabel = e.genLabel(lang="nl")
                        if definition != nlLabel:
                            val.modelXbrl.error("SBR.NL.2.2.3.04",
                                _("Link RoleType %(roleType)s definition does not match NL standard generic label, \ndefinition: %(definition)s \nNL label: %(label)s"),
                                modelObject=e, roleType=roleURI, definition=definition, label=nlLabel)
                    if definition and (definition[0].isspace() or definition[-1].isspace()):
                        val.modelXbrl.error("SBR.NL.2.2.3.07",
                            _('Link RoleType %(roleType)s definition has leading or trailing spaces: "%(definition)s"'),
                            modelObject=e, roleType=roleURI, definition=definition)

        # 6.7.13 arcrole types authority
        for e in modelDocument.xmlRootElement.iterdescendants(tag="{http://www.xbrl.org/2003/linkbase}arcroleType"):
            if isinstance(e,ModelObject):
                arcroleURI = e.get("arcroleURI")
                definesArcroles = True
                val.modelXbrl.error("SBR.NL.2.2.4.01",
                    _("Arcrole type definition is not allowed: %(arcroleURI)s"),
                    modelObject=e, arcroleURI=arcroleURI)
                    
        for appinfoElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}appinfo"):
            for nonLinkElt in appinfoElt.iterdescendants():
                if isinstance(nonLinkElt, ModelObject) and nonLinkElt.namespaceURI != XbrlConst.link:
                    val.modelXbrl.error("SBR.NL.2.2.11.05",
                        _("Appinfo contains disallowed non-link element %(element)s"),
                        modelObject=nonLinkElt, element=nonLinkElt.qname)

        for cplxTypeElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}complexType"):
            choiceElt = cplxTypeElt.find("{http://www.w3.org/2001/XMLSchema}choice")
            if choiceElt is not None:
                val.modelXbrl.error("SBR.NL.2.2.11.09",
                    _("ComplexType contains disallowed xs:choice element"),
                    modelObject=choiceElt)
                
        for cplxContentElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}complexContent"):
            if XmlUtil.descendantAttr(cplxContentElt, "http://www.w3.org/2001/XMLSchema", ("extension","restriction"), "base") != "sbr:placeholder":
                val.modelXbrl.error("SBR.NL.2.2.11.10",
                    _("ComplexContent is disallowed"),
                    modelObject=cplxContentElt)

        for typeEltTag in ("{http://www.w3.org/2001/XMLSchema}complexType",
                            "{http://www.w3.org/2001/XMLSchema}simpleType"):
            for typeElt in modelDocument.xmlRootElement.iter(tag=typeEltTag):
                definesTypes = True
                name = typeElt.get("name")
                if name:
                    if not name[0].islower() or not name.isalnum():
                        val.modelXbrl.error("SBR.NL.3.2.8.09",
                            _("Type name attribute must be lower camelcase: %(name)s."),
                            modelObject=typeElt, name=name)
        
        for enumElt in modelDocument.xmlRootElement.iter(tag="{http://www.w3.org/2001/XMLSchema}enumeration"):
            definesEnumerations = True
            if any(not valueElt.genLabel(lang="nl")
                   for valueElt in enumElt.iter(tag="{http://www.w3.org/2001/XMLSchema}value")):
                val.modelXbrl.error("SBR.NL.2.2.7.05",
                    _("Enumeration element has value(s) without generic label."),
                    modelObject=enumElt)

        if (definesLinkroles + definesArcroles + definesLinkParts +
            definesAbstractItems + definesNonabstractItems + 
            definesTuples + definesPresentationTuples + definesSpecificationTuples + definesTypes +
            definesEnumerations + definesDimensions + definesDomains + 
            definesHypercubes) != 1:
            schemaContents = []
            if definesLinkroles: schemaContents.append(_("linkroles"))
            if definesArcroles: schemaContents.append(_("arcroles"))
            if definesLinkParts: schemaContents.append(_("link parts"))
            if definesAbstractItems: schemaContents.append(_("abstract items"))
            if definesNonabstractItems: schemaContents.append(_("nonabstract items"))
            if definesTuples: schemaContents.append(_("tuples"))
            if definesPresentationTuples: schemaContents.append(_("sbrPresentationTuples"))
            if definesSpecificationTuples: schemaContents.append(_("sbrSpecificationTuples"))
            if definesTypes: schemaContents.append(_("types"))
            if definesEnumerations: schemaContents.append(_("enumerations"))
            if definesDimensions: schemaContents.append(_("dimensions"))
            if definesDomains: schemaContents.append(_("domains"))
            if definesHypercubes: schemaContents.append(_("hypercubes"))
            if schemaContents:
                if not ((definesTuples or definesPresentationTuples or definesSpecificationTuples) and
                        not (definesLinkroles or definesArcroles or definesLinkParts or definesAbstractItems or
                             definesTypes or definesDimensions or definesDomains or definesHypercubes)):
                    val.modelXbrl.error("SBR.NL.2.2.1.01",
                        _("Taxonomy schema may only define one of these: %(contents)s"),
                        modelObject=modelDocument, contents=', '.join(schemaContents))
            elif not any(refDoc.inDTS and refDoc.targetNamespace not in val.disclosureSystem.baseTaxonomyNamespaces
                         for refDoc in modelDocument.referencesDocument.keys()): # no linkbase ref or includes
                val.modelXbrl.error("SBR.NL.2.2.1.01",
                    _("Taxonomy schema must be a DTS entrypoint OR define linkroles OR arcroles OR link:parts OR context fragments OR abstract items OR tuples OR non-abstract elements OR types OR enumerations OR dimensions OR domains OR hypercubes"),
                    modelObject=modelDocument)
        if definesConcepts ^ any(  # xor so either concepts and no label LB or no concepts and has label LB
                   (refDoc.type == ModelDocument.Type.LINKBASE and
                    XmlUtil.descendant(refDoc.xmlRootElement, XbrlConst.link, "labelLink") is not None)
                   for refDoc in modelDocument.referencesDocument.keys()): # no label linkbase
            val.modelXbrl.error("SBR.NL.2.2.1.02",
                _("A schema that defines concepts MUST have a linked 2.1 label linkbase"),
                modelObject=modelDocument)
        if (definesNonabstractItems or definesTuples) and not any(  # was xor but changed to and not per RH 1/11/12
                   (refDoc.type == ModelDocument.Type.LINKBASE and
                   (XmlUtil.descendant(refDoc.xmlRootElement, XbrlConst.link, "referenceLink") is not None or
                    XmlUtil.descendant(refDoc.xmlRootElement, XbrlConst.link, "label", "{http://www.w3.org/1999/xlink}role", "http://www.xbrl.org/2003/role/documentation" ) is not None))
                    for refDoc in modelDocument.referencesDocument.keys()):
            val.modelXbrl.error("SBR.NL.2.2.1.03",
                _("A schema that defines non-abstract items MUST have a linked (2.1) reference linkbase AND/OR a label linkbase with @xlink:role=documentation"),
                modelObject=modelDocument)

    elif modelDocument.type == ModelDocument.Type.LINKBASE:
        pass
    visited.remove(modelDocument)

Example 38

Project: Arelle
Source File: Validate.py
View license
    def validateTestcase(self, testcase):
        self.modelXbrl.info("info", "Testcase", modelDocument=testcase)
        self.modelXbrl.viewModelObject(testcase.objectId())
        if hasattr(testcase, "testcaseVariations"):
            for modelTestcaseVariation in testcase.testcaseVariations:
                # update ui thread via modelManager (running in background here)
                self.modelXbrl.modelManager.viewModelObject(self.modelXbrl, modelTestcaseVariation.objectId())
                # is this a versioning report?
                resultIsVersioningReport = modelTestcaseVariation.resultIsVersioningReport
                resultIsXbrlInstance = modelTestcaseVariation.resultIsXbrlInstance
                resultIsTaxonomyPackage = modelTestcaseVariation.resultIsTaxonomyPackage
                formulaOutputInstance = None
                inputDTSes = defaultdict(list)
                baseForElement = testcase.baseForElement(modelTestcaseVariation)
                # try to load instance document
                self.modelXbrl.info("info", _("Variation %(id)s %(name)s: %(expected)s - %(description)s"),
                                    modelObject=modelTestcaseVariation, 
                                    id=modelTestcaseVariation.id, 
                                    name=modelTestcaseVariation.name, 
                                    expected=modelTestcaseVariation.expected, 
                                    description=modelTestcaseVariation.description)
                errorCaptureLevel = modelTestcaseVariation.severityLevel # default is INCONSISTENCY
                parameters = modelTestcaseVariation.parameters.copy()
                for readMeFirstUri in modelTestcaseVariation.readMeFirstUris:
                    if isinstance(readMeFirstUri,tuple):
                        # dtsName is for formula instances, but is from/to dts if versioning
                        dtsName, readMeFirstUri = readMeFirstUri
                    elif resultIsVersioningReport:
                        if inputDTSes: dtsName = "to"
                        else: dtsName = "from"
                    else:
                        dtsName = None
                    if resultIsVersioningReport and dtsName: # build multi-schemaRef containing document
                        if dtsName in inputDTSes:
                            dtsName = inputDTSes[dtsName]
                        else:
                            modelXbrl = ModelXbrl.create(self.modelXbrl.modelManager, 
                                         Type.DTSENTRIES,
                                         self.modelXbrl.modelManager.cntlr.webCache.normalizeUrl(readMeFirstUri[:-4] + ".dts", baseForElement),
                                         isEntry=True,
                                         errorCaptureLevel=errorCaptureLevel)
                        DTSdoc = modelXbrl.modelDocument
                        DTSdoc.inDTS = True
                        doc = modelDocumentLoad(modelXbrl, readMeFirstUri, base=baseForElement)
                        if doc is not None:
                            DTSdoc.referencesDocument[doc] = ModelDocumentReference("import", DTSdoc.xmlRootElement)  #fake import
                            doc.inDTS = True
                    elif resultIsTaxonomyPackage:
                        from arelle import PackageManager, PrototypeInstanceObject
                        dtsName = readMeFirstUri
                        modelXbrl = PrototypeInstanceObject.XbrlPrototype(self.modelXbrl.modelManager, readMeFirstUri)
                        PackageManager.packageInfo(self.modelXbrl.modelManager.cntlr, readMeFirstUri, reload=True, errors=modelXbrl.errors)
                    else: # not a multi-schemaRef versioning report
                        if self.useFileSource.isArchive:
                            modelXbrl = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                       readMeFirstUri,
                                                       _("validating"), 
                                                       base=baseForElement,
                                                       useFileSource=self.useFileSource,
                                                       errorCaptureLevel=errorCaptureLevel)
                        else: # need own file source, may need instance discovery
                            filesource = FileSource.FileSource(readMeFirstUri, self.modelXbrl.modelManager.cntlr)
                            if filesource and not filesource.selection and filesource.isArchive:
                                for _archiveFile in filesource.dir: # find instance document in archive
                                    filesource.select(_archiveFile)
                                    if ModelDocument.Type.identify(filesource, filesource.url) in (ModelDocument.Type.INSTANCE, ModelDocument.Type.INLINEXBRL):
                                        break # use this selection
                            modelXbrl = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                       filesource,
                                                       _("validating"), 
                                                       base=baseForElement,
                                                       errorCaptureLevel=errorCaptureLevel)
                        modelXbrl.isTestcaseVariation = True
                    if modelXbrl.modelDocument is None:
                        modelXbrl.error("arelle:notLoaded",
                             _("Testcase %(id)s %(name)s document not loaded: %(file)s"),
                             modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, file=os.path.basename(readMeFirstUri))
                        self.determineNotLoadedTestStatus(modelTestcaseVariation)
                        modelXbrl.close()
                    elif resultIsVersioningReport or resultIsTaxonomyPackage:
                        inputDTSes[dtsName] = modelXbrl
                    elif modelXbrl.modelDocument.type == Type.VERSIONINGREPORT:
                        ValidateVersReport.ValidateVersReport(self.modelXbrl).validate(modelXbrl)
                        self.determineTestStatus(modelTestcaseVariation, modelXbrl.errors)
                        modelXbrl.close()
                    elif testcase.type == Type.REGISTRYTESTCASE:
                        self.instValidator.validate(modelXbrl)  # required to set up dimensions, etc
                        self.instValidator.executeCallTest(modelXbrl, modelTestcaseVariation.id, 
                                   modelTestcaseVariation.cfcnCall, modelTestcaseVariation.cfcnTest)
                        self.determineTestStatus(modelTestcaseVariation, modelXbrl.errors)
                        self.instValidator.close()
                        modelXbrl.close()
                    else:
                        inputDTSes[dtsName].append(modelXbrl)
                        # validate except for formulas
                        _hasFormulae = modelXbrl.hasFormulae
                        modelXbrl.hasFormulae = False
                        try:
                            for pluginXbrlMethod in pluginClassMethods("TestcaseVariation.Xbrl.Loaded"):
                                pluginXbrlMethod(self.modelXbrl, modelXbrl, modelTestcaseVariation)
                            self.instValidator.validate(modelXbrl, parameters)
                            for pluginXbrlMethod in pluginClassMethods("TestcaseVariation.Xbrl.Validated"):
                                pluginXbrlMethod(self.modelXbrl, modelXbrl)
                        except Exception as err:
                            modelXbrl.error("exception:" + type(err).__name__,
                                _("Testcase variation validation exception: %(error)s, instance: %(instance)s"),
                                modelXbrl=modelXbrl, instance=modelXbrl.modelDocument.basename, error=err, exc_info=True)
                        modelXbrl.hasFormulae = _hasFormulae
                if resultIsVersioningReport and modelXbrl.modelDocument:
                    versReportFile = modelXbrl.modelManager.cntlr.webCache.normalizeUrl(
                        modelTestcaseVariation.versioningReportUri, baseForElement)
                    if os.path.exists(versReportFile): #validate existing
                        modelVersReport = ModelXbrl.load(self.modelXbrl.modelManager, versReportFile, _("validating existing version report"))
                        if modelVersReport and modelVersReport.modelDocument and modelVersReport.modelDocument.type == Type.VERSIONINGREPORT:
                            ValidateVersReport.ValidateVersReport(self.modelXbrl).validate(modelVersReport)
                            self.determineTestStatus(modelTestcaseVariation, modelVersReport.errors)
                            modelVersReport.close()
                    elif len(inputDTSes) == 2:
                        ModelVersReport.ModelVersReport(self.modelXbrl).diffDTSes(
                              versReportFile, inputDTSes["from"], inputDTSes["to"])
                        modelTestcaseVariation.status = "generated"
                    else:
                        modelXbrl.error("arelle:notLoaded",
                             _("Testcase %(id)s %(name)s DTSes not loaded, unable to generate versioning report: %(file)s"),
                             modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, file=os.path.basename(readMeFirstUri))
                        modelTestcaseVariation.status = "failed"
                    for inputDTS in inputDTSes.values():
                        inputDTS.close()
                    del inputDTSes # dereference
                elif resultIsTaxonomyPackage:
                    self.determineTestStatus(modelTestcaseVariation, modelXbrl.errors)
                    modelXbrl.close()
                elif inputDTSes:
                    # validate schema, linkbase, or instance
                    modelXbrl = inputDTSes[None][0]
                    for dtsName, inputDTS in inputDTSes.items():  # input instances are also parameters
                        if dtsName: # named instance
                            parameters[dtsName] = (None, inputDTS) #inputDTS is a list of modelXbrl's (instance DTSes)
                        elif len(inputDTS) > 1: # standard-input-instance with multiple instance documents
                            parameters[XbrlConst.qnStandardInputInstance] = (None, inputDTS) # allow error detection in validateFormula
                    if modelXbrl.hasTableRendering or modelTestcaseVariation.resultIsTable:
                        RenderingEvaluator.init(modelXbrl)
                    if modelXbrl.hasFormulae:
                        try:
                            # validate only formulae
                            self.instValidator.parameters = parameters
                            ValidateFormula.validate(self.instValidator)
                        except Exception as err:
                            modelXbrl.error("exception:" + type(err).__name__,
                                _("Testcase formula variation validation exception: %(error)s, instance: %(instance)s"),
                                modelXbrl=modelXbrl, instance=modelXbrl.modelDocument.basename, error=err, exc_info=True)
                    if modelTestcaseVariation.resultIsInfoset and self.modelXbrl.modelManager.validateInfoset:
                        for pluginXbrlMethod in pluginClassMethods("Validate.Infoset"):
                            pluginXbrlMethod(modelXbrl, modelTestcaseVariation.resultInfosetUri)
                        infoset = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                 modelTestcaseVariation.resultInfosetUri,
                                                   _("loading result infoset"), 
                                                   base=baseForElement,
                                                   useFileSource=self.useFileSource,
                                                   errorCaptureLevel=errorCaptureLevel)
                        if infoset.modelDocument is None:
                            modelXbrl.error("arelle:notLoaded",
                                _("Testcase %(id)s %(name)s result infoset not loaded: %(file)s"),
                                modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, 
                                file=os.path.basename(modelTestcaseVariation.resultXbrlInstance))
                            modelTestcaseVariation.status = "result infoset not loadable"
                        else:   # check infoset
                            ValidateInfoset.validate(self.instValidator, modelXbrl, infoset)
                        infoset.close()
                    if modelTestcaseVariation.resultIsTable: # and self.modelXbrl.modelManager.validateInfoset:
                        # diff (or generate) table infoset
                        resultTableUri = modelXbrl.modelManager.cntlr.webCache.normalizeUrl(modelTestcaseVariation.resultTableUri, baseForElement)
                        if not any(alternativeValidation(modelXbrl, resultTableUri)
                                   for alternativeValidation in pluginClassMethods("Validate.TableInfoset")):
                            ViewFileRenderedGrid.viewRenderedGrid(modelXbrl, resultTableUri, diffToFile=True)  # false to save infoset files
                    self.instValidator.close()
                    extraErrors = []
                    for pluginXbrlMethod in pluginClassMethods("TestcaseVariation.Validated"):
                        pluginXbrlMethod(self.modelXbrl, modelXbrl, extraErrors)
                    self.determineTestStatus(modelTestcaseVariation, [e for inputDTSlist in inputDTSes.values() for inputDTS in inputDTSlist for e in inputDTS.errors] + extraErrors) # include infoset errors in status
                    if modelXbrl.formulaOutputInstance and self.noErrorCodes(modelTestcaseVariation.actual): 
                        # if an output instance is created, and no string error codes, ignoring dict of assertion results, validate it
                        modelXbrl.formulaOutputInstance.hasFormulae = False #  block formulae on output instance (so assertion of input is not lost)
                        self.instValidator.validate(modelXbrl.formulaOutputInstance, modelTestcaseVariation.parameters)
                        self.determineTestStatus(modelTestcaseVariation, modelXbrl.formulaOutputInstance.errors)
                        if self.noErrorCodes(modelTestcaseVariation.actual): # if still 'clean' pass it forward for comparison to expected result instance
                            formulaOutputInstance = modelXbrl.formulaOutputInstance
                            modelXbrl.formulaOutputInstance = None # prevent it from being closed now
                        self.instValidator.close()
                    compareIxResultInstance = getattr(modelXbrl, "extractedInlineInstance", False) and modelTestcaseVariation.resultXbrlInstanceUri
                    if compareIxResultInstance:
                        formulaOutputInstance = modelXbrl # compare modelXbrl to generated output instance
                        errMsgPrefix = "ix"
                    else: # delete input instances before formula output comparision
                        for inputDTSlist in inputDTSes.values():
                            for inputDTS in inputDTSlist:
                                inputDTS.close()
                        del inputDTSes # dereference
                        errMsgPrefix = "formula"
                    if resultIsXbrlInstance and formulaOutputInstance and formulaOutputInstance.modelDocument:
                        expectedInstance = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                   modelTestcaseVariation.resultXbrlInstanceUri,
                                                   _("loading expected result XBRL instance"), 
                                                   base=baseForElement,
                                                   useFileSource=self.useFileSource,
                                                   errorCaptureLevel=errorCaptureLevel)
                        if expectedInstance.modelDocument is None:
                            self.modelXbrl.error("{}:expectedResultNotLoaded".format(errMsgPrefix),
                                _("Testcase %(id)s %(name)s expected result instance not loaded: %(file)s"),
                                modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, 
                                file=os.path.basename(modelTestcaseVariation.resultXbrlInstanceUri),
                                messageCodes=("formula:expectedResultNotLoaded","ix:expectedResultNotLoaded"))
                            modelTestcaseVariation.status = "result not loadable"
                        else:   # compare facts
                            if len(expectedInstance.facts) != len(formulaOutputInstance.facts):
                                formulaOutputInstance.error("{}:resultFactCounts".format(errMsgPrefix),
                                    _("Formula output %(countFacts)s facts, expected %(expectedFacts)s facts"),
                                    modelXbrl=modelXbrl, countFacts=len(formulaOutputInstance.facts),
                                    expectedFacts=len(expectedInstance.facts),
                                    messageCodes=("formula:resultFactCounts","ix:resultFactCounts"))
                            else:
                                formulaOutputFootnotesRelSet = ModelRelationshipSet(formulaOutputInstance, "XBRL-footnotes")
                                expectedFootnotesRelSet = ModelRelationshipSet(expectedInstance, "XBRL-footnotes")
                                def factFootnotes(fact, footnotesRelSet):
                                    footnotes = []
                                    footnoteRels = footnotesRelSet.fromModelObject(fact)
                                    if footnoteRels:
                                        # most process rels in same order between two instances, use labels to sort
                                        for i, footnoteRel in enumerate(sorted(footnoteRels,
                                                                               key=lambda r: (r.fromLabel,r.toLabel))):
                                            modelObject = footnoteRel.toModelObject
                                            if isinstance(modelObject, ModelResource):
                                                footnotes.append("Footnote {}: {}".format(
                                                   i+1, # compare footnote with normalize-space
                                                   re.sub(r'\s+', ' ', collapseWhitespace(modelObject.stringValue))))
                                            elif isinstance(modelObject, ModelFact):
                                                footnotes.append("Footnoted fact {}: {} context: {} value: {}".format(
                                                    i+1,
                                                    modelObject.qname,
                                                    modelObject.contextID,
                                                    collapseWhitespace(modelObject.value)))
                                    return footnotes
                                for expectedInstanceFact in expectedInstance.facts:
                                    unmatchedFactsStack = []
                                    formulaOutputFact = formulaOutputInstance.matchFact(expectedInstanceFact, unmatchedFactsStack, deemP0inf=True)
                                    if formulaOutputFact is None:
                                        if unmatchedFactsStack: # get missing nested tuple fact, if possible
                                            missingFact = unmatchedFactsStack[-1]
                                        else:
                                            missingFact = expectedInstanceFact
                                        formulaOutputInstance.error("{}:expectedFactMissing".format(errMsgPrefix),
                                            _("Output missing expected fact %(fact)s"),
                                            modelXbrl=missingFact, fact=missingFact.qname,
                                            messageCodes=("formula:expectedFactMissing","ix:expectedFactMissing"))
                                    else: # compare footnotes
                                        expectedInstanceFactFootnotes = factFootnotes(expectedInstanceFact, expectedFootnotesRelSet)
                                        formulaOutputFactFootnotes = factFootnotes(formulaOutputFact, formulaOutputFootnotesRelSet)
                                        if expectedInstanceFactFootnotes != formulaOutputFactFootnotes:
                                            formulaOutputInstance.error("{}:expectedFactFootnoteDifference".format(errMsgPrefix),
                                                _("Output expected fact %(fact)s expected footnotes %(footnotes1)s produced footnotes %(footnotes2)s"),
                                                modelXbrl=(formulaOutputFact,expectedInstanceFact), fact=expectedInstanceFact.qname, footnotes1=expectedInstanceFactFootnotes, footnotes2=formulaOutputFactFootnotes,
                                                messageCodes=("formula:expectedFactFootnoteDifference","ix:expectedFactFootnoteDifference"))

                            # for debugging uncomment next line to save generated instance document
                            # formulaOutputInstance.saveInstance(r"c:\temp\test-out-inst.xml")
                        expectedInstance.close()
                        del expectedInstance # dereference
                        self.determineTestStatus(modelTestcaseVariation, formulaOutputInstance.errors)
                        formulaOutputInstance.close()
                        del formulaOutputInstance
                    if compareIxResultInstance:
                        for inputDTSlist in inputDTSes.values():
                            for inputDTS in inputDTSlist:
                                inputDTS.close()
                        del inputDTSes # dereference
                # update ui thread via modelManager (running in background here)
                self.modelXbrl.modelManager.viewModelObject(self.modelXbrl, modelTestcaseVariation.objectId())
                    
            self.modelXbrl.modelManager.showStatus(_("ready"), 2000)

Example 39

Project: Arelle
Source File: ValidateFilingDimensions.py
View license
def checkDimensions(val, drsELRs):
    
    fromConceptELRs = defaultdict(set)
    hypercubes = set()
    hypercubesInLinkrole = defaultdict(set)
    domainsInLinkrole = defaultdict(set)
    for ELR in drsELRs:
        domainMemberRelationshipSet = val.modelXbrl.relationshipSet( XbrlConst.domainMember, ELR)
                            
        # check Hypercubes in ELR, accumulate list of primary items
        positiveAxisTableSources = defaultdict(set)
        positiveHypercubes = set()
        primaryItems = set()
        for hasHypercubeArcrole in (XbrlConst.all, XbrlConst.notAll):
            hasHypercubeRelationships = val.modelXbrl.relationshipSet(
                             hasHypercubeArcrole, ELR).fromModelObjects()
            for hasHcRels in hasHypercubeRelationships.values():
                for hasHcRel in hasHcRels:
                    sourceConcept = hasHcRel.fromModelObject
                    primaryItems.add(sourceConcept)
                    hc = hasHcRel.toModelObject
                    hypercubes.add(hc)
                    if hasHypercubeArcrole == XbrlConst.all:
                        positiveHypercubes.add(hc)
                        if not hasHcRel.isClosed:
                            val.modelXbrl.error("SBR.NL.2.3.6.04",
                                _("All hypercube %(hypercube)s in DRS role %(linkrole)s, does not have closed='true'"),
                                modelObject=hasHcRel, hypercube=hc.qname, linkrole=ELR)
                    elif hasHypercubeArcrole == XbrlConst.notAll:
                        if hasHcRel.isClosed:
                            val.modelXbrl.error(("EFM.6.16.06", "GFM.1.08.06"),
                                _("Not all hypercube %(hypercube)s in DRS role %(linkrole)s, does not have closed='false'"),
                                modelObject=hasHcRel, hypercube=hc.qname, linkrole=ELR, primaryItem=sourceConcept.qname)
                        if hc in positiveHypercubes:
                            val.modelXbrl.error(("EFM.6.16.08", "GFM.1.08.08"),
                                _("Not all hypercube %(hypercube)s in DRS role %(linkrole)s, is also the target of a positive hypercube"),
                                modelObject=hasHcRel, hypercube=hc.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), primaryItem=sourceConcept.qname)
                    dimELR = hasHcRel.targetRole
                    dimTargetRequired = (dimELR is not None)
                    if not dimELR:
                        dimELR = ELR
                    hypercubesInLinkrole[dimELR].add(hc) # this is the elr containing the HC-dim relations
                    hcDimRels = val.modelXbrl.relationshipSet(
                             XbrlConst.hypercubeDimension, dimELR).fromModelObject(hc)
                    if dimTargetRequired and len(hcDimRels) == 0:
                        val.modelXbrl.error(("EFM.6.16.09", "GFM.1.08.09"),
                            _("Table %(hypercube)s in DRS role %(linkrole)s, missing targetrole consecutive relationship"),
                            modelObject=hasHcRel, hypercube=hc.qname, fromConcept=sourceConcept.qname, toConcept=hc.qname, 
                            linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR),
                            arcroleURI=hasHcRel.arcrole, arcrole=os.path.basename(hasHcRel.arcrole))
                    for hcDimRel in hcDimRels:
                        dim = hcDimRel.toModelObject
                        if isinstance(dim, ModelConcept):
                            domELR = hcDimRel.targetRole
                            domTargetRequired = (domELR is not None)
                            if not domELR:
                                if dim.isExplicitDimension:
                                    domELR = dimELR
                                    if val.validateSBRNL:
                                        val.modelXbrl.error("SBR.NL.2.3.5.04",
                                            _("Hypercube %(hypercube)s in DRS role %(linkrole)s, missing targetrole to dimension %(dimension)s consecutive relationship"),
                                            modelObject=hcDimRel, hypercube=hc.qname, linkrole=ELR, dimension=dim.qname)
                            else:
                                if dim.isTypedDimension and val.validateSBRNL:
                                    val.modelXbrl.error("SBR.NL.2.3.5.07",
                                        _("Typed dimension %(dimension)s in DRS role %(linkrole)s, has targetrole consecutive relationship"),
                                        modelObject=hcDimRel, dimension=dim.qname, linkrole=ELR)
                            if hasHypercubeArcrole == XbrlConst.all:
                                positiveAxisTableSources[dim].add(sourceConcept)
                            elif hasHypercubeArcrole == XbrlConst.notAll and \
                                 (dim not in positiveAxisTableSources or \
                                  not commonAncestor(domainMemberRelationshipSet,
                                                  sourceConcept, positiveAxisTableSources[dim])):
                                val.modelXbrl.error(("EFM.6.16.07", "GFM.1.08.08"),
                                    _("Negative table axis %(dimension)s in DRS role %(linkrole)s, not in any positive table in same role"),
                                     modelObject=hcDimRel, dimension=dim.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), primaryItem=sourceConcept.qname)
                            dimDomRels = val.modelXbrl.relationshipSet(
                                 XbrlConst.dimensionDomain, domELR).fromModelObject(dim)   
                            if domTargetRequired and len(dimDomRels) == 0:
                                val.modelXbrl.error(("EFM.6.16.09", "GFM.1.08.09"),
                                    _("Axis %(dimension)s in DRS role %(linkrole)s, missing targetrole consecutive relationship"),
                                    modelObject=hcDimRel, dimension=dim.qname, fromConcept=hc.qname, toConcept=dim.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), arcroleURI=hasHcRel.arcrole, arcrole=os.path.basename(hcDimRel.arcrole))
                            if val.validateEFMorGFM:
                                # flatten DRS member relationsihps in ELR for undirected cycle detection
                                drsRelsFrom = defaultdict(list)
                                drsRelsTo = defaultdict(list)
                                getDrsRels(val, domELR, dimDomRels, ELR, drsRelsFrom, drsRelsTo)
                                # check for cycles
                                fromConceptELRs[hc].add(dimELR)
                                fromConceptELRs[dim].add(domELR)
                                cycleCausingConcept = undirectedFwdCycle(val, domELR, dimDomRels, ELR, drsRelsFrom, drsRelsTo, fromConceptELRs)
                                if cycleCausingConcept is not None:
                                    cycleCausingConcept.append(hcDimRel)
                                    val.modelXbrl.error(("EFM.6.16.04", "GFM.1.08.04"),
                                        _("Dimension relationships have an undirected cycle in DRS role %(linkrole)s \nstarting from table %(hypercube)s, \naxis %(dimension)s, \npath %(path)s"),
                                        modelObject=[hc, dim] + [rel for rel in cycleCausingConcept if not isinstance(rel, bool)], 
                                        linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), 
                                        hypercube=hc.qname, dimension=dim.qname, conceptFrom=dim.qname,
                                        path=cyclePath(hc,cycleCausingConcept))
                                fromConceptELRs.clear()
                            elif val.validateSBRNL:
                                checkSBRNLMembers(val, hc, dim, domELR, dimDomRels, ELR, True)
                                for dimDomRel in dimDomRels:
                                    dom = dimDomRel.toModelObject
                                    if isinstance(dom, ModelConcept):
                                        domainsInLinkrole[domELR].add(dom) # this is the elr containing the HC-dim relations
                if hasHypercubeArcrole == XbrlConst.all and len(hasHcRels) > 1:
                    val.modelXbrl.error(("EFM.6.16.05", "GFM.1.08.05"),
                        _("Multiple tables (%(hypercubeCount)s) DRS role %(linkrole)s, source %(concept)s, only 1 allowed"),
                        modelObject=[sourceConcept] + hasHcRels, 
                        hypercubeCount=len(hasHcRels), linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR),
                        concept=sourceConcept.qname,
                        hypercubes=', '.join(str(r.toModelObject.qname) for r in hasHcRels if isinstance(r.toModelObject, ModelConcept)))
                    
        # check for primary item dimension-member graph undirected cycles
        fromRelationships = domainMemberRelationshipSet.fromModelObjects()
        for relFrom, rels in fromRelationships.items():
            if relFrom in primaryItems:
                drsRelsFrom = defaultdict(list)
                drsRelsTo = defaultdict(list)
                getDrsRels(val, ELR, rels, ELR, drsRelsFrom, drsRelsTo)
                fromConceptELRs[relFrom].add(ELR)
                cycleCausingConcept = undirectedFwdCycle(val, ELR, rels, ELR, drsRelsFrom, drsRelsTo, fromConceptELRs)
                if cycleCausingConcept is not None:
                    val.modelXbrl.error(("EFM.6.16.04", "GFM.1.08.04"),
                        _("Domain-member primary-item relationships have an undirected cycle in DRS role %(linkrole)s \nstarting from %(conceptFrom)s, \npath %(path)s"),
                        modelObject=[relFrom] + [rel for rel in cycleCausingConcept if not isinstance(rel, bool)], 
                        linkrole=ELR, conceptFrom=relFrom.qname, path=cyclePath(relFrom, cycleCausingConcept))
                fromConceptELRs.clear()
            for rel in rels:
                fromMbr = rel.fromModelObject
                toMbr = rel.toModelObject
                toELR = rel.targetRole
                if isinstance(toMbr, ModelConcept) and toELR and len(
                    val.modelXbrl.relationshipSet(
                         XbrlConst.domainMember, toELR).fromModelObject(toMbr)) == 0:
                    val.modelXbrl.error(("EFM.6.16.09", "GFM.1.08.09"),
                        _("Domain member %(concept)s in DRS role %(linkrole)s, missing targetrole consecutive relationship"),
                        modelObject=rel, concept=fromMbr.qname, fromConcept=toMbr.qname, toConcept=fromMbr.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), arcroleURI=hasHcRel.arcrole, arcrole=os.path.basename(rel.arcrole))
                    
    if val.validateSBRNL:
        # check hypercubes for unique set of members
        for hc in hypercubes:
            for priHcRel in val.modelXbrl.relationshipSet(XbrlConst.all).toModelObject(hc):
                priItem = priHcRel.fromModelObject
                ELR = priHcRel.linkrole
                checkSBRNLMembers(val, hc, priItem, ELR, 
                                  val.modelXbrl.relationshipSet(XbrlConst.domainMember, ELR).fromModelObject(priItem), 
                                  ELR, False)
                if priHcRel.contextElement == 'segment':  
                    val.modelXbrl.error("SBR.NL.2.3.5.06",
                        _("hypercube %(hypercube)s in segment not allowed, ELR role %(linkrole)s"),
                        modelObject=priHcRel, linkrole=ELR, hypercube=hc.qname)
        for notAllRel in val.modelXbrl.relationshipSet(XbrlConst.notAll).modelRelationships:
            val.modelXbrl.error("SBR.NL.2.3.5.05",
                _("Notall from primary item %(primaryItem)s in ELR role %(linkrole)s to %(hypercube)s"),
                modelObject=val.modelXbrl, primaryItem=notAllRel.fromModelObject.qname, linkrole=notAllRel.linkrole, hypercube=notAllRel.toModelObject.qname)
        for ELR, hypercubes in hypercubesInLinkrole.items():
            '''removed RH 2011-12-06
            for modelRel in val.modelXbrl.relationshipSet("XBRL-dimensions", ELR).modelRelationships:
                if modelRel.fromModelObject != hc:
                    val.modelXbrl.error("SBR.NL.2.3.5.03",
                        _("ELR role %(linkrole)s, is not dedicated to %(hypercube)s, but also has %(otherQname)s"),
                        modelObject=val.modelXbrl, linkrole=ELR, hypercube=hc.qname, otherQname=modelRel.fromModelObject.qname)
            '''
            domains = domainsInLinkrole.get(ELR, emptySet)
            for hc in hypercubes:  # only one member
                for arcrole in (XbrlConst.parentChild, "XBRL-dimensions"):
                    for modelRel in val.modelXbrl.relationshipSet(arcrole, ELR).modelRelationships:
                        if (modelRel.fromModelObject != hc and modelRel.toModelObject != hc and
                            modelRel.fromModelObject not in domains and modelRel.toModelObject not in domains):
                            val.modelXbrl.error("SBR.NL.2.2.3.05",
                                _("ELR role %(linkrole)s, has hypercube %(hypercube)s and a %(arcrole)s relationship not involving the hypercube or primary domain, from %(fromConcept)s to %(toConcept)s"),
                                modelObject=modelRel, linkrole=ELR, hypercube=hc.qname, arcrole=os.path.basename(modelRel.arcrole), 
                                fromConcept=modelRel.fromModelObject.qname, 
                                toConcept=(modelRel.toModelObject.qname if isinstance(modelRel.toModelObject, ModelConcept) else "unknown"))
        domainsInLinkrole = defaultdict(set)
        dimDomMemsByLinkrole = defaultdict(set)
        for rel in val.modelXbrl.relationshipSet(XbrlConst.dimensionDomain).modelRelationships:
            relFrom = rel.fromModelObject
            relTo = rel.toModelObject
            if isinstance(relFrom, ModelConcept) and isinstance(relTo, ModelConcept):
                domainsInLinkrole[rel.targetRole].add(relFrom)
                domMems = set() # determine usable dom and mems of dimension in this linkrole
                if rel.isUsable:
                    domMems.add(relTo)
                for relMem in val.modelXbrl.relationshipSet(XbrlConst.domainMember, (rel.targetRole or rel.linkrole)).fromModelObject(relTo):
                    if relMem.isUsable:
                        domMems.add(relMem.toModelObject)
                dimDomMemsByLinkrole[(rel.linkrole,relFrom)].update(domMems)
                if rel.isUsable and val.modelXbrl.relationshipSet(XbrlConst.domainMember, rel.targetRole).fromModelObject(relTo):
                    val.modelXbrl.error("SBR.NL.2.3.7.05",
                        _("Dimension %(dimension)s in DRS role %(linkrole)s, has usable domain with members %(domain)s"),
                        modelObject=rel, dimension=relFrom.qname, linkrole=rel.linkrole, domain=relTo.qname)
                if not relTo.isAbstract:
                    val.modelXbrl.error("SBR.NL.2.3.7.02",
                        _("Dimension %(dimension)s in DRS role %(linkrole)s, has nonAbsract domain %(domain)s"),
                        modelObject=rel, dimension=relFrom.qname, linkrole=rel.linkrole, domain=relTo.qname)
                if relTo.substitutionGroupQname.localName not in ("domainItem","domainMemberItem"):
                    val.modelXbrl.error("SBR.NL.2.2.2.19",
                        _("Domain item %(domain)s in DRS role %(linkrole)s, in dimension %(dimension)s is not a domainItem"),
                        modelObject=rel, domain=relTo.qname, linkrole=rel.linkrole, dimension=relFrom.qname)
                if not rel.targetRole and relTo.substitutionGroupQname.localName == "domainItem":
                    val.modelXbrl.error("SBR.NL.2.3.6.03",
                        _("Dimension %(dimension)s in DRS role %(linkrole)s, missing targetrole to consecutive domain relationship"),
                        modelObject=rel, dimension=relFrom.qname, linkrole=rel.linkrole)
        for linkrole, domains in domainsInLinkrole.items():
            if linkrole and len(domains) > 1:
                val.modelXbrl.error("SBR.NL.2.3.7.04",
                    _("Linkrole %(linkrole)s, has multiple domains %(domains)s"),
                    modelObject=val.modelXbrl, linkrole=linkrole, domains=", ".join([str(dom.qname) for dom in domains]))
        del domainsInLinkrole   # dereference
        linkrolesByDimDomMems = defaultdict(set)
        for linkroleDim, domMems in dimDomMemsByLinkrole.items():
            linkrole, dim = linkroleDim
            linkrolesByDimDomMems[(dim,tuple(domMems))].add(linkrole)
        for dimDomMems, linkroles in linkrolesByDimDomMems.items():
            if len(linkroles) > 1:
                val.modelXbrl.error("SBR.NL.2.3.6.02",
                    _("Dimension %(dimension)s  usable members same in linkroles %(linkroles)s"),
                    modelObject=val.modelXbrl, dimension=dimDomMems[0].qname, linkroles=', '.join(l for l in linkroles))
        del dimDomMemsByLinkrole, linkrolesByDimDomMems
        for rel in val.modelXbrl.relationshipSet(XbrlConst.domainMember).modelRelationships:
            if val.modelXbrl.relationshipSet(XbrlConst.domainMember, rel.targetRole).fromModelObject(rel.toModelObject):
                val.modelXbrl.error("SBR.NL.2.3.7.03",
                    _("Domain member %(member)s in DRS role %(linkrole)s, has nested members"),
                    modelObject=rel, member=(rel.toModelObject.qname if isinstance(rel.toModelObject, ModelConcept) else None), linkrole=rel.linkrole)
        for rel in val.modelXbrl.relationshipSet(XbrlConst.domainMember).modelRelationships:
            relFrom = rel.fromModelObject
            relTo = rel.toModelObject
            if isinstance(relTo, ModelConcept):
                # avoid primary item relationships in these tests
                if relFrom.substitutionGroupQname.localName == "domainItem":
                    if relTo.substitutionGroupQname.localName != "domainMemberItem":
                        val.modelXbrl.error("SBR.NL.2.2.2.19",
                            _("Domain member item %(member)s in DRS role %(linkrole)s is not a domainMemberItem"),
                            modelObject=rel, member=relTo.qname, linkrole=rel.linkrole)
                else:
                    if relTo.substitutionGroupQname.localName == "domainMemberItem":
                        val.modelXbrl.error("SBR.NL.2.2.2.19",
                            _("Domain item %(domain)s in DRS role %(linkrole)s is not a domainItem"),
                            modelObject=rel, domain=relFrom.qname, linkrole=rel.linkrole)
                        break # don't repeat parent's error on rest of child members
                    elif relFrom.isAbstract and relFrom.substitutionGroupQname.localName != "primaryDomainItem":
                        val.modelXbrl.error("SBR.NL.2.2.2.19",
                            _("Abstract domain item %(domain)s in DRS role %(linkrole)s is not a primaryDomainItem"),
                            modelObject=rel, domain=relFrom.qname, linkrole=rel.linkrole)
                        break # don't repeat parent's error on rest of child members
        hypercubeDRSDimensions = defaultdict(dict)
        for hcDimRel in val.modelXbrl.relationshipSet(XbrlConst.hypercubeDimension).modelRelationships:
            hc = hcDimRel.fromModelObject
            if isinstance(hc, ModelConcept):
                ELR = hcDimRel.linkrole
                try:
                    hcDRSdims = hypercubeDRSDimensions[hc][ELR]
                except KeyError:
                    hcDRSdims = set()
                    hypercubeDRSDimensions[hc][ELR] = hcDRSdims
                hcDRSdims.add(hcDimRel.toModelObject)
        for hc, DRSdims in hypercubeDRSDimensions.items():
            hcELRdimSets = {}
            for ELR, mutableDims in DRSdims.items():
                dims = frozenset(mutableDims)
                if dims not in hcELRdimSets:
                    hcELRdimSets[dims] = ELR
                else: 
                    val.modelXbrl.error("SBR.NL.2.3.5.02",
                        _("Hypercube %(hypercube)s has same dimensions in ELR roles %(linkrole)s and %(linkrole2)s: %(dimensions)s"),
                        modelObject=hc, hypercube=hc.qname, linkrole=ELR, linkrole2=hcELRdimSets[dims],
                        dimensions=", ".join([str(dim.qname) for dim in dims]))
        del hypercubeDRSDimensions # dereference

Example 40

Project: Arelle
Source File: Validate.py
View license
    def validateTestcase(self, testcase):
        self.modelXbrl.info("info", "Testcase", modelDocument=testcase)
        self.modelXbrl.viewModelObject(testcase.objectId())
        if hasattr(testcase, "testcaseVariations"):
            for modelTestcaseVariation in testcase.testcaseVariations:
                # update ui thread via modelManager (running in background here)
                self.modelXbrl.modelManager.viewModelObject(self.modelXbrl, modelTestcaseVariation.objectId())
                # is this a versioning report?
                resultIsVersioningReport = modelTestcaseVariation.resultIsVersioningReport
                resultIsXbrlInstance = modelTestcaseVariation.resultIsXbrlInstance
                resultIsTaxonomyPackage = modelTestcaseVariation.resultIsTaxonomyPackage
                formulaOutputInstance = None
                inputDTSes = defaultdict(list)
                baseForElement = testcase.baseForElement(modelTestcaseVariation)
                # try to load instance document
                self.modelXbrl.info("info", _("Variation %(id)s %(name)s: %(expected)s - %(description)s"),
                                    modelObject=modelTestcaseVariation, 
                                    id=modelTestcaseVariation.id, 
                                    name=modelTestcaseVariation.name, 
                                    expected=modelTestcaseVariation.expected, 
                                    description=modelTestcaseVariation.description)
                errorCaptureLevel = modelTestcaseVariation.severityLevel # default is INCONSISTENCY
                parameters = modelTestcaseVariation.parameters.copy()
                for readMeFirstUri in modelTestcaseVariation.readMeFirstUris:
                    if isinstance(readMeFirstUri,tuple):
                        # dtsName is for formula instances, but is from/to dts if versioning
                        dtsName, readMeFirstUri = readMeFirstUri
                    elif resultIsVersioningReport:
                        if inputDTSes: dtsName = "to"
                        else: dtsName = "from"
                    else:
                        dtsName = None
                    if resultIsVersioningReport and dtsName: # build multi-schemaRef containing document
                        if dtsName in inputDTSes:
                            dtsName = inputDTSes[dtsName]
                        else:
                            modelXbrl = ModelXbrl.create(self.modelXbrl.modelManager, 
                                         Type.DTSENTRIES,
                                         self.modelXbrl.modelManager.cntlr.webCache.normalizeUrl(readMeFirstUri[:-4] + ".dts", baseForElement),
                                         isEntry=True,
                                         errorCaptureLevel=errorCaptureLevel)
                        DTSdoc = modelXbrl.modelDocument
                        DTSdoc.inDTS = True
                        doc = modelDocumentLoad(modelXbrl, readMeFirstUri, base=baseForElement)
                        if doc is not None:
                            DTSdoc.referencesDocument[doc] = ModelDocumentReference("import", DTSdoc.xmlRootElement)  #fake import
                            doc.inDTS = True
                    elif resultIsTaxonomyPackage:
                        from arelle import PackageManager, PrototypeInstanceObject
                        dtsName = readMeFirstUri
                        modelXbrl = PrototypeInstanceObject.XbrlPrototype(self.modelXbrl.modelManager, readMeFirstUri)
                        PackageManager.packageInfo(self.modelXbrl.modelManager.cntlr, readMeFirstUri, reload=True, errors=modelXbrl.errors)
                    else: # not a multi-schemaRef versioning report
                        if self.useFileSource.isArchive:
                            modelXbrl = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                       readMeFirstUri,
                                                       _("validating"), 
                                                       base=baseForElement,
                                                       useFileSource=self.useFileSource,
                                                       errorCaptureLevel=errorCaptureLevel)
                        else: # need own file source, may need instance discovery
                            filesource = FileSource.FileSource(readMeFirstUri, self.modelXbrl.modelManager.cntlr)
                            if filesource and not filesource.selection and filesource.isArchive:
                                for _archiveFile in filesource.dir: # find instance document in archive
                                    filesource.select(_archiveFile)
                                    if ModelDocument.Type.identify(filesource, filesource.url) in (ModelDocument.Type.INSTANCE, ModelDocument.Type.INLINEXBRL):
                                        break # use this selection
                            modelXbrl = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                       filesource,
                                                       _("validating"), 
                                                       base=baseForElement,
                                                       errorCaptureLevel=errorCaptureLevel)
                        modelXbrl.isTestcaseVariation = True
                    if modelXbrl.modelDocument is None:
                        modelXbrl.error("arelle:notLoaded",
                             _("Testcase %(id)s %(name)s document not loaded: %(file)s"),
                             modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, file=os.path.basename(readMeFirstUri))
                        self.determineNotLoadedTestStatus(modelTestcaseVariation)
                        modelXbrl.close()
                    elif resultIsVersioningReport or resultIsTaxonomyPackage:
                        inputDTSes[dtsName] = modelXbrl
                    elif modelXbrl.modelDocument.type == Type.VERSIONINGREPORT:
                        ValidateVersReport.ValidateVersReport(self.modelXbrl).validate(modelXbrl)
                        self.determineTestStatus(modelTestcaseVariation, modelXbrl.errors)
                        modelXbrl.close()
                    elif testcase.type == Type.REGISTRYTESTCASE:
                        self.instValidator.validate(modelXbrl)  # required to set up dimensions, etc
                        self.instValidator.executeCallTest(modelXbrl, modelTestcaseVariation.id, 
                                   modelTestcaseVariation.cfcnCall, modelTestcaseVariation.cfcnTest)
                        self.determineTestStatus(modelTestcaseVariation, modelXbrl.errors)
                        self.instValidator.close()
                        modelXbrl.close()
                    else:
                        inputDTSes[dtsName].append(modelXbrl)
                        # validate except for formulas
                        _hasFormulae = modelXbrl.hasFormulae
                        modelXbrl.hasFormulae = False
                        try:
                            for pluginXbrlMethod in pluginClassMethods("TestcaseVariation.Xbrl.Loaded"):
                                pluginXbrlMethod(self.modelXbrl, modelXbrl, modelTestcaseVariation)
                            self.instValidator.validate(modelXbrl, parameters)
                            for pluginXbrlMethod in pluginClassMethods("TestcaseVariation.Xbrl.Validated"):
                                pluginXbrlMethod(self.modelXbrl, modelXbrl)
                        except Exception as err:
                            modelXbrl.error("exception:" + type(err).__name__,
                                _("Testcase variation validation exception: %(error)s, instance: %(instance)s"),
                                modelXbrl=modelXbrl, instance=modelXbrl.modelDocument.basename, error=err, exc_info=True)
                        modelXbrl.hasFormulae = _hasFormulae
                if resultIsVersioningReport and modelXbrl.modelDocument:
                    versReportFile = modelXbrl.modelManager.cntlr.webCache.normalizeUrl(
                        modelTestcaseVariation.versioningReportUri, baseForElement)
                    if os.path.exists(versReportFile): #validate existing
                        modelVersReport = ModelXbrl.load(self.modelXbrl.modelManager, versReportFile, _("validating existing version report"))
                        if modelVersReport and modelVersReport.modelDocument and modelVersReport.modelDocument.type == Type.VERSIONINGREPORT:
                            ValidateVersReport.ValidateVersReport(self.modelXbrl).validate(modelVersReport)
                            self.determineTestStatus(modelTestcaseVariation, modelVersReport.errors)
                            modelVersReport.close()
                    elif len(inputDTSes) == 2:
                        ModelVersReport.ModelVersReport(self.modelXbrl).diffDTSes(
                              versReportFile, inputDTSes["from"], inputDTSes["to"])
                        modelTestcaseVariation.status = "generated"
                    else:
                        modelXbrl.error("arelle:notLoaded",
                             _("Testcase %(id)s %(name)s DTSes not loaded, unable to generate versioning report: %(file)s"),
                             modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, file=os.path.basename(readMeFirstUri))
                        modelTestcaseVariation.status = "failed"
                    for inputDTS in inputDTSes.values():
                        inputDTS.close()
                    del inputDTSes # dereference
                elif resultIsTaxonomyPackage:
                    self.determineTestStatus(modelTestcaseVariation, modelXbrl.errors)
                    modelXbrl.close()
                elif inputDTSes:
                    # validate schema, linkbase, or instance
                    modelXbrl = inputDTSes[None][0]
                    for dtsName, inputDTS in inputDTSes.items():  # input instances are also parameters
                        if dtsName: # named instance
                            parameters[dtsName] = (None, inputDTS) #inputDTS is a list of modelXbrl's (instance DTSes)
                        elif len(inputDTS) > 1: # standard-input-instance with multiple instance documents
                            parameters[XbrlConst.qnStandardInputInstance] = (None, inputDTS) # allow error detection in validateFormula
                    if modelXbrl.hasTableRendering or modelTestcaseVariation.resultIsTable:
                        RenderingEvaluator.init(modelXbrl)
                    if modelXbrl.hasFormulae:
                        try:
                            # validate only formulae
                            self.instValidator.parameters = parameters
                            ValidateFormula.validate(self.instValidator)
                        except Exception as err:
                            modelXbrl.error("exception:" + type(err).__name__,
                                _("Testcase formula variation validation exception: %(error)s, instance: %(instance)s"),
                                modelXbrl=modelXbrl, instance=modelXbrl.modelDocument.basename, error=err, exc_info=True)
                    if modelTestcaseVariation.resultIsInfoset and self.modelXbrl.modelManager.validateInfoset:
                        for pluginXbrlMethod in pluginClassMethods("Validate.Infoset"):
                            pluginXbrlMethod(modelXbrl, modelTestcaseVariation.resultInfosetUri)
                        infoset = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                 modelTestcaseVariation.resultInfosetUri,
                                                   _("loading result infoset"), 
                                                   base=baseForElement,
                                                   useFileSource=self.useFileSource,
                                                   errorCaptureLevel=errorCaptureLevel)
                        if infoset.modelDocument is None:
                            modelXbrl.error("arelle:notLoaded",
                                _("Testcase %(id)s %(name)s result infoset not loaded: %(file)s"),
                                modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, 
                                file=os.path.basename(modelTestcaseVariation.resultXbrlInstance))
                            modelTestcaseVariation.status = "result infoset not loadable"
                        else:   # check infoset
                            ValidateInfoset.validate(self.instValidator, modelXbrl, infoset)
                        infoset.close()
                    if modelTestcaseVariation.resultIsTable: # and self.modelXbrl.modelManager.validateInfoset:
                        # diff (or generate) table infoset
                        resultTableUri = modelXbrl.modelManager.cntlr.webCache.normalizeUrl(modelTestcaseVariation.resultTableUri, baseForElement)
                        if not any(alternativeValidation(modelXbrl, resultTableUri)
                                   for alternativeValidation in pluginClassMethods("Validate.TableInfoset")):
                            ViewFileRenderedGrid.viewRenderedGrid(modelXbrl, resultTableUri, diffToFile=True)  # false to save infoset files
                    self.instValidator.close()
                    extraErrors = []
                    for pluginXbrlMethod in pluginClassMethods("TestcaseVariation.Validated"):
                        pluginXbrlMethod(self.modelXbrl, modelXbrl, extraErrors)
                    self.determineTestStatus(modelTestcaseVariation, [e for inputDTSlist in inputDTSes.values() for inputDTS in inputDTSlist for e in inputDTS.errors] + extraErrors) # include infoset errors in status
                    if modelXbrl.formulaOutputInstance and self.noErrorCodes(modelTestcaseVariation.actual): 
                        # if an output instance is created, and no string error codes, ignoring dict of assertion results, validate it
                        modelXbrl.formulaOutputInstance.hasFormulae = False #  block formulae on output instance (so assertion of input is not lost)
                        self.instValidator.validate(modelXbrl.formulaOutputInstance, modelTestcaseVariation.parameters)
                        self.determineTestStatus(modelTestcaseVariation, modelXbrl.formulaOutputInstance.errors)
                        if self.noErrorCodes(modelTestcaseVariation.actual): # if still 'clean' pass it forward for comparison to expected result instance
                            formulaOutputInstance = modelXbrl.formulaOutputInstance
                            modelXbrl.formulaOutputInstance = None # prevent it from being closed now
                        self.instValidator.close()
                    compareIxResultInstance = getattr(modelXbrl, "extractedInlineInstance", False) and modelTestcaseVariation.resultXbrlInstanceUri
                    if compareIxResultInstance:
                        formulaOutputInstance = modelXbrl # compare modelXbrl to generated output instance
                        errMsgPrefix = "ix"
                    else: # delete input instances before formula output comparision
                        for inputDTSlist in inputDTSes.values():
                            for inputDTS in inputDTSlist:
                                inputDTS.close()
                        del inputDTSes # dereference
                        errMsgPrefix = "formula"
                    if resultIsXbrlInstance and formulaOutputInstance and formulaOutputInstance.modelDocument:
                        expectedInstance = ModelXbrl.load(self.modelXbrl.modelManager, 
                                                   modelTestcaseVariation.resultXbrlInstanceUri,
                                                   _("loading expected result XBRL instance"), 
                                                   base=baseForElement,
                                                   useFileSource=self.useFileSource,
                                                   errorCaptureLevel=errorCaptureLevel)
                        if expectedInstance.modelDocument is None:
                            self.modelXbrl.error("{}:expectedResultNotLoaded".format(errMsgPrefix),
                                _("Testcase %(id)s %(name)s expected result instance not loaded: %(file)s"),
                                modelXbrl=testcase, id=modelTestcaseVariation.id, name=modelTestcaseVariation.name, 
                                file=os.path.basename(modelTestcaseVariation.resultXbrlInstanceUri),
                                messageCodes=("formula:expectedResultNotLoaded","ix:expectedResultNotLoaded"))
                            modelTestcaseVariation.status = "result not loadable"
                        else:   # compare facts
                            if len(expectedInstance.facts) != len(formulaOutputInstance.facts):
                                formulaOutputInstance.error("{}:resultFactCounts".format(errMsgPrefix),
                                    _("Formula output %(countFacts)s facts, expected %(expectedFacts)s facts"),
                                    modelXbrl=modelXbrl, countFacts=len(formulaOutputInstance.facts),
                                    expectedFacts=len(expectedInstance.facts),
                                    messageCodes=("formula:resultFactCounts","ix:resultFactCounts"))
                            else:
                                formulaOutputFootnotesRelSet = ModelRelationshipSet(formulaOutputInstance, "XBRL-footnotes")
                                expectedFootnotesRelSet = ModelRelationshipSet(expectedInstance, "XBRL-footnotes")
                                def factFootnotes(fact, footnotesRelSet):
                                    footnotes = []
                                    footnoteRels = footnotesRelSet.fromModelObject(fact)
                                    if footnoteRels:
                                        # most process rels in same order between two instances, use labels to sort
                                        for i, footnoteRel in enumerate(sorted(footnoteRels,
                                                                               key=lambda r: (r.fromLabel,r.toLabel))):
                                            modelObject = footnoteRel.toModelObject
                                            if isinstance(modelObject, ModelResource):
                                                footnotes.append("Footnote {}: {}".format(
                                                   i+1, # compare footnote with normalize-space
                                                   re.sub(r'\s+', ' ', collapseWhitespace(modelObject.stringValue))))
                                            elif isinstance(modelObject, ModelFact):
                                                footnotes.append("Footnoted fact {}: {} context: {} value: {}".format(
                                                    i+1,
                                                    modelObject.qname,
                                                    modelObject.contextID,
                                                    collapseWhitespace(modelObject.value)))
                                    return footnotes
                                for expectedInstanceFact in expectedInstance.facts:
                                    unmatchedFactsStack = []
                                    formulaOutputFact = formulaOutputInstance.matchFact(expectedInstanceFact, unmatchedFactsStack, deemP0inf=True)
                                    if formulaOutputFact is None:
                                        if unmatchedFactsStack: # get missing nested tuple fact, if possible
                                            missingFact = unmatchedFactsStack[-1]
                                        else:
                                            missingFact = expectedInstanceFact
                                        formulaOutputInstance.error("{}:expectedFactMissing".format(errMsgPrefix),
                                            _("Output missing expected fact %(fact)s"),
                                            modelXbrl=missingFact, fact=missingFact.qname,
                                            messageCodes=("formula:expectedFactMissing","ix:expectedFactMissing"))
                                    else: # compare footnotes
                                        expectedInstanceFactFootnotes = factFootnotes(expectedInstanceFact, expectedFootnotesRelSet)
                                        formulaOutputFactFootnotes = factFootnotes(formulaOutputFact, formulaOutputFootnotesRelSet)
                                        if expectedInstanceFactFootnotes != formulaOutputFactFootnotes:
                                            formulaOutputInstance.error("{}:expectedFactFootnoteDifference".format(errMsgPrefix),
                                                _("Output expected fact %(fact)s expected footnotes %(footnotes1)s produced footnotes %(footnotes2)s"),
                                                modelXbrl=(formulaOutputFact,expectedInstanceFact), fact=expectedInstanceFact.qname, footnotes1=expectedInstanceFactFootnotes, footnotes2=formulaOutputFactFootnotes,
                                                messageCodes=("formula:expectedFactFootnoteDifference","ix:expectedFactFootnoteDifference"))

                            # for debugging uncomment next line to save generated instance document
                            # formulaOutputInstance.saveInstance(r"c:\temp\test-out-inst.xml")
                        expectedInstance.close()
                        del expectedInstance # dereference
                        self.determineTestStatus(modelTestcaseVariation, formulaOutputInstance.errors)
                        formulaOutputInstance.close()
                        del formulaOutputInstance
                    if compareIxResultInstance:
                        for inputDTSlist in inputDTSes.values():
                            for inputDTS in inputDTSlist:
                                inputDTS.close()
                        del inputDTSes # dereference
                # update ui thread via modelManager (running in background here)
                self.modelXbrl.modelManager.viewModelObject(self.modelXbrl, modelTestcaseVariation.objectId())
                    
            self.modelXbrl.modelManager.showStatus(_("ready"), 2000)

Example 41

Project: Arelle
Source File: ValidateFilingDimensions.py
View license
def checkDimensions(val, drsELRs):
    
    fromConceptELRs = defaultdict(set)
    hypercubes = set()
    hypercubesInLinkrole = defaultdict(set)
    domainsInLinkrole = defaultdict(set)
    for ELR in drsELRs:
        domainMemberRelationshipSet = val.modelXbrl.relationshipSet( XbrlConst.domainMember, ELR)
                            
        # check Hypercubes in ELR, accumulate list of primary items
        positiveAxisTableSources = defaultdict(set)
        positiveHypercubes = set()
        primaryItems = set()
        for hasHypercubeArcrole in (XbrlConst.all, XbrlConst.notAll):
            hasHypercubeRelationships = val.modelXbrl.relationshipSet(
                             hasHypercubeArcrole, ELR).fromModelObjects()
            for hasHcRels in hasHypercubeRelationships.values():
                for hasHcRel in hasHcRels:
                    sourceConcept = hasHcRel.fromModelObject
                    primaryItems.add(sourceConcept)
                    hc = hasHcRel.toModelObject
                    hypercubes.add(hc)
                    if hasHypercubeArcrole == XbrlConst.all:
                        positiveHypercubes.add(hc)
                        if not hasHcRel.isClosed:
                            val.modelXbrl.error("SBR.NL.2.3.6.04",
                                _("All hypercube %(hypercube)s in DRS role %(linkrole)s, does not have closed='true'"),
                                modelObject=hasHcRel, hypercube=hc.qname, linkrole=ELR)
                    elif hasHypercubeArcrole == XbrlConst.notAll:
                        if hasHcRel.isClosed:
                            val.modelXbrl.error(("EFM.6.16.06", "GFM.1.08.06"),
                                _("Not all hypercube %(hypercube)s in DRS role %(linkrole)s, does not have closed='false'"),
                                modelObject=hasHcRel, hypercube=hc.qname, linkrole=ELR, primaryItem=sourceConcept.qname)
                        if hc in positiveHypercubes:
                            val.modelXbrl.error(("EFM.6.16.08", "GFM.1.08.08"),
                                _("Not all hypercube %(hypercube)s in DRS role %(linkrole)s, is also the target of a positive hypercube"),
                                modelObject=hasHcRel, hypercube=hc.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), primaryItem=sourceConcept.qname)
                    dimELR = hasHcRel.targetRole
                    dimTargetRequired = (dimELR is not None)
                    if not dimELR:
                        dimELR = ELR
                    hypercubesInLinkrole[dimELR].add(hc) # this is the elr containing the HC-dim relations
                    hcDimRels = val.modelXbrl.relationshipSet(
                             XbrlConst.hypercubeDimension, dimELR).fromModelObject(hc)
                    if dimTargetRequired and len(hcDimRels) == 0:
                        val.modelXbrl.error(("EFM.6.16.09", "GFM.1.08.09"),
                            _("Table %(hypercube)s in DRS role %(linkrole)s, missing targetrole consecutive relationship"),
                            modelObject=hasHcRel, hypercube=hc.qname, fromConcept=sourceConcept.qname, toConcept=hc.qname, 
                            linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR),
                            arcroleURI=hasHcRel.arcrole, arcrole=os.path.basename(hasHcRel.arcrole))
                    for hcDimRel in hcDimRels:
                        dim = hcDimRel.toModelObject
                        if isinstance(dim, ModelConcept):
                            domELR = hcDimRel.targetRole
                            domTargetRequired = (domELR is not None)
                            if not domELR:
                                if dim.isExplicitDimension:
                                    domELR = dimELR
                                    if val.validateSBRNL:
                                        val.modelXbrl.error("SBR.NL.2.3.5.04",
                                            _("Hypercube %(hypercube)s in DRS role %(linkrole)s, missing targetrole to dimension %(dimension)s consecutive relationship"),
                                            modelObject=hcDimRel, hypercube=hc.qname, linkrole=ELR, dimension=dim.qname)
                            else:
                                if dim.isTypedDimension and val.validateSBRNL:
                                    val.modelXbrl.error("SBR.NL.2.3.5.07",
                                        _("Typed dimension %(dimension)s in DRS role %(linkrole)s, has targetrole consecutive relationship"),
                                        modelObject=hcDimRel, dimension=dim.qname, linkrole=ELR)
                            if hasHypercubeArcrole == XbrlConst.all:
                                positiveAxisTableSources[dim].add(sourceConcept)
                            elif hasHypercubeArcrole == XbrlConst.notAll and \
                                 (dim not in positiveAxisTableSources or \
                                  not commonAncestor(domainMemberRelationshipSet,
                                                  sourceConcept, positiveAxisTableSources[dim])):
                                val.modelXbrl.error(("EFM.6.16.07", "GFM.1.08.08"),
                                    _("Negative table axis %(dimension)s in DRS role %(linkrole)s, not in any positive table in same role"),
                                     modelObject=hcDimRel, dimension=dim.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), primaryItem=sourceConcept.qname)
                            dimDomRels = val.modelXbrl.relationshipSet(
                                 XbrlConst.dimensionDomain, domELR).fromModelObject(dim)   
                            if domTargetRequired and len(dimDomRels) == 0:
                                val.modelXbrl.error(("EFM.6.16.09", "GFM.1.08.09"),
                                    _("Axis %(dimension)s in DRS role %(linkrole)s, missing targetrole consecutive relationship"),
                                    modelObject=hcDimRel, dimension=dim.qname, fromConcept=hc.qname, toConcept=dim.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), arcroleURI=hasHcRel.arcrole, arcrole=os.path.basename(hcDimRel.arcrole))
                            if val.validateEFMorGFM:
                                # flatten DRS member relationsihps in ELR for undirected cycle detection
                                drsRelsFrom = defaultdict(list)
                                drsRelsTo = defaultdict(list)
                                getDrsRels(val, domELR, dimDomRels, ELR, drsRelsFrom, drsRelsTo)
                                # check for cycles
                                fromConceptELRs[hc].add(dimELR)
                                fromConceptELRs[dim].add(domELR)
                                cycleCausingConcept = undirectedFwdCycle(val, domELR, dimDomRels, ELR, drsRelsFrom, drsRelsTo, fromConceptELRs)
                                if cycleCausingConcept is not None:
                                    cycleCausingConcept.append(hcDimRel)
                                    val.modelXbrl.error(("EFM.6.16.04", "GFM.1.08.04"),
                                        _("Dimension relationships have an undirected cycle in DRS role %(linkrole)s \nstarting from table %(hypercube)s, \naxis %(dimension)s, \npath %(path)s"),
                                        modelObject=[hc, dim] + [rel for rel in cycleCausingConcept if not isinstance(rel, bool)], 
                                        linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), 
                                        hypercube=hc.qname, dimension=dim.qname, conceptFrom=dim.qname,
                                        path=cyclePath(hc,cycleCausingConcept))
                                fromConceptELRs.clear()
                            elif val.validateSBRNL:
                                checkSBRNLMembers(val, hc, dim, domELR, dimDomRels, ELR, True)
                                for dimDomRel in dimDomRels:
                                    dom = dimDomRel.toModelObject
                                    if isinstance(dom, ModelConcept):
                                        domainsInLinkrole[domELR].add(dom) # this is the elr containing the HC-dim relations
                if hasHypercubeArcrole == XbrlConst.all and len(hasHcRels) > 1:
                    val.modelXbrl.error(("EFM.6.16.05", "GFM.1.08.05"),
                        _("Multiple tables (%(hypercubeCount)s) DRS role %(linkrole)s, source %(concept)s, only 1 allowed"),
                        modelObject=[sourceConcept] + hasHcRels, 
                        hypercubeCount=len(hasHcRels), linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR),
                        concept=sourceConcept.qname,
                        hypercubes=', '.join(str(r.toModelObject.qname) for r in hasHcRels if isinstance(r.toModelObject, ModelConcept)))
                    
        # check for primary item dimension-member graph undirected cycles
        fromRelationships = domainMemberRelationshipSet.fromModelObjects()
        for relFrom, rels in fromRelationships.items():
            if relFrom in primaryItems:
                drsRelsFrom = defaultdict(list)
                drsRelsTo = defaultdict(list)
                getDrsRels(val, ELR, rels, ELR, drsRelsFrom, drsRelsTo)
                fromConceptELRs[relFrom].add(ELR)
                cycleCausingConcept = undirectedFwdCycle(val, ELR, rels, ELR, drsRelsFrom, drsRelsTo, fromConceptELRs)
                if cycleCausingConcept is not None:
                    val.modelXbrl.error(("EFM.6.16.04", "GFM.1.08.04"),
                        _("Domain-member primary-item relationships have an undirected cycle in DRS role %(linkrole)s \nstarting from %(conceptFrom)s, \npath %(path)s"),
                        modelObject=[relFrom] + [rel for rel in cycleCausingConcept if not isinstance(rel, bool)], 
                        linkrole=ELR, conceptFrom=relFrom.qname, path=cyclePath(relFrom, cycleCausingConcept))
                fromConceptELRs.clear()
            for rel in rels:
                fromMbr = rel.fromModelObject
                toMbr = rel.toModelObject
                toELR = rel.targetRole
                if isinstance(toMbr, ModelConcept) and toELR and len(
                    val.modelXbrl.relationshipSet(
                         XbrlConst.domainMember, toELR).fromModelObject(toMbr)) == 0:
                    val.modelXbrl.error(("EFM.6.16.09", "GFM.1.08.09"),
                        _("Domain member %(concept)s in DRS role %(linkrole)s, missing targetrole consecutive relationship"),
                        modelObject=rel, concept=fromMbr.qname, fromConcept=toMbr.qname, toConcept=fromMbr.qname, linkrole=ELR, linkroleDefinition=val.modelXbrl.roleTypeDefinition(ELR), arcroleURI=hasHcRel.arcrole, arcrole=os.path.basename(rel.arcrole))
                    
    if val.validateSBRNL:
        # check hypercubes for unique set of members
        for hc in hypercubes:
            for priHcRel in val.modelXbrl.relationshipSet(XbrlConst.all).toModelObject(hc):
                priItem = priHcRel.fromModelObject
                ELR = priHcRel.linkrole
                checkSBRNLMembers(val, hc, priItem, ELR, 
                                  val.modelXbrl.relationshipSet(XbrlConst.domainMember, ELR).fromModelObject(priItem), 
                                  ELR, False)
                if priHcRel.contextElement == 'segment':  
                    val.modelXbrl.error("SBR.NL.2.3.5.06",
                        _("hypercube %(hypercube)s in segment not allowed, ELR role %(linkrole)s"),
                        modelObject=priHcRel, linkrole=ELR, hypercube=hc.qname)
        for notAllRel in val.modelXbrl.relationshipSet(XbrlConst.notAll).modelRelationships:
            val.modelXbrl.error("SBR.NL.2.3.5.05",
                _("Notall from primary item %(primaryItem)s in ELR role %(linkrole)s to %(hypercube)s"),
                modelObject=val.modelXbrl, primaryItem=notAllRel.fromModelObject.qname, linkrole=notAllRel.linkrole, hypercube=notAllRel.toModelObject.qname)
        for ELR, hypercubes in hypercubesInLinkrole.items():
            '''removed RH 2011-12-06
            for modelRel in val.modelXbrl.relationshipSet("XBRL-dimensions", ELR).modelRelationships:
                if modelRel.fromModelObject != hc:
                    val.modelXbrl.error("SBR.NL.2.3.5.03",
                        _("ELR role %(linkrole)s, is not dedicated to %(hypercube)s, but also has %(otherQname)s"),
                        modelObject=val.modelXbrl, linkrole=ELR, hypercube=hc.qname, otherQname=modelRel.fromModelObject.qname)
            '''
            domains = domainsInLinkrole.get(ELR, emptySet)
            for hc in hypercubes:  # only one member
                for arcrole in (XbrlConst.parentChild, "XBRL-dimensions"):
                    for modelRel in val.modelXbrl.relationshipSet(arcrole, ELR).modelRelationships:
                        if (modelRel.fromModelObject != hc and modelRel.toModelObject != hc and
                            modelRel.fromModelObject not in domains and modelRel.toModelObject not in domains):
                            val.modelXbrl.error("SBR.NL.2.2.3.05",
                                _("ELR role %(linkrole)s, has hypercube %(hypercube)s and a %(arcrole)s relationship not involving the hypercube or primary domain, from %(fromConcept)s to %(toConcept)s"),
                                modelObject=modelRel, linkrole=ELR, hypercube=hc.qname, arcrole=os.path.basename(modelRel.arcrole), 
                                fromConcept=modelRel.fromModelObject.qname, 
                                toConcept=(modelRel.toModelObject.qname if isinstance(modelRel.toModelObject, ModelConcept) else "unknown"))
        domainsInLinkrole = defaultdict(set)
        dimDomMemsByLinkrole = defaultdict(set)
        for rel in val.modelXbrl.relationshipSet(XbrlConst.dimensionDomain).modelRelationships:
            relFrom = rel.fromModelObject
            relTo = rel.toModelObject
            if isinstance(relFrom, ModelConcept) and isinstance(relTo, ModelConcept):
                domainsInLinkrole[rel.targetRole].add(relFrom)
                domMems = set() # determine usable dom and mems of dimension in this linkrole
                if rel.isUsable:
                    domMems.add(relTo)
                for relMem in val.modelXbrl.relationshipSet(XbrlConst.domainMember, (rel.targetRole or rel.linkrole)).fromModelObject(relTo):
                    if relMem.isUsable:
                        domMems.add(relMem.toModelObject)
                dimDomMemsByLinkrole[(rel.linkrole,relFrom)].update(domMems)
                if rel.isUsable and val.modelXbrl.relationshipSet(XbrlConst.domainMember, rel.targetRole).fromModelObject(relTo):
                    val.modelXbrl.error("SBR.NL.2.3.7.05",
                        _("Dimension %(dimension)s in DRS role %(linkrole)s, has usable domain with members %(domain)s"),
                        modelObject=rel, dimension=relFrom.qname, linkrole=rel.linkrole, domain=relTo.qname)
                if not relTo.isAbstract:
                    val.modelXbrl.error("SBR.NL.2.3.7.02",
                        _("Dimension %(dimension)s in DRS role %(linkrole)s, has nonAbsract domain %(domain)s"),
                        modelObject=rel, dimension=relFrom.qname, linkrole=rel.linkrole, domain=relTo.qname)
                if relTo.substitutionGroupQname.localName not in ("domainItem","domainMemberItem"):
                    val.modelXbrl.error("SBR.NL.2.2.2.19",
                        _("Domain item %(domain)s in DRS role %(linkrole)s, in dimension %(dimension)s is not a domainItem"),
                        modelObject=rel, domain=relTo.qname, linkrole=rel.linkrole, dimension=relFrom.qname)
                if not rel.targetRole and relTo.substitutionGroupQname.localName == "domainItem":
                    val.modelXbrl.error("SBR.NL.2.3.6.03",
                        _("Dimension %(dimension)s in DRS role %(linkrole)s, missing targetrole to consecutive domain relationship"),
                        modelObject=rel, dimension=relFrom.qname, linkrole=rel.linkrole)
        for linkrole, domains in domainsInLinkrole.items():
            if linkrole and len(domains) > 1:
                val.modelXbrl.error("SBR.NL.2.3.7.04",
                    _("Linkrole %(linkrole)s, has multiple domains %(domains)s"),
                    modelObject=val.modelXbrl, linkrole=linkrole, domains=", ".join([str(dom.qname) for dom in domains]))
        del domainsInLinkrole   # dereference
        linkrolesByDimDomMems = defaultdict(set)
        for linkroleDim, domMems in dimDomMemsByLinkrole.items():
            linkrole, dim = linkroleDim
            linkrolesByDimDomMems[(dim,tuple(domMems))].add(linkrole)
        for dimDomMems, linkroles in linkrolesByDimDomMems.items():
            if len(linkroles) > 1:
                val.modelXbrl.error("SBR.NL.2.3.6.02",
                    _("Dimension %(dimension)s  usable members same in linkroles %(linkroles)s"),
                    modelObject=val.modelXbrl, dimension=dimDomMems[0].qname, linkroles=', '.join(l for l in linkroles))
        del dimDomMemsByLinkrole, linkrolesByDimDomMems
        for rel in val.modelXbrl.relationshipSet(XbrlConst.domainMember).modelRelationships:
            if val.modelXbrl.relationshipSet(XbrlConst.domainMember, rel.targetRole).fromModelObject(rel.toModelObject):
                val.modelXbrl.error("SBR.NL.2.3.7.03",
                    _("Domain member %(member)s in DRS role %(linkrole)s, has nested members"),
                    modelObject=rel, member=(rel.toModelObject.qname if isinstance(rel.toModelObject, ModelConcept) else None), linkrole=rel.linkrole)
        for rel in val.modelXbrl.relationshipSet(XbrlConst.domainMember).modelRelationships:
            relFrom = rel.fromModelObject
            relTo = rel.toModelObject
            if isinstance(relTo, ModelConcept):
                # avoid primary item relationships in these tests
                if relFrom.substitutionGroupQname.localName == "domainItem":
                    if relTo.substitutionGroupQname.localName != "domainMemberItem":
                        val.modelXbrl.error("SBR.NL.2.2.2.19",
                            _("Domain member item %(member)s in DRS role %(linkrole)s is not a domainMemberItem"),
                            modelObject=rel, member=relTo.qname, linkrole=rel.linkrole)
                else:
                    if relTo.substitutionGroupQname.localName == "domainMemberItem":
                        val.modelXbrl.error("SBR.NL.2.2.2.19",
                            _("Domain item %(domain)s in DRS role %(linkrole)s is not a domainItem"),
                            modelObject=rel, domain=relFrom.qname, linkrole=rel.linkrole)
                        break # don't repeat parent's error on rest of child members
                    elif relFrom.isAbstract and relFrom.substitutionGroupQname.localName != "primaryDomainItem":
                        val.modelXbrl.error("SBR.NL.2.2.2.19",
                            _("Abstract domain item %(domain)s in DRS role %(linkrole)s is not a primaryDomainItem"),
                            modelObject=rel, domain=relFrom.qname, linkrole=rel.linkrole)
                        break # don't repeat parent's error on rest of child members
        hypercubeDRSDimensions = defaultdict(dict)
        for hcDimRel in val.modelXbrl.relationshipSet(XbrlConst.hypercubeDimension).modelRelationships:
            hc = hcDimRel.fromModelObject
            if isinstance(hc, ModelConcept):
                ELR = hcDimRel.linkrole
                try:
                    hcDRSdims = hypercubeDRSDimensions[hc][ELR]
                except KeyError:
                    hcDRSdims = set()
                    hypercubeDRSDimensions[hc][ELR] = hcDRSdims
                hcDRSdims.add(hcDimRel.toModelObject)
        for hc, DRSdims in hypercubeDRSDimensions.items():
            hcELRdimSets = {}
            for ELR, mutableDims in DRSdims.items():
                dims = frozenset(mutableDims)
                if dims not in hcELRdimSets:
                    hcELRdimSets[dims] = ELR
                else: 
                    val.modelXbrl.error("SBR.NL.2.3.5.02",
                        _("Hypercube %(hypercube)s has same dimensions in ELR roles %(linkrole)s and %(linkrole2)s: %(dimensions)s"),
                        modelObject=hc, hypercube=hc.qname, linkrole=ELR, linkrole2=hcELRdimSets[dims],
                        dimensions=", ".join([str(dim.qname) for dim in dims]))
        del hypercubeDRSDimensions # dereference

Example 42

Project: nansat
Source File: mapper_sentinel1_l1.py
View license
    def __init__(self, fileName, gdalDataset, gdalMetadata,
                 manifestonly=False, **kwargs):

        if zipfile.is_zipfile(fileName):
            zz = zipfile.PyZipFile(fileName)
            # Assuming the file names are consistent, the polarization
            # dependent data should be sorted equally such that we can use the
            # same indices consistently for all the following lists
            # THIS IS NOT THE CASE...
            mdsFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist() if 'measurement/s1' in fn]
            calFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist()
                        if 'annotation/calibration/calibration-s1' in fn]
            noiseFiles = ['/vsizip/%s/%s' % (fileName, fn)
                          for fn in zz.namelist()
                          if 'annotation/calibration/noise-s1' in fn]
            annotationFiles = ['/vsizip/%s/%s' % (fileName, fn)
                               for fn in zz.namelist()
                               if 'annotation/s1' in fn]
            manifestFile = ['/vsizip/%s/%s' % (fileName, fn)
                            for fn in zz.namelist()
                            if 'manifest.safe' in fn]
            zz.close()
        else:
            mdsFiles = glob.glob('%s/measurement/s1*' % fileName)
            calFiles = glob.glob('%s/annotation/calibration/calibration-s1*'
                                 % fileName)
            noiseFiles = glob.glob('%s/annotation/calibration/noise-s1*'
                                   % fileName)
            annotationFiles = glob.glob('%s/annotation/s1*'
                                        % fileName)
            manifestFile = glob.glob('%s/manifest.safe' % fileName)

        if (not mdsFiles or not calFiles or not noiseFiles or
                not annotationFiles or not manifestFile):
            raise WrongMapperError

        mdsDict = {}
        for ff in mdsFiles:
            mdsDict[
                os.path.splitext(os.path.basename(ff))[0].split('-')[3]] = ff

        self.calXMLDict = {}
        for ff in calFiles:
            self.calXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.noiseXMLDict = {}
        for ff in noiseFiles:
            self.noiseXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.annotationXMLDict = {}
        for ff in annotationFiles:
            self.annotationXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[3]] = self.read_xml(ff)

        self.manifestXML = self.read_xml(manifestFile[0])

        if not os.path.split(fileName)[1][:3] in ['S1A', 'S1B']:
            raise WrongMapperError('Not Sentinel 1A or 1B')

        missionName = {'S1A': 'SENTINEL-1A', 'S1B': 'SENTINEL-1B'}[
            os.path.split(fileName)[1][:3]]

        # very fast constructor without any bands
        if manifestonly:
            self.init_from_manifest_only(self.manifestXML,
                                         self.annotationXMLDict[
                                         self.annotationXMLDict.keys()[0]],
                                         missionName)
            return

        gdalDatasets = {}
        for key in mdsDict.keys():
            # Open data files
            gdalDatasets[key] = gdal.Open(mdsDict[key])

        if not gdalDatasets:
            raise WrongMapperError('No Sentinel-1 datasets found')

        # Check metadata to confirm it is Sentinel-1 L1
        metadata = gdalDatasets[mdsDict.keys()[0]].GetMetadata()
        
        if not 'TIFFTAG_IMAGEDESCRIPTION' in metadata.keys():
            raise WrongMapperError
        if (not 'Sentinel-1' in metadata['TIFFTAG_IMAGEDESCRIPTION']
                and not 'L1' in metadata['TIFFTAG_IMAGEDESCRIPTION']):
            raise WrongMapperError

        warnings.warn('Sentinel-1 level-1 mapper is not yet adapted to '
                      'complex data. In addition, the band names should be '
                      'updated for multi-swath data - '
                      'and there might be other issues.')

        # create empty VRT dataset with geolocation only
        for key in gdalDatasets:
            VRT.__init__(self, gdalDatasets[key])
            break

        # Read annotation, noise and calibration xml-files
        pol = {}
        it = 0
        for key in self.annotationXMLDict:
            xml = Node.create(self.annotationXMLDict[key])
            pol[key] = (xml.node('product').
                        node('adsHeader')['polarisation'].upper())
            it += 1
            if it == 1:
                # Get incidence angle
                pi = xml.node('generalAnnotation').node('productInformation')

                self.dataset.SetMetadataItem('ORBIT_DIRECTION',
                                              str(pi['pass']))
                (X, Y, lon, lat, inc, ele, numberOfSamples,
                numberOfLines) = self.read_geolocation_lut(
                                                self.annotationXMLDict[key])

                X = np.unique(X)
                Y = np.unique(Y)

                lon = np.array(lon).reshape(len(Y), len(X))
                lat = np.array(lat).reshape(len(Y), len(X))
                inc = np.array(inc).reshape(len(Y), len(X))
                ele = np.array(ele).reshape(len(Y), len(X))

                incVRT = VRT(array=inc, lat=lat, lon=lon)
                eleVRT = VRT(array=ele, lat=lat, lon=lon)
                incVRT = incVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                eleVRT = eleVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                self.bandVRTs['incVRT'] = incVRT
                self.bandVRTs['eleVRT'] = eleVRT

        for key in self.calXMLDict:
            calibration_LUT_VRTs, longitude, latitude = (
                self.get_LUT_VRTs(self.calXMLDict[key],
                                  'calibrationVectorList',
                                  ['sigmaNought', 'betaNought',
                                   'gamma', 'dn']
                                  ))
            self.bandVRTs['LUT_sigmaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['sigmaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_betaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['betaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_gamma_VRT'] = calibration_LUT_VRTs['gamma']
            self.bandVRTs['LUT_dn_VRT'] = calibration_LUT_VRTs['dn']

        for key in self.noiseXMLDict:
            noise_LUT_VRT = self.get_LUT_VRTs(self.noiseXMLDict[key],
                                              'noiseVectorList',
                                              ['noiseLut'])[0]
            self.bandVRTs['LUT_noise_VRT_'+pol[key]] = (
                noise_LUT_VRT['noiseLut'].get_resized_vrt(
                    self.dataset.RasterXSize,
                    self.dataset.RasterYSize,
                    eResampleAlg=1))

        metaDict = []
        bandNumberDict = {}
        bnmax = 0
        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'DN_%s' % pol[key]
            # A dictionary of band numbers is needed for the pixel function
            # bands further down. This is not the best solution. It would be
            # better to have a function in VRT that returns the number given a
            # band name. This function exists in Nansat but could perhaps be
            # moved to VRT? The existing nansat function could just call the
            # VRT one...
            bandNumberDict[name] = bnmax + 1
            bnmax = bandNumberDict[name]
            band = gdalDatasets[key].GetRasterBand(1)
            dtype = band.DataType
            metaDict.append({
                'src': {
                    'SourceFilename': mdsDict[key],
                    'SourceBand': 1,
                    'DataType': dtype,
                },
                'dst': {
                    'name': name,
                    #'SourceTransferType': gdal.GetDataTypeName(dtype),
                    #'dataType': 6,
                },
            })
        # add bands with metadata and corresponding values to the empty VRT
        self._create_bands(metaDict)

        '''
        Calibration should be performed as

        s0 = DN^2/sigmaNought^2,

        where sigmaNought is from e.g.
        annotation/calibration/calibration-s1a-iw-grd-hh-20140811t151231-20140811t151301-001894-001cc7-001.xml,
        and DN is the Digital Numbers in the tiff files.

        Also the noise should be subtracted.

        See
        https://sentinel.esa.int/web/sentinel/sentinel-1-sar-wiki/-/wiki/Sentinel%20One/Application+of+Radiometric+Calibration+LUT
        '''
        # Get look direction
        sat_heading = initial_bearing(longitude[:-1, :],
                                      latitude[:-1, :],
                                      longitude[1:, :],
                                      latitude[1:, :])
        look_direction = scipy.ndimage.interpolation.zoom(
            np.mod(sat_heading + 90, 360),
            (np.shape(longitude)[0] / (np.shape(longitude)[0]-1.), 1))

        # Decompose, to avoid interpolation errors around 0 <-> 360
        look_direction_u = np.sin(np.deg2rad(look_direction))
        look_direction_v = np.cos(np.deg2rad(look_direction))
        look_u_VRT = VRT(array=look_direction_u,
                         lat=latitude, lon=longitude)
        look_v_VRT = VRT(array=look_direction_v,
                         lat=latitude, lon=longitude)
        lookVRT = VRT(lat=latitude, lon=longitude)
        lookVRT._create_band([{'SourceFilename': look_u_VRT.fileName,
                               'SourceBand': 1},
                              {'SourceFilename': look_v_VRT.fileName,
                               'SourceBand': 1}],
                             {'PixelFunctionType': 'UVToDirectionTo'}
                             )

        # Blow up to full size
        lookVRT = lookVRT.get_resized_vrt(self.dataset.RasterXSize,
                                          self.dataset.RasterYSize,
                                          eResampleAlg=1)

        # Store VRTs so that they are accessible later
        self.bandVRTs['look_u_VRT'] = look_u_VRT
        self.bandVRTs['look_v_VRT'] = look_v_VRT
        self.bandVRTs['lookVRT'] = lookVRT

        metaDict = []
        # Add bands to full size VRT
        for key in pol:
            name = 'LUT_sigmaNought_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': {'SourceFilename':
                         (self.bandVRTs['LUT_sigmaNought_VRT_' +
                          pol[key]].fileName),
                         'SourceBand': 1
                         },
                 'dst': {'name': name
                         }
                 })
            name = 'LUT_noise_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append({
                'src': {
                    'SourceFilename': self.bandVRTs['LUT_noise_VRT_' +
                                                   pol[key]].fileName,
                    'SourceBand': 1
                },
                'dst': {
                    'name': name
                }
            })

        name = 'look_direction'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        metaDict.append({
            'src': {
                'SourceFilename': self.bandVRTs['lookVRT'].fileName,
                'SourceBand': 1
            },
            'dst': {
                'wkv': 'sensor_azimuth_angle',
                'name': name
            }
        })

        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'sigma0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]],
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_sigmaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })
            name = 'beta0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]]
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_betaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_brightness_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })

        self._create_bands(metaDict)

        # Add incidence angle as band
        name = 'incidence_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['incVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_incidence',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add elevation angle as band
        name = 'elevation_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['eleVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_elevation',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add sigma0_VV
        pp = [pol[key] for key in pol]
        if 'VV' not in pp and 'HH' in pp:
            name = 'sigma0_VV'
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            src = [{'SourceFilename': self.fileName,
                    'SourceBand': bandNumberDict['DN_HH'],
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_noise_VRT_HH'].
                                       fileName),
                    'SourceBand': 1
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_sigmaNought_VRT_HH'].
                                       fileName),
                    'SourceBand': 1,
                    },
                   {'SourceFilename': self.bandVRTs['incVRT'].fileName,
                    'SourceBand': 1}
                   ]
            dst = {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                   'PixelFunctionType': 'Sentinel1Sigma0HHToSigma0VV',
                   'polarization': 'VV',
                   'suffix': 'VV'}
            self._create_band(src, dst)
            self.dataset.FlushCache()

        # set time as acquisition start time
        n = Node.create(self.manifestXML)
        meta = n.node('metadataSection')
        for nn in meta.children:
            if nn.getAttribute('ID') == u'acquisitionPeriod':
                # set valid time
                self.dataset.SetMetadataItem(
                    'time_coverage_start',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:startTime'])
                          ).isoformat())
                self.dataset.SetMetadataItem(
                    'time_coverage_end',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:stopTime'])
                          ).isoformat())

        # Get dictionary describing the instrument and platform according to
        # the GCMD keywords
        mm = pti.get_gcmd_instrument('sar')
        ee = pti.get_gcmd_platform(missionName)

        # TODO: Validate that the found instrument and platform are indeed what we
        # want....

        self.dataset.SetMetadataItem('instrument', json.dumps(mm))
        self.dataset.SetMetadataItem('platform', json.dumps(ee))

Example 43

Project: nansat
Source File: mapper_sentinel1_l1.py
View license
    def __init__(self, fileName, gdalDataset, gdalMetadata,
                 manifestonly=False, **kwargs):

        if zipfile.is_zipfile(fileName):
            zz = zipfile.PyZipFile(fileName)
            # Assuming the file names are consistent, the polarization
            # dependent data should be sorted equally such that we can use the
            # same indices consistently for all the following lists
            # THIS IS NOT THE CASE...
            mdsFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist() if 'measurement/s1' in fn]
            calFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist()
                        if 'annotation/calibration/calibration-s1' in fn]
            noiseFiles = ['/vsizip/%s/%s' % (fileName, fn)
                          for fn in zz.namelist()
                          if 'annotation/calibration/noise-s1' in fn]
            annotationFiles = ['/vsizip/%s/%s' % (fileName, fn)
                               for fn in zz.namelist()
                               if 'annotation/s1' in fn]
            manifestFile = ['/vsizip/%s/%s' % (fileName, fn)
                            for fn in zz.namelist()
                            if 'manifest.safe' in fn]
            zz.close()
        else:
            mdsFiles = glob.glob('%s/measurement/s1*' % fileName)
            calFiles = glob.glob('%s/annotation/calibration/calibration-s1*'
                                 % fileName)
            noiseFiles = glob.glob('%s/annotation/calibration/noise-s1*'
                                   % fileName)
            annotationFiles = glob.glob('%s/annotation/s1*'
                                        % fileName)
            manifestFile = glob.glob('%s/manifest.safe' % fileName)

        if (not mdsFiles or not calFiles or not noiseFiles or
                not annotationFiles or not manifestFile):
            raise WrongMapperError

        mdsDict = {}
        for ff in mdsFiles:
            mdsDict[
                os.path.splitext(os.path.basename(ff))[0].split('-')[3]] = ff

        self.calXMLDict = {}
        for ff in calFiles:
            self.calXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.noiseXMLDict = {}
        for ff in noiseFiles:
            self.noiseXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.annotationXMLDict = {}
        for ff in annotationFiles:
            self.annotationXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[3]] = self.read_xml(ff)

        self.manifestXML = self.read_xml(manifestFile[0])

        if not os.path.split(fileName)[1][:3] in ['S1A', 'S1B']:
            raise WrongMapperError('Not Sentinel 1A or 1B')

        missionName = {'S1A': 'SENTINEL-1A', 'S1B': 'SENTINEL-1B'}[
            os.path.split(fileName)[1][:3]]

        # very fast constructor without any bands
        if manifestonly:
            self.init_from_manifest_only(self.manifestXML,
                                         self.annotationXMLDict[
                                         self.annotationXMLDict.keys()[0]],
                                         missionName)
            return

        gdalDatasets = {}
        for key in mdsDict.keys():
            # Open data files
            gdalDatasets[key] = gdal.Open(mdsDict[key])

        if not gdalDatasets:
            raise WrongMapperError('No Sentinel-1 datasets found')

        # Check metadata to confirm it is Sentinel-1 L1
        metadata = gdalDatasets[mdsDict.keys()[0]].GetMetadata()
        
        if not 'TIFFTAG_IMAGEDESCRIPTION' in metadata.keys():
            raise WrongMapperError
        if (not 'Sentinel-1' in metadata['TIFFTAG_IMAGEDESCRIPTION']
                and not 'L1' in metadata['TIFFTAG_IMAGEDESCRIPTION']):
            raise WrongMapperError

        warnings.warn('Sentinel-1 level-1 mapper is not yet adapted to '
                      'complex data. In addition, the band names should be '
                      'updated for multi-swath data - '
                      'and there might be other issues.')

        # create empty VRT dataset with geolocation only
        for key in gdalDatasets:
            VRT.__init__(self, gdalDatasets[key])
            break

        # Read annotation, noise and calibration xml-files
        pol = {}
        it = 0
        for key in self.annotationXMLDict:
            xml = Node.create(self.annotationXMLDict[key])
            pol[key] = (xml.node('product').
                        node('adsHeader')['polarisation'].upper())
            it += 1
            if it == 1:
                # Get incidence angle
                pi = xml.node('generalAnnotation').node('productInformation')

                self.dataset.SetMetadataItem('ORBIT_DIRECTION',
                                              str(pi['pass']))
                (X, Y, lon, lat, inc, ele, numberOfSamples,
                numberOfLines) = self.read_geolocation_lut(
                                                self.annotationXMLDict[key])

                X = np.unique(X)
                Y = np.unique(Y)

                lon = np.array(lon).reshape(len(Y), len(X))
                lat = np.array(lat).reshape(len(Y), len(X))
                inc = np.array(inc).reshape(len(Y), len(X))
                ele = np.array(ele).reshape(len(Y), len(X))

                incVRT = VRT(array=inc, lat=lat, lon=lon)
                eleVRT = VRT(array=ele, lat=lat, lon=lon)
                incVRT = incVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                eleVRT = eleVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                self.bandVRTs['incVRT'] = incVRT
                self.bandVRTs['eleVRT'] = eleVRT

        for key in self.calXMLDict:
            calibration_LUT_VRTs, longitude, latitude = (
                self.get_LUT_VRTs(self.calXMLDict[key],
                                  'calibrationVectorList',
                                  ['sigmaNought', 'betaNought',
                                   'gamma', 'dn']
                                  ))
            self.bandVRTs['LUT_sigmaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['sigmaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_betaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['betaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_gamma_VRT'] = calibration_LUT_VRTs['gamma']
            self.bandVRTs['LUT_dn_VRT'] = calibration_LUT_VRTs['dn']

        for key in self.noiseXMLDict:
            noise_LUT_VRT = self.get_LUT_VRTs(self.noiseXMLDict[key],
                                              'noiseVectorList',
                                              ['noiseLut'])[0]
            self.bandVRTs['LUT_noise_VRT_'+pol[key]] = (
                noise_LUT_VRT['noiseLut'].get_resized_vrt(
                    self.dataset.RasterXSize,
                    self.dataset.RasterYSize,
                    eResampleAlg=1))

        metaDict = []
        bandNumberDict = {}
        bnmax = 0
        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'DN_%s' % pol[key]
            # A dictionary of band numbers is needed for the pixel function
            # bands further down. This is not the best solution. It would be
            # better to have a function in VRT that returns the number given a
            # band name. This function exists in Nansat but could perhaps be
            # moved to VRT? The existing nansat function could just call the
            # VRT one...
            bandNumberDict[name] = bnmax + 1
            bnmax = bandNumberDict[name]
            band = gdalDatasets[key].GetRasterBand(1)
            dtype = band.DataType
            metaDict.append({
                'src': {
                    'SourceFilename': mdsDict[key],
                    'SourceBand': 1,
                    'DataType': dtype,
                },
                'dst': {
                    'name': name,
                    #'SourceTransferType': gdal.GetDataTypeName(dtype),
                    #'dataType': 6,
                },
            })
        # add bands with metadata and corresponding values to the empty VRT
        self._create_bands(metaDict)

        '''
        Calibration should be performed as

        s0 = DN^2/sigmaNought^2,

        where sigmaNought is from e.g.
        annotation/calibration/calibration-s1a-iw-grd-hh-20140811t151231-20140811t151301-001894-001cc7-001.xml,
        and DN is the Digital Numbers in the tiff files.

        Also the noise should be subtracted.

        See
        https://sentinel.esa.int/web/sentinel/sentinel-1-sar-wiki/-/wiki/Sentinel%20One/Application+of+Radiometric+Calibration+LUT
        '''
        # Get look direction
        sat_heading = initial_bearing(longitude[:-1, :],
                                      latitude[:-1, :],
                                      longitude[1:, :],
                                      latitude[1:, :])
        look_direction = scipy.ndimage.interpolation.zoom(
            np.mod(sat_heading + 90, 360),
            (np.shape(longitude)[0] / (np.shape(longitude)[0]-1.), 1))

        # Decompose, to avoid interpolation errors around 0 <-> 360
        look_direction_u = np.sin(np.deg2rad(look_direction))
        look_direction_v = np.cos(np.deg2rad(look_direction))
        look_u_VRT = VRT(array=look_direction_u,
                         lat=latitude, lon=longitude)
        look_v_VRT = VRT(array=look_direction_v,
                         lat=latitude, lon=longitude)
        lookVRT = VRT(lat=latitude, lon=longitude)
        lookVRT._create_band([{'SourceFilename': look_u_VRT.fileName,
                               'SourceBand': 1},
                              {'SourceFilename': look_v_VRT.fileName,
                               'SourceBand': 1}],
                             {'PixelFunctionType': 'UVToDirectionTo'}
                             )

        # Blow up to full size
        lookVRT = lookVRT.get_resized_vrt(self.dataset.RasterXSize,
                                          self.dataset.RasterYSize,
                                          eResampleAlg=1)

        # Store VRTs so that they are accessible later
        self.bandVRTs['look_u_VRT'] = look_u_VRT
        self.bandVRTs['look_v_VRT'] = look_v_VRT
        self.bandVRTs['lookVRT'] = lookVRT

        metaDict = []
        # Add bands to full size VRT
        for key in pol:
            name = 'LUT_sigmaNought_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': {'SourceFilename':
                         (self.bandVRTs['LUT_sigmaNought_VRT_' +
                          pol[key]].fileName),
                         'SourceBand': 1
                         },
                 'dst': {'name': name
                         }
                 })
            name = 'LUT_noise_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append({
                'src': {
                    'SourceFilename': self.bandVRTs['LUT_noise_VRT_' +
                                                   pol[key]].fileName,
                    'SourceBand': 1
                },
                'dst': {
                    'name': name
                }
            })

        name = 'look_direction'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        metaDict.append({
            'src': {
                'SourceFilename': self.bandVRTs['lookVRT'].fileName,
                'SourceBand': 1
            },
            'dst': {
                'wkv': 'sensor_azimuth_angle',
                'name': name
            }
        })

        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'sigma0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]],
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_sigmaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })
            name = 'beta0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]]
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_betaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_brightness_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })

        self._create_bands(metaDict)

        # Add incidence angle as band
        name = 'incidence_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['incVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_incidence',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add elevation angle as band
        name = 'elevation_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['eleVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_elevation',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add sigma0_VV
        pp = [pol[key] for key in pol]
        if 'VV' not in pp and 'HH' in pp:
            name = 'sigma0_VV'
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            src = [{'SourceFilename': self.fileName,
                    'SourceBand': bandNumberDict['DN_HH'],
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_noise_VRT_HH'].
                                       fileName),
                    'SourceBand': 1
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_sigmaNought_VRT_HH'].
                                       fileName),
                    'SourceBand': 1,
                    },
                   {'SourceFilename': self.bandVRTs['incVRT'].fileName,
                    'SourceBand': 1}
                   ]
            dst = {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                   'PixelFunctionType': 'Sentinel1Sigma0HHToSigma0VV',
                   'polarization': 'VV',
                   'suffix': 'VV'}
            self._create_band(src, dst)
            self.dataset.FlushCache()

        # set time as acquisition start time
        n = Node.create(self.manifestXML)
        meta = n.node('metadataSection')
        for nn in meta.children:
            if nn.getAttribute('ID') == u'acquisitionPeriod':
                # set valid time
                self.dataset.SetMetadataItem(
                    'time_coverage_start',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:startTime'])
                          ).isoformat())
                self.dataset.SetMetadataItem(
                    'time_coverage_end',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:stopTime'])
                          ).isoformat())

        # Get dictionary describing the instrument and platform according to
        # the GCMD keywords
        mm = pti.get_gcmd_instrument('sar')
        ee = pti.get_gcmd_platform(missionName)

        # TODO: Validate that the found instrument and platform are indeed what we
        # want....

        self.dataset.SetMetadataItem('instrument', json.dumps(mm))
        self.dataset.SetMetadataItem('platform', json.dumps(ee))

Example 44

Project: WiFi-Pumpkin
Source File: bdf_proxy.py
View license
    def binaryGrinder(self, binaryFile):
        """
        Feed potential binaries into this function,
        it will return the result PatchedBinary, False, or None
        """
        with open(binaryFile, 'r+b') as f:
            binaryTMPHandle = f.read()

        binaryHeader = binaryTMPHandle[:4]
        result = None

        try:
            if binaryHeader[:2] == 'MZ':  # PE/COFF
                pe = pefile.PE(data=binaryTMPHandle, fast_load=True)
                magic = pe.OPTIONAL_HEADER.Magic
                machineType = pe.FILE_HEADER.Machine

                # update when supporting more than one arch
                if (magic == int('20B', 16) and machineType == 0x8664 and
                            self.WindowsType.lower() in ['all', 'x64']):
                    add_section = False
                    cave_jumping = False
                    if self.WindowsIntelx64['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx64['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx64['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx64['SHELL'],
                                             HOST=self.WindowsIntelx64['HOST'],
                                             PORT=int(self.WindowsIntelx64['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             PATCH_DLL=self.WindowsIntelx64.as_bool('PATCH_DLL'),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx64['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.WindowsIntelx64.as_bool('ZERO_CERT'),
                                             PATCH_METHOD=self.WindowsIntelx64['PATCH_METHOD'].lower()
                                             )

                    result = targetFile.run_this()

                elif (machineType == 0x14c and
                              self.WindowsType.lower() in ['all', 'x86']):
                    add_section = False
                    cave_jumping = False
                    # add_section wins for cave_jumping
                    # default is single for BDF
                    if self.WindowsIntelx86['PATCH_TYPE'].lower() == 'append':
                        add_section = True
                    elif self.WindowsIntelx86['PATCH_TYPE'].lower() == 'jump':
                        cave_jumping = True

                    # if automatic override
                    if self.WindowsIntelx86['PATCH_METHOD'].lower() == 'automatic':
                        cave_jumping = True

                    targetFile = pebin.pebin(FILE=binaryFile,
                                             OUTPUT=os.path.basename(binaryFile),
                                             SHELL=self.WindowsIntelx86['SHELL'],
                                             HOST=self.WindowsIntelx86['HOST'],
                                             PORT=int(self.WindowsIntelx86['PORT']),
                                             ADD_SECTION=add_section,
                                             CAVE_JUMPING=cave_jumping,
                                             IMAGE_TYPE=self.WindowsType,
                                             PATCH_DLL=self.WindowsIntelx86.as_bool('PATCH_DLL'),
                                             SUPPLIED_SHELLCODE=self.WindowsIntelx86['SUPPLIED_SHELLCODE'],
                                             ZERO_CERT=self.WindowsIntelx86.as_bool('ZERO_CERT'),
                                             PATCH_METHOD=self.WindowsIntelx86['PATCH_METHOD'].lower()
                                             )

                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') == '7f454c46':  # ELF

                targetFile = elfbin.elfbin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                if targetFile.class_type == 0x1:
                    # x86CPU Type
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx86['SHELL'],
                                               HOST=self.LinuxIntelx86['HOST'],
                                               PORT=int(self.LinuxIntelx86['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx86['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType
                                               )
                    result = targetFile.run_this()
                elif targetFile.class_type == 0x2:
                    # x64
                    targetFile = elfbin.elfbin(FILE=binaryFile,
                                               OUTPUT=os.path.basename(binaryFile),
                                               SHELL=self.LinuxIntelx64['SHELL'],
                                               HOST=self.LinuxIntelx64['HOST'],
                                               PORT=int(self.LinuxIntelx64['PORT']),
                                               SUPPLIED_SHELLCODE=self.LinuxIntelx64['SUPPLIED_SHELLCODE'],
                                               IMAGE_TYPE=self.LinuxType
                                               )
                    result = targetFile.run_this()

            elif binaryHeader[:4].encode('hex') in ['cefaedfe', 'cffaedfe', 'cafebabe']:  # Macho
                targetFile = machobin.machobin(FILE=binaryFile, SUPPORT_CHECK=False)
                targetFile.support_check()

                # ONE CHIP SET MUST HAVE PRIORITY in FAT FILE

                if targetFile.FAT_FILE is True:
                    if self.FatPriority == 'x86':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx86['SHELL'],
                                                       HOST=self.MachoIntelx86['HOST'],
                                                       PORT=int(self.MachoIntelx86['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority
                                                       )
                        result = targetFile.run_this()

                    elif self.FatPriority == 'x64':
                        targetFile = machobin.machobin(FILE=binaryFile,
                                                       OUTPUT=os.path.basename(binaryFile),
                                                       SHELL=self.MachoIntelx64['SHELL'],
                                                       HOST=self.MachoIntelx64['HOST'],
                                                       PORT=int(self.MachoIntelx64['PORT']),
                                                       SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                       FAT_PRIORITY=self.FatPriority
                                                       )
                        result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x7':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx86['SHELL'],
                                                   HOST=self.MachoIntelx86['HOST'],
                                                   PORT=int(self.MachoIntelx86['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx86['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority
                                                   )
                    result = targetFile.run_this()

                elif targetFile.mach_hdrs[0]['CPU Type'] == '0x1000007':
                    targetFile = machobin.machobin(FILE=binaryFile,
                                                   OUTPUT=os.path.basename(binaryFile),
                                                   SHELL=self.MachoIntelx64['SHELL'],
                                                   HOST=self.MachoIntelx64['HOST'],
                                                   PORT=int(self.MachoIntelx64['PORT']),
                                                   SUPPLIED_SHELLCODE=self.MachoIntelx64['SUPPLIED_SHELLCODE'],
                                                   FAT_PRIORITY=self.FatPriority
                                                   )
                    result = targetFile.run_this()

            return result

        except Exception as e:
            EnhancedOutput.print_error('binaryGrinder: {0}'.format(e))
            EnhancedOutput.logging_warning("Exception in binaryGrinder {0}".format(e))
            return None

Example 45

Project: AstroBox
Source File: printjob.py
View license
	def run(self):
		profile = self._printer._profile

		self._printer._heatingUp = True
		self._printer.mcHeatingUpUpdate(True)
		self._heatupWaitStartTime = time.time()
		self._heatingTool = True
		self._lastLayerHeight = 0.0
		self._currentLayer = 0

		try:
			self._printer._comm.reset()
			self._printer._comm.build_start_notification(os.path.basename(self._file['filename'])[:15])
			self._printer._comm.set_build_percent(0)

			self._file['start_time'] = time.time()
			self._file['progress'] = 0

			lastProgressReport = 0
			lastProgressValueSentToPrinter = 0
			lastHeatingCheck = self._file['start_time']

			with open(self._file['filename'], 'rb') as f:
				while True:
					packet = bytearray()

					try:
						command = f.read(1)

						if self._canceled or len(command) == 0:
							break

						packet.append(ord(command))

						command = struct.unpack("B",command)
						try:
							parse = self.commandTable[command[0]]

						except KeyError:
							raise Exception("Unexpected packet type: 0x%x" % command[0])

						if type(parse) == type(""):
							packetLen = struct.calcsize(parse)
							packetData = f.read(packetLen)
							if len(packetData) != packetLen:
								raise Exception("Packet incomplete")
						else:
							packetData = parse(f)

						for c in packetData:
							packet.append(ord(c))

						if self.send_packet(packet):
							self._serialLoggerEnabled and self._serialLogger.debug('{"event":"packet_sent", "data": "%s"}' % ' '.join('0x{:02x}'.format(x) for x in packet) )

							now = time.time()
							if now - lastProgressReport > self.UPDATE_INTERVAL_SECS:
								position = f.tell()
								self._file['position'] = position
								self._file['progress'] = float(position) / float(self._file['size'])
								self._printer.mcProgress()

								printerProgress = int(self._file['progress'] * 100.0)

								if lastProgressValueSentToPrinter != printerProgress:
									try:
										self._printer._comm.set_build_percent(printerProgress)
										lastProgressValueSentToPrinter = printerProgress
										lastProgressReport = now

									except BufferOverflowError:
										time.sleep(.2)

							if self._printer._heatingUp and now - lastHeatingCheck > self.UPDATE_INTERVAL_SECS:
								lastHeatingCheck = now

								if  	( not self._heatingPlatform or ( self._heatingPlatform and self._printer._comm.is_platform_ready(0) ) )  \
									and ( not self._heatingTool or ( self._heatingTool and self._printer._comm.is_tool_ready(0) ) ):
								 
									self._heatingTool = False
									self._heatingPlatform = False
									self._printer._heatingUp = False
									self._printer.mcHeatingUpUpdate(False)
									self._heatupWaitTimeLost = now - self._heatupWaitStartTime
									self._heatupWaitStartTime = now
									self._file['start_time'] += self._heatupWaitTimeLost

					except ProtocolError as e:
						self._logger.warn('ProtocolError: %s' % e)

			self._printer._comm.build_end_notification()

			if self._canceled:
				self._printer._comm.clear_buffer()

			#Is possible to get a BufferOverflowError when succesfully completing a job
			continueEndSequence = False
			findAxesTries = 0
			while not continueEndSequence:
				try:
					self._printer._comm.find_axes_maximums(['x', 'y'], 200, 10)
					continueEndSequence = True

				except BufferOverflowError:
					if findAxesTries < 3:
						time.sleep(.2)
						findAxesTries += 1
					else:
						continueEndSequence = True

			#find current z:
			moveToPosition, endStopsStates = self._printer._comm.get_extended_position()
			zEndStopReached = endStopsStates & 16 == 16 #endstop is a bitfield. See https://github.com/makerbot/s3g/blob/master/doc/s3gProtocol.md (command 21)

			if zEndStopReached:

				#calculate Z max
				moveToPosition[2] = self._printer._profile.values['axes']['Z']['platform_length'] *  self._printer._profile.values['axes']['Z']['steps_per_mm']

				#move to the bottom:
				self._printer._comm.queue_extended_point_classic(moveToPosition, 100)

			self._printer._comm.toggle_axes(['x','y','z','a','b'], False)

			self._printer._changeState(self._printer.STATE_OPERATIONAL)

			payload = {
				"file": self._file['filename'],
				"filename": os.path.basename(self._file['filename']),
				"origin": self._file['origin'],
				"time": self._printer.getPrintTime(),
				"layerCount": self._currentLayer
			}

			if self._canceled:
				self._printer.printJobCancelled()
				eventManager().fire(Events.PRINT_FAILED, payload)
				self._printer._fileManager.printFailed(payload['filename'], payload['time'])

			else:
				self._printer.mcPrintjobDone()
				self._printer._fileManager.printSucceeded(payload['filename'], payload['time'], payload['layerCount'])
				eventManager().fire(Events.PRINT_DONE, payload)

		except BuildCancelledError:
			self._logger.warn('Build Cancel detected')
			self._printer.printJobCancelled()
			payload = {
				"file": self._file['filename'],
				"filename": os.path.basename(self._file['filename']),
				"origin": self._file['origin'],
				"time": self._printer.getPrintTime()
			}
			eventManager().fire(Events.PRINT_FAILED, payload)
			self._printer._fileManager.printFailed(payload['filename'], payload['time'])
			self._printer._changeState(self._printer.STATE_OPERATIONAL)

		except ExternalStopError:
			self._logger.warn('External Stop detected')
			self._printer._comm.writer.set_external_stop(False)
			self._printer.printJobCancelled()
			payload = {
				"file": self._file['filename'],
				"filename": os.path.basename(self._file['filename']),
				"origin": self._file['origin'],
				"time": self._printer.getPrintTime()
			}
			eventManager().fire(Events.PRINT_FAILED, payload)
			self._printer._fileManager.printFailed(payload['filename'], payload['time'])
			self._printer._changeState(self._printer.STATE_OPERATIONAL)			

		except Exception as e:
			self._errorValue = getExceptionString()
			self._printer._changeState(self._printer.STATE_ERROR)
			eventManager().fire(Events.ERROR, {"error": self._errorValue })
			self._logger.error(self._errorValue)

Example 46

Project: python-netsnmpagent
Source File: netsnmpagent.py
View license
	def __init__(self, **args):
		"""Initializes a new netsnmpAgent instance.
		
		"args" is a dictionary that can contain the following
		optional parameters:
		
		- AgentName     : The agent's name used for registration with net-snmp.
		- MasterSocket  : The transport specification of the AgentX socket of
		                  the running snmpd instance to connect to (see the
		                  "LISTENING ADDRESSES" section in the snmpd(8) manpage).
		                  Change this if you want to use eg. a TCP transport or
		                  access a custom snmpd instance, eg. as shown in
		                  run_simple_agent.sh, or for automatic testing.
		- PersistenceDir: The directory to use to store persistence information.
		                  Change this if you want to use a custom snmpd
		                  instance, eg. for automatic testing.
		- MIBFiles      : A list of filenames of MIBs to be loaded. Required if
		                  the OIDs, for which variables will be registered, do
		                  not belong to standard MIBs and the custom MIBs are not
		                  located in net-snmp's default MIB path
		                  (/usr/share/snmp/mibs).
		- UseMIBFiles   : Whether to use MIB files at all or not. When False,
		                  the parser for MIB files will not be initialized, so
		                  neither system-wide MIB files nor the ones provided
		                  in the MIBFiles argument will be in use.
		- LogHandler    : An optional Python function that will be registered
		                  with net-snmp as a custom log handler. If specified,
		                  this function will be called for every log message
		                  net-snmp itself generates, with parameters as follows:
		                  1. a string indicating the message's priority: one of
		                  "Emergency", "Alert", "Critical", "Error", "Warning",
		                  "Notice", "Info" or "Debug".
		                  2. the actual log message. Note that heading strings
		                  such as "Warning: " and "Error: " will be stripped off
		                  since the priority level is explicitly known and can
		                  be used to prefix the log message, if desired.
		                  Trailing linefeeds will also have been stripped off.
		                  If undefined, log messages will be written to stderr
		                  instead. """

		# Default settings
		defaults = {
			"AgentName"     : os.path.splitext(os.path.basename(sys.argv[0]))[0],
			"MasterSocket"  : None,
			"PersistenceDir": None,
			"UseMIBFiles"   : True,
			"MIBFiles"      : None,
			"LogHandler"    : None,
		}
		for key in defaults:
			setattr(self, key, args.get(key, defaults[key]))
		if self.UseMIBFiles and self.MIBFiles is not None and type(self.MIBFiles) not in (list, tuple):
			self.MIBFiles = (self.MIBFiles,)

		# Initialize status attribute -- until start() is called we will accept
		# SNMP object registrations
		self._status = netsnmpAgentStatus.REGISTRATION

		# Unfortunately net-snmp does not give callers of init_snmp() (used
		# in the start() method) any feedback about success or failure of
		# connection establishment. But for AgentX clients this information is
		# quite essential, thus we need to implement some more or less ugly
		# workarounds.

		# For net-snmp 5.7.x, we can derive success and failure from the log
		# messages it generates. Normally these go to stderr, in the absence
		# of other so-called log handlers. Alas we define a callback function
		# that we will register with net-snmp as a custom log handler later on,
		# hereby effectively gaining access to the desired information.
		def _py_log_handler(majorID, minorID, serverarg, clientarg):
			# "majorID" and "minorID" are the callback IDs with which this
			# callback function was registered. They are useful if the same
			# callback was registered multiple times.
			# Both "serverarg" and "clientarg" are pointers that can be used to
			# convey information from the calling context to the callback
			# function: "serverarg" gets passed individually to every call of
			# snmp_call_callbacks() while "clientarg" was initially passed to
			# snmp_register_callback().

			# In this case, "majorID" and "minorID" are always the same (see the
			# registration code below). "serverarg" needs to be cast back to
			# become a pointer to a "snmp_log_message" C structure (passed by
			# net-snmp's log_handler_callback() in snmplib/snmp_logging.c) while
			# "clientarg" will be None (see the registration code below).
			logmsg = ctypes.cast(serverarg, snmp_log_message_p)

			# Generate textual description of priority level
			priorities = {
				LOG_EMERG: "Emergency",
				LOG_ALERT: "Alert",
				LOG_CRIT: "Critical",
				LOG_ERR: "Error",
				LOG_WARNING: "Warning",
				LOG_NOTICE: "Notice",
				LOG_INFO: "Info",
				LOG_DEBUG: "Debug"
			}
			msgprio = priorities[logmsg.contents.priority]

			# Strip trailing linefeeds and in addition "Warning: " and "Error: "
			# from msgtext as these conditions are already indicated through
			# msgprio
			msgtext = re.sub(
				"^(Warning|Error): *",
				"",
				u(logmsg.contents.msg.rstrip(b"\n"))
			)

			# Intercept log messages related to connection establishment and
			# failure to update the status of this netsnmpAgent object. This is
			# really an ugly hack, introducing a dependency on the particular
			# text of log messages -- hopefully the net-snmp guys won't
			# translate them one day.
			if  msgprio == "Warning" \
			or  msgprio == "Error" \
			and re.match("Failed to .* the agentx master agent.*", msgtext):
				# If this was the first connection attempt, we consider the
				# condition fatal: it is more likely that an invalid
				# "MasterSocket" was specified than that we've got concurrency
				# issues with our agent being erroneously started before snmpd.
				if self._status == netsnmpAgentStatus.FIRSTCONNECT:
					self._status = netsnmpAgentStatus.CONNECTFAILED

					# No need to log this message -- we'll generate our own when
					# throwing a netsnmpAgentException as consequence of the
					# ECONNECT
					return 0

				# Otherwise we'll stay at status RECONNECTING and log net-snmp's
				# message like any other. net-snmp code will keep retrying to
				# connect.
			elif msgprio == "Info" \
			and  re.match("AgentX subagent connected", msgtext):
				self._status = netsnmpAgentStatus.CONNECTED
			elif msgprio == "Info" \
			and  re.match("AgentX master disconnected us.*", msgtext):
				self._status = netsnmpAgentStatus.RECONNECTING

			# If "LogHandler" was defined, call it to take care of logging.
			# Otherwise print all log messages to stderr to resemble net-snmp
			# standard behavior (but add log message's associated priority in
			# plain text as well)
			if self.LogHandler:
				self.LogHandler(msgprio, msgtext)
			else:
				print("[{0}] {1}".format(msgprio, msgtext))

			return 0

		# We defined a Python function that needs a ctypes conversion so it can
		# be called by C code such as net-snmp. That's what SNMPCallback() is
		# used for. However we also need to store the reference in "self" as it
		# will otherwise be lost at the exit of this function so that net-snmp's
		# attempt to call it would end in nirvana...
		self._log_handler = SNMPCallback(_py_log_handler)

		# Now register our custom log handler with majorID SNMP_CALLBACK_LIBRARY
		# and minorID SNMP_CALLBACK_LOGGING.
		if libnsa.snmp_register_callback(
			SNMP_CALLBACK_LIBRARY,
			SNMP_CALLBACK_LOGGING,
			self._log_handler,
			None
		) != SNMPERR_SUCCESS:
			raise netsnmpAgentException(
				"snmp_register_callback() failed for _netsnmp_log_handler!"
			)

		# Finally the net-snmp logging system needs to be told to enable
		# logging through callback functions. This will actually register a
		# NETSNMP_LOGHANDLER_CALLBACK log handler that will call out to any
		# callback functions with the majorID and minorID shown above, such as
		# ours.
		libnsa.snmp_enable_calllog()

		# Unfortunately our custom log handler above is still not enough: in
		# net-snmp 5.4.x there were no "AgentX master disconnected" log
		# messages yet. So we need another workaround to be able to detect
		# disconnects for this release. Both net-snmp 5.4.x and 5.7.x support
		# a callback mechanism using the "majorID" SNMP_CALLBACK_APPLICATION and
		# the "minorID" SNMPD_CALLBACK_INDEX_STOP, which we can abuse for our
		# purposes. Again, we start by defining a callback function.
		def _py_index_stop_callback(majorID, minorID, serverarg, clientarg):
			# For "majorID" and "minorID" see our log handler above.
			# "serverarg" is a disguised pointer to a "netsnmp_session"
			# structure (passed by net-snmp's subagent_open_master_session() and
			# agentx_check_session() in agent/mibgroup/agentx/subagent.c). We
			# can ignore it here since we have a single session only anyway.
			# "clientarg" will be None again (see the registration code below).

			# We only care about SNMPD_CALLBACK_INDEX_STOP as our custom log
			# handler above already took care of all other events.
			if minorID == SNMPD_CALLBACK_INDEX_STOP:
				self._status = netsnmpAgentStatus.RECONNECTING

			return 0

		# Convert it to a C callable function and store its reference
		self._index_stop_callback = SNMPCallback(_py_index_stop_callback)

		# Register it with net-snmp
		if libnsa.snmp_register_callback(
			SNMP_CALLBACK_APPLICATION,
			SNMPD_CALLBACK_INDEX_STOP,
			self._index_stop_callback,
			None
		) != SNMPERR_SUCCESS:
			raise netsnmpAgentException(
				"snmp_register_callback() failed for _netsnmp_index_callback!"
			)

		# No enabling necessary here

		# Make us an AgentX client
		if libnsa.netsnmp_ds_set_boolean(
			NETSNMP_DS_APPLICATION_ID,
			NETSNMP_DS_AGENT_ROLE,
			1
		) != SNMPERR_SUCCESS:
			raise netsnmpAgentException(
				"netsnmp_ds_set_boolean() failed for NETSNMP_DS_AGENT_ROLE!"
			)

		# Use an alternative transport specification to connect to the master?
		# Defaults to "/var/run/agentx/master".
		# (See the "LISTENING ADDRESSES" section in the snmpd(8) manpage)
		if self.MasterSocket:
			if libnsa.netsnmp_ds_set_string(
				NETSNMP_DS_APPLICATION_ID,
				NETSNMP_DS_AGENT_X_SOCKET,
				b(self.MasterSocket)
			) != SNMPERR_SUCCESS:
				raise netsnmpAgentException(
					"netsnmp_ds_set_string() failed for NETSNMP_DS_AGENT_X_SOCKET!"
				)

		# Use an alternative persistence directory?
		if self.PersistenceDir:
			if libnsa.netsnmp_ds_set_string(
				NETSNMP_DS_LIBRARY_ID,
				NETSNMP_DS_LIB_PERSISTENT_DIR,
				b(self.PersistenceDir)
			) != SNMPERR_SUCCESS:
				raise netsnmpAgentException(
					"netsnmp_ds_set_string() failed for NETSNMP_DS_LIB_PERSISTENT_DIR!"
				)

		# Initialize net-snmp library (see netsnmp_agent_api(3))
		if libnsa.init_agent(b(self.AgentName)) != 0:
			raise netsnmpAgentException("init_agent() failed!")

		# Initialize MIB parser
		if self.UseMIBFiles:
			libnsa.netsnmp_init_mib()

		# If MIBFiles were specified (ie. MIBs that can not be found in
		# net-snmp's default MIB directory /usr/share/snmp/mibs), read
		# them in so we can translate OID strings to net-snmp's internal OID
		# format.
		if self.UseMIBFiles and self.MIBFiles:
			for mib in self.MIBFiles:
				if libnsa.read_mib(b(mib)) == 0:
					raise netsnmpAgentException("netsnmp_read_module({0}) " +
					                            "failed!".format(mib))

		# Initialize our SNMP object registry
		self._objs = defaultdict(dict)

Example 47

Project: hellanzb
Source File: NZBSegmentQueue.py
View license
    def parseNZB(self, nzb, verbose = True):
        """ Initialize the queue from the specified nzb file """
        # Create a parser
        parser = make_parser()
        
        # No XML namespaces here
        parser.setFeature(feature_namespaces, 0)
        parser.setFeature(feature_external_ges, 0)

        # Create the handler
        fileName = nzb.nzbFileName
        self.nzbAdd(nzb)
        needWorkFiles = []
        needWorkSegments = []
        nzbp = NZBParser(nzb, needWorkFiles, needWorkSegments)
        
        # Tell the parser to use it
        parser.setContentHandler(nzbp)

        nzb.calculatingBytes = True
        # Parse the input
        try:
            parser.parse(fileName)
        except SAXParseException, saxpe:
            nzb.calculatingBytes = False
            self.nzbDone(nzb)
            msg = 'Unable to parse invalid NZB file: %s: %s' % \
                (os.path.basename(fileName), saxpe.getException())
            raise FatalError(msg)
        nzb.calculatingBytes = False

        # We trust the NZB XML's <segment number="111"> attribute, but if the sequence of
        # segments does not begin at "1", the parser wouldn't have found the
        # nzbFile.firstSegment
        for needWorkFile in nzbp.needWorkFiles:
            if needWorkFile.firstSegment is None and len(needWorkFile.nzbSegments):
                # Set the firstSegment to the smallest segment number
                sortedSegments = [(nzbSegment.number, nzbSegment) for nzbSegment in \
                                  needWorkFile.nzbSegments]
                sortedSegments.sort()
                needWorkFile.firstSegment = sortedSegments[0][1]
                needWorkFile.firstSegment.priority = NZBSegmentQueue.NZB_CONTENT_P

        s = time.time()
        # The parser will add all the segments of all the NZBFiles that have not already
        # been downloaded. After the parsing, we'll check if each of those segments have
        # already been downloaded. it's faster to check all segments at one time
        needDlFiles, needDlSegments, onDiskSegments = segmentsNeedDownload(needWorkSegments,
                                                                           overwriteZeroByteSegments = \
                                                                           nzb.overwriteZeroByteFiles)
        e = time.time() - s

        # firstSegmentsDownloaded needs to be tweaked if isSkippedPar and no segments were
        # found on disk by segmentsNeedDownload. i.e. first segments have ALWAYS already
        # been downloaded in isParRecovery mode
        fauxFirstSegmentsDownloaded = 0
        if Hellanzb.SMART_PAR and nzb.isParRecovery:
            for nzbFile in nzb.nzbFiles:
                if nzbFile.isSkippedPar and nzbFile.firstSegment not in onDiskSegments:
                    nzb.firstSegmentsDownloaded += 1
                    fauxFirstSegmentsDownloaded += 1
                    
        # Calculate and print parsed/skipped/queued statistics
        skippedPars = 0
        queuedParBlocks = 0
        for nzbFile in needDlFiles:
            if nzbFile.isSkippedPar:
                skippedPars += 1
            elif nzb.isParRecovery and nzbFile.isExtraPar and \
                    not nzbFile.isSkippedPar and len(nzbFile.todoNzbSegments) and \
                    nzbFile.filename is not None and not isHellaTemp(nzbFile.filename):
                queuedParBlocks += getParSize(nzbFile.filename)

        onDiskBytes = 0
        for nzbSegment in onDiskSegments:
            onDiskBytes += nzbSegment.bytes
        for nzbFile in nzb.nzbFiles:
            if nzbFile not in needDlFiles:
                onDiskBytes += nzbFile.totalBytes
        onDiskFilesCount = nzbp.fileCount - len(needWorkFiles)
        onDiskSegmentsCount = len(onDiskSegments)
        info('Parsed: %i files (%i posts), %s' % (nzbp.fileCount, nzbp.segmentCount,
                                                  prettySize(nzb.totalBytes)))
        if onDiskFilesCount or onDiskSegmentsCount:
            filesMsg = segmentsMsg = separator = ''
            if onDiskFilesCount:
                filesMsg = '%i files' % onDiskFilesCount
            if onDiskSegmentsCount:
                segmentsMsg = '%i segments' % onDiskSegmentsCount
            if onDiskFilesCount and onDiskSegmentsCount:
                separator = ' and '
            info('Skipped (on disk): %s%s%s, %s' % (filesMsg, separator, segmentsMsg,
                                                    prettySize(onDiskBytes)))

        # Tally what was skipped for correct percentages in the UI
        for nzbSegment in onDiskSegments:
            nzbSegment.nzbFile.totalSkippedBytes += nzbSegment.bytes
            nzbSegment.nzbFile.nzb.totalSkippedBytes += nzbSegment.bytes

        # The needWorkFiles will tell us what nzbFiles are missing from the
        # FS. segmentsNeedDownload will further tell us what files need to be
        # downloaded. files missing from the FS (needWorkFiles) but not needing to be
        # downloaded (in needDlFiles) simply need to be assembled
        for nzbFile in needWorkFiles:
            if nzbFile not in needDlFiles:
                # Don't automatically 'finish' the NZB, we'll take care of that in this
                # function if necessary
                if verbose:
                    info(nzbFile.getFilename() + ': Assembling -- all segments were on disk')
                
                # NOTE: this function is destructive to the passed in nzbFile! And is only
                # called on occasion (might bite you in the ass one day)
                try:
                    assembleNZBFile(nzbFile, autoFinish = False)
                except OutOfDiskSpace:
                    self.nzbDone(nzb)
                    # FIXME: Shouldn't exit here
                    error('Cannot assemble ' + nzbFile.getFilename() + ': No space left on device! Exiting..')
                    Hellanzb.Core.shutdown(True)

        for nzbSegment in needDlSegments:
            # smartDequeue called from segmentsNeedDownload would have set
            # isSkippedParFile for us
            if not nzbSegment.nzbFile.isSkippedPar:
                self.put((nzbSegment.priority, nzbSegment))
            else:
                # This would need to be downloaded if we didn't skip the segment, they are
                # officially dequeued, and can be requeued later
                nzbSegment.nzbFile.dequeuedSegments.add(nzbSegment)
                
        # Requeue files in certain situations
        if nzb.firstSegmentsDownloaded == len(nzb.nzbFiles):
            # NOTE: This block of code does not commonly happen with newzbin.com NZBs: due
            # to how the DupeHandler handles .NFO files. newzbin.com seems to always
            # duplicate the .NFO file in their NZBs
            smartRequeue(nzb)
            logSkippedPars(nzb)
                
        if nzb.isParRecovery and nzb.skippedParSubjects and len(nzb.skippedParSubjects) and \
                not len(self):
            # FIXME: This recovering ALL pars should be a mode (with a flag on the NZB
            # object). No par skipping would occur in this mode -- for the incredibly rare
            # case that first segments are lost prior to this mode taking place. What will
            # happen doesn't make sense: hellanzb will say 'recovering ALL pars', then
            # SmartPar will later skip pars
            msg = 'Par recovery download: No pars with prefix: %s -- recovering ALL pars' % \
                nzb.parPrefix
            if skippedPars:
                msg = '%s (%i par files)' % (msg, skippedPars)
            if verbose:
                warn(msg)
            for nzbSegment in needDlSegments:
                if nzbSegment.nzbFile.isSkippedPar:
                    self.put((nzbSegment.priority, nzbSegment))
                    nzbSegment.nzbFile.todoNzbSegments.add(nzbSegment)

            # Only reset the isSkippedPar flag after queueing
            for nzbSegment in needDlSegments:
                if nzbSegment.nzbFile.isSkippedPar:
                    nzbSegment.nzbFile.isSkippedPar = False

            # We might have faked the value of this: reset it
            nzb.firstSegmentsDownloaded -= fauxFirstSegmentsDownloaded
                    
        if not len(self):
            self.nzbDone(nzb)
            if verbose:
                info(nzb.archiveName + ': Assembled archive!')
            
            reactor.callLater(0, Hellanzb.Daemon.handleNZBDone, nzb)

            # True == the archive is complete
            return True

        # Finally tally the size of the queue
        self.calculateTotalQueuedBytes()
        dlMsg = 'Queued: %s' % prettySize(self.totalQueuedBytes)
        if nzb.isParRecovery and queuedParBlocks:
            dlMsg += ' (recovering %i %s)' % (queuedParBlocks, getParRecoveryName(nzb.parType))
        info(dlMsg)

        # Archive not complete
        return False

Example 48

Project: tp-libvirt
Source File: virsh_blockcommit.py
View license
def run(test, params, env):
    """
    Test command: virsh blockcommit <domain> <path>

    1) Prepare test environment.
    2) Commit changes from a snapshot down to its backing image.
    3) Recover test environment.
    4) Check result.
    """

    def make_disk_snapshot(postfix_n):
        # Add all disks into commandline.
        disks = vm.get_disk_devices()

        # Make three external snapshots for disks only
        for count in range(1, 4):
            options = "%s_%s %s%s-desc " % (postfix_n, count,
                                            postfix_n, count)
            options += "--disk-only --atomic --no-metadata"
            if needs_agent:
                options += " --quiesce"

            for disk in disks:
                disk_detail = disks[disk]
                basename = os.path.basename(disk_detail['source'])

                # Remove the original suffix if any, appending
                # ".postfix_n[0-9]"
                diskname = basename.split(".")[0]
                snap_name = "%s.%s%s" % (diskname, postfix_n, count)
                disk_external = os.path.join(tmp_dir, snap_name)

                snapshot_external_disks.append(disk_external)
                options += " %s,snapshot=external,file=%s" % (disk,
                                                              disk_external)

            cmd_result = virsh.snapshot_create_as(vm_name, options,
                                                  ignore_status=True,
                                                  debug=True)
            status = cmd_result.exit_status
            if status != 0:
                raise error.TestFail("Failed to make snapshots for disks!")

            # Create a file flag in VM after each snapshot
            flag_file = tempfile.NamedTemporaryFile(prefix=("snapshot_test_"),
                                                    dir="/tmp")
            file_path = flag_file.name
            flag_file.close()

            status, output = session.cmd_status_output("touch %s" % file_path)
            if status:
                raise error.TestFail("Touch file in vm failed. %s" % output)
            snapshot_flag_files.append(file_path)

    # MAIN TEST CODE ###
    # Process cartesian parameters
    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    vm_state = params.get("vm_state", "running")
    needs_agent = "yes" == params.get("needs_agent", "yes")
    replace_vm_disk = "yes" == params.get("replace_vm_disk", "no")
    top_inactive = ("yes" == params.get("top_inactive"))
    with_timeout = ("yes" == params.get("with_timeout_option", "no"))
    status_error = ("yes" == params.get("status_error", "no"))
    base_option = params.get("base_option", "none")
    middle_base = "yes" == params.get("middle_base", "no")
    pivot_opt = "yes" == params.get("pivot_opt", "no")
    snap_in_mirror = "yes" == params.get("snap_in_mirror", "no")
    snap_in_mirror_err = "yes" == params.get("snap_in_mirror_err", "no")
    with_active_commit = "yes" == params.get("with_active_commit", "no")
    multiple_chain = "yes" == params.get("multiple_chain", "no")
    virsh_dargs = {'debug': True}

    # Process domain disk device parameters
    disk_type = params.get("disk_type")
    disk_src_protocol = params.get("disk_source_protocol")
    restart_tgtd = params.get("restart_tgtd", 'no')
    vol_name = params.get("vol_name")
    tmp_dir = data_dir.get_tmp_dir()
    pool_name = params.get("pool_name", "gluster-pool")
    brick_path = os.path.join(tmp_dir, pool_name)

    if not top_inactive:
        if not libvirt_version.version_compare(1, 2, 4):
            raise error.TestNAError("live active block commit is not supported"
                                    " in current libvirt version.")

    # A backup of original vm
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    # Abort the test if there are snapshots already
    exsiting_snaps = virsh.snapshot_list(vm_name)
    if len(exsiting_snaps) != 0:
        raise error.TestFail("There are snapshots created for %s already" %
                             vm_name)

    snapshot_external_disks = []
    cmd_session = None
    try:
        if disk_src_protocol == 'iscsi' and disk_type == 'network':
            if not libvirt_version.version_compare(1, 0, 4):
                raise error.TestNAError("'iscsi' disk doesn't support in"
                                        " current libvirt version.")

        # Set vm xml and guest agent
        if replace_vm_disk:
            if disk_src_protocol == "rbd" and disk_type == "network":
                src_host = params.get("disk_source_host", "EXAMPLE_HOSTS")
                mon_host = params.get("mon_host", "EXAMPLE_MON_HOST")
                if src_host.count("EXAMPLE") or mon_host.count("EXAMPLE"):
                    raise error.TestNAError("Please provide rbd host first.")
            libvirt.set_vm_disk(vm, params, tmp_dir)

        if needs_agent:
            vm.prepare_guest_agent()

        # The first disk is supposed to include OS
        # We will perform blockcommit operation for it.
        first_disk = vm.get_first_disk_devices()
        blk_source = first_disk['source']
        blk_target = first_disk['target']
        snapshot_flag_files = []

        # get a vm session before snapshot
        session = vm.wait_for_login()
        # do snapshot
        postfix_n = 'snap'
        make_disk_snapshot(postfix_n)

        basename = os.path.basename(blk_source)
        diskname = basename.split(".")[0]
        snap_src_lst = [blk_source]
        if multiple_chain:
            snap_name = "%s.%s1" % (diskname, postfix_n)
            snap_top = os.path.join(tmp_dir, snap_name)
            top_index = snapshot_external_disks.index(snap_top) + 1
            omit_list = snapshot_external_disks[top_index:]
            vm.destroy(gracefully=False)
            vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
            disk_xml = vmxml.get_devices(device_type="disk")[0]
            vmxml.del_device(disk_xml)
            disk_dict = {'attrs': {'file': snap_top}}
            disk_xml.source = disk_xml.new_disk_source(**disk_dict)
            vmxml.add_device(disk_xml)
            vmxml.sync()
            vm.start()
            session = vm.wait_for_login()
            postfix_n = 'new_snap'
            make_disk_snapshot(postfix_n)
            snap_src_lst = [blk_source]
            snap_src_lst += snapshot_external_disks
            logging.debug("omit list is %s", omit_list)
            for i in omit_list:
                snap_src_lst.remove(i)
        else:
            # snapshot src file list
            snap_src_lst += snapshot_external_disks
        backing_chain = ''
        for i in reversed(range(4)):
            if i == 0:
                backing_chain += "%s" % snap_src_lst[i]
            else:
                backing_chain += "%s -> " % snap_src_lst[i]

        logging.debug("The backing chain is: %s" % backing_chain)

        # check snapshot disk xml backingStore is expected
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        disks = vmxml.devices.by_device_tag('disk')
        disk_xml = None
        for disk in disks:
            if disk.target['dev'] != blk_target:
                continue
            else:
                disk_xml = disk.xmltreefile
                logging.debug("the target disk xml after snapshot is %s",
                              disk_xml)
                break

        if not disk_xml:
            raise error.TestFail("Can't find disk xml with target %s" %
                                 blk_target)
        elif libvirt_version.version_compare(1, 2, 4):
            # backingStore element introuduced in 1.2.4
            chain_lst = snap_src_lst[::-1]
            ret = check_chain_xml(disk_xml, chain_lst)
            if not ret:
                raise error.TestFail("Domain image backing chain check failed")

        # set blockcommit_options
        top_image = None
        blockcommit_options = "--wait --verbose"

        if with_timeout:
            blockcommit_options += " --timeout 1"

        if base_option == "shallow":
            blockcommit_options += " --shallow"
        elif base_option == "base":
            if middle_base:
                snap_name = "%s.%s1" % (diskname, postfix_n)
                blk_source = os.path.join(tmp_dir, snap_name)
            blockcommit_options += " --base %s" % blk_source

        if top_inactive:
            snap_name = "%s.%s2" % (diskname, postfix_n)
            top_image = os.path.join(tmp_dir, snap_name)
            blockcommit_options += " --top %s" % top_image
        else:
            blockcommit_options += " --active"
            if pivot_opt:
                blockcommit_options += " --pivot"

        if vm_state == "shut off":
            vm.destroy(gracefully=True)

        if with_active_commit:
            # inactive commit follow active commit will fail with bug 1135339
            cmd = "virsh blockcommit %s %s --active --pivot" % (vm_name,
                                                                blk_target)
            cmd_session = aexpect.ShellSession(cmd)

        # Run test case
        # Active commit does not support on rbd based disk with bug 1200726
        result = virsh.blockcommit(vm_name, blk_target,
                                   blockcommit_options, **virsh_dargs)

        # Check status_error
        libvirt.check_exit_status(result, status_error)
        if result.exit_status and status_error:
            return

        while True:
            vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)

            disks = vmxml.devices.by_device_tag('disk')
            for disk in disks:
                if disk.target['dev'] != blk_target:
                    continue
                else:
                    disk_xml = disk.xmltreefile
                    break

            if not top_inactive:
                disk_mirror = disk_xml.find('mirror')
                if '--pivot' not in blockcommit_options:
                    if disk_mirror is not None:
                        job_type = disk_mirror.get('job')
                        job_ready = disk_mirror.get('ready')
                        src_element = disk_mirror.find('source')
                        disk_src_file = None
                        for elem in ('file', 'name', 'dev'):
                            elem_val = src_element.get(elem)
                            if elem_val:
                                disk_src_file = elem_val
                                break
                        err_msg = "blockcommit base source "
                        err_msg += "%s not expected" % disk_src_file
                        if '--shallow' in blockcommit_options:
                            if not multiple_chain:
                                if disk_src_file != snap_src_lst[2]:
                                    raise error.TestFail(err_msg)
                            else:
                                if disk_src_file != snap_src_lst[3]:
                                    raise error.TestFail(err_msg)
                        else:
                            if disk_src_file != blk_source:
                                raise error.TestFail(err_msg)
                        if libvirt_version.version_compare(1, 2, 7):
                            # The job attribute mentions which API started the
                            # operation since 1.2.7.
                            if job_type != 'active-commit':
                                raise error.TestFail("blockcommit job type '%s'"
                                                     " not expected" % job_type)
                            if job_ready != 'yes':
                                # The attribute ready, if present, tracks
                                # progress of the job: yes if the disk is known
                                # to be ready to pivot, or, since 1.2.7, abort
                                # or pivot if the job is in the process of
                                # completing.
                                continue
                            else:
                                logging.debug("after active block commit job "
                                              "ready for pivot, the target disk"
                                              " xml is %s", disk_xml)
                                break
                        else:
                            break
                    else:
                        break
                else:
                    if disk_mirror is None:
                        logging.debug(disk_xml)
                        if "--shallow" in blockcommit_options:
                            chain_lst = snap_src_lst[::-1]
                            chain_lst.pop(0)
                            ret = check_chain_xml(disk_xml, chain_lst)
                            if not ret:
                                raise error.TestFail("Domain image backing "
                                                     "chain check failed")
                        elif "--base" in blockcommit_options:
                            chain_lst = snap_src_lst[::-1]
                            base_index = chain_lst.index(blk_source)
                            chain_lst = chain_lst[base_index:]
                            ret = check_chain_xml(disk_xml, chain_lst)
                            if not ret:
                                raise error.TestFail("Domain image backing "
                                                     "chain check failed")
                        break
                    else:
                        # wait pivot after commit is synced
                        continue
            else:
                logging.debug("after inactive commit the disk xml is: %s"
                              % disk_xml)
                if libvirt_version.version_compare(1, 2, 4):
                    if "--shallow" in blockcommit_options:
                        chain_lst = snap_src_lst[::-1]
                        chain_lst.remove(top_image)
                        ret = check_chain_xml(disk_xml, chain_lst)
                        if not ret:
                            raise error.TestFail("Domain image backing chain "
                                                 "check failed")
                    elif "--base" in blockcommit_options:
                        chain_lst = snap_src_lst[::-1]
                        top_index = chain_lst.index(top_image)
                        base_index = chain_lst.index(blk_source)
                        val_tmp = []
                        for i in range(top_index, base_index):
                            val_tmp.append(chain_lst[i])
                        for i in val_tmp:
                            chain_lst.remove(i)
                        ret = check_chain_xml(disk_xml, chain_lst)
                        if not ret:
                            raise error.TestFail("Domain image backing chain "
                                                 "check failed")
                    break
                else:
                    break

        # Check flag files
        if not vm_state == "shut off" and not multiple_chain:
            for flag in snapshot_flag_files:
                status, output = session.cmd_status_output("cat %s" % flag)
                if status:
                    raise error.TestFail("blockcommit failed: %s" % output)

        if not pivot_opt and snap_in_mirror:
            # do snapshot during mirror phase
            snap_path = "%s/%s.snap" % (tmp_dir, vm_name)
            snap_opt = "--disk-only --atomic --no-metadata "
            snap_opt += "vda,snapshot=external,file=%s" % snap_path
            snapshot_external_disks.append(snap_path)
            cmd_result = virsh.snapshot_create_as(vm_name, snap_opt,
                                                  ignore_statues=True,
                                                  debug=True)
            libvirt.check_exit_status(cmd_result, snap_in_mirror_err)
    finally:
        if vm.is_alive():
            vm.destroy(gracefully=False)
        # Recover xml of vm.
        vmxml_backup.sync("--snapshots-metadata")
        if cmd_session:
            cmd_session.close()
        for disk in snapshot_external_disks:
            if os.path.exists(disk):
                os.remove(disk)

        if disk_src_protocol == 'iscsi':
            libvirt.setup_or_cleanup_iscsi(is_setup=False,
                                           restart_tgtd=restart_tgtd)
        elif disk_src_protocol == 'gluster':
            libvirt.setup_or_cleanup_gluster(False, vol_name, brick_path)
            libvirtd = utils_libvirtd.Libvirtd()
            libvirtd.restart()
        elif disk_src_protocol == 'netfs':
            restore_selinux = params.get('selinux_status_bak')
            libvirt.setup_or_cleanup_nfs(is_setup=False,
                                         restore_selinux=restore_selinux)

Example 49

View license
def run(test, params, env):
    """
    Test virsh migrate when disks are virtio-scsi.
    """

    def check_vm_state(vm, state):
        """
        Return True if vm is in the correct state.
        """
        try:
            actual_state = vm.state()
        except process.CmdError:
            return False
        if cmp(actual_state, state) == 0:
            return True
        else:
            return False

    def check_disks_in_vm(vm, vm_ip, disks_list=[], runner=None):
        """
        Check disks attached to vm.
        """
        fail_list = []
        while len(disks_list):
            disk = disks_list.pop()
            if runner:
                check_cmd = ("ssh %s \"dd if=/dev/urandom of=%s bs=1 "
                             "count=1024\"" % (vm_ip, disk))
                try:
                    logging.debug(runner.run(check_cmd))
                    continue
                except process.CmdError, detail:
                    logging.debug("Remote checking failed:%s", detail)
                    fail_list.append(disk)
            else:
                check_cmd = "dd if=/dev/urandom of=%s bs=1 count=1024"
                session = vm.wait_for_login()
                cs = session.cmd_status(check_cmd)
                if cs:
                    fail_list.append(disk)
                session.close()
        if len(fail_list):
            raise error.TestFail("Checking attached devices failed:%s"
                                 % fail_list)

    def get_disk_id(device):
        """
        Show disk by id.
        """
        output = process.run("ls /dev/disk/by-id/", shell=True).stdout
        for line in output.splitlines():
            disk_ids = line.split()
            for disk_id in disk_ids:
                disk = os.path.basename(
                    process.run("readlink %s" % disk_id, shell=True).stdout)
                if disk == os.path.basename(device):
                    return disk_id
        return None

    def cleanup_ssh_config(vm):
        session = vm.wait_for_login()
        session.cmd("rm -f ~/.ssh/authorized_keys")
        session.cmd("rm -f ~/.ssh/id_rsa*")
        session.close()

    vm = env.get_vm(params.get("migrate_main_vm"))
    source_type = params.get("disk_source_type", "file")
    device_type = params.get("disk_device_type", "disk")
    disk_format = params.get("disk_format_type", "raw")
    if source_type == "file":
        params['added_disk_type'] = "file"
    else:
        params['added_disk_type'] = "block"
        block_device = params.get("disk_block_device", "/dev/EXAMPLE")
        if block_device.count("EXAMPLE"):
            # Prepare host parameters
            local_host = params.get("migrate_source_host", "LOCAL.EXAMPLE")
            remote_host = params.get("migrate_dest_host", "REMOTE.EXAMPLE")
            remote_user = params.get("migrate_dest_user", "root")
            remote_passwd = params.get("migrate_dest_pwd")
            if remote_host.count("EXAMPLE") or local_host.count("EXAMPLE"):
                raise error.TestNAError("Config remote or local host first.")
            rdm_params = {'remote_ip': remote_host,
                          'remote_user': remote_user,
                          'remote_pwd': remote_passwd}
            rdm = utils_test.RemoteDiskManager(rdm_params)
            # Try to build an iscsi device
            # For local, target is a device name
            target = utlv.setup_or_cleanup_iscsi(is_setup=True, is_login=True,
                                                 emulated_image="emulated-iscsi")
            logging.debug("Created target: %s", target)
            try:
                # Attach this iscsi device both local and remote
                remote_device = rdm.iscsi_login_setup(local_host, target)
            except Exception, detail:
                utlv.setup_or_cleanup_iscsi(is_setup=False)
                raise error.TestError("Attach iscsi device on remote failed:%s"
                                      % detail)

            # Use id to get same path on local and remote
            block_device = get_disk_id(target)
            if block_device is None:
                rdm.iscsi_login_setup(local_host, target, is_login=False)
                utlv.setup_or_cleanup_iscsi(is_setup=False)
                raise error.TestError("Set iscsi device couldn't find id?")

    srcuri = params.get("virsh_migrate_srcuri")
    dsturi = params.get("virsh_migrate_dsturi")
    remote_ip = params.get("remote_ip")
    username = params.get("remote_user", "root")
    host_pwd = params.get("remote_pwd")
    # Connection to remote, init here for cleanup
    runner = None
    # Identify easy config. mistakes early
    warning_text = ("Migration VM %s URI %s appears problematic "
                    "this may lead to migration problems. "
                    "Consider specifying vm.connect_uri using "
                    "fully-qualified network-based style.")

    if srcuri.count('///') or srcuri.count('EXAMPLE'):
        raise error.TestNAError(warning_text % ('source', srcuri))

    if dsturi.count('///') or dsturi.count('EXAMPLE'):
        raise error.TestNAError(warning_text % ('destination', dsturi))

    # Config auto-login to remote host for migration
    ssh_key.setup_ssh_key(remote_ip, username, host_pwd)

    sys_image = vm.get_first_disk_devices()
    sys_image_source = sys_image["source"]
    sys_image_info = utils_misc.get_image_info(sys_image_source)
    logging.debug("System image information:\n%s", sys_image_info)
    sys_image_fmt = sys_image_info["format"]
    created_img_path = os.path.join(os.path.dirname(sys_image_source),
                                    "vsmimages")

    migrate_in_advance = "yes" == params.get("migrate_in_advance", "no")

    status_error = "yes" == params.get("status_error", "no")
    if source_type == "file" and device_type == "lun":
        status_error = True

    try:
        # For safety and easily reasons, we'd better define a new vm
        new_vm_name = "%s_vsmtest" % vm.name
        mig = utlv.MigrationTest()
        if vm.is_alive():
            vm.destroy()
        utlv.define_new_vm(vm.name, new_vm_name)
        vm = libvirt_vm.VM(new_vm_name, vm.params, vm.root_dir,
                           vm.address_cache)

        # Change the disk of the vm to shared disk
        # Detach exist devices
        devices = vm.get_blk_devices()
        for device in devices:
            s_detach = virsh.detach_disk(vm.name, device, "--config",
                                         debug=True)
            if not s_detach:
                raise error.TestError("Detach %s failed before test.", device)

        # Attach system image as vda
        # Then added scsi disks will be sda,sdb...
        attach_args = "--subdriver %s --config" % sys_image_fmt
        virsh.attach_disk(vm.name, sys_image_source, "vda",
                          attach_args, debug=True)

        vms = [vm]

        def start_check_vm(vm):
            try:
                vm.start()
            except virt_vm.VMStartError, detail:
                if status_error:
                    logging.debug("Expected failure:%s", detail)
                    return None, None
                else:
                    raise
            vm.wait_for_login()

            # Confirm VM can be accessed through network.
            # And this ip will be used on remote after migration
            vm_ip = vm.get_address()
            vm_pwd = params.get("password")
            s_ping, o_ping = utils_test.ping(vm_ip, count=2, timeout=60)
            logging.info(o_ping)
            if s_ping != 0:
                raise error.TestFail("%s did not respond after several "
                                     "seconds with attaching new devices."
                                     % vm.name)
            return vm_ip, vm_pwd

        options = "--live --unsafe"
        # Do migration before attaching new devices
        if migrate_in_advance:
            vm_ip, vm_pwd = start_check_vm(vm)
            cleanup_ssh_config(vm)
            mig_thread = threading.Thread(target=mig.thread_func_migration,
                                          args=(vm, dsturi, options))
            mig_thread.start()
            # Make sure migration is running
            time.sleep(2)

        # Attach other disks
        params['added_disk_target'] = "scsi"
        params['target_bus'] = "scsi"
        params['device_type'] = device_type
        params['type_name'] = source_type
        params['added_disk_format'] = disk_format
        if migrate_in_advance:
            params["attach_disk_config"] = "no"
            attach_disk_config = False
        else:
            params["attach_disk_config"] = "yes"
            attach_disk_config = True
        try:
            if source_type == "file":
                utlv.attach_disks(vm, "%s/image" % created_img_path,
                                  None, params)
            else:
                ret = utlv.attach_additional_device(vm.name, "sda", block_device,
                                                    params, config=attach_disk_config)
                if ret.exit_status:
                    raise error.TestFail(ret)
        except (error.TestFail, process.CmdError), detail:
            if status_error:
                logging.debug("Expected failure:%s", detail)
                return
            else:
                raise

        if migrate_in_advance:
            mig_thread.join(60)
            if mig_thread.isAlive():
                mig.RET_LOCK.acquire()
                mig.MIGRATION = False
                mig.RET_LOCK.release()
        else:
            vm_ip, vm_pwd = start_check_vm(vm)

        # Have got expected failures when starting vm, end the test
        if vm_ip is None and status_error:
            return

        # Start checking before migration and go on checking after migration
        disks = []
        for target in vm.get_disk_devices().keys():
            if target != "vda":
                disks.append("/dev/%s" % target)

        checked_count = int(params.get("checked_count", 0))
        disks_before = disks[:(checked_count / 2)]
        disks_after = disks[(checked_count / 2):checked_count]
        logging.debug("Disks to be checked:\nBefore migration:%s\n"
                      "After migration:%s", disks_before, disks_after)

        options = "--live --unsafe"
        if not migrate_in_advance:
            cleanup_ssh_config(vm)
            mig.do_migration(vms, None, dsturi, "orderly", options, 120)

        if mig.RET_MIGRATION:
            utils_test.check_dest_vm_network(vm, vm_ip, remote_ip,
                                             username, host_pwd)
            runner = remote.RemoteRunner(host=remote_ip, username=username,
                                         password=host_pwd)
            # After migration, config autologin to vm
            ssh_key.setup_remote_ssh_key(vm_ip, "root", vm_pwd)
            check_disks_in_vm(vm, vm_ip, disks_after, runner)

            if migrate_in_advance:
                raise error.TestFail("Migration before attaching successfully, "
                                     "but not expected.")

    finally:
        # Cleanup remote vm
        if srcuri != dsturi:
            mig.cleanup_dest_vm(vm, srcuri, dsturi)
        # Cleanup created vm anyway
        if vm.is_alive():
            vm.destroy(gracefully=False)
        virsh.undefine(new_vm_name)

        # Cleanup iscsi device for block if it is necessary
        if source_type == "block":
            if params.get("disk_block_device",
                          "/dev/EXAMPLE").count("EXAMPLE"):
                rdm.iscsi_login_setup(local_host, target, is_login=False)
                utlv.setup_or_cleanup_iscsi(is_setup=False,
                                            emulated_image="emulated-iscsi")

        if runner:
            runner.session.close()
        process.run("rm -f %s/*vsmtest" % created_img_path, shell=True)

Example 50

View license
def run(test, params, env):
    """
    Test command: virsh schedinfo.

    This version provide base test of virsh schedinfo command:
    virsh schedinfo <vm> [--set<set_ref>]
    TODO: to support more parameters.

    1) Get parameters and prepare vm's state
    2) Prepare test options.
    3) Run schedinfo command to set or get parameters.
    4) Get schedinfo in cgroup
    5) Recover environment like vm's state
    6) Check result.
    """
    def get_parameter_in_cgroup(vm, cgroup_type, parameter):
        """
        Get vm's cgroup value.

        :Param vm: the vm object
        :Param cgroup_type: type of cgroup we want, vcpu or emulator.
        :Param parameter: the cgroup parameter of vm which we need to get.
        :return: False if expected controller is not mounted.
                 else return value's result object.
        """
        cgroup_path = \
            utils_cgroup.resolve_task_cgroup_path(vm.get_pid(), "cpu")

        if not cgroup_type == "emulator":
            # When a VM has an 'emulator' child cgroup present, we must
            # strip off that suffix when detecting the cgroup for a machine
            if os.path.basename(cgroup_path) == "emulator":
                cgroup_path = os.path.dirname(cgroup_path)
            cgroup_file = os.path.join(cgroup_path, parameter)
        else:
            cgroup_file = os.path.join(cgroup_path, parameter)

        cg_file = None
        try:
            try:
                cg_file = open(cgroup_file)
                result = cg_file.read()
            except IOError:
                raise error.TestError("Failed to open cgroup file %s"
                                      % cgroup_file)
        finally:
            if cg_file is not None:
                cg_file.close()
        return result.strip()

    def schedinfo_output_analyse(result, set_ref, scheduler="posix"):
        """
        Get the value of set_ref.

        :param result: CmdResult struct
        :param set_ref: the parameter has been set
        :param scheduler: the scheduler of qemu(default is posix)
        """
        output = result.stdout.strip()
        if not re.search("Scheduler", output):
            raise error.TestFail("Output is not standard:\n%s" % output)

        result_lines = output.splitlines()
        set_value = None
        for line in result_lines:
            key_value = line.split(":")
            key = key_value[0].strip()
            value = key_value[1].strip()
            if key == "Scheduler":
                if value != scheduler:
                    raise error.TestNAError("This test do not support"
                                            " %s scheduler." % scheduler)
            elif key == set_ref:
                set_value = value
                break
        return set_value

    # Prepare test options
    vm_ref = params.get("schedinfo_vm_ref", "domname")
    options_ref = params.get("schedinfo_options_ref", "")
    options_suffix = params.get("schedinfo_options_suffix", "")
    schedinfo_param = params.get("schedinfo_param", "vcpu")
    set_ref = params.get("schedinfo_set_ref", "")
    cgroup_ref = params.get("schedinfo_cgroup_ref", "cpu.shares")
    set_value = params.get("schedinfo_set_value", "")
    set_method = params.get("schedinfo_set_method", "cmd")
    set_value_expected = params.get("schedinfo_set_value_expected", "")
    # The default scheduler on qemu/kvm is posix
    scheduler_value = "posix"
    status_error = params.get("status_error", "no")

    # Prepare vm test environment
    vm_name = params.get("main_vm")

    if set_ref == "none":
        options_ref = "--set"
        set_ref = None
    elif set_ref:
        if set_method == 'cmd':
            if set_value:
                options_ref = "--set %s=%s" % (set_ref, set_value)
            else:
                options_ref = "--set %s" % set_ref
        elif set_method == 'xml':
            xml = vm_xml.VMXML.new_from_dumpxml(vm_name)
            try:
                cputune = xml.cputune
            except xcepts.LibvirtXMLNotFoundError:
                cputune = vm_xml.VMCPUTuneXML()
            name_map = {
                'cpu_shares': 'shares',
                'vcpu_period': 'period',
                'vcpu_quota': 'quota',
                'emulator_period': 'emulator_period',
                'emulator_quota': 'emulator_quota',
            }
            cputune[name_map[set_ref]] = int(set_value)
            xml.cputune = cputune
            xml.sync()

    vm = env.get_vm(vm_name)
    if vm.is_dead():
        vm.start()
    domid = vm.get_id()
    domuuid = vm.get_uuid()

    if vm_ref == "domid":
        vm_ref = domid
    elif vm_ref == "domname":
        vm_ref = vm_name
    elif vm_ref == "domuuid":
        vm_ref = domuuid
    elif vm_ref == "hex_id":
        if domid == '-':
            vm_ref = domid
        else:
            vm_ref = hex(int(domid))

    options_ref += " %s " % options_suffix

    # Run command
    result = virsh.schedinfo(vm_ref, options_ref,
                             ignore_status=True, debug=True)
    status = result.exit_status

    # VM must be running to get cgroup parameters.
    if not vm.is_alive():
        vm.start()

    if options_ref.count("config"):
        vm.destroy()
        vm.start()

    set_value_of_cgroup = get_parameter_in_cgroup(vm, cgroup_type=schedinfo_param,
                                                  parameter=cgroup_ref)
    vm.destroy()

    if set_ref:
        set_value_of_output = schedinfo_output_analyse(result, set_ref,
                                                       scheduler_value)

    # Check result
    if status_error == "no":
        if status:
            raise error.TestFail("Run failed with right command.")
        else:
            if set_ref and set_value_expected:
                logging.info("value will be set:%s\n"
                             "set value in output:%s\n"
                             "set value in cgroup:%s\n"
                             "expected value:%s" % (
                                 set_value, set_value_of_output,
                                 set_value_of_cgroup, set_value_expected))
                if set_value_of_output is None:
                    raise error.TestFail("Get parameter %s failed." % set_ref)
                # Value in output of virsh schedinfo is not guaranteed 'correct'
                # when we use --config.
                # This is my attempt to fix it
                # http://www.redhat.com/archives/libvir-list/2014-May/msg00466.html.
                # But this patch did not go into upstream of libvirt.
                # Libvirt just guarantee that the value is correct in next boot
                # when we use --config. So skip checking of output in this case.
                if (not (set_value_expected == set_value_of_output) and
                        not (options_ref.count("config"))):
                    raise error.TestFail("Run successful but value "
                                         "in output is not expected.")
                if not (set_value_expected == set_value_of_cgroup):
                    raise error.TestFail("Run successful but value "
                                         "in cgroup is not expected.")
    else:
        if not status:
            raise error.TestFail("Run successfully with wrong command.")