glob.glob

Here are the examples of the python api glob.glob taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

View license
def run(cpac_outdir, outfile_name, qap_type, session_format):

    import os
    import glob
    import yaml

 
    if qap_type == "anat":

        outputs = ["anatomical_reorient", "anatomical_csf_mask", \
                   "anatomical_gm_mask", "anatomical_wm_mask", \
                   "anatomical_to_mni_linear_xfm"]

    elif qap_type == "func":

        outputs = ["mean_functional", "functional_brain_mask", \
                   "motion_correct", "coordinate_transformation"]


    outputs_dict = {}


    for sub_dir in os.listdir(cpac_outdir):

        if not os.path.isdir(os.path.join(cpac_outdir, sub_dir)):
            continue

        sessions = []

        # if the folder structure is sub_id/session_id/scan_id/...
        if session_format == 1:
            for session in os.listdir(os.path.join(cpac_outdir, sub_dir)):
                if os.path.isdir(os.path.join(cpac_outdir, sub_dir, session)):
                    sessions.append(session)

        # if there is no session label in the folder structure
        if session_format == 2:
            # since there is no session, let's assign one
            sessions = ["session_1"]

        # if the session is embedded in the subject ID
        if session_format == 3:
            subject_session = sub_dir

            if "_" not in sub_dir:
                err = "\n\n[!] You said your CPAC output directory had the " \
                      "session IDs embedded in the subject IDs, but it " \
                      "doesn't seem that way for subject ID %s!\n\nIs it " \
                      " formatted like this?   ../pipeline_output/subject_" \
                      "session/output/..\n\nIf not, you're using the wrong " \
                      "option for session_format! Use the -h flag to see " \
                      "the documentation.\n\n%s not being included in the " \
                      "subject list.\n\n" % (sub_dir, sub_dir)
                print err
                continue

            session_id = sub_dir.split("_",1)[1]
            sub_dir = sub_dir.split("_",1)[0]
            sessions = [session_id]


        for session in sessions:

            for resource in outputs:

                resource_path = ""

                if session_format == 1:
                    resource_folder = os.path.join(cpac_outdir, sub_dir, \
                                                       session, resource)
                elif session_format == 2:
                    resource_folder = os.path.join(cpac_outdir, sub_dir, \
                                                       resource)

                elif session_format == 3:
                    resource_folder = os.path.join(cpac_outdir, \
                                                       subject_session, \
                                                       resource)

                # if this current resource/output does not exist for this
                # subject, go to next resource in list
                if not os.path.isdir(resource_folder):
                    continue


                if qap_type == "anat":

                    ''' until CPAC writes multiple anat scans in the '''
                    ''' output folder structure '''
                    scans = ["anat_1"]


                if qap_type == "func":
    
                    scans = []

                    for item in os.listdir(resource_folder):
                        if os.path.isdir(os.path.join(resource_folder, item)):
                            item = item.replace("_scan_","")
                            item = item.replace("_rest","")
                            scans.append(item)


                for scan in scans:

                    if qap_type == "anat":

                        if "mask" in resource:
                            resource_paths = glob.glob(os.path.join(resource_folder, "*", "*"))
                        else:
                            resource_paths = glob.glob(os.path.join(resource_folder, "*"))

                        if len(resource_paths) == 1:
                            resource_path = resource_paths[0]
                        else:
                            print "\nMultiple files for %s for subject %s!!" \
                                  % (resource, sub_dir)
                            print "Check the directory: %s" \
                                      % resource_folder
                            print "%s for %s has not been included in the " \
                                  "subject list.\n" % (resource, sub_dir)
                            continue

                    if qap_type == "func":

                        fullscan = "_scan_" + scan + "_rest"

                        resource_paths = glob.glob(os.path.join(resource_folder, fullscan, "*"))

                        if len(resource_paths) == 1:
                            resource_path = resource_paths[0]
                        else:
                            print "\nMultiple files for %s for subject %s!!" \
                                  % (resource, sub_dir)
                            print "Check the directory: %s" \
                                      % resource_folder
                            print "%s for %s has not been included in the " \
                                  "subject list.\n" % (resource, sub_dir)
                            continue


                    ''' put a catch here for multiple files '''


                    if sub_dir not in outputs_dict.keys():
                        outputs_dict[sub_dir] = {}

                    if session not in outputs_dict[sub_dir].keys():
                        outputs_dict[sub_dir][session] = {}

                    if resource not in outputs_dict[sub_dir][session].keys():
                        outputs_dict[sub_dir][session][resource] = {}

                    if scan not in outputs_dict[sub_dir][session][resource].keys():
                        outputs_dict[sub_dir][session][resource][scan] = resource_path




    # make up for QAP - CPAC resource naming discrepancy
    for subid in outputs_dict.keys():

        for session in outputs_dict[subid].keys():

            for resource in outputs_dict[subid][session].keys():

                if resource == "motion_correct":

                    filepath = outputs_dict[subid][session]["motion_correct"]

                    outputs_dict[subid][session]["func_motion_correct"] = \
                        filepath

                    del outputs_dict[subid][session]["motion_correct"]

                if resource == "anatomical_to_mni_linear_xfm":

                    filepath = outputs_dict[subid][session]["anatomical_to_mni_linear_xfm"]

                    outputs_dict[subid][session]["flirt_affine_xfm"] = \
                        filepath

                    del outputs_dict[subid][session]["anatomical_to_mni_linear_xfm"]



    outfile = os.path.join(os.getcwd(), outfile_name + ".yml")

    with open(outfile, 'w') as f:

        f.write(yaml.dump(outputs_dict, default_flow_style=True))

Example 2

Project: auto-sklearn
Source File: ensemble_builder.py
View license
    def main(self):

        watch = StopWatch()
        watch.start_task('ensemble_builder')

        used_time = 0
        time_iter = 0
        index_run = 0
        num_iteration = 0
        current_num_models = 0
        last_hash = None
        current_hash = None

        dir_ensemble = os.path.join(self.backend.temporary_directory,
                                    '.auto-sklearn',
                                    'predictions_ensemble')
        dir_valid = os.path.join(self.backend.temporary_directory,
                                 '.auto-sklearn',
                                 'predictions_valid')
        dir_test = os.path.join(self.backend.temporary_directory,
                                '.auto-sklearn',
                                'predictions_test')
        paths_ = [dir_ensemble, dir_valid, dir_test]

        dir_ensemble_list_mtimes = []

        self.logger.debug('Starting main loop with %f seconds and %d iterations '
                          'left.' % (self.limit - used_time, num_iteration))
        while used_time < self.limit or (self.max_iterations > 0 and
                                         self.max_iterations >= num_iteration):
            num_iteration += 1
            self.logger.debug('Time left: %f', self.limit - used_time)
            self.logger.debug('Time last ensemble building: %f', time_iter)

            # Reload the ensemble targets every iteration, important, because cv may
            # update the ensemble targets in the cause of running auto-sklearn
            # TODO update cv in order to not need this any more!
            targets_ensemble = self.backend.load_targets_ensemble()

            # Load the predictions from the models
            exists = [os.path.isdir(dir_) for dir_ in paths_]
            if not exists[0]:  # all(exists):
                self.logger.debug('Prediction directory %s does not exist!' %
                              dir_ensemble)
                time.sleep(2)
                used_time = watch.wall_elapsed('ensemble_builder')
                continue

            if self.shared_mode is False:
                dir_ensemble_list = sorted(glob.glob(os.path.join(
                    dir_ensemble, 'predictions_ensemble_%s_*.npy' % self.seed)))
                if exists[1]:
                    dir_valid_list = sorted(glob.glob(os.path.join(
                        dir_valid, 'predictions_valid_%s_*.npy' % self.seed)))
                else:
                    dir_valid_list = []
                if exists[2]:
                    dir_test_list = sorted(glob.glob(os.path.join(
                        dir_test, 'predictions_test_%s_*.npy' % self.seed)))
                else:
                    dir_test_list = []
            else:
                dir_ensemble_list = sorted(os.listdir(dir_ensemble))
                dir_valid_list = sorted(os.listdir(dir_valid)) if exists[1] else []
                dir_test_list = sorted(os.listdir(dir_test)) if exists[2] else []

            # Check the modification times because predictions can be updated
            # over time!
            old_dir_ensemble_list_mtimes = dir_ensemble_list_mtimes
            dir_ensemble_list_mtimes = []
            # The ensemble dir can contain non-model files. We filter them and
            # use the following list instead
            dir_ensemble_model_files = []

            for dir_ensemble_file in dir_ensemble_list:
                if dir_ensemble_file.endswith("/"):
                    dir_ensemble_file = dir_ensemble_file[:-1]
                if not dir_ensemble_file.endswith(".npy"):
                    self.logger.warning('Error loading file (not .npy): %s', dir_ensemble_file)
                    continue

                dir_ensemble_model_files.append(dir_ensemble_file)
                basename = os.path.basename(dir_ensemble_file)
                dir_ensemble_file = os.path.join(dir_ensemble, basename)
                mtime = os.path.getmtime(dir_ensemble_file)
                dir_ensemble_list_mtimes.append(mtime)

            if len(dir_ensemble_model_files) == 0:
                self.logger.debug('Directories are empty')
                time.sleep(2)
                used_time = watch.wall_elapsed('ensemble_builder')
                continue

            if len(dir_ensemble_model_files) <= current_num_models and \
                    old_dir_ensemble_list_mtimes == dir_ensemble_list_mtimes:
                self.logger.debug('Nothing has changed since the last time')
                time.sleep(2)
                used_time = watch.wall_elapsed('ensemble_builder')
                continue

            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                # TODO restructure time management in the ensemble builder,
                # what is the time of index_run actually needed for?
                watch.start_task('index_run' + str(index_run))
            watch.start_task('ensemble_iter_' + str(num_iteration))

            # List of num_runs (which are in the filename) which will be included
            #  later
            include_num_runs = []
            backup_num_runs = []
            model_and_automl_re = re.compile(r'_([0-9]*)_([0-9]*)\.npy$')
            if self.ensemble_nbest is not None:
                # Keeps track of the single scores of each model in our ensemble
                scores_nbest = []
                # The indices of the model that are currently in our ensemble
                indices_nbest = []
                # The names of the models
                model_names = []

            model_names_to_scores = dict()

            model_idx = 0
            for model_name in dir_ensemble_model_files:
                if model_name.endswith("/"):
                    model_name = model_name[:-1]
                basename = os.path.basename(model_name)

                try:
                    if self.precision is "16":
                        predictions = np.load(os.path.join(dir_ensemble, basename)).astype(dtype=np.float16)
                    elif self.precision is "32":
                        predictions = np.load(os.path.join(dir_ensemble, basename)).astype(dtype=np.float32)
                    elif self.precision is "64":
                        predictions = np.load(os.path.join(dir_ensemble, basename)).astype(dtype=np.float64)
                    else:
                        predictions = np.load(os.path.join(dir_ensemble, basename))

                    score = calculate_score(targets_ensemble, predictions,
                                            self.task_type, self.metric,
                                            predictions.shape[1])

                except Exception as e:
                    self.logger.warning('Error loading %s: %s - %s',
                                        basename, type(e), e)
                    score = -1

                model_names_to_scores[model_name] = score
                match = model_and_automl_re.search(model_name)
                automl_seed = int(match.group(1))
                num_run = int(match.group(2))

                if self.ensemble_nbest is not None:
                    if score <= 0.001:
                        self.logger.info('Model only predicts at random: ' +
                                         model_name + ' has score: ' + str(score))
                        backup_num_runs.append((automl_seed, num_run))
                    # If we have less models in our ensemble than ensemble_nbest add
                    # the current model if it is better than random
                    elif len(scores_nbest) < self.ensemble_nbest:
                        scores_nbest.append(score)
                        indices_nbest.append(model_idx)
                        include_num_runs.append((automl_seed, num_run))
                        model_names.append(model_name)
                    else:
                        # Take the worst performing model in our ensemble so far
                        idx = np.argmin(np.array([scores_nbest]))

                        # If the current model is better than the worst model in
                        # our ensemble replace it by the current model
                        if scores_nbest[idx] < score:
                            self.logger.info(
                                'Worst model in our ensemble: %s with score %f '
                                'will be replaced by model %s with score %f',
                                model_names[idx], scores_nbest[idx], model_name,
                                score)
                            # Exclude the old model
                            del scores_nbest[idx]
                            scores_nbest.append(score)
                            del include_num_runs[idx]
                            del indices_nbest[idx]
                            indices_nbest.append(model_idx)
                            include_num_runs.append((automl_seed, num_run))
                            del model_names[idx]
                            model_names.append(model_name)

                        # Otherwise exclude the current model from the ensemble
                        else:
                            # include_num_runs.append(True)
                            pass

                else:
                    # Load all predictions that are better than random
                    if score <= 0.001:
                        # include_num_runs.append(True)
                        self.logger.info('Model only predicts at random: ' +
                                         model_name + ' has score: ' +
                                         str(score))
                        backup_num_runs.append((automl_seed, num_run))
                    else:
                        include_num_runs.append((automl_seed, num_run))

                model_idx += 1

            # If there is no model better than random guessing, we have to use
            # all models which do random guessing
            if len(include_num_runs) == 0:
                include_num_runs = backup_num_runs

            indices_to_model_names = dict()
            indices_to_run_num = dict()
            for i, model_name in enumerate(dir_ensemble_model_files):
                match = model_and_automl_re.search(model_name)
                automl_seed = int(match.group(1))
                num_run = int(match.group(2))
                if (automl_seed, num_run) in include_num_runs:
                    num_indices = len(indices_to_model_names)
                    indices_to_model_names[num_indices] = model_name
                    indices_to_run_num[num_indices] = (automl_seed, num_run)

            try:
                all_predictions_train, all_predictions_valid, all_predictions_test =\
                    self.get_all_predictions(dir_ensemble,
                                             dir_ensemble_model_files,
                                             dir_valid, dir_valid_list,
                                             dir_test, dir_test_list,
                                             include_num_runs,
                                             model_and_automl_re,
                                             self.precision)
            except IOError:
                self.logger.error('Could not load the predictions.')
                continue

            if len(include_num_runs) == 0:
                self.logger.error('All models do just random guessing')
                time.sleep(2)
                continue

            else:
                ensemble = EnsembleSelection(ensemble_size=self.ensemble_size,
                                             task_type=self.task_type,
                                             metric=self.metric)

                try:
                    ensemble.fit(all_predictions_train, targets_ensemble,
                                 include_num_runs)
                    self.logger.info(ensemble)

                except ValueError as e:
                    self.logger.error('Caught ValueError: ' + str(e))
                    used_time = watch.wall_elapsed('ensemble_builder')
                    time.sleep(2)
                    continue
                except IndexError as e:
                    self.logger.error('Caught IndexError: ' + str(e))
                    used_time = watch.wall_elapsed('ensemble_builder')
                    time.sleep(2)
                    continue
                except Exception as e:
                    self.logger.error('Caught error! %s', str(e))
                    used_time = watch.wall_elapsed('ensemble_builder')
                    time.sleep(2)
                    continue

                # Output the score
                self.logger.info('Training performance: %f' % ensemble.train_score_)

                self.logger.info('Building the ensemble took %f seconds' %
                            watch.wall_elapsed('ensemble_iter_' + str(num_iteration)))

            # Set this variable here to avoid re-running the ensemble builder
            # every two seconds in case the ensemble did not change
            current_num_models = len(dir_ensemble_model_files)

            ensemble_predictions = ensemble.predict(all_predictions_train)
            if sys.version_info[0] == 2:
                ensemble_predictions.flags.writeable = False
                current_hash = hash(ensemble_predictions.data)
            else:
                current_hash = hash(ensemble_predictions.data.tobytes())

            # Only output a new ensemble and new predictions if the output of the
            # ensemble would actually change!
            # TODO this is neither safe (collisions, tests only with the ensemble
            #  prediction, but not the ensemble), implement a hash function for
            # each possible ensemble builder.
            if last_hash is not None:
                if current_hash == last_hash:
                    self.logger.info('Ensemble output did not change.')
                    time.sleep(2)
                    continue
                else:
                    last_hash = current_hash
            else:
                last_hash = current_hash

            # Save the ensemble for later use in the main auto-sklearn module!
            self.backend.save_ensemble(ensemble, index_run, self.seed)

            # Save predictions for valid and test data set
            if len(dir_valid_list) == len(dir_ensemble_model_files):
                all_predictions_valid = np.array(all_predictions_valid)
                ensemble_predictions_valid = ensemble.predict(all_predictions_valid)
                if self.task_type == BINARY_CLASSIFICATION:
                    ensemble_predictions_valid = ensemble_predictions_valid[:, 1]
                if self.low_precision:
                    if self.task_type in [BINARY_CLASSIFICATION, MULTICLASS_CLASSIFICATION, MULTILABEL_CLASSIFICATION]:
                        ensemble_predictions_valid[ensemble_predictions_valid < 1e-4] = 0.
                    if self.metric in [BAC_METRIC, F1_METRIC]:
                        bin_array = np.zeros(ensemble_predictions_valid.shape, dtype=np.int32)
                        if (self.task_type != MULTICLASS_CLASSIFICATION) or (
                            ensemble_predictions_valid.shape[1] == 1):
                            bin_array[ensemble_predictions_valid >= 0.5] = 1
                        else:
                            sample_num = ensemble_predictions_valid.shape[0]
                            for i in range(sample_num):
                                j = np.argmax(ensemble_predictions_valid[i, :])
                                bin_array[i, j] = 1
                        ensemble_predictions_valid = bin_array
                    if self.task_type in CLASSIFICATION_TASKS:
                        if ensemble_predictions_valid.size < (20000 * 20):
                            precision = 3
                        else:
                            precision = 2
                    else:
                        if ensemble_predictions_valid.size > 1000000:
                            precision = 4
                        else:
                            # File size maximally 2.1MB
                            precision = 6

                self.backend.save_predictions_as_txt(ensemble_predictions_valid,
                                                'valid', index_run, prefix=self.dataset_name,
                                                precision=precision)
            else:
                self.logger.info('Could not find as many validation set predictions (%d)'
                             'as ensemble predictions (%d)!.',
                            len(dir_valid_list), len(dir_ensemble_model_files))

            del all_predictions_valid

            if len(dir_test_list) == len(dir_ensemble_model_files):
                all_predictions_test = np.array(all_predictions_test)
                ensemble_predictions_test = ensemble.predict(all_predictions_test)
                if self.task_type == BINARY_CLASSIFICATION:
                    ensemble_predictions_test = ensemble_predictions_test[:, 1]
                if self.low_precision:
                    if self.task_type in [BINARY_CLASSIFICATION, MULTICLASS_CLASSIFICATION, MULTILABEL_CLASSIFICATION]:
                        ensemble_predictions_test[ensemble_predictions_test < 1e-4] = 0.
                    if self.metric in [BAC_METRIC, F1_METRIC]:
                        bin_array = np.zeros(ensemble_predictions_test.shape,
                                             dtype=np.int32)
                        if (self.task_type != MULTICLASS_CLASSIFICATION) or (
                                    ensemble_predictions_test.shape[1] == 1):
                            bin_array[ensemble_predictions_test >= 0.5] = 1
                        else:
                            sample_num = ensemble_predictions_test.shape[0]
                            for i in range(sample_num):
                                j = np.argmax(ensemble_predictions_test[i, :])
                                bin_array[i, j] = 1
                        ensemble_predictions_test = bin_array
                    if self.task_type in CLASSIFICATION_TASKS:
                        if ensemble_predictions_test.size < (20000 * 20):
                            precision = 3
                        else:
                            precision = 2
                    else:
                        if ensemble_predictions_test.size > 1000000:
                            precision = 4
                        else:
                            precision = 6

                self.backend.save_predictions_as_txt(ensemble_predictions_test,
                                                     'test', index_run, prefix=self.dataset_name,
                                                     precision=precision)
            else:
                self.logger.info('Could not find as many test set predictions (%d) as '
                             'ensemble predictions (%d)!',
                            len(dir_test_list), len(dir_ensemble_model_files))

            del all_predictions_test

            current_num_models = len(dir_ensemble_model_files)
            watch.stop_task('index_run' + str(index_run))
            time_iter = watch.get_wall_dur('index_run' + str(index_run))
            used_time = watch.wall_elapsed('ensemble_builder')
            index_run += 1
        return

Example 3

Project: mpop
Source File: h5_pps_l2.py
View license
    def load(self, satscene, *args, **kwargs):
        """Read data from file and load it into *satscene*.
        """
        lonlat_is_loaded = False

        geofilename = kwargs.get('geofilename')
        prodfilename = kwargs.get('filename')

        products = []
        if "CTTH" in satscene.channels_to_load:
            products.append("ctth")
        if "CT" in satscene.channels_to_load:
            products.append("cloudtype")
        if "CMA" in satscene.channels_to_load:
            products.append("cloudmask")
        if "PC" in satscene.channels_to_load:
            products.append("precipclouds")
        if "CPP" in satscene.channels_to_load:
            products.append("cpp")

        if len(products) == 0:
            return

        try:
            area_name = satscene.area_id or satscene.area.area_id
        except AttributeError:
            area_name = "satproj_?????_?????"

        # Looking for geolocation file

        conf = ConfigParser()
        conf.read(os.path.join(CONFIG_PATH, satscene.fullname + ".cfg"))

        try:
            geodir = conf.get(satscene.instrument_name + "-level3",
                              "cloud_product_geodir",
                              vars=os.environ)
        except NoOptionError:
            LOG.warning("No option 'geodir' in level3 section")
            geodir = None

        if not geofilename and geodir:
            # Load geo file from config file:
            try:
                if not satscene.orbit:
                    orbit = ""
                else:
                    orbit = satscene.orbit
                geoname_tmpl = conf.get(satscene.instrument_name + "-level3",
                                        "cloud_product_geofilename",
                                        raw=True,
                                        vars=os.environ)
                filename_tmpl = (satscene.time_slot.strftime(geoname_tmpl)
                                 % {"orbit": str(orbit).zfill(5) or "*",
                                    "area": area_name,
                                    "satellite": satscene.satname + satscene.number})

                file_list = glob.glob(os.path.join(geodir, filename_tmpl))
                if len(file_list) > 1:
                    LOG.warning("More than 1 file matching for geoloaction: "
                                + str(file_list))
                elif len(file_list) == 0:
                    LOG.warning(
                        "No geolocation file matching!: "
                        + os.path.join(geodir, filename_tmpl))
                else:
                    geofilename = file_list[0]
            except NoOptionError:
                geofilename = None

        # Reading the products

        classes = {"ctth": CloudTopTemperatureHeight,
                   "cloudtype": CloudType,
                   "cloudmask": CloudMask,
                   "precipclouds": PrecipitationClouds,
                   "cpp": CloudPhysicalProperties
                   }

        nodata_mask = False

        area = None
        lons = None
        lats = None
        chn = None
        shape = None
        read_external_geo = {}
        for product in products:
            LOG.debug("Loading " + product)

            if isinstance(prodfilename, (list, tuple, set)):
                for fname in prodfilename:
                    kwargs['filename'] = fname
                    self.load(satscene, *args, **kwargs)
                return
            elif (prodfilename and
                  os.path.basename(prodfilename).startswith('S_NWC')):
                if os.path.basename(prodfilename).split("_")[2] == NEW_PRODNAMES[product]:
                    filename = prodfilename
                else:
                    continue
            else:
                filename = conf.get(satscene.instrument_name + "-level3",
                                    "cloud_product_filename",
                                    raw=True,
                                    vars=os.environ)
                directory = conf.get(satscene.instrument_name + "-level3",
                                     "cloud_product_dir",
                                     vars=os.environ)
                pathname_tmpl = os.path.join(directory, filename)
                LOG.debug("Path = " + str(pathname_tmpl))

                if not satscene.orbit:
                    orbit = ""
                else:
                    orbit = satscene.orbit

                filename_tmpl = (satscene.time_slot.strftime(pathname_tmpl)
                                 % {"orbit": str(orbit).zfill(5) or "*",
                                    "area": area_name,
                                    "satellite": satscene.satname + satscene.number,
                                    "product": product})

                file_list = glob.glob(filename_tmpl)
                if len(file_list) == 0:
                    product_name = NEW_PRODNAMES.get(product, product)
                    LOG.info("No " + str(product) +
                             " product in old format matching")
                    filename_tmpl = (satscene.time_slot.strftime(pathname_tmpl)
                                     % {"orbit": str(orbit).zfill(5) or "*",
                                        "area": area_name,
                                        "satellite": satscene.satname + satscene.number,
                                        "product": product_name})

                    file_list = glob.glob(filename_tmpl)

                if len(file_list) > 1:
                    LOG.warning("More than 1 file matching for " + product + "! "
                                + str(file_list))
                    continue
                elif len(file_list) == 0:
                    LOG.warning(
                        "No " + product + " matching!: " + filename_tmpl)
                    continue
                else:
                    filename = file_list[0]

            chn = classes[product]()
            chn.read(filename, lonlat_is_loaded == False)
            satscene.channels.append(chn)
            # Check if geolocation is loaded:
            if not chn.area:
                read_external_geo[product] = chn
                shape = chn.shape

        # Check if some 'channel'/product needs geolocation. If some product does
        # not have geolocation, get it from the geofilename:
        if not read_external_geo:
            LOG.info("Loading PPS parameters done.")
            return

        # Load geolocation
        interpolate = False
        if geofilename:
            geodict = get_lonlat(geofilename)
            lons, lats = geodict['lon'], geodict['lat']
            if lons.shape != shape or lats.shape != shape:
                interpolate = True
                row_indices = geodict['row_indices']
                column_indices = geodict['col_indices']

            lonlat_is_loaded = True
        else:
            LOG.warning("No Geo file specified: " +
                        "Geolocation will be loaded from product")

        if lonlat_is_loaded:
            if interpolate:
                from geotiepoints import SatelliteInterpolator

                cols_full = np.arange(shape[1])
                rows_full = np.arange(shape[0])

                satint = SatelliteInterpolator((lons, lats),
                                               (row_indices,
                                                column_indices),
                                               (rows_full, cols_full))
                # satint.fill_borders("y", "x")
                lons, lats = satint.interpolate()

            try:
                from pyresample import geometry
                lons = np.ma.masked_array(lons, nodata_mask)
                lats = np.ma.masked_array(lats, nodata_mask)
                area = geometry.SwathDefinition(lons=lons,
                                                lats=lats)
            except ImportError:
                area = None

        for chn in read_external_geo.values():
            if area:
                chn.area = area
            else:
                chn.lat = lats
                chn.lon = lons

        LOG.info("Loading PPS parameters done.")

        return

Example 4

Project: mpop
Source File: viirs_sdr.py
View license
    def load(self, satscene, calibrate=1, time_interval=None,
             area=None, filename=None, **kwargs):
        """Read viirs SDR reflectances and Tbs from file and load it into
        *satscene*.
        """
        if satscene.instrument_name != "viirs":
            raise ValueError("Wrong instrument, expecting viirs")

        if kwargs:
            logger.warning(
                "Unsupported options for viirs reader: %s", str(kwargs))

        conf = ConfigParser()
        conf.read(os.path.join(CONFIG_PATH, satscene.fullname + ".cfg"))
        options = {}
        for option, value in conf.items(satscene.instrument_name + "-level2",
                                        raw=True):
            options[option] = value

        band_list = [s.name for s in satscene.channels]
        chns = satscene.channels_to_load & set(band_list)
        if len(chns) == 0:
            return

        if time_interval:
            time_start, time_end = time_interval
        else:
            time_start, time_end = satscene.time_slot, None

        import glob

        if "filename" not in options:
            raise IOError("No filename given, cannot load")

        values = {"orbit": satscene.orbit,
                  "satname": satscene.satname,
                  "instrument": satscene.instrument_name,
                  "satellite": satscene.satname
                  #"satellite": satscene.fullname
                  }

        file_list = []
        if filename is not None:
            if not isinstance(filename, (list, set, tuple)):
                filename = [filename]
            geofile_list = []
            for fname in filename:
                if os.path.basename(fname).startswith("SV"):
                    file_list.append(fname)
                elif os.path.basename(fname).startswith("G"):
                    geofile_list.append(fname)
                else:
                    logger.info("Unrecognized SDR file: %s", fname)
            if file_list:
                directory = os.path.dirname(file_list[0])
            if geofile_list:
                geodirectory = os.path.dirname(geofile_list[0])

        if not file_list:
            filename_tmpl = strftime(
                satscene.time_slot, options["filename"]) % values

            directory = strftime(satscene.time_slot, options["dir"]) % values

            if not os.path.exists(directory):
                #directory = globify(options["dir"]) % values
                directory = globify(
                    strftime(satscene.time_slot, options["dir"])) % values
                logger.debug(
                    "Looking for files in directory " + str(directory))
                directories = glob.glob(directory)
                if len(directories) > 1:
                    raise IOError("More than one directory for npp scene... " +
                                  "\nSearch path = %s\n\tPlease check npp.cfg file!" % directory)
                elif len(directories) == 0:
                    raise IOError("No directory found for npp scene. " +
                                  "\nSearch path = %s\n\tPlease check npp.cfg file!" % directory)
                else:
                    directory = directories[0]

            file_list = glob.glob(os.path.join(directory, filename_tmpl))

            # Only take the files in the interval given:
            logger.debug("Number of files before segment selection: "
                         + str(len(file_list)))
            for fname in file_list:
                if os.path.basename(fname).startswith("SVM14"):
                    logger.debug("File before segmenting: "
                                 + os.path.basename(fname))
            file_list = _get_swathsegment(
                file_list, time_start, time_end, area)
            logger.debug("Number of files after segment selection: "
                         + str(len(file_list)))

            for fname in file_list:
                if os.path.basename(fname).startswith("SVM14"):
                    logger.debug("File after segmenting: "
                                 + os.path.basename(fname))

            logger.debug("Template = " + str(filename_tmpl))

            # 22 VIIRS bands (16 M-bands + 5 I-bands + DNB)
            if len(file_list) % 22 != 0:
                logger.warning("Number of SDR files is not divisible by 22!")
            if len(file_list) == 0:
                logger.debug(
                    "File template = " + str(os.path.join(directory, filename_tmpl)))
                raise IOError("No VIIRS SDR file matching!: " +
                              "Start time = " + str(time_start) +
                              "  End time = " + str(time_end))

            geo_dir_string = options.get("geo_dir", None)
            if geo_dir_string:
                geodirectory = strftime(
                    satscene.time_slot, geo_dir_string) % values
            else:
                geodirectory = directory
            logger.debug("Geodir = " + str(geodirectory))

            geofile_list = []
            geo_filenames_string = options.get("geo_filenames", None)
            if geo_filenames_string:
                geo_filenames_tmpl = strftime(satscene.time_slot,
                                              geo_filenames_string) % values
                geofile_list = glob.glob(os.path.join(geodirectory,
                                                      geo_filenames_tmpl))
                logger.debug("List of geo-files: " + str(geofile_list))
                # Only take the files in the interval given:
                geofile_list = _get_swathsegment(
                    geofile_list, time_start, time_end)

            logger.debug("List of geo-files (after time interval selection): "
                         + str(geofile_list))

        filenames = [os.path.basename(s) for s in file_list]

        glob_info = {}

        self.geofiles = geofile_list

        logger.debug("Channels to load: " + str(satscene.channels_to_load))
        for chn in satscene.channels_to_load:
            # Take only those files in the list matching the band:
            # (Filename starts with 'SV' and then the band-name)
            fnames_band = []

            try:
                fnames_band = [s for s in filenames if s.find('SV' + chn) >= 0]
            except TypeError:
                logger.warning('Band frequency not available from VIIRS!')
                logger.info('Asking for channel' + str(chn) + '!')

            if len(fnames_band) == 0:
                continue

            filename_band = [
                os.path.join(directory, fname) for fname in fnames_band]
            logger.debug("fnames_band = " + str(filename_band))

            band = ViirsBandData(filename_band, calibrate=calibrate).read()
            logger.debug('Band id = ' + band.band_id)

            # If the list of geo-files is not specified in the config file or
            # some of them do not exist, we rely on what is written in the
            # band-data metadata header:
            if len(geofile_list) < len(filename_band):
                geofilenames_band = [os.path.join(geodirectory, gfile) for
                                     gfile in band.geo_filenames]
                logger.debug("Geolocation filenames for band: " +
                             str(geofilenames_band))
                # Check if the geo-filenames found from the metadata actually
                # exist and issue a warning if they do not:
                for geofilename in geofilenames_band:
                    if not os.path.exists(geofilename):
                        logger.warning("Geo file defined in metadata header " +
                                       "does not exist: " + str(geofilename))

            elif band.band_id.startswith('M'):
                geofilenames_band = [geofile for geofile in geofile_list
                                     if os.path.basename(geofile).startswith('GMTCO')]
                if len(geofilenames_band) != len(filename_band):
                    # Try the geoid instead:
                    geofilenames_band = [geofile for geofile in geofile_list
                                         if os.path.basename(geofile).startswith('GMODO')]
                    if len(geofilenames_band) != len(filename_band):
                        raise IOError("Not all geo location files " +
                                      "for this scene are present for band " +
                                      band.band_id + "!")
            elif band.band_id.startswith('I'):
                geofilenames_band = [geofile for geofile in geofile_list
                                     if os.path.basename(geofile).startswith('GITCO')]
                if len(geofilenames_band) != len(filename_band):
                    # Try the geoid instead:
                    geofilenames_band = [geofile for geofile in geofile_list
                                         if os.path.basename(geofile).startswith('GIMGO')]
                    if len(geofilenames_band) != len(filename_band):
                        raise IOError("Not all geo location files " +
                                      "for this scene are present for band " +
                                      band.band_id + "!")
            elif band.band_id.startswith('D'):
                geofilenames_band = [geofile for geofile in geofile_list
                                     if os.path.basename(geofile).startswith('GDNBO')]
                if len(geofilenames_band) != len(filename_band):
                    raise IOError("Not all geo-location files " +
                                  "for this scene are present for " +
                                  "the Day Night Band!")

            band.read_lonlat(geofilepaths=geofilenames_band)

            if not band.band_desc:
                logger.warning('Band name = ' + band.band_id)
                raise AttributeError('Band description not supported!')

            satscene[chn].data = band.data
            satscene[chn].info['units'] = band.units
            satscene[chn].info['band_id'] = band.band_id
            satscene[chn].info['start_time'] = band.begin_time
            satscene[chn].info['end_time'] = band.end_time
            if chn in ['M01', 'M02', 'M03', 'M04', 'M05', 'M06', 'M07', 'M08', 'M09', 'M10', 'M11',
                       'I01', 'I02', 'I03']:
                satscene[chn].info['sun_zen_correction_applied'] = True

            # We assume the same geolocation should apply to all M-bands!
            # ...and the same to all I-bands:

            from pyresample import geometry

            satscene[chn].area = geometry.SwathDefinition(
                lons=np.ma.masked_where(band.data.mask,
                                        band.geolocation.longitudes,
                                        copy=False),
                lats=np.ma.masked_where(band.data.mask,
                                        band.geolocation.latitudes,
                                        copy=False))
            area_name = ("swath_" + satscene.fullname + "_" +
                         str(satscene.time_slot) + "_"
                         + str(satscene[chn].data.shape) + "_" +
                         band.band_uid)

            satscene[chn].area.area_id = area_name
            satscene[chn].area_id = area_name

            if self.shape is None:
                self.shape = band.data.shape

            # except ImportError:
            #    satscene[chn].area = None
            #    satscene[chn].lat = np.ma.array(band.latitude, mask=band.data.mask)
            #    satscene[chn].lon = np.ma.array(band.longitude, mask=band.data.mask)

            # if 'institution' not in glob_info:
            ##    glob_info['institution'] = band.global_info['N_Dataset_Source']
            # if 'mission_name' not in glob_info:
            ##    glob_info['mission_name'] = band.global_info['Mission_Name']

        ViirsGeolocationData.clear_cache()

        # Compulsory global attribudes
        satscene.info["title"] = (satscene.satname.capitalize() +
                                  " satellite, " +
                                  satscene.instrument_name.capitalize() +
                                  " instrument.")
        if 'institution' in glob_info:
            satscene.info["institution"] = glob_info['institution']

        if 'mission_name' in glob_info:
            satscene.add_to_history(glob_info['mission_name'] +
                                    " VIIRS SDR read by mpop")
        else:
            satscene.add_to_history("NPP/JPSS VIIRS SDR read by mpop")

        satscene.info["references"] = "No reference."
        satscene.info["comments"] = "No comment."

        satscene.info["start_time"] = min([chn.info["start_time"]
                                           for chn in satscene
                                           if chn.is_loaded()])
        satscene.info["end_time"] = max([chn.info["end_time"]
                                         for chn in satscene
                                         if chn.is_loaded()])

Example 5

Project: emogenerator
Source File: emogenerator.py
View license
def emogenerator(options, inArguments):
	# If we don't have an input file lets try and find one in the cwd
	if options.input == None:
		files = glob.glob('*.xcdatamodel')
		if files:
			options.input = files[0]
	if options.input == None:
		files = glob.glob('*.xcdatamodeld')
		if files:
			options.input = files[0]
	if options.input == None:
		files = glob.glob('*.mom')
		if files:
			options.input = files[0]
	if options.input == None:
		raise Exception('Could not find a data model file.')

	# Sanitize input directories with a trailing slash
	if options.input[-1] == '/':
		options.input = options.input[:-1]

	# If we still don't have an input file we need to bail.
	if not os.path.exists(options.input):
		raise Exception('Input file doesnt exist at %s' % options.input)

	logger.info('Using \'%s\'' % options.input)

	options.input_type = os.path.splitext(options.input)[1][1:]
	if options.input_type not in ['mom', 'xcdatamodel', 'xcdatamodeld']:
		raise Exception('Input file is not a .mom or a .xcdatamodel. Why are you trying to trick me?')

	logger.info('Processing \'%s\'', options.input)

	# Set up a list of CoreData attribute types to Cocoa classes/C types. In theory this could be user configurable, but I don't see the need.
	theTypenamesByAttributeType = {
		CoreData.NSStringAttributeType: dict(cocoaType = 'NSString *'),
		CoreData.NSDateAttributeType: dict(cocoaType = 'NSDate *'),
		CoreData.NSBinaryDataAttributeType: dict(cocoaType = 'NSData *'),
		CoreData.NSDecimalAttributeType: dict(cocoaType = 'NSDecimalNumber *'),
		CoreData.NSInteger16AttributeType: dict(cocoaType = 'NSNumber *', ctype = 'short', toCTypeConverter = 'shortValue', toCocoaTypeConverter = 'numberWithShort'),
		CoreData.NSInteger32AttributeType: dict(cocoaType = 'NSNumber *', ctype = 'int', toCTypeConverter = 'intValue', toCocoaTypeConverter = 'numberWithInt'),
		CoreData.NSInteger64AttributeType: dict(cocoaType = 'NSNumber *', ctype = 'long long', toCTypeConverter = 'longLongValue', toCocoaTypeConverter = 'numberWithLongLong'),
		CoreData.NSDoubleAttributeType: dict(cocoaType = 'NSNumber *', ctype = 'double', toCTypeConverter = 'doubleValue', toCocoaTypeConverter = 'numberWithDouble'),
		CoreData.NSFloatAttributeType: dict(cocoaType = 'NSNumber *', ctype = 'float', toCTypeConverter = 'floatValue', toCocoaTypeConverter = 'numberWithFloat'),
		CoreData.NSBooleanAttributeType: dict(cocoaType = 'NSNumber *', ctype = 'BOOL', toCTypeConverter = 'boolValue', toCocoaTypeConverter = 'numberWithBool'),
		CoreData.NSTransformableAttributeType: dict(cocoaType = 'id '),
		}

	if options.input_type in ['xcdatamodel', 'xcdatamodeld']:
		if not os.path.exists(options.momcpath):
			raise Exception('Cannot find momc at \'%s\'' % options.momcpath)
		logger.info('Using momc at \'%s\'', options.momcpath)
		# Create a place to put the generated mom file
		theTempDirectory = tempfile.mkdtemp()
		theObjectModelPath = os.path.join(theTempDirectory, 'Output.mom')

		# Tell momc to compile our xcdatamodel into a managed object model
		theResult = subprocess.call([options.momcpath, options.input, theObjectModelPath])
		if theResult != 0:
			raise Exception('momc failed with %d', theResult)
	else:
		theObjectModelPath = options.input

	# No? Ok, let's fall back to the cwd
	if options.template == None:
		options.template = 'templates'

	logger.info('Using input mom file \'%s\'', theObjectModelPath)
	logger.info('Using output directory \'%s\'', options.output)
	logger.info('Using template directory \'%s\'', options.template)

	# Load the managed object model.
	theObjectModelURL = Foundation.NSURL.fileURLWithPath_(theObjectModelPath)
	theObjectModel = CoreData.NSManagedObjectModel.alloc().initWithContentsOfURL_(theObjectModelURL)

	# Start up genshi..
	theLoader = genshi.template.TemplateLoader(options.template)

	theContext = dict(
		C = lambda X:X[0].upper() + X[1:],
		author = Foundation.NSFullUserName(),
		date = datetime.datetime.now().strftime('%x'),
		year = datetime.datetime.now().year,
		organizationName = '__MyCompanyName__',
		options = dict(
			suppressAccessorDeclarations = True,
			suppressAccessorDefinitions = True,
			),
		)

	theXcodePrefs = Foundation.NSDictionary.dictionaryWithContentsOfFile_(os.path.expanduser('~/Library/Preferences/com.apple.xcode.plist'))
	if theXcodePrefs:
		if 'PBXCustomTemplateMacroDefinitions' in theXcodePrefs:
			if 'ORGANIZATIONNAME' in theXcodePrefs['PBXCustomTemplateMacroDefinitions']:
				theContext['organizationName'] = theXcodePrefs['PBXCustomTemplateMacroDefinitions']['ORGANIZATIONNAME']

	# Process each entity...
	for theEntityDescription in theObjectModel.entities():
		# Create a dictionary describing the entity, we'll be passing this to the genshi template.
		theEntityDict = {
			'entity': theEntityDescription,
			'name': theEntityDescription.name(),
			'className': theEntityDescription.managedObjectClassName(),
			'superClassName': 'NSManagedObject',
			'properties': [],
			'relatedEntityClassNames': [],
			}

		if theEntityDict['className'] == 'NSManagedObject':
			logger.info('Skipping entity \'%s\', no custom subclass specified.', theEntityDescription.name())
			continue

		if theEntityDescription.superentity():
			theEntityDict['superClassName'] = theEntityDescription.superentity().managedObjectClassName()

		# Process each property of the entity.
		for thePropertyDescription in theEntityDescription.properties():
			if theEntityDescription != thePropertyDescription.entity():
				continue

			# This dictionary describes the property and is appended to the entity dictionary we created earlier.
			thePropertyDict = {
				'property': thePropertyDescription,
				'name': MyString(thePropertyDescription.name()),
				'type': MyString(thePropertyDescription.className()),
				'CType': None,
				}

			if thePropertyDescription.className() == 'NSAttributeDescription':
				if thePropertyDescription.attributeType() not in theTypenamesByAttributeType:
					logger.warning('Did not understand the property type: %d', thePropertyDescription.attributeType())
					continue

				theTypenameByAttributeType = theTypenamesByAttributeType[thePropertyDescription.attributeType()]


				theCocoaType = theTypenameByAttributeType['cocoaType']
				if type(theCocoaType) != str:
					theCocoaType = theCocoaType(thePropertyDescription)
				thePropertyDict['CocoaType'] = theCocoaType


				if 'ctype' in theTypenameByAttributeType:
					thePropertyDict['CType'] = theTypenameByAttributeType['ctype']
					thePropertyDict['toCTypeConverter'] = theTypenameByAttributeType['toCTypeConverter']
					thePropertyDict['toCocoaTypeConverter'] = theTypenameByAttributeType['toCocoaTypeConverter']

			elif thePropertyDescription.className() == 'NSRelationshipDescription':
				thePropertyDict['isToMany'] = thePropertyDescription.isToMany()
				thePropertyDict['destinationEntityClassNames'] = thePropertyDescription.destinationEntity().managedObjectClassName()
				theEntityDict['relatedEntityClassNames'].append(thePropertyDescription.destinationEntity().managedObjectClassName())
			else:
				continue

			theEntityDict['properties'].append(thePropertyDict)

		theEntityDict['attributes'] = [x for x in theEntityDict['properties'] if x['type'] == 'NSAttributeDescription']
		theEntityDict['relationships'] = [x for x in theEntityDict['properties'] if x['type'] == 'NSRelationshipDescription']

		theTemplateNames = ['classname.h.genshi', 'classname.m.genshi']
		for theTemplateName in theTemplateNames:

			theTemplate = theLoader.load(theTemplateName, cls=genshi.template.NewTextTemplate)

			theContext['entity'] = theEntityDict

			theStream = theTemplate.generate(**theContext)
			theNewContent = theStream.render()

			theFilename = theEntityDescription.managedObjectClassName() + '.' + re.match(r'.+\.(.+)\.genshi', theTemplateName).group(1)

			theOutputPath = os.path.join(options.output, theFilename)

			if os.path.exists(theOutputPath) == False:
				file(theOutputPath, 'w').write(theNewContent)
			else:
				theCurrentContent = file(theOutputPath).read()
				theNewContent = merge(theNewContent, theCurrentContent, [
					('#pragma mark begin emogenerator accessors', '#pragma mark end emogenerator accessors'),
					('#pragma mark begin emogenerator forward declarations', '#pragma mark end emogenerator forward declarations'),
					('#pragma mark begin emogenerator relationship accessors', '#pragma mark end emogenerator relationship accessors'),
					])
				if theNewContent != theCurrentContent:
					file(theOutputPath, 'w').write(theNewContent)

Example 6

Project: cgat
Source File: tophat_segment_juncs.py
View license
def main( argv = None ):
    """script main.

    parses command line options in sys.argv, unless *argv* is given.
    """

    if DISABLE:
        print "# tophat_segment_juncs.py disabled"
        argv[0] = "segment_juncs.original"
        runCommand( argv , "segment_juncs.log" )
        return 0

    E.Start( no_parsing = True )

    # collect arguments
    parser = argparse.ArgumentParser(description='Process tophat options.')
    parser.add_argument('-p', '--num-threads', metavar='N', type=int, dest='nthreads',
                         help='number of threads')
    parser.add_argument('--version', action='version', version='%(prog)s')
    options, args = parser.parse_known_args( argv[1:] )

    E.info( "parallelizing segment juncs with %i threads" % options.nthreads )
    
    x = argv.index("--ium-reads") + 1
    
    all_options = argv[1:x]

    (input_missing_reads, input_genome, 
     output_junctions, 
     output_insertions, output_deletions,
     input_left_all_reads,
     input_left_all_map,
     input_left_segments_maps ) = argv[x:x + 8]

    input_left_segments_maps = input_left_segments_maps.split(",")

    if len(argv) > x + 8:
        ( input_right_all_reads,
          input_right_all_map,
          input_right_segments_maps ) = argv[x+8:x+11]
        input_right_segments_maps = input_right_segments_maps.split(",")
    else:
        input_right_all_reads = ""
        input_right_all_map = ""
        input_right_segments_maps = []

    keys = set()
    
    # some filenames might appear multiple times
    files_to_split = set([input_left_all_map, \
                              input_right_all_map ] +\
                             input_left_segments_maps +\
                             input_right_segments_maps )

    E.info( "splitting %i files" % len(files_to_split))

    ## split all map files by chromosome
    for filename in files_to_split:
        if filename == "": continue
        E.info("splitting %s" % filename )
        base, ext = os.path.splitext( filename )

        f = glob.glob( "%s.input.*%s" % (filename, ext) )
        if f:
            E.info("files already exist - skipping" )
            keys.update( [ re.match("%s.input.(\S+)%s" % (filename,ext), x ).groups()[0] for x in f ] )
            continue
        
        infile = IOTools.openFile( filename )

        outfiles = IOTools.FilePool( filename + ".input.%s" + ext )

        for line in infile:
            key = line.split("\t")[2]
            keys.add( key )
            outfiles.write( key, line )

        outfiles.close()

    # keys = set( ["chr1", "chr2", "chr3", "chr4", "chr5",
    #              "chr6", "chr7", "chr8", "chr9", "chr10",
    #              "chr11", "chr12", "chr13", "chr14", "chr15",
    #              "chr16", "chr17", "chr18", "chr19", "chr20",
    #              "chr21", "chr22", "chrX", "chrY", "chrM" ] )

    E.info( "working on %i contigs: %s" % (len(keys), list(keys)))

    pool = multiprocessing.pool.ThreadPool( options.nthreads )
    #pool = threadpool.ThreadPool( THREADS )

    tmpdir = os.path.dirname( input_left_all_reads )
    logdir = os.path.join( tmpdir[:-len("tmp")], "logs" )

    if not os.path.exists(logdir):
        raise IOError( "can not find logdir %s" % logdir )

    args = []
    for key in keys:

        def modout( old, key ):
            if not old:return ""
            _, ext = os.path.splitext( old )
            return old + ".output.%s%s" % (key, ext)

        def modin( old, key ):
            if not old:return ""
            _, ext = os.path.splitext( old )
            return old + ".input.%s%s" % (key,ext)

        def modgenome( old, key ):
            dirname, filename = os.path.split(old)
            genome, ext = os.path.splitext( filename )
            if genome.lower().endswith("_cs"): genome = genome[:-3]
            new = os.path.join( dirname, genome + ".perchrom", key + ext )
            if not os.path.exists(new):
                raise ValueError( "can not find chromoseme file %s" % new )
            return new

        cmd = ["segment_juncs"] +\
            all_options +\
            [input_missing_reads,  \
                 modgenome(input_genome,key), \
                 modout(output_junctions,key),\
                 modout(output_insertions,key),\
                 modout(output_deletions,key),\
                 input_left_all_reads,\
                 modin( input_left_all_map, key ),\
                 ",".join( [ modin( x, key ) for x in input_left_segments_maps ] ),\
                 input_right_all_reads,\
                 modin( input_right_all_map, key ),\
                 ",".join( [ modin( x, key ) for x in input_right_segments_maps ] ) ]


        logfile = os.path.join(logdir, "segment_juncs_%s.log" % key )
        args.append( (cmd,logfile) )

    E.info( "submitting %i jobs" % len(keys) )

    pool.map( runCommand, args, chunksize = 1 )
    pool.close()
    pool.join()

    E.info("all jobs finished successfully" )

    E.info("merging results")
    ## merge results
    for filename in (output_junctions, output_insertions, output_deletions):
        outfile = open(filename, "w")
        for inf in glob.glob( filename + ".output.*" ):
            infile = open( inf, "r" )
            outfile.write( infile.read() )
            infile.close()
        outfile.close()
        
    E.info("results merged")

    ## cleaning up is done automatically by tophat
    E.info("cleaning up" )
    for f in glob.glob( os.path.join( tmpdir, "*.output.*") ) +\
            glob.glob( os.path.join( tmpdir, "*.input.*") ):
        os.remove(f)

    ## write footer and output benchmark information.
    E.Stop()

Example 7

Project: pygr
Source File: pairwise_hg18_megatest.py
View license
    def test_build(self):
        'Test building an NLMSA and querying results'
        from pygr import seqdb, cnestedlist
        genomedict = {}
        for orgstr in msaSpeciesList:
            genomedict[orgstr] = pygr.Data.getResource('TEST.Seq.Genome.'
                                                       + orgstr)
        uniondict = seqdb.PrefixUnionDict(genomedict)
        if smallSampleKey:
            axtlist = glob.glob(os.path.join(axtDir, '*' + os.sep
                                             + smallSampleKey + '.*.net.axt'))
        else:
            axtlist = glob.glob(os.path.join(axtDir, '*' + os.sep
                                             + '*.*.net.axt'))
        axtlist.sort()
        msaname = os.path.join(self.path, 'hg18_pairwise5way')
        # 500MB VERSION
        msa1 = cnestedlist.NLMSA(msaname, 'w', uniondict, axtFiles=axtlist,
                                 maxlen=536870912, maxint=22369620)
        msa1.__doc__ = 'TEST NLMSA for hg18 pairwise5way'
        pygr.Data.addResource('TEST.MSA.UCSC.hg18_pairwise5way', msa1)
        pygr.Data.save()
        msa = pygr.Data.getResource('TEST.MSA.UCSC.hg18_pairwise5way')
        outfileName = os.path.join(testInputDir, 'splicesite_hg18%s.txt'
                                   % smallSamplePostfix)
        outputName = os.path.join(testInputDir,
                                  'splicesite_hg18%s_pairwise5way.txt'
                                  % smallSamplePostfix)
        newOutputName = 'splicesite_new1.txt'
        tmpInputName = self.copyFile(outfileName)
        tmpOutputName = self.copyFile(outputName)
        tmpNewOutputName = os.path.join(self.path, newOutputName)
        outfile = open(tmpNewOutputName, 'w')
        for lines in open(tmpInputName, 'r').xreadlines():
            chrid, intstart, intend, nobs = string.split(lines.strip(), '\t')
            intstart, intend, nobs = int(intstart), int(intend), int(nobs)
            site1 = msa.seqDict['hg18' + '.' + chrid][intstart:intstart + 2]
            site2 = msa.seqDict['hg18' + '.' + chrid][intend - 2:intend]
            edges1 = msa[site1].edges()
            edges2 = msa[site2].edges()
            if len(edges1) == 0: # EMPTY EDGES
                wlist = str(site1), 'hg18', chrid, intstart, intstart + 2, \
                        '', '', '', '', ''
                outfile.write('\t'.join(map(str, wlist)) + '\n')
            if len(edges2) == 0: # EMPTY EDGES
                wlist = str(site2), 'hg18', chrid, intend - 2, intend, '', \
                        '', '', '', ''
                outfile.write('\t'.join(map(str, wlist)) + '\n')
            saveList = []
            for src, dest, e in edges1:
                if len(str(src)) != 2 or len(str(dest)) != 2:
                    continue
                dotindex = (~msa.seqDict)[src].index('.')
                srcspecies, src1 = (~msa.seqDict)[src][:dotindex], \
                        (~msa.seqDict)[src][dotindex + 1:]
                dotindex = (~msa.seqDict)[dest].index('.')
                destspecies, dest1 = (~msa.seqDict)[dest][:dotindex], \
                        (~msa.seqDict)[dest][dotindex + 1:]
                wlist = str(src), srcspecies, src1, src.start, src.stop, \
                        str(dest), destspecies, dest1, dest.start, dest.stop
                saveList.append('\t'.join(map(str, wlist)) + '\n')
            for src, dest, e in edges2:
                if len(str(src)) != 2 or len(str(dest)) != 2:
                    continue
                dotindex = (~msa.seqDict)[src].index('.')
                srcspecies, src1 = (~msa.seqDict)[src][:dotindex], \
                        (~msa.seqDict)[src][dotindex + 1:]
                dotindex = (~msa.seqDict)[dest].index('.')
                destspecies, dest1 = (~msa.seqDict)[dest][:dotindex], \
                        (~msa.seqDict)[dest][dotindex + 1:]
                wlist = str(src), srcspecies, src1, src.start, src.stop, \
                        str(dest), destspecies, dest1, dest.start, dest.stop
                saveList.append('\t'.join(map(str, wlist)) + '\n')
            saveList.sort() # SORTED IN ORDER TO COMPARE WITH PREVIOUS RESULTS
            for saveline in saveList:
                outfile.write(saveline)
        outfile.close()
        md5old = hashlib.md5()
        md5old.update(open(tmpNewOutputName, 'r').read())
        md5new = hashlib.md5()
        md5new.update(open(tmpOutputName, 'r').read())
        assert md5old.digest() == md5new.digest()

        # TEXT<->BINARY TEST
        msafilelist = glob.glob(msaname + '*')
        msa.save_seq_dict()
        cnestedlist.dump_textfile(msaname, os.path.join(self.path,
                                                      'hg18_pairwise5way.txt'))
        for filename in msafilelist:
            os.remove(filename)
        runPath = os.path.realpath(os.curdir)
        os.chdir(self.path)
        cnestedlist.textfile_to_binaries('hg18_pairwise5way.txt')
        os.chdir(runPath)

        msa1 = cnestedlist.NLMSA(msaname, 'r')
        msa1.__doc__ = 'TEST NLMSA for hg18 pairwise5way'
        pygr.Data.addResource('TEST.MSA.UCSC.hg18_pairwise5way', msa1)
        pygr.Data.save()
        msa = pygr.Data.getResource('TEST.MSA.UCSC.hg18_pairwise5way')
        newOutputName = 'splicesite_new2.txt'
        tmpInputName = self.copyFile(outfileName)
        tmpOutputName = self.copyFile(outputName)
        tmpNewOutputName = os.path.join(self.path, newOutputName)
        outfile = open(tmpNewOutputName, 'w')
        for lines in open(tmpInputName, 'r').xreadlines():
            chrid, intstart, intend, nobs = string.split(lines.strip(), '\t')
            intstart, intend, nobs = int(intstart), int(intend), int(nobs)
            site1 = msa.seqDict['hg18' + '.' + chrid][intstart:intstart + 2]
            site2 = msa.seqDict['hg18' + '.' + chrid][intend - 2:intend]
            edges1 = msa[site1].edges()
            edges2 = msa[site2].edges()
            if len(edges1) == 0: # EMPTY EDGES
                wlist = str(site1), 'hg18', chrid, intstart, intstart + 2, \
                        '', '', '', '', ''
                outfile.write('\t'.join(map(str, wlist)) + '\n')
            if len(edges2) == 0: # EMPTY EDGES
                wlist = str(site2), 'hg18', chrid, intend - 2, intend, '', \
                        '', '', '', ''
                outfile.write('\t'.join(map(str, wlist)) + '\n')
            saveList = []
            for src, dest, e in edges1:
                if len(str(src)) != 2 or len(str(dest)) != 2:
                    continue
                dotindex = (~msa.seqDict)[src].index('.')
                srcspecies, src1 = (~msa.seqDict)[src][:dotindex], \
                        (~msa.seqDict)[src][dotindex + 1:]
                dotindex = (~msa.seqDict)[dest].index('.')
                destspecies, dest1 = (~msa.seqDict)[dest][:dotindex], \
                        (~msa.seqDict)[dest][dotindex + 1:]
                wlist = str(src), srcspecies, src1, src.start, src.stop, \
                        str(dest), destspecies, dest1, dest.start, dest.stop
                saveList.append('\t'.join(map(str, wlist)) + '\n')
            for src, dest, e in edges2:
                if len(str(src)) != 2 or len(str(dest)) != 2:
                    continue
                dotindex = (~msa.seqDict)[src].index('.')
                srcspecies, src1 = (~msa.seqDict)[src][:dotindex], \
                        (~msa.seqDict)[src][dotindex + 1:]
                dotindex = (~msa.seqDict)[dest].index('.')
                destspecies, dest1 = (~msa.seqDict)[dest][:dotindex], \
                        (~msa.seqDict)[dest][dotindex + 1:]
                wlist = str(src), srcspecies, src1, src.start, src.stop, \
                        str(dest), destspecies, dest1, dest.start, dest.stop
                saveList.append('\t'.join(map(str, wlist)) + '\n')
            saveList.sort() # SORTED IN ORDER TO COMPARE WITH PREVIOUS RESULTS
            for saveline in saveList:
                outfile.write(saveline)
        outfile.close()
        md5old = hashlib.md5()
        md5old.update(open(tmpNewOutputName, 'r').read())
        md5new = hashlib.md5()
        md5new.update(open(tmpOutputName, 'r').read())
        assert md5old.digest() == md5new.digest()

Example 8

Project: stonix
Source File: ConfigureLDAPServer.py
View license
    def report(self):
        ''''''
        try:
            compliant = True
            self.ph = Pkghelper(self.logger, self.environ)
            self.detailedresults = ""
            if self.ph.manager == "apt-get":
                self.ldap = "slapd"
            elif self.ph.manager == "zypper":
                self.ldap = "openldap2"
            else:
                self.ldap = "openldap-servers"
            if self.ph.check(self.ldap):
                #is ldap configured for tls?

                #do the ldap files have the correct permissions?
                slapd = "/etc/openldap/slapd.conf"
                if os.path.exists(slapd):
                    statdata = os.stat(slapd)
                    mode = stat.S_IMODE(statdata.st_mode)
                    ownergrp = getUserGroupName(slapd)
                    owner = ownergrp[0]
                    group = ownergrp[1]
                    if mode != 416:
                        self.detailedresults += "permissions on " + slapd + \
                            " aren't 640\n"
                        debug = "permissions on " + slapd + " aren't 640\n"
                        self.logger.log(LogPriority.DEBUG, debug)
                        compliant = False
                    if owner != "root":
                        self.detailedresults += "Owner of " + slapd + \
                            " isn't root\n"
                        debug = "Owner of " + slapd + " isn't root\n"
                        self.logger.log(LogPriority.DEBUG, debug)
                        compliant = False
                    if group != "ldap":
                        self.detailedresults += "Group owner of " + slapd + \
                            " isn't ldap\n"
                        debug = "Group owner of " + slapd + " isn't ldap\n"
                        self.logger.log(LogPriority.DEBUG, debug)
                        compliant = False
                #apt-get systems
                slapd = "/etc/ldap/ldap.conf"
                if os.path.exists(slapd):
                    statdata = os.stat(slapd)
                    mode = stat.S_IMODE(statdata.st_mode)
                    ownergrp = getUserGroupName(slapd)
                    owner = ownergrp[0]
                    group = ownergrp[1]
                    if mode != 420:
                        self.detailedresults += "permissions on " + slapd + \
                            " aren't 644\n"
                        debug = "permissions on " + slapd + " aren't 644\n"
                        self.logger.log(LogPriority.DEBUG, debug)
                        compliant = False
                    if owner != "root":
                        self.detailedresults += "Owner of " + slapd + \
                            " isn't root\n"
                        debug = "Owner of " + slapd + " isn't root\n"
                        self.logger.log(LogPriority.DEBUG, debug)
                        compliant = False
                    if group != "root":
                        self.detailedresults += "Group owner of " + slapd + \
                            " isn't root\n"
                        debug = "Group owner of " + slapd + " isn't root\n"
                        self.logger.log(LogPriority.DEBUG, debug)
                        compliant = False
                slapdd = "/etc/openldap/slapd.d/"
                if os.path.exists(slapdd):
                    dirs = glob.glob(slapdd + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416:
                                self.detailedresults += "Permissions " + \
                                    "aren't 640 on " + loc + "\n"
                                debug = "Permissions aren't 640 on " + loc + \
                                    "\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if owner != "ldap":
                                self.detailedresults += "Owner of " + loc + \
                                    " isn't ldap\n"
                                debug = "Owner of " + loc + " isn't ldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if group != "ldap":
                                self.detailedresults += "Group of " + loc + \
                                    " isn't ldap\n"
                                debug = "Group of " + loc + " isn't ldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                #apt-get systems
                slapdd = "/etc/ldap/slapd.d/"
                if os.path.exists(slapdd):
                    dirs = glob.glob(slapdd + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 384:
                                self.detailedresults += "Permissions " + \
                                    "aren't 640 on " + loc + "\n"
                                debug = "Permissions aren't 600 on " + loc + \
                                    "\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if owner != "openldap":
                                self.detailedresults += "Owner of " + loc + \
                                    " isn't ldap\n"
                                debug = "Owner of " + loc + " isn't openldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if group != "openldap":
                                self.detailedresults += "Group of " + loc + \
                                    " isn't ldap\n"
                                debug = "Group of " + loc + " isn't openldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                cnconfig = "/etc/openldap/slapd.d/cn=config/"
                if os.path.exists(cnconfig):
                    dirs = glob.glob(cnconfig + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416:
                                self.detailedresults += "Permissions " + \
                                    "aren't 640 on " + loc + "\n"
                                debug = "Permissions aren't 640 on " + loc + \
                                    "\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if owner != "ldap":
                                self.detailedresults += "Owner of " + loc + \
                                    " isn't ldap\n"
                                debug = "Owner of " + loc + " isn't ldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if group != "ldap":
                                self.detailedresults += "Group of " + loc + \
                                    " isn't ldap\n"
                                debug = "Group of " + loc + " isn't ldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                #apt-get systems
                cnconfig = "/etc/ldap/slapd.d/cn=config/"
                if os.path.exists(cnconfig):
                    dirs = glob.glob(cnconfig + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 384:
                                self.detailedresults += "Permissions " + \
                                    "aren't 600 on " + loc + "\n"
                                debug = "Permissions aren't 600 on " + loc + \
                                    "\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if owner != "openldap":
                                self.detailedresults += "Owner of " + loc + \
                                    " isn't openldap\n"
                                debug = "Owner of " + loc + " isn't openldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if group != "openldap":
                                self.detailedresults += "Group of " + loc + \
                                    " isn't ldap\n"
                                debug = "Group of " + loc + " isn't openldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                pki = "/etc/pki/tls/ldap/"
                if os.path.exists(pki):
                    dirs = glob.glob(pki + "*")
                    for loc in dirs:
                        if not os.path.isdir():
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416:
                                self.detailedresults += "Permissions " + \
                                    "aren't 640 on " + loc + "\n"
                                debug = "Permissions aren't 640 on " + loc + \
                                    "\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if owner != "root":
                                self.detailedresults += "Owner of " + loc + \
                                    " isn't root\n"
                                debug = "Owner of " + loc + " isn't root\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                            if group != "ldap":
                                self.detailedresults += "Group of " + loc + \
                                    " isn't ldap\n"
                                debug = "Group of " + loc + " isn't ldap\n"
                                self.logger.log(LogPriority.DEBUG, debug)
                                compliant = False
                if os.path.exists("/etc/pki/tls/CA/"):
                    dirs = glob.glob("/etc/pki/tls/CA/*")
                    for loc in dirs:
                        if not os.path.isdir():
                            if not checkPerms(loc, [0, 0, 420], self.logger):
                                compliant = False
                                self.detailedresults += "Permissions " + \
                                    "aren't correct on " + loc + " file\n"
                                debug = "Permissions aren't correct on " + \
                                    loc + " file\n"
                                self.logger.log(LogPriority.DEBUG, debug)
            self.compliant = compliant
        except (KeyboardInterrupt, SystemExit):
            raise
        except Exception:
            self.rulesuccess = False
            self.detailedresults += "\n" + traceback.format_exc()
            self.logdispatch.log(LogPriority.ERROR, self.detailedresults)
        self.formatDetailedResults("report", self.compliant,
                                   self.detailedresults)
        self.logdispatch.log(LogPriority.INFO, self.detailedresults)
        return self.compliant
        self.compliant = compliant

Example 9

Project: stonix
Source File: ConfigureLDAPServer.py
View license
    def fix(self):
        try:
            if not self.ci.getcurrvalue():
                return
            success = True
            self.iditerator = 0
            eventlist = self.statechglogger.findrulechanges(self.rulenumber)
            for event in eventlist:
                self.statechglogger.deleteentry(event)
            if self.ph.check(self.ldap):
                #is ldap configured for tls?

                #do the ldap files have the correct permissions?
                slapd = "/etc/openldap/slapd.conf"
                if os.path.exists(slapd):
                    statdata = os.stat(slapd)
                    mode = stat.S_IMODE(statdata.st_mode)
                    ownergrp = getUserGroupName(slapd)
                    owner = ownergrp[0]
                    group = ownergrp[1]
                    if mode != 416 or owner != "root" or group != "ldap":
                        origuid = statdata.st_uid
                        origgid = statdata.st_gid
                        if grp.getgrnam("ldap")[2]:
                            gid = grp.getgrnam("ldap")[2]
                            self.iditerator += 1
                            myid = iterate(self.iditerator, self.rulenumber)
                            event = {"eventtype": "perm",
                                     "startstate": [origuid, origgid, mode],
                                     "endstate": [0, gid, 416],
                                     "filepath": slapd}
                            self.statechglogger.recordchgevent(myid, event)
                            os.chmod(slapd, 416)
                            os.chown(slapd, 0, gid)
                            resetsecon(slapd)
                        else:
                            success = False
                            debug = "Unable to determine the id " + \
                                "number of ldap group.  Will not change " + \
                                "permissions on " + slapd + " file\n"
                            self.logger.log(LogPriority.DEBUG, debug)
                #apt-get systems
                slapd = "/etc/ldap/ldap.conf"
                if os.path.exists(slapd):
                    statdata = os.stat(slapd)
                    mode = stat.S_IMODE(statdata.st_mode)
                    ownergrp = getUserGroupName(slapd)
                    owner = ownergrp[0]
                    group = ownergrp[1]
                    if mode != 420 or owner != "root" or group != "root":
                        origuid = statdata.st_uid
                        origgid = statdata.st_gid
                        if grp.getgrnam("root")[2] != "":
                            gid = grp.getgrnam("root")[2]
                            self.iditerator += 1
                            myid = iterate(self.iditerator, self.rulenumber)
                            event = {"eventtype": "perm",
                                     "startstate": [origuid, origgid, mode],
                                     "endstate": [0, gid, 420],
                                     "filepath": slapd}
                            self.statechglogger.recordchgevent(myid, event)
                            os.chmod(slapd, 420)
                            os.chown(slapd, 0, gid)
                            resetsecon(slapd)
                        else:
                            success = False
                            debug = "Unable to determine the id " + \
                                "number of root group.  Will not change " + \
                                "permissions on " + slapd + " file\n"
                            self.logger.log(LogPriority.DEBUG, debug)
                slapdd = "/etc/openldap/slapd.d/"
                if os.path.exists(slapdd):
                    dirs = glob.glob(slapdd + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416 or owner != "ldap" or group \
                                    != "ldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("ldap")[2] != "":
                                    if pwd.getpwnam("ldap")[2] != "":
                                        gid = grp.getgrnam("ldap")[2]
                                        uid = pwd.getpwnam("ldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 416],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 416)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                #apt-get systems
                slapdd = "/etc/ldap/slapd.d/"
                if os.path.exists(slapdd):
                    dirs = glob.glob(slapdd + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 384 or owner != "openldap" or group \
                                    != "openldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("openldap")[2] != "":
                                    if pwd.getpwnam("openldap")[2] != "":
                                        gid = grp.getgrnam("openldap")[2]
                                        uid = pwd.getpwnam("openldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 384],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 384)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                cnconfig = "/etc/openldap/slapd.d/cn=config/"
                if os.path.exists(cnconfig):
                    dirs = glob.glob(cnconfig + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416 or owner != "ldap" or group != "ldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("ldap")[2] != "":
                                    if pwd.getpwnam("ldap")[2] != "":
                                        gid = grp.getgrnam("ldap")[2]
                                        uid = pwd.getpwnam("ldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 416],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 416)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                #apt-get systems
                cnconfig = "/etc/ldap/slapd.d/cn=config/"
                if os.path.exists(cnconfig):
                    dirs = glob.glob(cnconfig + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 384 or owner != "openldap" or group != "openldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("openldap")[2] != "":
                                    if pwd.getpwnam("openldap")[2] != "":
                                        gid = grp.getgrnam("openldap")[2]
                                        uid = pwd.getpwnam("openldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 384],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 384)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                pki = "/etc/pki/tls/ldap/"
                if os.path.exists(pki):
                    dirs = glob.glob(pki + "*")
                    for loc in dirs:
                        if not os.path.isdir():
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416 or owner != "root" or group != "ldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("ldap")[2] != "":
                                    gid = grp.getgrnam("ldap")[2]
                                    self.iditerator += 1
                                    myid = iterate(self.iditerator, self.rulenumber)
                                    event = {"eventtype": "perm",
                                             "startstate": [origuid, origgid, mode],
                                             "endstate": [0, gid, 416],
                                             "filepath": loc}
                                    self.statechglogger.recordchgevent(myid, event)
                                    os.chmod(slapd, 416)
                                    os.chown(slapd, 0, gid)
                                    resetsecon(slapd)
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not change " + \
                                        "permissions on " + loc + " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                if os.path.exists("/etc/pki/tls/CA/"):
                    dirs = glob.glob("/etc/pki/tls/CA/*")
                    for loc in dirs:
                        if not os.path.isdir():
                            if not checkPerms(loc, [0, 0, 420], self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, self.rulenumber)
                                if not setPerms(loc, [0, 0, 420], self.logger,
                                                self.statechglogger, myid):
                                    debug = "Unable to set permissions on " + \
                                        loc + " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                                    success = False
                                else:
                                    resetsecon(loc)
            self.rulesuccess = success
        except (KeyboardInterrupt, SystemExit):
            # User initiated exit
            raise
        except Exception:
            self.rulesuccess = False
            self.detailedresults += "\n" + traceback.format_exc()
            self.logdispatch.log(LogPriority.ERROR, self.detailedresults)
        self.formatDetailedResults("fix", self.rulesuccess,
                                   self.detailedresults)
        self.logdispatch.log(LogPriority.INFO, self.detailedresults)
        return self.rulesuccess

Example 10

Project: C-PAC
Source File: extract_data.py
View license
def extract_data(c, param_map):
    """
    Method to generate a CPAC input subject list
    python file. The method extracts anatomical
    and functional data for each site( if multiple site)
    and/or scan parameters for each site and put it into
    a data structure read by python

    Example:
    subjects_list =[
       {
        'subject_id' : '0050386',
        'unique_id' : 'session_1',
        'anat': '/Users/home/data/NYU/0050386/session_1/anat_1/anat.nii.gz',
        'rest':{
            'rest_1_rest' : '/Users/home/data/NYU/0050386/session_1/rest_1/rest.nii.gz',
            'rest_2_rest' : '/Users/home/data/NYU/0050386/session_1/rest_2/rest.nii.gz',
            }
        'scan_parameters':{
            'tr': '2',
            'acquisition': 'alt+z2',
            'reference': '17',
            'first_tr': '',
            'last_tr': '',
            }
        },
    ]

    or

    subjects_list =[
       {
        'subject_id' : '0050386',
        'unique_id' : 'session_1',
        'anat': '/Users/home/data/NYU/0050386/session_1/anat_1/anat.nii.gz',
        'rest':{
            'rest_1_rest' : '/Users/home/data/NYU/0050386/session_1/rest_1/rest.nii.gz',
            'rest_2_rest' : '/Users/home/data/NYU/0050386/session_1/rest_2/rest.nii.gz',
            }
          },
    ]

    """

    #method to read each line of the file into list
    #returns list
    def get_list(arg):
        if isinstance(arg, list):
            ret_list = arg
        else:
            ret_list = [fline.rstrip('\r\n') for fline in open(arg, 'r').readlines()]

        return ret_list

    exclusion_list = []
    if c.exclusionSubjectList is not None:
        exclusion_list = get_list(c.exclusionSubjectList)

    subject_list = []
    if c.subjectList is not None:
        subject_list = get_list(c.subjectList)

    #check if Template is correct
    def checkTemplate(template):

        if template.count('%s') != 2:
            msg = "Please provide '%s' in the template" \
                  "where your site and subjects are present"\
                  "Please see examples"
            logging.exception(msg)
            raise Exception(msg)

        filename, ext = os.path.splitext(os.path.basename(template))
        ext = os.path.splitext(filename)[1] + ext

        if ext not in [".nii", ".nii.gz"]:
            msg = "Invalid file name", os.path.basename(template)
            logging.exception(msg)
            raise Exception(msg)

    def get_site_list(path):
        base, relative = path.split('%s')
        sites = os.listdir(base)
        return sites
    
    def check_length(scan_name, file_name):
        
        if len(file_name) > 30:
            msg = "filename- %s is too long."\
                   "It should not be more than 30 characters."%(file_name)
            logging.exception(msg)
            raise Exception(msg)
        
        if len(scan_name) - len(os.path.splitext(os.path.splitext(file_name)[0])[0])>= 40:
            msg = "scan name %s is too long."\
                  "It should not be more than 20 characters"\
                  %(scan_name.replace("_"+os.path.splitext(os.path.splitext(file_name)[0])[0], ''))
            logging.exception(msg)
            raise Exception(msg)

    def create_site_subject_mapping(base, relative):

        #mapping between site and subject
        site_subject_map = {}
        base_path_list = []

        if c.siteList is not None:
            site_list = get_list(c.siteList)
        else:
            site_list = get_site_list(base)

        for site in site_list:
            paths = glob.glob(string.replace(base, '%s', site))
            base_path_list.extend(paths)
            for path in paths:
                for sub in os.listdir(path):
                    #check if subject is present in subject_list
                    if subject_list:
                        if sub in subject_list and sub not in exclusion_list:
                            site_subject_map[sub] = site
                    elif sub not in exclusion_list:
                        if sub not in '.DS_Store':
                            site_subject_map[sub] = site

        return base_path_list, site_subject_map

    #method to split the input template path
    #into base, path before subject directory
    #and relative, path after subject directory
    def getPath(template):

        checkTemplate(template)
        base, relative = template.rsplit("%s", 1)
        base, subject_map = create_site_subject_mapping(base, relative)
        base.sort()
        relative = relative.lstrip("/")
        return base, relative, subject_map

    #get anatomical base path and anatomical relative path
    anat_base, anat_relative = getPath(c.anatomicalTemplate)[:2]

    #get functional base path, functional relative path and site-subject map
    func_base, func_relative, subject_map = getPath(c.functionalTemplate)

    if not anat_base:
        msg = "Anatomical Data template incorrect. No such file or directory %s", anat_base
        logging.exception(msg)
        raise Exception(msg)

    if not func_base:
        msg = "Functional Data template incorrect. No such file or directory %s, func_base"
        logging.exception(msg)
        raise Exception(msg)
        
    if len(anat_base) != len(func_base):
        msg1 = "Some sites are missing, Please check your template"\
              , anat_base, "!=", func_base
        logging.exception(msg1)
        
        msg2 = " Base length Unequal. Some sites are missing."\
               "extract_data doesn't script support this.Please" \
               "Provide your own subjects_list file"
        logging.exception(msg2)
        raise Exception(msg2)

    #calculate the length of relative paths(path after subject directory)
    func_relative_len = len(func_relative.split('/'))
    anat_relative_len = len(anat_relative.split('/'))

    def check_for_sessions(relative_path, path_length):
        """
        Method to check if there are sessions present

        """
        #default
        session_present = False
        session_path = 'session_1'

        #session present if path_length is equal to 3
        if path_length == 3:
            relative_path_list = relative_path.split('/')
            session_path = relative_path_list[0]
            relative_path = string.join(relative_path_list[1:], "/")
            session_present = True
        elif path_length > 3:
            msg = "extract_data script currently doesn't support this directory structure."\
                  "Please provide the subjects_list file to run CPAC."\
                  "For more information refer to manual"
            logging.exception(msg)
            raise Exception(msg)
        return session_present, session_path, relative_path

    func_session_present, func_session_path, func_relative = \
        check_for_sessions(func_relative, func_relative_len)

    anat_session_present, anat_session_path, anat_relative = \
        check_for_sessions(anat_relative, anat_relative_len)

    f = open(os.path.join(c.outputSubjectListLocation, "CPAC_subject_list_%s.yml" % c.subjectListName[0]), 'wb')



    def fetch_path(i, anat_sub, func_sub, session_id):
        """
        Method to extract anatomical and functional
        path for a session and print to file

        Parameters
        ----------
        i : int
            index of site
        anat_sub : string
            string containing subject/ concatenated
            subject-session path for anatomical file
        func_sub: string
            string containing subject/ concatenated
            subject-session path for functional file
        session_id: string
            session

        Raises
        ------
        Exception
        """

        try:

            def print_begin_of_file(sub, session_id):
                print >> f, "-"
                print >> f, "    subject_id: '" + sub + "'"
                print >> f, "    unique_id: '" + session_id + "'" 

            def print_end_of_file(sub):
                if param_map is not None:
                    try:
                        logging.debug("site for sub %s -> %s" %(sub, subject_map.get(sub)))
                        logging.debug("scan parameters for the above site %s"%param_map.get(subject_map.get(sub)))
                        print >> f, "    scan_parameters:"
                        print >> f, "        tr: '" + param_map.get(subject_map.get(sub))[4] + "'" 
                        print >> f, "        acquisition: '" + param_map.get(subject_map.get(sub))[0] + "'" 
                        print >> f, "        reference: '" + param_map.get(subject_map.get(sub))[3] + "'"
                        print >> f, "        first_tr: '" + param_map.get(subject_map.get(sub))[1] + "'"
                        print >> f, "        last_tr: '" + param_map.get(subject_map.get(sub))[2] + "'"
                    except:
                        msg = " No Parameter values for the %s site is defined in the scan"\
                              " parameters csv file" %subject_map.get(sub)
                        raise ValueError(msg)

            #get anatomical file
            anat_base_path = os.path.join(anat_base[i], anat_sub)
            func_base_path = os.path.join(func_base[i], func_sub)

            anat = None
            func = None

            anat = glob.glob(os.path.join(anat_base_path, anat_relative))
            func = glob.glob(os.path.join(func_base_path, func_relative))

            if anat and func:
                print_begin_of_file(anat_sub.split("/")[0], session_id)
                print >> f, "    anat: '" + os.path.realpath(anat[0]) + "'"
                print >> f, "    rest: "

                #iterate for each rest session
                for iter in func:
                    #get scan_id
                    iterable = os.path.splitext(os.path.splitext(iter.replace(func_base_path, '').lstrip("/"))[0])[0]
                    iterable = iterable.replace("/", "_")
                    check_length(iterable, os.path.basename(os.path.realpath(iter)))
                    print>>f, "      " + iterable + ": '" + os.path.realpath(iter) + "'"
                
                print_end_of_file(anat_sub.split("/")[0])
                
            else:
                logging.debug("skipping subject %s"%anat_sub.split("/")[0])
        
        except ValueError:

            logging.exception(ValueError.message)
            raise

        except Exception, e:

            err_msg = 'Exception while felching anatomical and functional ' \
                      'paths: \n' + str(e)

            logging.exception(err_msg)
            raise Exception(err_msg)



    def walk(index, sub):
        """
        Method which walks across each subject
        path in the data site path

        Parameters
        ----------
        index : int
            index of site
        sub : string
            subject_id

        Raises
        ------
        Exception
        """
        try:

            if func_session_present:
                #if there are sessions
                if "*" in func_session_path:
                    session_list = glob.glob(os.path.join(func_base[index], os.path.join(sub, func_session_path)))
                else:
                    session_list = [func_session_path]

                if session_list:
                    for session in session_list:
                        session_id = os.path.basename(session)
                        if anat_session_present:
                            if func_session_path == anat_session_path:
                                fetch_path(index, os.path.join(sub, session_id), os.path.join(sub, session_id), session_id)
                            else:
                                fetch_path(index, os.path.join(sub, anat_session_path), os.path.join(sub, session_id), session_id)
                        else:
                            fetch_path(index, sub, os.path.join(sub, session_id), session_id)
                else:
                    logging.debug("Skipping subject %s", sub)

            else:
                logging.debug("No sessions")
                session_id = ''
                fetch_path(index, sub, sub, session_id)

        except Exception:

            logging.exception(Exception.message)
            raise

        except:

            err_msg = 'Please make sessions are consistent across all ' \
                      'subjects.\n\n'

            logging.exception(err_msg)
            raise Exception(err_msg)


    try:
        for i in range(len(anat_base)):
            for sub in os.listdir(anat_base[i]):
                #check if subject is present in subject_list
                if subject_list:
                    if sub in subject_list and sub not in exclusion_list:
                        logging.debug("extracting data for subject: %s", sub)
                        walk(i, sub)
                #check that subject is not in exclusion list
                elif sub not in exclusion_list and sub not in '.DS_Store':
                    logging.debug("extracting data for subject: %s", sub)
                    walk(i, sub)

        
        name = os.path.join(c.outputSubjectListLocation, 'CPAC_subject_list.yml')
        print "Extraction Successfully Completed...Input Subjects_list for CPAC - %s" % name

    except Exception:

        logging.exception(Exception.message)
        raise

    finally:

        f.close()

Example 11

Project: C-PAC
Source File: extract_data_multiscan.py
View license
def extract_data(c, param_map):
    """
    Method to generate a CPAC input subject list
    python file. The method extracts anatomical
    functional data and scan parameters for each 
    site( if multiple site) and for each scan 
    and put it into a data structure read by python

    Note:
    -----
    Use this tool only if the scan parameters are different 
    for each scan as shown in the example below.
    
    Example:
    --------
    subjects_list = [
        {
            'subject_id': '0021001',
            'unique_id': 'session2',
            'anat': '/home/data/multiband_data/NKITRT/0021001/anat/mprage.nii.gz',
            'rest':{
              'RfMRI_mx_1400_rest': '/home/data/multiband_data/NKITRT/0021001/session2/RfMRI_mx_1400/rest.nii.gz',
              'RfMRI_mx_645_rest': '/home/data/multiband_data/NKITRT/0021001/session2/RfMRI_mx_645/rest.nii.gz',
              'RfMRI_std_2500_rest': '/home/data/multiband_data/NKITRT/0021001/session2/RfMRI_std_2500/rest.nii.gz',
              },
            'scan_parameters':{
                'TR':{
                    'RfMRI_mx_1400_rest': '1.4',
                    'RfMRI_mx_645_rest': '1.4',
                    'RfMRI_std_2500_rest': '2.5',
                    },
                'Acquisition':{
                    'RfMRI_mx_1400_rest': '/home/data/1400.txt',
                    'RfMRI_mx_645_rest': '/home/data/645.txt',
                    'RfMRI_std_2500_rest': '/home/data/2500.txt',
                    },
                'Reference':{
                    'RfMRI_mx_1400_rest': '32',
                    'RfMRI_mx_645_rest': '20',
                    'RfMRI_std_2500_rest': '19',
                    },
                'FirstTR':{
                    'RfMRI_mx_1400_rest': '7',
                    'RfMRI_mx_645_rest': '15',
                    'RfMRI_std_2500_rest': '4',
                    },
                'LastTR':{
                    'RfMRI_mx_1400_rest': '440',
                    'RfMRI_mx_645_rest': '898',
                    'RfMRI_std_2500_rest': 'None',
                    },
                }
        },

    ]
    """

    #method to read each line of the file into list
    #returns list
    def get_list(arg):
        if isinstance(arg, list):
            ret_list = arg
        else:
            ret_list = [fline.rstrip('\r\n') for fline in open(arg, 'r').readlines()]

        return ret_list

    exclusion_list = []
    if c.exclusionSubjectList is not None:
        exclusion_list = get_list(c.exclusionSubjectList)

    subject_list = []
    if c.subjectList is not None:
        subject_list = get_list(c.subjectList)

    #check if Template is correct
    def checkTemplate(template):

        if template.count('%s') != 2:
            raise Exception("Please provide '%s' in the template" \
                            "where your site and subjects are present"\
                            "Please see examples")

        filename, ext = os.path.splitext(os.path.basename(template))
        ext = os.path.splitext(filename)[1] + ext

        if ext not in [".nii", ".nii.gz"]:
            raise Exception("Invalid file name", os.path.basename(template))

    def get_site_list(path):
        base = path.split('%s')[0]
        sites = os.listdir(base)
        return sites

    def check_length(scan_name, file_name):
               
        if len(file_name) > 30:
            msg = "filename- %s is too long."\
                   "It should not be more than 30 characters."%(file_name)
            raise Exception(msg)
        
        if len(scan_name) - len(os.path.splitext(os.path.splitext(file_name)[0])[0])>= 20:
            msg = "scan name %s is too long."\
                  "It should not be more than 20 characters"\
                  %(scan_name.replace("_"+os.path.splitext(os.path.splitext(file_name)[0])[0], ''))
            raise Exception(msg)
        


    def create_site_subject_mapping(base, relative):

        #mapping between site and subject
        site_subject_map = {}
        base_path_list = []

        if c.siteList is not None:
            site_list = get_list(c.siteList)
        else:
            site_list = get_site_list(base)

        for site in site_list:
            paths = glob.glob(string.replace(base, '%s', site))
            base_path_list.extend(paths)
            for path in paths:
                for sub in os.listdir(path):
                    #check if subject is present in subject_list
                    if subject_list:
                        if sub in subject_list and sub not in exclusion_list:
                            site_subject_map[sub] = site
                    elif sub not in exclusion_list:
                        if sub not in '.DS_Store':
                            site_subject_map[sub] = site

        return base_path_list, site_subject_map

    #method to split the input template path
    #into base, path before subject directory
    #and relative, path after subject directory
    def getPath(template):

        checkTemplate(template)
        base, relative = template.rsplit("%s", 1)
        base, subject_map = create_site_subject_mapping(base, relative)
        base.sort()
        relative = relative.lstrip("/")
        return base, relative, subject_map

    #get anatomical base path and anatomical relative path
    anat_base, anat_relative = getPath(c.anatomicalTemplate)[:2]

    #get functional base path, functional relative path and site-subject map
    func_base, func_relative, subject_map = getPath(c.functionalTemplate)

    if not anat_base:
        print "No such file or directory ", anat_base
        raise Exception("Anatomical Data template incorrect")

    if not func_base:
        print "No such file or directory", func_base
        raise Exception("Functional Data template incorrect")

    if len(anat_base) != len(func_base):
        print "Some sites are missing, Please check your"\
              "template", anat_base, "!=", func_base
        raise Exception(" Base length Unequal. Some sites are missing."\
                           "extract_data doesn't script support this.Please" \
                           "Provide your own subjects_list file")

    #calculate the length of relative paths(path after subject directory)
    func_relative_len = len(func_relative.split('/'))
    anat_relative_len = len(anat_relative.split('/'))

    def check_for_sessions(relative_path, path_length):
        """
        Method to check if there are sessions present

        """
        #default
        session_present = False
        session_path = 'session_1'

        #session present if path_length is equal to 3
        if path_length == 3:
            relative_path_list = relative_path.split('/')
            session_path = relative_path_list[0]
            relative_path = string.join(relative_path_list[1:], "/")
            session_present = True
        elif path_length > 3:
            raise Exception("extract_data script currently doesn't support"\
                             "this directory structure.Please provide the"\
                             "subjects_list file to run CPAC." \
                             "For more information refer to manual")

        return session_present, session_path, relative_path

#    if func_relative_len!= anat_relative_len:
#        raise Exception(" extract_data script currently doesn't"\
#                          "support different relative paths for"\
#                          "Anatomical and functional files")

    func_session_present, func_session_path, func_relative = \
        check_for_sessions(func_relative, func_relative_len)

    anat_session_present, anat_session_path, anat_relative = \
        check_for_sessions(anat_relative, anat_relative_len)

    f = open(os.path.join(c.outputSubjectListLocation, "CPAC_subject_list.yml"), 'wb')

    def fetch_path(i, anat_sub, func_sub, session_id):
        """
        Method to extract anatomical and functional
        path for a session and print to file

        Parameters
        ----------
        i : int
            index of site
        anat_sub : string
            string containing subject/ concatenated
            subject-session path for anatomical file
        func_sub: string
            string containing subject/ concatenated
            subject-session path for functional file
        session_id: string
            session

        Raises
        ------
        Exception
        """

        try:

            def print_begin_of_file(sub, session_id):
                print >> f, "-"
                print >> f, "    subject_id: '" + sub + "'"
                print >> f, "    unique_id: '" + session_id + "'"

            def print_end_of_file(sub, scan_list):
                if param_map is not None:
                    def print_scan_param(index):
                        try:
                            for scan in scan_list:
                                print>>f,  "            " + scan[1] + ": '" + \
                                param_map.get((subject_map.get(sub), scan[0]))[index] + "'"
                        
                        except:
                            raise Exception(" No Parameter values for the %s site and %s scan is defined in the scan"\
                                            " parameters csv file" % (subject_map.get(sub), scan[0]))

                    print "site for sub", sub, "->", subject_map.get(sub)
                    print >>f, "    scan_parameters: "
                    print >> f, "        tr:" 
                    print_scan_param(4) 
                    print >> f, "        acquisition:" 
                    print_scan_param(0) 
                    print >> f, "        reference:" 
                    print_scan_param(3) 
                    print >> f, "        first_tr:" 
                    print_scan_param(1) 
                    print >> f, "        last_tr:" 
                    print_scan_param(2) 

 
            #get anatomical file
            anat_base_path = os.path.join(anat_base[i], anat_sub)
            func_base_path = os.path.join(func_base[i], func_sub)

            anat = None
            func = None

            anat = glob.glob(os.path.join(anat_base_path, anat_relative))
            func = glob.glob(os.path.join(func_base_path, func_relative))
            scan_list = []
            if anat and func:
                print_begin_of_file(anat_sub.split("/")[0], session_id)
                print >> f, "    anat: '" + anat[0] + "'" 
                print >>f, "    rest: "

                #iterate for each rest session
                for iter in func:
                    #get scan_id
                    iterable = os.path.splitext(os.path.splitext(iter.replace(func_base_path,'').lstrip("/"))[0])[0]
                    scan_name = iterable.replace("/", "_")
                    scan_list.append((os.path.dirname(iterable), scan_name))
                    check_length(scan_name, os.path.basename(iter))
                    print>>f,  "      " + scan_name + ": '" + iter +  "'"
                print_end_of_file(anat_sub.split("/")[0], scan_list)

        except Exception:
            raise

    def walk(index, sub):
        """
        Method which walks across each subject
        path in the data site path

        Parameters
        ----------
        index : int
            index of site
        sub : string
            subject_id

        Raises
        ------
        Exception
        """
        try:

            if func_session_present:
                #if there are sessions
                if "*" in func_session_path:
                    session_list = glob.glob(os.path.join(func_base[index], os.path.join(sub, func_session_path)))
                else:
                    session_list = [func_session_path]

                for session in session_list:
                    session_id = os.path.basename(session)
                    if anat_session_present:
                        if func_session_path == anat_session_path:
                            fetch_path(index, os.path.join(sub, session_id), os.path.join(sub, session_id), session_id)
                        else:
                            fetch_path(index, os.path.join(sub, anat_session_path), os.path.join(sub, session_id), session_id)
                    else:
                        fetch_path(index, sub, os.path.join(sub, session_id), session_id)
            else:
                print "No sessions"
                session_id = ''
                fetch_path(index, sub, sub, session_id)

        except Exception:
            raise
        except:
            print "Please make sessions are consistent across all subjects"
            raise

    try:
        for i in range(len(anat_base)):
            for sub in os.listdir(anat_base[i]):
                #check if subject is present in subject_list
                if subject_list:
                    if sub in subject_list and sub not in exclusion_list:
                        print "extracting data for subject: ", sub
                        walk(i, sub)
                #check that subject is not in exclusion list
                elif sub not in exclusion_list and sub not in '.DS_Store':
                    print "extracting data for subject: ", sub
                    walk(i, sub)

        
        name = os.path.join(c.outputSubjectListLocation, 'CPAC_subject_list.yml')
        print "Extraction Complete...Input Subjects_list for CPAC - %s" % name
    except Exception:
        raise
    finally:
        f.close()

Example 12

Project: tools-iuc
Source File: raxml.py
View license
def __main__():
    usage = "usage: %prog -T <threads> -s <input> -n <output> -m <model> [optional arguments]"

    # Parse the primary wrapper's command line options
    parser = optparse.OptionParser(usage=usage)
    # raxml binary name, hardcoded in the xml file
    parser.add_option("--binary", action="store", type="string", dest="binary", help="Command to run")
    # (-a)
    parser.add_option("--weightfile", action="store", type="string", dest="weightfile", help="Column weight file")
    # (-A)
    parser.add_option("--secondary_structure_model", action="store", type="string", dest="secondary_structure_model", help="Secondary structure model")
    # (-b)
    parser.add_option("--bootseed", action="store", type="int", dest="bootseed", help="Bootstrap random number seed")
    # (-c)
    parser.add_option("--numofcats", action="store", type="int", dest="numofcats", help="Number of distinct rate categories")
    # (-d)
    parser.add_option("--search_complete_random_tree", action="store_true", dest="search_complete_random_tree", help="Search with a complete random starting tree")
    # (-D)
    parser.add_option("--ml_search_convergence", action="store_true", dest="ml_search_convergence", help="ML search onvergence criterion")
    # (-e)
    parser.add_option("--model_opt_precision", action="store", type="float", dest="model_opt_precision", help="Model Optimization Precision (-e)")
    # (-E)
    parser.add_option("--excludefile", action="store", type="string", dest="excludefile", help="Exclude File Name")
    # (-f)
    parser.add_option("--search_algorithm", action="store", type="string", dest="search_algorithm", help="Search Algorithm")
    # (-F)
    parser.add_option("--save_memory_cat_model", action="store_true", dest="save_memory_cat_model", help="Save memory under CAT and GTRGAMMA models")
    # (-g)
    parser.add_option("--groupingfile", action="store", type="string", dest="groupingfile", help="Grouping File Name")
    # (-G)
    parser.add_option("--enable_evol_heuristics", action="store_true", dest="enable_evol_heuristics", help="Enable evol algo heuristics")
    # (-i)
    parser.add_option("--initial_rearrangement_setting", action="store", type="int", dest="initial_rearrangement_setting", help="Initial Rearrangement Setting")
    # (-I)
    parser.add_option("--posterior_bootstopping_analysis", action="store", type="string", dest="posterior_bootstopping_analysis", help="Posterior bootstopping analysis")
    # (-J)
    parser.add_option("--majority_rule_consensus", action="store", type="string", dest="majority_rule_consensus", help="Majority rule consensus")
    # (-k)
    parser.add_option("--print_branch_lengths", action="store_true", dest="print_branch_lengths", help="Print branch lengths")
    # (-K)
    parser.add_option("--multistate_sub_model", action="store", type="string", dest="multistate_sub_model", help="Multistate substitution model")
    # (-m)
    parser.add_option("--model_type", action="store", type="string", dest="model_type", help="Model Type")
    parser.add_option("--base_model", action="store", type="string", dest="base_model", help="Base Model")
    parser.add_option("--aa_empirical_freq", action="store_true", dest="aa_empirical_freq", help="Use AA Empirical base frequences")
    parser.add_option("--aa_search_matrix", action="store", type="string", dest="aa_search_matrix", help="AA Search Matrix")
    # (-n)
    parser.add_option("--name", action="store", type="string", dest="name", help="Run Name")
    # (-N/#)
    parser.add_option("--number_of_runs", action="store", type="int", dest="number_of_runs", help="Number of alternative runs")
    parser.add_option("--number_of_runs_bootstop", action="store", type="string", dest="number_of_runs_bootstop", help="Number of alternative runs based on the bootstop criteria")
    # (-M)
    parser.add_option("--estimate_individual_branch_lengths", action="store_true", dest="estimate_individual_branch_lengths", help="Estimate individual branch lengths")
    # (-o)
    parser.add_option("--outgroup_name", action="store", type="string", dest="outgroup_name", help="Outgroup Name")
    # (-O)
    parser.add_option("--disable_undetermined_seq_check", action="store_true", dest="disable_undetermined_seq_check", help="Disable undetermined sequence check")
    # (-p)
    parser.add_option("--random_seed", action="store", type="int", dest="random_seed", help="Random Number Seed")
    # (-P)
    parser.add_option("--external_protein_model", action="store", type="string", dest="external_protein_model", help="External Protein Model")
    # (-q)
    parser.add_option("--multiple_model", action="store", type="string", dest="multiple_model", help="Multiple Model File")
    # (-r)
    parser.add_option("--constraint_file", action="store", type="string", dest="constraint_file", help="Constraint File")
    # (-R)
    parser.add_option("--bin_model_parameter_file", action="store", type="string", dest="bin_model_parameter_file", help="Constraint File")
    # (-s)
    parser.add_option("--source", action="store", type="string", dest="source", help="Input file")
    # (-S)
    parser.add_option("--secondary_structure_file", action="store", type="string", dest="secondary_structure_file", help="Secondary structure file")
    # (-t)
    parser.add_option("--starting_tree", action="store", type="string", dest="starting_tree", help="Starting Tree")
    # (-T)
    parser.add_option("--threads", action="store", type="int", dest="threads", help="Number of threads to use")
    # (-u)
    parser.add_option("--use_median_approximation", action="store_true", dest="use_median_approximation", help="Use median approximation")
    # (-U)
    parser.add_option("--save_memory_gappy_alignments", action="store_true", dest="save_memory_gappy_alignments", help="Save memory in large gapped alignments")
    # (-V)
    parser.add_option("--disable_rate_heterogeneity", action="store_true", dest="disable_rate_heterogeneity", help="Disable rate heterogeneity")
    # (-W)
    parser.add_option("--sliding_window_size", action="store", type="string", dest="sliding_window_size", help="Sliding window size")
    # (-x)
    parser.add_option("--rapid_bootstrap_random_seed", action="store", type="int", dest="rapid_bootstrap_random_seed", help="Rapid Boostrap Random Seed")
    # (-y)
    parser.add_option("--parsimony_starting_tree_only", action="store_true", dest="parsimony_starting_tree_only", help="Generate a parsimony starting tree only")
    # (-z)
    parser.add_option("--file_multiple_trees", action="store", type="string", dest="file_multiple_trees", help="Multiple Trees File")

    (options, args) = parser.parse_args()
    cmd = []

    # Required parameters
    binary = options.binary
    cmd.append(binary)
    # Threads
    if options.threads > 1:
        threads = "-T %d" % options.threads
        cmd.append(threads)
    # Source
    source = "-s %s" % options.source
    cmd.append(source)
    # Hardcode to "galaxy" first to simplify the output part of the wrapper
    # name = "-n %s" % options.name
    name = "-n galaxy"
    cmd.append(name)
    # Model
    model_type = options.model_type
    base_model = options.base_model
    aa_search_matrix = options.aa_search_matrix
    aa_empirical_freq = options.aa_empirical_freq
    if model_type == 'aminoacid':
        model = "-m %s%s" % (base_model, aa_search_matrix)
        if aa_empirical_freq:
            model = "-m %s%s%s" % (base_model, aa_search_matrix, 'F')
        # (-P)
        if options.external_protein_model:
            external_protein_model = "-P %s" % options.external_protein_model
            cmd.append(external_protein_model)
    else:
        model = "-m %s" % base_model
    cmd.append(model)
    if model == "GTRCAT":
        # (-c)
        if options.numofcats:
            numofcats = "-c %d" % options.numofcats
            cmd.append(numofcats)
    # Optional parameters
    if options.number_of_runs_bootstop:
        number_of_runs_bootstop = "-N %s" % options.number_of_runs_bootstop
        cmd.append(number_of_runs_bootstop)
    else:
        number_of_runs_bootstop = ''
    if options.number_of_runs:
        number_of_runs_opt = "-N %d" % options.number_of_runs
        cmd.append(number_of_runs_opt)
    else:
        number_of_runs_opt = 0
    # (-a)
    if options.weightfile:
        weightfile = "-a %s" % options.weightfile
        cmd.append(weightfile)
    # (-A)
    if options.secondary_structure_model:
        secondary_structure_model = "-A %s" % options.secondary_structure_model
        cmd.append(secondary_structure_model )
    # (-b)
    if options.bootseed:
        bootseed = "-b %d" % options.bootseed
        cmd.append(bootseed)
    else:
        bootseed = 0
    # -C - doesn't work in pthreads version, skipped
    if options.search_complete_random_tree:
        cmd.append("-d")
    if options.ml_search_convergence:
        cmd.append("-D" )
    if options.model_opt_precision:
        model_opt_precision = "-e %f" % options.model_opt_precision
        cmd.append(model_opt_precision)
    if options.excludefile:
        excludefile = "-E %s" % options.excludefile
        cmd.append(excludefile)
    if options.search_algorithm:
        search_algorithm = "-f %s" % options.search_algorithm
        cmd.append(search_algorithm)
    if options.save_memory_cat_model:
        cmd.append("-F")
    if options.groupingfile:
        groupingfile = "-g %s" % options.groupingfile
        cmd.append(groupingfile)
    if options.enable_evol_heuristics:
        enable_evol_heuristics = "-G %f" % options.enable_evol_heuristics
        cmd.append(enable_evol_heuristics )
    if options.initial_rearrangement_setting:
        initial_rearrangement_setting = "-i %s" % options.initial_rearrangement_setting
        cmd.append(initial_rearrangement_setting)
    if options.posterior_bootstopping_analysis:
        posterior_bootstopping_analysis = "-I %s" % options.posterior_bootstopping_analysis
        cmd.append(posterior_bootstopping_analysis)
    if options.majority_rule_consensus:
        majority_rule_consensus = "-J %s" % options.majority_rule_consensus
        cmd.append(majority_rule_consensus)
    if options.print_branch_lengths:
        cmd.append("-k")
    if options.multistate_sub_model:
        multistate_sub_model = "-K %s" % options.multistate_sub_model
        cmd.append(multistate_sub_model)
    if options.estimate_individual_branch_lengths:
        cmd.append("-M")
    if options.outgroup_name:
        outgroup_name = "-o %s" % options.outgroup_name
        cmd.append(outgroup_name)
    if options.disable_undetermined_seq_check:
        cmd.append("-O")
    if options.random_seed:
        random_seed = "-p %d" % options.random_seed
        cmd.append(random_seed)
    multiple_model = None
    if options.multiple_model:
        multiple_model = "-q %s" % options.multiple_model
        cmd.append(multiple_model)
    if options.constraint_file:
        constraint_file = "-r %s" % options.constraint_file
        cmd.append(constraint_file)
    if options.bin_model_parameter_file:
        bin_model_parameter_file_name = "RAxML_binaryModelParameters.galaxy"
        os.symlink(options.bin_model_parameter_file, bin_model_parameter_file_name )
        bin_model_parameter_file = "-R %s" % options.bin_model_parameter_file
        # Needs testing. Is the hardcoded name or the real path needed?
        cmd.append(bin_model_parameter_file)
    if options.secondary_structure_file:
        secondary_structure_file = "-S %s" % options.secondary_structure_file
        cmd.append(secondary_structure_file)
    if options.starting_tree:
        starting_tree = "-t %s" % options.starting_tree
        cmd.append(starting_tree)
    if options.use_median_approximation:
        cmd.append("-u")
    if options.save_memory_gappy_alignments:
        cmd.append("-U")
    if options.disable_rate_heterogeneity:
        cmd.append("-V")
    if options.sliding_window_size:
        sliding_window_size = "-W %d" % options.sliding_window_size
        cmd.append(sliding_window_size)
    if options.rapid_bootstrap_random_seed:
        rapid_bootstrap_random_seed = "-x %d" % options.rapid_bootstrap_random_seed
        cmd.append(rapid_bootstrap_random_seed)
    else:
        rapid_bootstrap_random_seed = 0
    if options.parsimony_starting_tree_only:
        cmd.append("-y")
    if options.file_multiple_trees:
        file_multiple_trees = "-z %s" % options.file_multiple_trees
        cmd.append(file_multiple_trees)

    print "cmd list: ", cmd, "\n"

    full_cmd = " ".join(cmd)
    print "Command string: %s" % full_cmd

    try:
        proc = subprocess.Popen(args=full_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except Exception as err:
        sys.stderr.write("Error invoking command: \n%s\n\n%s\n" % (cmd, err))
        sys.exit(1)
    stdout, stderr = proc.communicate()
    return_code = proc.returncode
    if return_code:
        sys.stdout.write(stdout)
        sys.stderr.write(stderr)
        sys.stderr.write("Return error code %i from command:\n" % return_code)
        sys.stderr.write("%s\n" % cmd)
    else:
        sys.stdout.write(stdout)
        sys.stdout.write(stderr)

    # Multiple runs - concatenate
    if number_of_runs_opt > 0:
        if (bootseed == 0) and (rapid_bootstrap_random_seed == 0 ):
            runfiles = glob.glob('RAxML*RUN*')
            runfiles.sort(key=getint)
        # Logs
            outfile = open('RAxML_log.galaxy', 'w')
            for filename in runfiles:
                if fnmatch.fnmatch(filename, 'RAxML_log.galaxy.RUN.*'):
                    infile = open(filename, 'r')
                    filename_line = "%s\n" % filename
                    outfile.write(filename_line)
                    for line in infile:
                        outfile.write(line)
                    infile.close()
            outfile.close()
        # Parsimony Trees
            outfile = open('RAxML_parsimonyTree.galaxy', 'w')
            for filename in runfiles:
                if fnmatch.fnmatch(filename, 'RAxML_parsimonyTree.galaxy.RUN.*'):
                    infile = open(filename, 'r')
                    filename_line = "%s\n" % filename
                    outfile.write(filename_line)
                    for line in infile:
                        outfile.write(line)
                    infile.close()
            outfile.close()
        # Results
            outfile = open('RAxML_result.galaxy', 'w')
            for filename in runfiles:
                if fnmatch.fnmatch(filename, 'RAxML_result.galaxy.RUN.*'):
                    infile = open(filename, 'r')
                    filename_line = "%s\n" % filename
                    outfile.write(filename_line)
                    for line in infile:
                        outfile.write(line)
                    infile.close()
            outfile.close()
    # Multiple Model Partition Files
    if multiple_model:
        files = glob.glob('RAxML_bestTree.galaxy.PARTITION.*')
        if len(files) > 0:
            files.sort(key=getint)
            outfile = open('RAxML_bestTreePartitions.galaxy', 'w')
            # Best Tree Partitions
            for filename in files:
                if fnmatch.fnmatch(filename, 'RAxML_bestTree.galaxy.PARTITION.*'):
                    infile = open(filename, 'r')
                    filename_line = "%s\n" % filename
                    outfile.write(filename_line)
                    for line in infile:
                        outfile.write(line)
                    infile.close()
            outfile.close()
        else:
            outfile = open('RAxML_bestTreePartitions.galaxy', 'w')
            outfile.write("No partition files were produced.\n")
            outfile.close()

        # Result Partitions
        files = glob.glob('RAxML_result.galaxy.PARTITION.*')
        if len(files) > 0:
            files.sort(key=getint)
            outfile = open('RAxML_resultPartitions.galaxy', 'w')
            for filename in files:
                if fnmatch.fnmatch(filename, 'RAxML_result.galaxy.PARTITION.*'):
                    infile = open(filename, 'r')
                    filename_line = "%s\n" % filename
                    outfile.write(filename_line)
                    for line in infile:
                        outfile.write(line)
                    infile.close()
            outfile.close()
        else:
            outfile = open('RAxML_resultPartitions.galaxy', 'w')
            outfile.write("No partition files were produced.\n")
            outfile.close()

    # DEBUG options
    infof = open('RAxML_info.galaxy', 'a')
    infof.write('\nOM: CLI options DEBUG START:\n')
    infof.write(options.__repr__())
    infof.write('\nOM: CLI options DEBUG END\n')

Example 13

Project: WikiDAT
Source File: tasks.py
View license
    def execute(self, page_fan, rev_fan, page_cache_size, rev_cache_size,
                mirror, download_files, base_ports, control_ports,
                dumps_dir=None, debug=False):
        """
        Run data retrieval and loading actions.
        Arguments:
            - page_fan = Number of workers to fan out page elements parsing
            - rev_fan = Number of workers to fan out rev elements parsing
            - db_user = User name to connect to local database
            - db_passw = Password for database user
            - mirror = Base URL of site hosting XML dumps
        """
        print("----------------------------------------------------------")
        print(("""Executing ETL:RevHistory on lang: {0} date: {1}"""
               .format(self.lang, self.date)))
        print(("ETL lines = {0} page_fan = {1} rev_fan = {2}"
               .format(self.etl_lines, page_fan, rev_fan)))
        print("Download files =", download_files)
        print("Start time is {0}".format(time.strftime("%Y-%m-%d %H:%M:%S %Z",
                                                       time.localtime())))
        print("----------------------------------------------------------")
        print()
        if download_files:
            # TODO: Use proper logging module to track execution progress
            # Choose corresponding file downloader and etl wrapper
            print("Downloading new dump files from %s, for language %s" % (
                  mirror, self.lang))
            self.down = RevHistDownloader(mirror, self.lang, dumps_dir)
            # Donwload latest set of dump files
            self.paths, self.date = self.down.download(self.date)
            if not self.paths:
                print("Error: dump files with pages-logging info not found.")
                print("Program will exit now.")
                sys.exit()

            print("Retrieved dump files for lang %s, date: %s" % (self.lang,
                                                                  self.date))
            print()

        else:
            print("Looking for revision-history dump file(s) in data dir")
            # Case of dumps folder provided explicity
            if dumps_dir:
                # Allow specifying relative paths, as well
                abs_dumps_path = os.path.expanduser(dumps_dir)
                dumps_path = os.path.join(abs_dumps_path,
                                          self.lang + '_dumps', self.date)
                # Retrieve path to all available files to feed ETL lines
                if not os.path.exists(dumps_path):
                    print("No dump files will be downloaded and local folder with dump files not found.")
                    print("Please, specify a valid path to local folder containing dump files.")
                    print("Program will exit now.")
                    sys.exit()

                else:
                    # Attempt to find list of .7z or .xml files to be processed
                    self.paths = glob.glob(os.path.join(dumps_path,
                                                        '*pages-meta-history*.7z'))
                    if not self.paths:
                        self.paths = glob.glob(os.path.join(dumps_path,
                                                            '*pages-meta-history*.xml'))
                        if not self.paths:
                            print("Directory %s does not contain any valid dump file." % dumps_path)
                            print("Program will exit now.")
                            sys.exit()
            # If not provided explicitly, look for default location of
            # dumps directory
            else:
                dumps_dir = os.path.join("data", self.lang + '_dumps',
                                         self.date)
                # Look up dump files in default directory name
                if not os.path.exists(dumps_dir):
                    print("Default directory %s containing dump files not found." % dumps_dir)
                    print ("Program will exit now.")
                    sys.exit()

                else:
                    self.paths = glob.glob(os.path.join(dumps_dir, '*pages-meta-history*.7z'))
                    if not self.paths:
                        self.paths = glob.glob(os.path.join(dumps_dir,
                                                            '*pages-meta-history*.xml'))
                        if not self.paths:
                            print("Directory %s does not contain any valid dump file." % dumps_dir)
                            print("Program will exit now.")
                            sys.exit()
            print("Found revision-history dump file(s) to process.")
            print()
        # Print list of file paths in debug mode
        if debug:
            print("paths: ", str(self.paths))
            print()

        # Create database
        # TODO: Empty correspoding tables if DB already exists
        # or let the user select behaviour with config argument
        if self.DB_exists():
            self.create_DB(complete=False)
        else:
            self.create_DB(complete=True)

        # First insert namespace info in DB
        dump = DumpFile(self.paths[0])
        db_schema = MySQLDB(host=self.host, port=self.port, user=self.db_user,
                            passwd=self.db_passw, db=self.db_name)
        db_schema.connect()
        db_schema.insert_namespaces(nsdict=dump.get_namespaces())
        db_schema.close()

        # Complete the queue of paths to be processed and STOP flags for
        # each ETL subprocess
        paths_queue = mp.JoinableQueue()
        for path in self.paths:
            paths_queue.put(path)

        for x in range(self.etl_lines):
            paths_queue.put('STOP')

        for x in range(self.etl_lines):
            new_etl = RevisionHistoryETL(
                name="[ETL:RevHistory-%s]" % x,
                paths_queue=paths_queue, lang=self.lang,
                page_fan=page_fan, rev_fan=rev_fan,
                page_cache_size=page_cache_size,
                rev_cache_size=rev_cache_size,
                db_name=self.db_name,
                db_user=self.db_user, db_passw=self.db_passw,
                base_port=base_ports[x]+(20*x),
                control_port=control_ports[x]+(20*x)
                )
            self.etl_list.append(new_etl)

        print("ETL:RevHistory task defined OK.")
        print("Proceeding with ETL workflows. This may take time...")
        print()
        # Extract, process and load information in local DB
        for etl in self.etl_list:
            etl.start()
            # Wait a second for new ETL process to start all subprocesses
            time.sleep(1)

        # Wait for ETL lines to finish
        for etl in self.etl_list:
            etl.join()

        # Insert user info after all ETL lines have finished
        # to ensure that all metadata are stored in Redis cache
        # disregarding of the execution order
        data_dir = os.path.join(os.getcwd(), os.path.split(self.paths[0])[0])
        db_users = MySQLDB(host=self.host, port=self.port, user=self.db_user,
                           passwd=self.db_passw, db=self.db_name)
        db_users.connect()
        users_file_to_db(con=db_users, lang=self.lang,
                         log_file=os.path.join(data_dir, 'logs', 'users.log'),
                         tmp_dir=os.path.join(data_dir, 'tmp')
                         )
        db_users.close()
        # TODO: logger; ETL step completed, proceeding with data
        # analysis and visualization
        print("ETL:RevHistory task finished for language %s and date %s" % (
              self.lang, self.date))
        print()
        # Create primary keys for all tables
        # TODO: This must also be tracked by main logging module
        print("Now creating primary key indexes in database tables.")
        print("This may take a while...")
        print()
        db_pks = MySQLDB(host='localhost', port=3306, user=self.db_user,
                         passwd=self.db_passw, db=self.db_name)
        db_pks.connect()
        db_pks.create_pks_revhist()
        db_pks.close()

Example 14

Project: pymetawear
Source File: setup.py
View license
def build_solution():
    # Establish source paths.
    basedir = os.path.abspath(os.path.dirname(__file__))
    pkg_dir = os.path.join(basedir, 'pymetawear')
    path_to_metawear_python_wrappers = os.path.join(
        pkg_dir, 'Metawear-CppAPI', 'wrapper', 'python')

    if os.path.exists(os.path.join(basedir, '.git')):
        # The package was cloned from Github and the submodule can
        # therefore be brought in by Git methods.

        # Git submodule init
        p = subprocess.Popen(['git', 'submodule', 'init'],
                             cwd=basedir, stdout=sys.stdout, stderr=sys.stderr)
        p.communicate()

        # Git submodule update
        p = subprocess.Popen(['git', 'submodule', 'update'],
                             cwd=basedir, stdout=sys.stdout, stderr=sys.stderr)
        p.communicate()
    else:
        # The package was downloaded as zip or tar.gz from PyPI. It should
        # have the MetaWear-CppAPI folder bundled and the building can be done immediately.
        pass

    if platform.uname()[0] == 'Linux':
        arch = os.uname()[-1]
        if arch in ('x86_64', 'amd64'):
            dist_dir = 'x64'
        elif 'arm' in arch:
            dist_dir = 'arm'
        else:
            dist_dir = 'x86'

        # Run make file for MetaWear-CppAPI
        p = subprocess.Popen(
            ['make', 'clean'],
            cwd=os.path.join(pkg_dir, 'Metawear-CppAPI'),
            stdout=sys.stdout, stderr=sys.stderr)
        p.communicate()
        p = subprocess.Popen(
            ['make', 'build'],
            cwd=os.path.join(pkg_dir, 'Metawear-CppAPI'),
            stdout=sys.stdout, stderr=sys.stderr)
        p.communicate()

        path_to_dist_dir = os.path.join(
            pkg_dir, 'Metawear-CppAPI', 'dist', 'release', 'lib', dist_dir)

        for f in [s for s in os.listdir(pkg_dir) if s.startswith('libmetawear')]:
            os.remove(os.path.join(pkg_dir, f))

        symlinks_to_create = []
        # Copy the built shared library to pymetawear folder.
        for dist_file in glob.glob(path_to_dist_dir + "/libmetawear.*"):
            if os.path.islink(dist_file):
                symlinks_to_create.append(
                    (os.path.basename(os.readlink(dist_file)),
                     os.path.basename(dist_file)))
            else:
                destination_file = os.path.join(
                    pkg_dir, os.path.basename(dist_file))
                shutil.copy(dist_file, destination_file)

        # Create symlinks for the libmetawear shared library.
        for symlink_src, symlink_dest in symlinks_to_create:
            destination_symlink = os.path.join(pkg_dir, symlink_dest)
            os.symlink(symlink_src, destination_symlink)

    elif platform.uname()[0] == 'Windows':
        arch = platform.architecture()[0]
        if arch == '32bit':
            dist_dir = 'Win32'
            msbuild_file = 'MetaWear.Win32.vcxproj'
            build_options = '/p:Configuration=Release;Platform=Win32'
        elif 'arm' in arch:
            dist_dir = 'ARM'
            msbuild_file = 'MetaWear.WinRT.vcxproj'
            build_options = '/p:Configuration=Release;Platform=ARM'
        else:
            dist_dir = 'x64'
            msbuild_file = 'MetaWear.Win32.vcxproj'
            build_options = '/p:Configuration=Release;Platform=x64'

        # Run msbuild file for MetaWear-CppAPI
        vsvars_file = glob.glob('c:\\Progr*/**/**/Too*/vsvars32.bat')[0]
        p = subprocess.Popen('"{0}" & MSBuild.exe {1} {2}'.format(
            vsvars_file, msbuild_file, build_options),
            cwd=os.path.join(pkg_dir, 'Metawear-CppAPI'),
            stdout=sys.stdout, stderr=sys.stderr)
        p.communicate()

        for f in [s for s in os.listdir(pkg_dir) if (s.startswith('MetaWear') and s.endswith('.dll'))]:
            os.remove(os.path.join(pkg_dir, f))

        path_to_dist_dir = os.path.join(
            pkg_dir, 'Metawear-CppAPI', 'dist', 'Release', 'lib', dist_dir)

        # Copy the built shared library to pymetawear folder.
        for dist_file in glob.glob(path_to_dist_dir + "/MetaWear.*.dll"):
            destination_file = os.path.join(
                pkg_dir, os.path.basename(dist_file))
            shutil.copy(dist_file, destination_file)
    else:
        raise NotImplementedError("Building on this platform is not implemented.")

    # Copy the Mbientlab Python wrappers to pymetawear folder.
    # First create folders if needed.
    try:
        os.makedirs(os.path.join(pkg_dir, 'mbientlab', 'metawear'))
    except:
        pass

    init_files_to_create = [
        os.path.join(pkg_dir, 'mbientlab', '__init__.py'),
        os.path.join(pkg_dir, 'mbientlab', 'metawear', '__init__.py')
    ]
    for init_file in init_files_to_create:
        with open(init_file, 'w') as f:
            f.write("#!/usr/bin/env python\n# -*- coding: utf-8 -*-")

    # Copy all Python files from the MetWear C++ API Python wrapper
    for pth, _, pyfiles in os.walk(
            os.path.join(path_to_metawear_python_wrappers,
                         'mbientlab', 'metawear')):
        for py_file in filter(lambda x: os.path.splitext(x)[1] == '.py', pyfiles):
            try:
                shutil.copy(
                    os.path.join(pth, py_file),
                    os.path.join(pkg_dir, 'mbientlab', 'metawear', py_file))
            except:
                pass

Example 15

Project: LS-BSR
Source File: ls_bsr.py
View license
def main(directory,id,filter,processors,genes,cluster_method,blast,length,
         max_plog,min_hlog,f_plog,keep,filter_peps,filter_scaffolds,prefix,temp_dir,debug):
    start_dir = os.getcwd()
    ap=os.path.abspath("%s" % start_dir)
    dir_path=os.path.abspath("%s" % directory)
    logging.logPrint("Testing paths of dependencies")
    if blast=="blastn" or blast=="tblastn":
        ab = subprocess.call(['which', 'blastn'])
        if ab == 0:
            print "citation: Altschul SF, Madden TL, Schaffer AA, Zhang J, Zhang Z, Miller W, and Lipman DJ. 1997. Gapped BLAST and PSI-BLAST: a new generation of protein database search programs. Nucleic Acids Res 25:3389-3402"
        else:
            print "blastn isn't in your path, but needs to be!"
            sys.exit()
    if "NULL" in temp_dir:
        fastadir = tempfile.mkdtemp()
    else:
        fastadir = os.path.abspath("%s" % temp_dir)
        if os.path.exists('%s' % temp_dir):
            print "old run directory exists in your genomes directory (%s).  Delete and run again" % temp_dir
            sys.exit()
        else:
            os.makedirs('%s' % temp_dir)
    for infile in glob.glob(os.path.join(dir_path, '*.fasta')):
        name=get_seq_name(infile)
        os.link("%s" % infile, "%s/%s.new" % (fastadir,name))
    if "null" in genes:
        rc = subprocess.call(['which', 'prodigal'])
        if rc == 0:
            pass
        else:
            print "prodigal is not in your path, but needs to be!"
            sys.exit()
        print "citation: Hyatt D, Chen GL, Locascio PF, Land ML, Larimer FW, and Hauser LJ. 2010. Prodigal: prokaryotic gene recognition and translation initiation site identification. BMC Bioinformatics 11:119"
        if "usearch" in cluster_method:
            print "citation: Edgar RC. 2010. Search and clustering orders of magnitude faster than BLAST. Bioinformatics 26:2460-2461"
        elif "cd-hit" in cluster_method:
            print "citation: Li, W., Godzik, A. 2006. Cd-hit: a fast program for clustering and comparing large sets of protein or nuceltodie sequences. Bioinformatics 22(13):1658-1659"
        elif "vsearch" in cluster_method:
            print "citation: Rognes, T., Flouri, T., Nichols, B., Qunice, C., Mahe, Frederic. 2016. VSEARCH: a versatile open source tool for metagenomics. PeerJ Preprints. DOI: https://doi.org/10.7287/peerj.preprints.2409v1"
        if blast=="blat":
            ac = subprocess.call(['which', 'blat'])
            if ac == 0:
                print "citation: W.James Kent. 2002. BLAT - The BLAST-Like Alignment Tool.  Genome Research 12:656-664"
            else:
                print "You have requested blat, but it is not in your PATH"
                sys.exit()
        logging.logPrint("predicting genes with Prodigal")
        predict_genes(fastadir, processors)
        logging.logPrint("Prodigal done")
        """This function produces locus tags"""
        genbank_hits = process_genbank_files(dir_path)
        if genbank_hits == None or len(genbank_hits) == 0:
            os.system("cat *genes.seqs > all_gene_seqs.out")
            if filter_scaffolds == "T":
                filter_scaffolds("all_gene_seqs.out")
                os.system("mv tmp.out all_gene_seqs.out")
            else:
                pass
        else:
            logging.logPrint("Converting genbank files")
            """First combine all of the prodigal files into one file"""
            os.system("cat *genes.seqs > all_gene_seqs.out")
            if filter_scaffolds == "T":
                filter_scaffolds("all_gene_seqs.out")
                os.system("mv tmp.out all_gene_seqs.out")
            else:
                pass
            """This combines the locus tags with the Prodigal prediction"""
            os.system("cat *locus_tags.fasta all_gene_seqs.out > tmp.out")
            os.system("mv tmp.out all_gene_seqs.out")
            """I also need to convert the GenBank file to a FASTA file"""
            for hit in genbank_hits:
                reduced_hit = hit.replace(".gbk","")
                SeqIO.convert("%s/%s" % (dir_path, hit), "genbank", "%s.fasta.new" % reduced_hit, "fasta")
        if "NULL" in cluster_method:
            print "Clustering chosen, but no method selected...exiting"
            sys.exit()
        elif "usearch" in cluster_method:
            ac = subprocess.call(['which', 'usearch'])
            if ac == 0:
                os.system("mkdir split_files")
                os.system("cp all_gene_seqs.out split_files/all_sorted.txt")
                os.chdir("split_files/")
                logging.logPrint("Splitting FASTA file for use with USEARCH")
                split_files("all_sorted.txt")
                logging.logPrint("clustering with USEARCH at an ID of %s" % id)
                run_usearch(id)
                os.system("cat *.usearch.out > all_sorted.txt")
                os.system("mv all_sorted.txt %s" % fastadir)
                os.chdir("%s" % fastadir)
                uclust_cluster(id)
                logging.logPrint("USEARCH clustering finished")
            else:
                print "usearch must be in your path as usearch...exiting"
                sys.exit()
        elif "vsearch" in cluster_method:
            ac = subprocess.call(['which', 'vsearch'])
            if ac == 0:
                logging.logPrint("clustering with VSEARCH at an ID of %s, using %s processors" % (id,processors))
                run_vsearch(id, processors)
                os.system("mv vsearch.out consensus.fasta")
                logging.logPrint("VSEARCH clustering finished")
            else:
                print "vsearch must be in your path as vsearch...exiting"
                sys.exit()
        elif "cd-hit" in cluster_method:
            ac = subprocess.call(['which', 'cd-hit-est'])
            if ac == 0:
                logging.logPrint("clustering with cd-hit at an ID of %s, using %s processors" % (id,processors))
                subprocess.check_call("cd-hit-est -i all_gene_seqs.out -o consensus.fasta -M 0 -T %s -c %s > /dev/null 2>&1" % (processors, id), shell=True)
            else:
                print "cd-hit must be in your path as cd-hit-est...exiting"
                sys.exit()
        """need to check for dups here"""
        dup_ids = test_duplicate_header_ids("consensus.fasta")
        if dup_ids == "True":
            pass
        elif dup_ids == "False":
            print "duplicate headers identified, renaming.."
            rename_fasta_header("consensus.fasta", "tmp.txt")
            os.system("mv tmp.txt consensus.fasta")
        else:
            pass
        if "tblastn" == blast:
            subprocess.check_call("makeblastdb -in consensus.fasta -dbtype nucl > /dev/null 2>&1", shell=True)
            translate_consensus("consensus.fasta")
            if filter_peps == "T":
                filter_seqs("tmp.pep")
                os.system("rm tmp.pep")
            else:
                os.system("mv tmp.pep consensus.pep")
            clusters = get_cluster_ids("consensus.pep")
            blast_against_self_tblastn("tblastn", "consensus.fasta", "consensus.pep", "tmp_blast.out", processors, filter)
        elif "blastn" == blast:
            subprocess.check_call("makeblastdb -in consensus.fasta -dbtype nucl > /dev/null 2>&1", shell=True)
            blast_against_self_blastn("blastn", "consensus.fasta", "consensus.fasta", "tmp_blast.out", filter, processors)
            clusters = get_cluster_ids("consensus.fasta")
        elif "blat" == blast:
            blat_against_self("consensus.fasta", "consensus.fasta", "tmp_blast.out", processors)
            clusters = get_cluster_ids("consensus.fasta")
        else:
            pass
        subprocess.check_call("sort -u -k 1,1 tmp_blast.out > self_blast.out", shell=True)
        ref_scores=parse_self_blast(open("self_blast.out", "U"))
        subprocess.check_call("rm tmp_blast.out self_blast.out", shell=True)
        os.system("rm *new_genes.*")
        if blast == "tblastn" or blast == "blastn":
            logging.logPrint("starting BLAST")
        else:
            logging.logPrint("starting BLAT")
        if "tblastn" == blast:
            blast_against_each_genome_tblastn(dir_path, processors, "consensus.pep", filter)
        elif "blastn" == blast:
            blast_against_each_genome_blastn(dir_path, processors, filter, "consensus.fasta")
        elif "blat" == blast:
            blat_against_each_genome(dir_path, "consensus.fasta",processors)
        else:
            pass
    else:
        logging.logPrint("Using pre-compiled set of predicted genes")
        files = glob.glob(os.path.join(dir_path, "*.fasta"))
        if len(files)==0:
            print "no usable reference genomes found!"
            sys.exit()
        else:
            pass
        gene_path=os.path.abspath("%s" % genes)
        dup_ids = test_duplicate_header_ids(gene_path)
        if dup_ids == "True":
            pass
        elif dup_ids == "False":
            print "duplicate headers identified, exiting.."
            sys.exit()
        clusters = get_cluster_ids(gene_path)
        os.system("cp %s %s" % (gene_path,fastadir))
        os.chdir("%s" % fastadir)
        if gene_path.endswith(".pep"):
            logging.logPrint("using tblastn on peptides")
            try:
                subprocess.check_call("makeblastdb -in %s -dbtype prot > /dev/null 2>&1" % gene_path, shell=True)
            except:
                logging.logPrint("problem encountered with BLAST database")
                sys.exit()
            blast_against_self_tblastn("blastp", gene_path, gene_path, "tmp_blast.out", processors, filter)
            subprocess.check_call("sort -u -k 1,1 tmp_blast.out > self_blast.out", shell=True)
            ref_scores=parse_self_blast(open("self_blast.out", "U"))
            subprocess.check_call("rm tmp_blast.out self_blast.out", shell=True)
            logging.logPrint("starting BLAST")
            blast_against_each_genome_tblastn(dir_path, processors, gene_path, filter)
        elif gene_path.endswith(".fasta"):
            if "tblastn" == blast:
                logging.logPrint("using tblastn")
                translate_genes(gene_path)
                try:
                    subprocess.check_call("makeblastdb -in %s -dbtype nucl > /dev/null 2>&1" % gene_path, shell=True)
                except:
                    logging.logPrint("problem encountered with BLAST database")
                    sys.exit()
                blast_against_self_tblastn("tblastn", gene_path, "genes.pep", "tmp_blast.out", processors, filter)
                subprocess.check_call("sort -u -k 1,1 tmp_blast.out > self_blast.out", shell=True)
                ref_scores=parse_self_blast(open("self_blast.out", "U"))
                subprocess.check_call("rm tmp_blast.out self_blast.out", shell=True)
                logging.logPrint("starting BLAST")
                blast_against_each_genome_tblastn(dir_path, processors, "genes.pep", filter)
                os.system("cp genes.pep %s" % start_dir)
            elif "blastn" == blast:
                logging.logPrint("using blastn")
                try:
                    subprocess.check_call("makeblastdb -in %s -dbtype nucl > /dev/null 2>&1" % gene_path, shell=True)
                except:
                    logging.logPrint("Database not formatted correctly...exiting")
                    sys.exit()
                try:
                    blast_against_self_blastn("blastn", gene_path, gene_path, "tmp_blast.out", filter, processors)
                except:
                    print "problem with blastn, exiting"
                    sys.exit()
                subprocess.check_call("sort -u -k 1,1 tmp_blast.out > self_blast.out", shell=True)
                os.system("cp self_blast.out tmp.out")
                ref_scores=parse_self_blast(open("self_blast.out", "U"))
                subprocess.check_call("rm tmp_blast.out self_blast.out", shell=True)
                logging.logPrint("starting BLAST")
                try:
                    blast_against_each_genome_blastn(dir_path, processors, filter, gene_path)
                except:
                    print "problem with blastn, exiting"
                    sys.exit()
            elif "blat" == blast:
                logging.logPrint("using blat")
                blat_against_self(gene_path, gene_path, "tmp_blast.out", processors)
                subprocess.check_call("sort -u -k 1,1 tmp_blast.out > self_blast.out", shell=True)
                ref_scores=parse_self_blast(open("self_blast.out", "U"))
                subprocess.check_call("rm tmp_blast.out self_blast.out", shell=True)
                logging.logPrint("starting BLAT")
                blat_against_each_genome(dir_path,gene_path,processors)
            else:
                pass
        else:
            print "input file format not supported"
            sys.exit()
    find_dups_dev(ref_scores, length, max_plog, min_hlog, clusters, processors)
    if blast=="blat":
        logging.logPrint("BLAT done")
    else:
        logging.logPrint("BLAST done")
    parse_blast_report("false")
    get_unique_lines()
    curr_dir=os.getcwd()
    table_files = glob.glob(os.path.join(curr_dir, "*.filtered.unique"))
    files_and_temp_names = [(str(idx), os.path.join(curr_dir, f))
                            for idx, f in enumerate(table_files)]
    names=[]
    table_list = []
    nr_sorted=sorted(clusters)
    centroid_list = []
    centroid_list.append(" ")
    for x in nr_sorted:
        centroid_list.append(x)
    table_list.append(centroid_list)
    logging.logPrint("starting matrix building")
    new_names,new_table = new_loop(files_and_temp_names, processors, clusters, debug)
    new_table_list = table_list+new_table
    logging.logPrint("matrix built")
    open("ref.list", "a").write("\n")
    for x in nr_sorted:
        open("ref.list", "a").write("%s\n" % x)
    names_out = open("names.txt", "w")
    names_redux = [val for subl in new_names for val in subl]
    for x in names_redux: print >> names_out, "".join(x)
    names_out.close()
    create_bsr_matrix_dev(new_table_list)
    divide_values("bsr_matrix", ref_scores)
    subprocess.check_call("paste ref.list BSR_matrix_values.txt > %s/bsr_matrix_values.txt" % start_dir, shell=True)
    if "T" in f_plog:
        filter_paralogs("%s/bsr_matrix_values.txt" % start_dir, "paralog_ids.txt")
        os.system("cp bsr_matrix_values_filtered.txt %s" % start_dir)
    else:
        pass
    try:
        subprocess.check_call("cp dup_matrix.txt names.txt consensus.pep consensus.fasta duplicate_ids.txt paralog_ids.txt %s" % ap, shell=True, stderr=open(os.devnull, 'w'))
    except:
        sys.exc_clear()
    """new code to rename files according to a prefix"""
    import datetime
    timestamp = datetime.datetime.now()
    rename = str(timestamp.year), str(timestamp.month), str(timestamp.day), str(timestamp.hour), str(timestamp.minute), str(timestamp.second)
    os.chdir("%s" % ap)
    if "NULL" in prefix:
        os.system("mv dup_matrix.txt %s_dup_matrix.txt" % "".join(rename))
        os.system("mv names.txt %s_names.txt" % "".join(rename))
        os.system("mv duplicate_ids.txt %s_duplicate_ids.txt" % "".join(rename))
        os.system("mv paralog_ids.txt %s_paralog_ids.txt" % "".join(rename))
        os.system("mv bsr_matrix_values.txt %s_bsr_matrix.txt" % "".join(rename))
        if os.path.isfile("consensus.fasta"):
            os.system("mv consensus.fasta %s_consensus.fasta" % "".join(rename))
        if os.path.isfile("consensus.pep"):
            os.system("mv consensus.pep %s_consensus.pep" % "".join(rename))
    else:
        os.system("mv dup_matrix.txt %s_dup_matrix.txt" % prefix)
        os.system("mv names.txt %s_names.txt" % prefix)
        os.system("mv duplicate_ids.txt %s_duplicate_ids.txt" % prefix)
        os.system("mv paralog_ids.txt %s_paralog_ids.txt" % prefix)
        os.system("mv bsr_matrix_values.txt %s_bsr_matrix.txt" % prefix)
        if os.path.isfile("consensus.fasta"):
            os.system("mv consensus.fasta %s_consensus.fasta" % prefix)
        if os.path.isfile("consensus.pep"):
            os.system("mv consensus.pep %s_consensus.pep" % prefix)
    if "NULL" in prefix:
        outfile = open("%s_run_parameters.txt" % "".join(rename), "w")
    else:
        outfile = open("%s_run_parameters.txt" % prefix, "w")
    print >> outfile, "-d %s \\" % directory
    print >> outfile, "-i %s \\" % id
    print >> outfile, "-f %s \\" % filter
    print >> outfile, "-p %s \\" % processors
    print >> outfile, "-g %s \\" % genes
    print >> outfile, "-c %s \\" % cluster_method
    print >> outfile, "-b %s \\" % blast
    print >> outfile, "-l %s \\" % length
    print >> outfile, "-m %s \\" % max_plog
    print >> outfile, "-n %s \\" % min_hlog
    print >> outfile, "-t %s \\" % f_plog
    print >> outfile, "-k %s \\" % keep
    print >> outfile, "-s %s \\" % filter_peps
    print >> outfile, "-e %s \\" % filter_scaffolds
    print >> outfile, "-x %s \\" % prefix
    print >> outfile, "-z %s" % debug
    print >> outfile, "temp data stored here if kept: %s" % fastadir
    outfile.close()
    logging.logPrint("all Done")
    if "T" == keep:
        pass
    else:
        os.system("rm -rf %s" % fastadir)
    os.chdir("%s" % ap)

Example 16

Project: python-for-android
Source File: build.py
View license
    def prepare_build_environment(self, user_sdk_dir, user_ndk_dir,
                                  user_android_api, user_ndk_ver):
        '''Checks that build dependencies exist and sets internal variables
        for the Android SDK etc.

        ..warning:: This *must* be called before trying any build stuff

        '''

        self.ensure_dirs()

        if self._build_env_prepared:
            return

        # AND: This needs revamping to carefully check each dependency
        # in turn
        ok = True

        # Work out where the Android SDK is
        sdk_dir = None
        if user_sdk_dir:
            sdk_dir = user_sdk_dir
        if sdk_dir is None:  # This is the old P4A-specific var
            sdk_dir = environ.get('ANDROIDSDK', None)
        if sdk_dir is None:  # This seems used more conventionally
            sdk_dir = environ.get('ANDROID_HOME', None)
        if sdk_dir is None:  # Checks in the buildozer SDK dir, useful
            #                # for debug tests of p4a
            possible_dirs = glob.glob(expanduser(join(
                '~', '.buildozer', 'android', 'platform', 'android-sdk-*')))
            if possible_dirs:
                info('Found possible SDK dirs in buildozer dir: {}'.format(
                    ', '.join([d.split(os.sep)[-1] for d in possible_dirs])))
                info('Will attempt to use SDK at {}'.format(possible_dirs[0]))
                warning('This SDK lookup is intended for debug only, if you '
                        'use python-for-android much you should probably '
                        'maintain your own SDK download.')
                sdk_dir = possible_dirs[0]
        if sdk_dir is None:
            warning('Android SDK dir was not specified, exiting.')
            exit(1)
        self.sdk_dir = realpath(sdk_dir)

        # Check what Android API we're using
        android_api = None
        if user_android_api:
            android_api = user_android_api
            if android_api is not None:
                info('Getting Android API version from user argument')
        if android_api is None:
            android_api = environ.get('ANDROIDAPI', None)
            if android_api is not None:
                info('Found Android API target in $ANDROIDAPI')
        if android_api is None:
            info('Android API target was not set manually, using '
                 'the default of {}'.format(DEFAULT_ANDROID_API))
            android_api = DEFAULT_ANDROID_API
        android_api = int(android_api)
        self.android_api = android_api

        if self.android_api >= 21 and self.archs[0].arch == 'armeabi':
            error('Asked to build for armeabi architecture with API '
                  '{}, but API 21 or greater does not support armeabi'.format(
                      self.android_api))
            error('You probably want to build with --arch=armeabi-v7a instead')
            exit(1)

        android = sh.Command(join(sdk_dir, 'tools', 'android'))
        targets = android('list').stdout.decode('utf-8').split('\n')
        apis = [s for s in targets if re.match(r'^ *API level: ', s)]
        apis = [re.findall(r'[0-9]+', s) for s in apis]
        apis = [int(s[0]) for s in apis if s]
        info('Available Android APIs are ({})'.format(
            ', '.join(map(str, apis))))
        if android_api in apis:
            info(('Requested API target {} is available, '
                  'continuing.').format(android_api))
        else:
            warning(('Requested API target {} is not available, install '
                     'it with the SDK android tool.').format(android_api))
            warning('Exiting.')
            exit(1)

        # Find the Android NDK
        # Could also use ANDROID_NDK, but doesn't look like many tools use this
        ndk_dir = None
        if user_ndk_dir:
            ndk_dir = user_ndk_dir
            if ndk_dir is not None:
                info('Getting NDK dir from from user argument')
        if ndk_dir is None:  # The old P4A-specific dir
            ndk_dir = environ.get('ANDROIDNDK', None)
            if ndk_dir is not None:
                info('Found NDK dir in $ANDROIDNDK')
        if ndk_dir is None:  # Apparently the most common convention
            ndk_dir = environ.get('NDK_HOME', None)
            if ndk_dir is not None:
                info('Found NDK dir in $NDK_HOME')
        if ndk_dir is None:  # Another convention (with maven?)
            ndk_dir = environ.get('ANDROID_NDK_HOME', None)
            if ndk_dir is not None:
                info('Found NDK dir in $ANDROID_NDK_HOME')
        if ndk_dir is None:  # Checks in the buildozer NDK dir, useful
            #                # for debug tests of p4a
            possible_dirs = glob.glob(expanduser(join(
                '~', '.buildozer', 'android', 'platform', 'android-ndk-r*')))
            if possible_dirs:
                info('Found possible NDK dirs in buildozer dir: {}'.format(
                    ', '.join([d.split(os.sep)[-1] for d in possible_dirs])))
                info('Will attempt to use NDK at {}'.format(possible_dirs[0]))
                warning('This NDK lookup is intended for debug only, if you '
                        'use python-for-android much you should probably '
                        'maintain your own NDK download.')
                ndk_dir = possible_dirs[0]
        if ndk_dir is None:
            warning('Android NDK dir was not specified, exiting.')
            exit(1)
        self.ndk_dir = realpath(ndk_dir)

        # Find the NDK version, and check it against what the NDK dir
        # seems to report
        ndk_ver = None
        if user_ndk_ver:
            ndk_ver = user_ndk_ver
            if ndk_dir is not None:
                info('Got NDK version from from user argument')
        if ndk_ver is None:
            ndk_ver = environ.get('ANDROIDNDKVER', None)
            if ndk_dir is not None:
                info('Got NDK version from $ANDROIDNDKVER')

        self.ndk = 'google'

        try:
            with open(join(ndk_dir, 'RELEASE.TXT')) as fileh:
                reported_ndk_ver = fileh.read().split(' ')[0].strip()
        except IOError:
            pass
        else:
            if reported_ndk_ver.startswith('crystax-ndk-'):
                reported_ndk_ver = reported_ndk_ver[12:]
                self.ndk = 'crystax'
            if ndk_ver is None:
                ndk_ver = reported_ndk_ver
                info(('Got Android NDK version from the NDK dir: '
                      'it is {}').format(ndk_ver))
            else:
                if ndk_ver != reported_ndk_ver:
                    warning('NDK version was set as {}, but checking '
                            'the NDK dir claims it is {}.'.format(
                                ndk_ver, reported_ndk_ver))
                    warning('The build will try to continue, but it may '
                            'fail and you should check '
                            'that your setting is correct.')
                    warning('If the NDK dir result is correct, you don\'t '
                            'need to manually set the NDK ver.')
        if ndk_ver is None:
            warning('Android NDK version could not be found, exiting.')
            exit(1)
        self.ndk_ver = ndk_ver

        info('Using {} NDK {}'.format(self.ndk.capitalize(), self.ndk_ver))

        virtualenv = None
        if virtualenv is None:
            virtualenv = sh.which('virtualenv2')
        if virtualenv is None:
            virtualenv = sh.which('virtualenv-2.7')
        if virtualenv is None:
            virtualenv = sh.which('virtualenv')
        if virtualenv is None:
            raise IOError('Couldn\'t find a virtualenv executable, '
                          'you must install this to use p4a.')
        self.virtualenv = virtualenv
        info('Found virtualenv at {}'.format(virtualenv))

        # path to some tools
        self.ccache = sh.which("ccache")
        if not self.ccache:
            info('ccache is missing, the build will not be optimized in the '
                 'future.')
        for cython_fn in ("cython2", "cython-2.7", "cython"):
            cython = sh.which(cython_fn)
            if cython:
                self.cython = cython
                break
        else:
            error('No cython binary found. Exiting.')
            exit(1)
        if not self.cython:
            ok = False
            warning("Missing requirement: cython is not installed")

        # AND: need to change if supporting multiple archs at once
        arch = self.archs[0]
        platform_dir = arch.platform_dir
        toolchain_prefix = arch.toolchain_prefix
        toolchain_version = None
        self.ndk_platform = join(
            self.ndk_dir,
            'platforms',
            'android-{}'.format(self.android_api),
            platform_dir)
        if not exists(self.ndk_platform):
            warning('ndk_platform doesn\'t exist: {}'.format(
                self.ndk_platform))
            ok = False

        py_platform = sys.platform
        if py_platform in ['linux2', 'linux3']:
            py_platform = 'linux'

        toolchain_versions = []
        toolchain_path = join(self.ndk_dir, 'toolchains')
        if os.path.isdir(toolchain_path):
            toolchain_contents = glob.glob('{}/{}-*'.format(toolchain_path,
                                                            toolchain_prefix))
            toolchain_versions = [split(path)[-1][len(toolchain_prefix) + 1:]
                                  for path in toolchain_contents]
        else:
            warning('Could not find toolchain subdirectory!')
            ok = False
        toolchain_versions.sort()

        toolchain_versions_gcc = []
        for toolchain_version in toolchain_versions:
            if toolchain_version[0].isdigit():
                # GCC toolchains begin with a number
                toolchain_versions_gcc.append(toolchain_version)

        if toolchain_versions:
            info('Found the following toolchain versions: {}'.format(
                toolchain_versions))
            info('Picking the latest gcc toolchain, here {}'.format(
                toolchain_versions_gcc[-1]))
            toolchain_version = toolchain_versions_gcc[-1]
        else:
            warning('Could not find any toolchain for {}!'.format(
                toolchain_prefix))
            ok = False

        self.toolchain_prefix = toolchain_prefix
        self.toolchain_version = toolchain_version
        # Modify the path so that sh finds modules appropriately
        environ['PATH'] = (
            '{ndk_dir}/toolchains/{toolchain_prefix}-{toolchain_version}/'
            'prebuilt/{py_platform}-x86/bin/:{ndk_dir}/toolchains/'
            '{toolchain_prefix}-{toolchain_version}/prebuilt/'
            '{py_platform}-x86_64/bin/:{ndk_dir}:{sdk_dir}/'
            'tools:{path}').format(
                sdk_dir=self.sdk_dir, ndk_dir=self.ndk_dir,
                toolchain_prefix=toolchain_prefix,
                toolchain_version=toolchain_version,
                py_platform=py_platform, path=environ.get('PATH'))

        for executable in ("pkg-config", "autoconf", "automake", "libtoolize",
                           "tar", "bzip2", "unzip", "make", "gcc", "g++"):
            if not sh.which(executable):
                warning("Missing executable: {} is not installed".format(
                    executable))

        if not ok:
            error('{}python-for-android cannot continue; aborting{}'.format(
                Err_Fore.RED, Err_Fore.RESET))
            sys.exit(1)

Example 17

Project: python-for-android
Source File: build.py
View license
    def prepare_build_environment(self, user_sdk_dir, user_ndk_dir,
                                  user_android_api, user_ndk_ver):
        '''Checks that build dependencies exist and sets internal variables
        for the Android SDK etc.

        ..warning:: This *must* be called before trying any build stuff

        '''

        self.ensure_dirs()

        if self._build_env_prepared:
            return

        # AND: This needs revamping to carefully check each dependency
        # in turn
        ok = True

        # Work out where the Android SDK is
        sdk_dir = None
        if user_sdk_dir:
            sdk_dir = user_sdk_dir
        if sdk_dir is None:  # This is the old P4A-specific var
            sdk_dir = environ.get('ANDROIDSDK', None)
        if sdk_dir is None:  # This seems used more conventionally
            sdk_dir = environ.get('ANDROID_HOME', None)
        if sdk_dir is None:  # Checks in the buildozer SDK dir, useful
            #                # for debug tests of p4a
            possible_dirs = glob.glob(expanduser(join(
                '~', '.buildozer', 'android', 'platform', 'android-sdk-*')))
            if possible_dirs:
                info('Found possible SDK dirs in buildozer dir: {}'.format(
                    ', '.join([d.split(os.sep)[-1] for d in possible_dirs])))
                info('Will attempt to use SDK at {}'.format(possible_dirs[0]))
                warning('This SDK lookup is intended for debug only, if you '
                        'use python-for-android much you should probably '
                        'maintain your own SDK download.')
                sdk_dir = possible_dirs[0]
        if sdk_dir is None:
            warning('Android SDK dir was not specified, exiting.')
            exit(1)
        self.sdk_dir = realpath(sdk_dir)

        # Check what Android API we're using
        android_api = None
        if user_android_api:
            android_api = user_android_api
            if android_api is not None:
                info('Getting Android API version from user argument')
        if android_api is None:
            android_api = environ.get('ANDROIDAPI', None)
            if android_api is not None:
                info('Found Android API target in $ANDROIDAPI')
        if android_api is None:
            info('Android API target was not set manually, using '
                 'the default of {}'.format(DEFAULT_ANDROID_API))
            android_api = DEFAULT_ANDROID_API
        android_api = int(android_api)
        self.android_api = android_api

        if self.android_api >= 21 and self.archs[0].arch == 'armeabi':
            error('Asked to build for armeabi architecture with API '
                  '{}, but API 21 or greater does not support armeabi'.format(
                      self.android_api))
            error('You probably want to build with --arch=armeabi-v7a instead')
            exit(1)

        android = sh.Command(join(sdk_dir, 'tools', 'android'))
        targets = android('list').stdout.decode('utf-8').split('\n')
        apis = [s for s in targets if re.match(r'^ *API level: ', s)]
        apis = [re.findall(r'[0-9]+', s) for s in apis]
        apis = [int(s[0]) for s in apis if s]
        info('Available Android APIs are ({})'.format(
            ', '.join(map(str, apis))))
        if android_api in apis:
            info(('Requested API target {} is available, '
                  'continuing.').format(android_api))
        else:
            warning(('Requested API target {} is not available, install '
                     'it with the SDK android tool.').format(android_api))
            warning('Exiting.')
            exit(1)

        # Find the Android NDK
        # Could also use ANDROID_NDK, but doesn't look like many tools use this
        ndk_dir = None
        if user_ndk_dir:
            ndk_dir = user_ndk_dir
            if ndk_dir is not None:
                info('Getting NDK dir from from user argument')
        if ndk_dir is None:  # The old P4A-specific dir
            ndk_dir = environ.get('ANDROIDNDK', None)
            if ndk_dir is not None:
                info('Found NDK dir in $ANDROIDNDK')
        if ndk_dir is None:  # Apparently the most common convention
            ndk_dir = environ.get('NDK_HOME', None)
            if ndk_dir is not None:
                info('Found NDK dir in $NDK_HOME')
        if ndk_dir is None:  # Another convention (with maven?)
            ndk_dir = environ.get('ANDROID_NDK_HOME', None)
            if ndk_dir is not None:
                info('Found NDK dir in $ANDROID_NDK_HOME')
        if ndk_dir is None:  # Checks in the buildozer NDK dir, useful
            #                # for debug tests of p4a
            possible_dirs = glob.glob(expanduser(join(
                '~', '.buildozer', 'android', 'platform', 'android-ndk-r*')))
            if possible_dirs:
                info('Found possible NDK dirs in buildozer dir: {}'.format(
                    ', '.join([d.split(os.sep)[-1] for d in possible_dirs])))
                info('Will attempt to use NDK at {}'.format(possible_dirs[0]))
                warning('This NDK lookup is intended for debug only, if you '
                        'use python-for-android much you should probably '
                        'maintain your own NDK download.')
                ndk_dir = possible_dirs[0]
        if ndk_dir is None:
            warning('Android NDK dir was not specified, exiting.')
            exit(1)
        self.ndk_dir = realpath(ndk_dir)

        # Find the NDK version, and check it against what the NDK dir
        # seems to report
        ndk_ver = None
        if user_ndk_ver:
            ndk_ver = user_ndk_ver
            if ndk_dir is not None:
                info('Got NDK version from from user argument')
        if ndk_ver is None:
            ndk_ver = environ.get('ANDROIDNDKVER', None)
            if ndk_dir is not None:
                info('Got NDK version from $ANDROIDNDKVER')

        self.ndk = 'google'

        try:
            with open(join(ndk_dir, 'RELEASE.TXT')) as fileh:
                reported_ndk_ver = fileh.read().split(' ')[0].strip()
        except IOError:
            pass
        else:
            if reported_ndk_ver.startswith('crystax-ndk-'):
                reported_ndk_ver = reported_ndk_ver[12:]
                self.ndk = 'crystax'
            if ndk_ver is None:
                ndk_ver = reported_ndk_ver
                info(('Got Android NDK version from the NDK dir: '
                      'it is {}').format(ndk_ver))
            else:
                if ndk_ver != reported_ndk_ver:
                    warning('NDK version was set as {}, but checking '
                            'the NDK dir claims it is {}.'.format(
                                ndk_ver, reported_ndk_ver))
                    warning('The build will try to continue, but it may '
                            'fail and you should check '
                            'that your setting is correct.')
                    warning('If the NDK dir result is correct, you don\'t '
                            'need to manually set the NDK ver.')
        if ndk_ver is None:
            warning('Android NDK version could not be found, exiting.')
            exit(1)
        self.ndk_ver = ndk_ver

        info('Using {} NDK {}'.format(self.ndk.capitalize(), self.ndk_ver))

        virtualenv = None
        if virtualenv is None:
            virtualenv = sh.which('virtualenv2')
        if virtualenv is None:
            virtualenv = sh.which('virtualenv-2.7')
        if virtualenv is None:
            virtualenv = sh.which('virtualenv')
        if virtualenv is None:
            raise IOError('Couldn\'t find a virtualenv executable, '
                          'you must install this to use p4a.')
        self.virtualenv = virtualenv
        info('Found virtualenv at {}'.format(virtualenv))

        # path to some tools
        self.ccache = sh.which("ccache")
        if not self.ccache:
            info('ccache is missing, the build will not be optimized in the '
                 'future.')
        for cython_fn in ("cython2", "cython-2.7", "cython"):
            cython = sh.which(cython_fn)
            if cython:
                self.cython = cython
                break
        else:
            error('No cython binary found. Exiting.')
            exit(1)
        if not self.cython:
            ok = False
            warning("Missing requirement: cython is not installed")

        # AND: need to change if supporting multiple archs at once
        arch = self.archs[0]
        platform_dir = arch.platform_dir
        toolchain_prefix = arch.toolchain_prefix
        toolchain_version = None
        self.ndk_platform = join(
            self.ndk_dir,
            'platforms',
            'android-{}'.format(self.android_api),
            platform_dir)
        if not exists(self.ndk_platform):
            warning('ndk_platform doesn\'t exist: {}'.format(
                self.ndk_platform))
            ok = False

        py_platform = sys.platform
        if py_platform in ['linux2', 'linux3']:
            py_platform = 'linux'

        toolchain_versions = []
        toolchain_path = join(self.ndk_dir, 'toolchains')
        if os.path.isdir(toolchain_path):
            toolchain_contents = glob.glob('{}/{}-*'.format(toolchain_path,
                                                            toolchain_prefix))
            toolchain_versions = [split(path)[-1][len(toolchain_prefix) + 1:]
                                  for path in toolchain_contents]
        else:
            warning('Could not find toolchain subdirectory!')
            ok = False
        toolchain_versions.sort()

        toolchain_versions_gcc = []
        for toolchain_version in toolchain_versions:
            if toolchain_version[0].isdigit():
                # GCC toolchains begin with a number
                toolchain_versions_gcc.append(toolchain_version)

        if toolchain_versions:
            info('Found the following toolchain versions: {}'.format(
                toolchain_versions))
            info('Picking the latest gcc toolchain, here {}'.format(
                toolchain_versions_gcc[-1]))
            toolchain_version = toolchain_versions_gcc[-1]
        else:
            warning('Could not find any toolchain for {}!'.format(
                toolchain_prefix))
            ok = False

        self.toolchain_prefix = toolchain_prefix
        self.toolchain_version = toolchain_version
        # Modify the path so that sh finds modules appropriately
        environ['PATH'] = (
            '{ndk_dir}/toolchains/{toolchain_prefix}-{toolchain_version}/'
            'prebuilt/{py_platform}-x86/bin/:{ndk_dir}/toolchains/'
            '{toolchain_prefix}-{toolchain_version}/prebuilt/'
            '{py_platform}-x86_64/bin/:{ndk_dir}:{sdk_dir}/'
            'tools:{path}').format(
                sdk_dir=self.sdk_dir, ndk_dir=self.ndk_dir,
                toolchain_prefix=toolchain_prefix,
                toolchain_version=toolchain_version,
                py_platform=py_platform, path=environ.get('PATH'))

        for executable in ("pkg-config", "autoconf", "automake", "libtoolize",
                           "tar", "bzip2", "unzip", "make", "gcc", "g++"):
            if not sh.which(executable):
                warning("Missing executable: {} is not installed".format(
                    executable))

        if not ok:
            error('{}python-for-android cannot continue; aborting{}'.format(
                Err_Fore.RED, Err_Fore.RESET))
            sys.exit(1)

Example 18

Project: calibre
Source File: freeze.py
View license
    def freeze(self):
        shutil.copy2(self.j(self.src_root, 'LICENSE'), self.base)

        self.info('Adding CRT')
        shutil.copytree(CRT, self.j(self.base, os.path.basename(CRT)))

        self.info('Adding resources...')
        tgt = self.j(self.base, 'resources')
        if os.path.exists(tgt):
            shutil.rmtree(tgt)
        shutil.copytree(self.j(self.src_root, 'resources'), tgt)

        self.info('Adding Qt and python...')
        shutil.copytree(r'C:\Python%s\DLLs'%self.py_ver, self.dll_dir,
                ignore=shutil.ignore_patterns('msvc*.dll', 'Microsoft.*'))
        for x in glob.glob(self.j(OPENSSL_DIR, 'bin', '*.dll')):
            shutil.copy2(x, self.dll_dir)
        for x in glob.glob(self.j(ICU_DIR, 'source', 'lib', '*.dll')):
            shutil.copy2(x, self.dll_dir)

        for x in QT_DLLS:
            shutil.copy2(os.path.join(QT_DIR, 'bin', x), self.dll_dir)
        shutil.copy2(r'C:\windows\system32\python%s.dll'%self.py_ver,
                    self.dll_dir)
        for dirpath, dirnames, filenames in os.walk(r'C:\Python%s\Lib'%self.py_ver):
            if os.path.basename(dirpath) == 'pythonwin':
                continue
            for f in filenames:
                if f.lower().endswith('.dll'):
                    f = self.j(dirpath, f)
                    shutil.copy2(f, self.dll_dir)
        shutil.copy2(
            r'C:\Python%(v)s\Lib\site-packages\pywin32_system32\pywintypes%(v)s.dll'
            % dict(v=self.py_ver), self.dll_dir)

        def ignore_lib(root, items):
            ans = []
            for x in items:
                ext = os.path.splitext(x)[1]
                if (not ext and (x in ('demos', 'tests'))) or \
                    (ext in ('.dll', '.chm', '.htm', '.txt')):
                    ans.append(x)
            return ans

        shutil.copytree(r'C:\Python%s\Lib'%self.py_ver, self.lib_dir,
                ignore=ignore_lib)

        # Fix win32com
        sp_dir = self.j(self.lib_dir, 'site-packages')
        comext = self.j(sp_dir, 'win32comext')
        shutil.copytree(self.j(comext, 'shell'), self.j(sp_dir, 'win32com', 'shell'))
        shutil.rmtree(comext)

        # Fix PyCrypto and Pillow, removing the bootstrap .py modules that load
        # the .pyd modules, since they do not work when in a zip file
        for folder in os.listdir(sp_dir):
            folder = self.j(sp_dir, folder)
            if os.path.isdir(folder):
                self.fix_pyd_bootstraps_in(folder)

        for pat in (r'PyQt5\uic\port_v3', ):
            x = glob.glob(self.j(self.lib_dir, 'site-packages', pat))[0]
            shutil.rmtree(x)
        pyqt = self.j(self.lib_dir, 'site-packages', 'PyQt5')
        for x in {x for x in os.listdir(pyqt) if x.endswith('.pyd')} - PYQT_MODULES:
            os.remove(self.j(pyqt, x))

        self.info('Adding calibre sources...')
        for x in glob.glob(self.j(self.SRC, '*')):
            if os.path.isdir(x):
                shutil.copytree(x, self.j(sp_dir, self.b(x)))
            else:
                shutil.copy(x, self.j(sp_dir, self.b(x)))

        for x in (r'calibre\manual', r'calibre\trac', 'pythonwin'):
            deld = self.j(sp_dir, x)
            if os.path.exists(deld):
                shutil.rmtree(deld)

        for x in os.walk(self.j(sp_dir, 'calibre')):
            for f in x[-1]:
                if not f.endswith('.py'):
                    os.remove(self.j(x[0], f))

        self.info('Byte-compiling all python modules...')
        for x in ('test', 'lib2to3', 'distutils'):
            shutil.rmtree(self.j(self.lib_dir, x))
        for x in os.walk(self.lib_dir):
            root = x[0]
            for f in x[-1]:
                if f.endswith('.py'):
                    y = self.j(root, f)
                    rel = os.path.relpath(y, self.lib_dir)
                    try:
                        py_compile.compile(y, dfile=rel, doraise=True)
                        os.remove(y)
                    except:
                        self.warn('Failed to byte-compile', y)
                    pyc, pyo = y+'c', y+'o'
                    epyc, epyo, epy = map(os.path.exists, (pyc,pyo,y))
                    if (epyc or epyo) and epy:
                        os.remove(y)
                    if epyo and epyc:
                        os.remove(pyc)

        self.info('\nAdding Qt plugins...')
        qt_prefix = QT_DIR
        plugdir = self.j(qt_prefix, 'plugins')
        tdir = self.j(self.base, 'qt_plugins')
        for d in QT_PLUGINS:
            self.info('\t', d)
            imfd = os.path.join(plugdir, d)
            tg = os.path.join(tdir, d)
            if os.path.exists(tg):
                shutil.rmtree(tg)
            shutil.copytree(imfd, tg)

        for dirpath, dirnames, filenames in os.walk(tdir):
            for x in filenames:
                if not x.endswith('.dll'):
                    os.remove(self.j(dirpath, x))

        self.info('\nAdding third party dependencies')

        self.info('\tAdding misc binary deps')
        bindir = os.path.join(SW, 'bin')
        for x in ('pdftohtml', 'pdfinfo', 'pdftoppm', 'jpegtran-calibre', 'cjpeg-calibre'):
            shutil.copy2(os.path.join(bindir, x+'.exe'), self.base)
        for x in ('', '.manifest'):
            fname = 'optipng.exe' + x
            src = os.path.join(bindir, fname)
            shutil.copy2(src, self.base)
            src = os.path.join(self.base, fname)
            os.rename(src, src.replace('.exe', '-calibre.exe'))
        for pat in ('*.dll',):
            for f in glob.glob(os.path.join(bindir, pat)):
                ok = True
                for ex in ('expatw', 'testplug'):
                    if ex in f.lower():
                        ok = False
                if not ok:
                    continue
                dest = self.dll_dir
                shutil.copy2(f, dest)
        for x in ('zlib1.dll', 'libxml2.dll', 'libxslt.dll', 'libexslt.dll'):
            msrc = self.j(bindir, x+'.manifest')
            if os.path.exists(msrc):
                shutil.copy2(msrc, self.dll_dir)

Example 19

Project: calibre
Source File: freeze.py
View license
    def freeze(self):
        shutil.copy2(self.j(self.src_root, 'LICENSE'), self.base)

        self.info('Adding CRT')
        shutil.copytree(CRT, self.j(self.base, os.path.basename(CRT)))

        self.info('Adding resources...')
        tgt = self.j(self.base, 'resources')
        if os.path.exists(tgt):
            shutil.rmtree(tgt)
        shutil.copytree(self.j(self.src_root, 'resources'), tgt)

        self.info('Adding Qt and python...')
        shutil.copytree(r'C:\Python%s\DLLs'%self.py_ver, self.dll_dir,
                ignore=shutil.ignore_patterns('msvc*.dll', 'Microsoft.*'))
        for x in glob.glob(self.j(OPENSSL_DIR, 'bin', '*.dll')):
            shutil.copy2(x, self.dll_dir)
        for x in glob.glob(self.j(ICU_DIR, 'source', 'lib', '*.dll')):
            shutil.copy2(x, self.dll_dir)

        for x in QT_DLLS:
            shutil.copy2(os.path.join(QT_DIR, 'bin', x), self.dll_dir)
        shutil.copy2(r'C:\windows\system32\python%s.dll'%self.py_ver,
                    self.dll_dir)
        for dirpath, dirnames, filenames in os.walk(r'C:\Python%s\Lib'%self.py_ver):
            if os.path.basename(dirpath) == 'pythonwin':
                continue
            for f in filenames:
                if f.lower().endswith('.dll'):
                    f = self.j(dirpath, f)
                    shutil.copy2(f, self.dll_dir)
        shutil.copy2(
            r'C:\Python%(v)s\Lib\site-packages\pywin32_system32\pywintypes%(v)s.dll'
            % dict(v=self.py_ver), self.dll_dir)

        def ignore_lib(root, items):
            ans = []
            for x in items:
                ext = os.path.splitext(x)[1]
                if (not ext and (x in ('demos', 'tests'))) or \
                    (ext in ('.dll', '.chm', '.htm', '.txt')):
                    ans.append(x)
            return ans

        shutil.copytree(r'C:\Python%s\Lib'%self.py_ver, self.lib_dir,
                ignore=ignore_lib)

        # Fix win32com
        sp_dir = self.j(self.lib_dir, 'site-packages')
        comext = self.j(sp_dir, 'win32comext')
        shutil.copytree(self.j(comext, 'shell'), self.j(sp_dir, 'win32com', 'shell'))
        shutil.rmtree(comext)

        # Fix PyCrypto and Pillow, removing the bootstrap .py modules that load
        # the .pyd modules, since they do not work when in a zip file
        for folder in os.listdir(sp_dir):
            folder = self.j(sp_dir, folder)
            if os.path.isdir(folder):
                self.fix_pyd_bootstraps_in(folder)

        for pat in (r'PyQt5\uic\port_v3', ):
            x = glob.glob(self.j(self.lib_dir, 'site-packages', pat))[0]
            shutil.rmtree(x)
        pyqt = self.j(self.lib_dir, 'site-packages', 'PyQt5')
        for x in {x for x in os.listdir(pyqt) if x.endswith('.pyd')} - PYQT_MODULES:
            os.remove(self.j(pyqt, x))

        self.info('Adding calibre sources...')
        for x in glob.glob(self.j(self.SRC, '*')):
            if os.path.isdir(x):
                shutil.copytree(x, self.j(sp_dir, self.b(x)))
            else:
                shutil.copy(x, self.j(sp_dir, self.b(x)))

        for x in (r'calibre\manual', r'calibre\trac', 'pythonwin'):
            deld = self.j(sp_dir, x)
            if os.path.exists(deld):
                shutil.rmtree(deld)

        for x in os.walk(self.j(sp_dir, 'calibre')):
            for f in x[-1]:
                if not f.endswith('.py'):
                    os.remove(self.j(x[0], f))

        self.info('Byte-compiling all python modules...')
        for x in ('test', 'lib2to3', 'distutils'):
            shutil.rmtree(self.j(self.lib_dir, x))
        for x in os.walk(self.lib_dir):
            root = x[0]
            for f in x[-1]:
                if f.endswith('.py'):
                    y = self.j(root, f)
                    rel = os.path.relpath(y, self.lib_dir)
                    try:
                        py_compile.compile(y, dfile=rel, doraise=True)
                        os.remove(y)
                    except:
                        self.warn('Failed to byte-compile', y)
                    pyc, pyo = y+'c', y+'o'
                    epyc, epyo, epy = map(os.path.exists, (pyc,pyo,y))
                    if (epyc or epyo) and epy:
                        os.remove(y)
                    if epyo and epyc:
                        os.remove(pyc)

        self.info('\nAdding Qt plugins...')
        qt_prefix = QT_DIR
        plugdir = self.j(qt_prefix, 'plugins')
        tdir = self.j(self.base, 'qt_plugins')
        for d in QT_PLUGINS:
            self.info('\t', d)
            imfd = os.path.join(plugdir, d)
            tg = os.path.join(tdir, d)
            if os.path.exists(tg):
                shutil.rmtree(tg)
            shutil.copytree(imfd, tg)

        for dirpath, dirnames, filenames in os.walk(tdir):
            for x in filenames:
                if not x.endswith('.dll'):
                    os.remove(self.j(dirpath, x))

        self.info('\nAdding third party dependencies')

        self.info('\tAdding misc binary deps')
        bindir = os.path.join(SW, 'bin')
        for x in ('pdftohtml', 'pdfinfo', 'pdftoppm', 'jpegtran-calibre', 'cjpeg-calibre'):
            shutil.copy2(os.path.join(bindir, x+'.exe'), self.base)
        for x in ('', '.manifest'):
            fname = 'optipng.exe' + x
            src = os.path.join(bindir, fname)
            shutil.copy2(src, self.base)
            src = os.path.join(self.base, fname)
            os.rename(src, src.replace('.exe', '-calibre.exe'))
        for pat in ('*.dll',):
            for f in glob.glob(os.path.join(bindir, pat)):
                ok = True
                for ex in ('expatw', 'testplug'):
                    if ex in f.lower():
                        ok = False
                if not ok:
                    continue
                dest = self.dll_dir
                shutil.copy2(f, dest)
        for x in ('zlib1.dll', 'libxml2.dll', 'libxslt.dll', 'libexslt.dll'):
            msrc = self.j(bindir, x+'.manifest')
            if os.path.exists(msrc):
                shutil.copy2(msrc, self.dll_dir)

Example 20

Project: LASIF
Source File: ses3d_models.py
View license
    def __init__(self, directory, domain, model_type="earth_model"):
        """
        The init function.

        :param directory: The directory where the earth model or kernel is
            located.
        :param model_type: Determined the type of model loaded. Currently
            two are supported:
                * earth_model - The standard SES3D model files (default)
                * kernel - The kernels. Identifies by lots of grad_* files.
                * wavefield - The raw wavefields.
        """
        self.directory = directory
        self.boxfile = os.path.join(self.directory, "boxfile")
        if not os.path.exists(self.boxfile):
            msg = "boxfile not found. Wrong directory?"
            raise ValueError(msg)

        # Read the boxfile.
        self.setup = self._read_boxfile()

        self.domain = domain
        self.model_type = model_type

        self.one_d_model = OneDimensionalModel("ak135-f")

        if model_type == "earth_model":
            # Now check what different models are available in the directory.
            # This information is also used to infer the degree of the used
            # lagrange polynomial.
            components = ["A", "B", "C", "lambda", "mu", "rhoinv", "Q"]
            self.available_derived_components = ["vp", "vsh", "vsv", "rho"]
            self.components = {}
            self.parsed_components = {}
            for component in components:
                files = glob.glob(
                    os.path.join(directory, "%s[0-9]*" % component))
                if len(files) != len(self.setup["subdomains"]):
                    continue
                # Check that the naming is continuous.
                all_good = True
                for _i in xrange(len(self.setup["subdomains"])):
                    if os.path.join(directory,
                                    "%s%i" % (component, _i)) in files:
                        continue
                    all_good = False
                    break
                if all_good is False:
                    msg = "Naming for component %s is off. It will be skipped."
                    warnings.warn(msg)
                    continue
                # They also all need to have the same size.
                if len(set([os.path.getsize(_i) for _i in files])) != 1:
                    msg = ("Component %s has the right number of model files "
                           "but they are not of equal size") % component
                    warnings.warn(msg)
                    continue
                # Sort the files by ascending number.
                files.sort(key=lambda x: int(re.findall(r"\d+$",
                                             (os.path.basename(x)))[0]))
                self.components[component] = {"filenames": files}
        elif model_type == "wavefield":
            components = ["vz", "vx", "vy", "vz"]
            self.available_derived_components = []
            self.components = {}
            self.parsed_components = {}
            for component in components:
                files = glob.glob(os.path.join(directory, "%s_*_*" %
                                  component))
                if not files:
                    continue
                timesteps = collections.defaultdict(list)
                for filename in files:
                    timestep = int(os.path.basename(filename).split("_")[-1])
                    timesteps[timestep].append(filename)

                for timestep, filenames in timesteps.iteritems():
                    self.components["%s %s" % (component, timestep)] = \
                        {"filenames": sorted(
                            filenames,
                            key=lambda x: int(
                                os.path.basename(x).split("_")[1]))}
        elif model_type == "kernel":
            # Now check what different models are available in the directory.
            # This information is also used to infer the degree of the used
            # lagrange polynomial.
            components = ["grad_cp", "grad_csh", "grad_csv", "grad_rho"]
            self.available_derived_components = []
            self.components = {}
            self.parsed_components = {}
            for component in components:
                files = glob.glob(
                    os.path.join(directory, "%s_[0-9]*" % component))
                if len(files) != len(self.setup["subdomains"]):
                    continue
                if len(set([os.path.getsize(_i) for _i in files])) != 1:
                    msg = ("Component %s has the right number of model files "
                           "but they are not of equal size") % component
                    warnings.warn(msg)
                    continue
                    # Sort the files by ascending number.
                files.sort(key=lambda x: int(
                    re.findall(r"\d+$",
                               (os.path.basename(x)))[0]))
                self.components[component] = {"filenames": files}
        else:
            msg = "model_type '%s' not known." % model_type
            raise ValueError(msg)

        # All files for a single component have the same size. Now check that
        # all files have the same size.
        unique_filesizes = len(list(set([
            os.path.getsize(_i["filenames"][0])
            for _i in self.components.itervalues()])))
        if unique_filesizes != 1:
            msg = ("The different components in the folder do not have the "
                   "same number of samples")
            raise ValueError(msg)

        # Now calculate the lagrange polynomial degree. All necessary
        # information is present.
        size = os.path.getsize(self.components.values()[0]["filenames"][0])
        sd = self.setup["subdomains"][0]
        x, y, z = sd["index_x_count"], sd["index_y_count"], sd["index_z_count"]
        self.lagrange_polynomial_degree = \
            int(round(((size * 0.25) / (x * y * z)) ** (1.0 / 3.0) - 1))

        self._calculate_final_dimensions()

        # Setup the boundaries.
        self.lat_bounds = [
            rotations.colat2lat(_i)
            for _i in self.setup["physical_boundaries_x"][::-1]]
        self.lng_bounds = self.setup["physical_boundaries_y"]
        self.depth_bounds = [
            6371 - _i / 1000.0 for _i in self.setup["physical_boundaries_z"]]

        self.collocation_points_lngs = self._get_collocation_points_along_axis(
            self.lng_bounds[0], self.lng_bounds[1],
            self.setup["point_count_in_y"])
        self.collocation_points_lats = self._get_collocation_points_along_axis(
            self.lat_bounds[0], self.lat_bounds[1],
            self.setup["point_count_in_x"])
        self.collocation_points_depth = \
            self._get_collocation_points_along_axis(
                self.depth_bounds[1], self.depth_bounds[0],
                self.setup["point_count_in_z"])[::-1]

Example 21

Project: LASIF
Source File: ses3d_models.py
View license
    def __init__(self, directory, domain, model_type="earth_model"):
        """
        The init function.

        :param directory: The directory where the earth model or kernel is
            located.
        :param model_type: Determined the type of model loaded. Currently
            two are supported:
                * earth_model - The standard SES3D model files (default)
                * kernel - The kernels. Identifies by lots of grad_* files.
                * wavefield - The raw wavefields.
        """
        self.directory = directory
        self.boxfile = os.path.join(self.directory, "boxfile")
        if not os.path.exists(self.boxfile):
            msg = "boxfile not found. Wrong directory?"
            raise ValueError(msg)

        # Read the boxfile.
        self.setup = self._read_boxfile()

        self.domain = domain
        self.model_type = model_type

        self.one_d_model = OneDimensionalModel("ak135-f")

        if model_type == "earth_model":
            # Now check what different models are available in the directory.
            # This information is also used to infer the degree of the used
            # lagrange polynomial.
            components = ["A", "B", "C", "lambda", "mu", "rhoinv", "Q"]
            self.available_derived_components = ["vp", "vsh", "vsv", "rho"]
            self.components = {}
            self.parsed_components = {}
            for component in components:
                files = glob.glob(
                    os.path.join(directory, "%s[0-9]*" % component))
                if len(files) != len(self.setup["subdomains"]):
                    continue
                # Check that the naming is continuous.
                all_good = True
                for _i in xrange(len(self.setup["subdomains"])):
                    if os.path.join(directory,
                                    "%s%i" % (component, _i)) in files:
                        continue
                    all_good = False
                    break
                if all_good is False:
                    msg = "Naming for component %s is off. It will be skipped."
                    warnings.warn(msg)
                    continue
                # They also all need to have the same size.
                if len(set([os.path.getsize(_i) for _i in files])) != 1:
                    msg = ("Component %s has the right number of model files "
                           "but they are not of equal size") % component
                    warnings.warn(msg)
                    continue
                # Sort the files by ascending number.
                files.sort(key=lambda x: int(re.findall(r"\d+$",
                                             (os.path.basename(x)))[0]))
                self.components[component] = {"filenames": files}
        elif model_type == "wavefield":
            components = ["vz", "vx", "vy", "vz"]
            self.available_derived_components = []
            self.components = {}
            self.parsed_components = {}
            for component in components:
                files = glob.glob(os.path.join(directory, "%s_*_*" %
                                  component))
                if not files:
                    continue
                timesteps = collections.defaultdict(list)
                for filename in files:
                    timestep = int(os.path.basename(filename).split("_")[-1])
                    timesteps[timestep].append(filename)

                for timestep, filenames in timesteps.iteritems():
                    self.components["%s %s" % (component, timestep)] = \
                        {"filenames": sorted(
                            filenames,
                            key=lambda x: int(
                                os.path.basename(x).split("_")[1]))}
        elif model_type == "kernel":
            # Now check what different models are available in the directory.
            # This information is also used to infer the degree of the used
            # lagrange polynomial.
            components = ["grad_cp", "grad_csh", "grad_csv", "grad_rho"]
            self.available_derived_components = []
            self.components = {}
            self.parsed_components = {}
            for component in components:
                files = glob.glob(
                    os.path.join(directory, "%s_[0-9]*" % component))
                if len(files) != len(self.setup["subdomains"]):
                    continue
                if len(set([os.path.getsize(_i) for _i in files])) != 1:
                    msg = ("Component %s has the right number of model files "
                           "but they are not of equal size") % component
                    warnings.warn(msg)
                    continue
                    # Sort the files by ascending number.
                files.sort(key=lambda x: int(
                    re.findall(r"\d+$",
                               (os.path.basename(x)))[0]))
                self.components[component] = {"filenames": files}
        else:
            msg = "model_type '%s' not known." % model_type
            raise ValueError(msg)

        # All files for a single component have the same size. Now check that
        # all files have the same size.
        unique_filesizes = len(list(set([
            os.path.getsize(_i["filenames"][0])
            for _i in self.components.itervalues()])))
        if unique_filesizes != 1:
            msg = ("The different components in the folder do not have the "
                   "same number of samples")
            raise ValueError(msg)

        # Now calculate the lagrange polynomial degree. All necessary
        # information is present.
        size = os.path.getsize(self.components.values()[0]["filenames"][0])
        sd = self.setup["subdomains"][0]
        x, y, z = sd["index_x_count"], sd["index_y_count"], sd["index_z_count"]
        self.lagrange_polynomial_degree = \
            int(round(((size * 0.25) / (x * y * z)) ** (1.0 / 3.0) - 1))

        self._calculate_final_dimensions()

        # Setup the boundaries.
        self.lat_bounds = [
            rotations.colat2lat(_i)
            for _i in self.setup["physical_boundaries_x"][::-1]]
        self.lng_bounds = self.setup["physical_boundaries_y"]
        self.depth_bounds = [
            6371 - _i / 1000.0 for _i in self.setup["physical_boundaries_z"]]

        self.collocation_points_lngs = self._get_collocation_points_along_axis(
            self.lng_bounds[0], self.lng_bounds[1],
            self.setup["point_count_in_y"])
        self.collocation_points_lats = self._get_collocation_points_along_axis(
            self.lat_bounds[0], self.lat_bounds[1],
            self.setup["point_count_in_x"])
        self.collocation_points_depth = \
            self._get_collocation_points_along_axis(
                self.depth_bounds[1], self.depth_bounds[0],
                self.setup["point_count_in_z"])[::-1]

Example 22

Project: weboob
Source File: setup.py
View license
def install_weboob():
    scripts = set(os.listdir('scripts'))
    packages = set(find_packages(exclude=['modules']))

    hildon_scripts = set(('masstransit',))
    qt_scripts = set(('qboobmsg',
                      'qhavedate',
                      'qvideoob',
                      'weboob-config-qt',
                      'qwebcontentedit',
                      'qflatboob',
                      'qcineoob',
                      'qcookboob',
                      'qbooblyrics',
                      'qhandjoob'))

    if not options.hildon:
        scripts = scripts - hildon_scripts
    if options.qt:
        build_qt()
    else:
        scripts = scripts - qt_scripts

    hildon_packages = set((
        'weboob.applications.masstransit',
    ))
    qt_packages = set((
        'weboob.applications.qboobmsg',
        'weboob.applications.qboobmsg.ui',
        'weboob.applications.qcineoob',
        'weboob.applications.qcineoob.ui',
        'weboob.applications.qcookboob',
        'weboob.applications.qcookboob.ui',
        'weboob.applications.qbooblyrics',
        'weboob.applications.qbooblyrics.ui',
        'weboob.applications.qhandjoob',
        'weboob.applications.qhandjoob.ui',
        'weboob.applications.qhavedate',
        'weboob.applications.qhavedate.ui',
        'weboob.applications.qvideoob',
        'weboob.applications.qvideoob.ui',
        'weboob.applications.qweboobcfg',
        'weboob.applications.qweboobcfg.ui',
        'weboob.applications.qwebcontentedit',
        'weboob.applications.qwebcontentedit.ui'
        'weboob.applications.qflatboob',
        'weboob.applications.qflatboob.ui'
    ))

    if not options.hildon:
        packages = packages - hildon_packages
    if not options.qt:
        packages = packages - qt_packages

    data_files = [
        ('share/man/man1', glob.glob('man/*')),
    ]
    if options.xdg:
        data_files.extend([
            ('share/applications', glob.glob('desktop/*')),
            ('share/icons/hicolor/64x64/apps', glob.glob('icons/*')),
        ])

    # Do not put PyQt, it does not work properly.
    requirements = [
        'lxml',
        'feedparser',
        'requests>=2.0.0',
        'python-dateutil',
        'PyYAML',
        'prettytable',
        'google-api-python-client',
    ]
    try:
        import Image
    except ImportError:
        requirements.append('Pillow')
    else:
        # detect Pillow-only feature, or weird Debian stuff
        if hasattr(Image, 'alpha_composite') or 'PILcompat' in Image.__file__:
            requirements.append('Pillow')
        else:
            requirements.append('PIL')

    if sys.version_info < (3, 0):
        requirements.append('mechanize')

    if sys.version_info < (3, 2):
        requirements.append('futures')

    if sys.version_info < (2, 6):
        print('Python older than 2.6 is not supported.', file=sys.stderr)
        sys.exit(1)

    if not options.deps:
        requirements = []

    try:
        if sys.argv[1] == 'requirements':
            print('\n'.join(requirements))
            sys.exit(0)
    except IndexError:
        pass

    setup(
        name='weboob',
        version='1.2',
        description='Weboob, Web Outside Of Browsers',
        long_description=open('README').read(),
        author='Romain Bignon',
        author_email='[email protected]',
        maintainer='Romain Bignon',
        maintainer_email='[email protected]',
        url='http://weboob.org/',
        license='GNU AGPL 3',
        classifiers=[
            'Environment :: Console',
            'Environment :: X11 Applications :: Qt',
            'License :: OSI Approved :: GNU Affero General Public License v3',
            'Programming Language :: Python :: 2.6',
            'Programming Language :: Python :: 2.7',
            'Programming Language :: Python',
            'Topic :: Communications :: Email',
            'Topic :: Internet :: WWW/HTTP',
        ],

        packages=packages,
        scripts=[os.path.join('scripts', script) for script in scripts],
        data_files=data_files,

        install_requires=requirements,
    )

Example 23

Project: CRISPResso
Source File: CRISPRessoPooledCORE.py
View license
def main():
    try:
        print '  \n~~~CRISPRessoPooled~~~'
        print '-Analysis of CRISPR/Cas9 outcomes from POOLED deep sequencing data-'
        print r'''
              )                                            )
             (           _______________________          (
            __)__       | __  __  __     __ __  |        __)__
         C\|     \      ||__)/  \/  \|  |_ |  \ |     C\|     \
           \     /      ||   \__/\__/|__|__|__/ |       \     /
            \___/       |_______________________|        \___/
        '''
    
    
        print'\n[Luca Pinello 2015, send bugs, suggestions or *green coffee* to lucapinello AT gmail DOT com]\n\n',
    
        __version__ = re.search(
            '^__version__\s*=\s*"(.*)"',
            open(os.path.join(_ROOT,'CRISPRessoCORE.py')).read(),
            re.M
            ).group(1)
        print 'Version %s\n' % __version__
    
        parser = argparse.ArgumentParser(description='CRISPRessoPooled Parameters',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('-r1','--fastq_r1', type=str,  help='First fastq file', required=True,default='Fastq filename' )
        parser.add_argument('-r2','--fastq_r2', type=str,  help='Second fastq file for paired end reads',default='')
        parser.add_argument('-f','--amplicons_file', type=str,  help='Amplicons description file. In particular, this file, is a tab delimited text file with up to 5 columns (2 required):\
        \nAMPLICON_NAME:  an identifier for the amplicon (must be unique)\nAMPLICON_SEQUENCE:  amplicon sequence used in the design of the experiment\n\
        \nsgRNA_SEQUENCE (OPTIONAL):  sgRNA sequence used for this amplicon without the PAM sequence. If more than one separate them by commas and not spaces. If not available enter NA.\
        \nEXPECTED_AMPLICON_AFTER_HDR (OPTIONAL): expected amplicon sequence in case of HDR. If not available enter NA.\
        \nCODING_SEQUENCE (OPTIONAL): Subsequence(s) of the amplicon corresponding to coding sequences. If more than one separate them by commas and not spaces. If not available enter NA.', default='')
        parser.add_argument('-x','--bowtie2_index', type=str, help='Basename of Bowtie2 index for the reference genome', default='')
    
        #tool specific optional
        parser.add_argument('--gene_annotations', type=str, help='Gene Annotation Table from UCSC Genome Browser Tables (http://genome.ucsc.edu/cgi-bin/hgTables?command=start), \
        please select as table "knowGene", as output format "all fields from selected table" and as file returned "gzip compressed"', default='')
        parser.add_argument('-p','--n_processes',type=int, help='Specify the number of processes to use for the quantification.\
        Please use with caution since increasing this parameter will increase significantly the memory required to run CRISPResso.',default=1)        
        parser.add_argument('--botwie2_options_string', type=str, help='Override options for the Bowtie2 alignment command',default=' -k 1 --end-to-end -N 0 --np 0 ')
        parser.add_argument('--min_reads_to_use_region',  type=float, help='Minimum number of reads that align to a region to perform the CRISPResso analysis', default=1000)
    
        #general CRISPResso optional
        parser.add_argument('-q','--min_average_read_quality', type=int, help='Minimum average quality score (phred33) to keep a read', default=0)
        parser.add_argument('-s','--min_single_bp_quality', type=int, help='Minimum single bp score (phred33) to keep a read', default=0)
        parser.add_argument('--min_identity_score', type=float, help='Min identity score for the alignment', default=60.0)
        parser.add_argument('-n','--name',  help='Output name', default='')
        parser.add_argument('-o','--output_folder',  help='', default='')
        parser.add_argument('--trim_sequences',help='Enable the trimming of Illumina adapters with Trimmomatic',action='store_true')
        parser.add_argument('--trimmomatic_options_string', type=str, help='Override options for Trimmomatic',default=' ILLUMINACLIP:%s:0:90:10:0:true MINLEN:40' % get_data('NexteraPE-PE.fa'))
        parser.add_argument('--min_paired_end_reads_overlap',  type=int, help='Minimum required overlap length between two reads to provide a confident overlap. ', default=4)
        parser.add_argument('--max_paired_end_reads_overlap',  type=int, help='parameter for the flash merging step, this parameter  is the maximum overlap length expected in approximately 90%% of read pairs. Please see the flash manual for more information.', default=100)    
        parser.add_argument('--hide_mutations_outside_window_NHEJ',help='This parameter allows to visualize only the mutations overlapping the cleavage site and used to classify a read as NHEJ. This parameter has no effect on the quanitification of the NHEJ. It  may be helpful to mask a pre-existing and known mutations or sequencing errors outside the window used for quantification of NHEJ events.',action='store_true')
        parser.add_argument('-w','--window_around_sgrna', type=int, help='Window(s) in bp around the cleavage position (half on on each side) as determined by the provide guide RNA sequence to quantify the indels. Any indels outside this window are excluded. A value of 0 disables this filter.', default=1)
        parser.add_argument('--cleavage_offset', type=int, help="Cleavage offset to use within respect to the 3' end of the provided sgRNA sequence. Remember that the sgRNA sequence must be entered without the PAM. The default is -3 and is suitable for the SpCas9 system. For alternate nucleases, other cleavage offsets may be appropriate, for example, if using Cpf1 this parameter would be set to 1.", default=-3)    
        parser.add_argument('--exclude_bp_from_left', type=int, help='Exclude bp from the left side of the amplicon sequence for the quantification of the indels', default=15)
        parser.add_argument('--exclude_bp_from_right', type=int, help='Exclude bp from the right side of the amplicon sequence for the quantification of the indels', default=15)
        parser.add_argument('--hdr_perfect_alignment_threshold',  type=float, help='Sequence homology %% for an HDR occurrence', default=98.0)
        parser.add_argument('--ignore_substitutions',help='Ignore substitutions events for the quantification and visualization',action='store_true')    
        parser.add_argument('--ignore_insertions',help='Ignore insertions events for the quantification and visualization',action='store_true')  
        parser.add_argument('--ignore_deletions',help='Ignore deletions events for the quantification and visualization',action='store_true')  
        parser.add_argument('--needle_options_string',type=str,help='Override options for the Needle aligner',default=' -gapopen=10 -gapextend=0.5  -awidth3=5000')
        parser.add_argument('--keep_intermediate',help='Keep all the  intermediate files',action='store_true')
        parser.add_argument('--dump',help='Dump numpy arrays and pandas dataframes to file for debugging purposes',action='store_true')
        parser.add_argument('--save_also_png',help='Save also .png images additionally to .pdf files',action='store_true')
        
         
    
        args = parser.parse_args()
        
     
    
        crispresso_options=['window_around_sgrna','cleavage_offset','min_average_read_quality','min_single_bp_quality','min_identity_score',
                                   'min_single_bp_quality','exclude_bp_from_left',
                                   'exclude_bp_from_right',
                                   'hdr_perfect_alignment_threshold','ignore_substitutions','ignore_insertions','ignore_deletions',
                                  'needle_options_string',
                                  'keep_intermediate',
                                  'dump',
                                  'save_also_png','hide_mutations_outside_window_NHEJ','n_processes',]
    
        
        def propagate_options(cmd,options,args):
        
            for option in options :
                if option:
                    val=eval('args.%s' % option )
      
                    if type(val)==str:
                        cmd+=' --%s "%s"' % (option,str(val)) # this is for options with space like needle...
                    elif type(val)==bool:
                        if val:
                            cmd+=' --%s' % option
                    else:
                        cmd+=' --%s %s' % (option,str(val))
                
            return cmd
        
        info('Checking dependencies...')
    
        if check_samtools() and check_bowtie2():
            info('\n All the required dependencies are present!')
        else:
            sys.exit(1)
    
        #check files
        check_file(args.fastq_r1)
        if args.fastq_r2:
            check_file(args.fastq_r2)
    
        if args.bowtie2_index:
            check_file(args.bowtie2_index+'.1.bt2')
    
        if args.amplicons_file:
            check_file(args.amplicons_file)
    
        if args.gene_annotations:
            check_file(args.gene_annotations)
    
        if args.amplicons_file and not args.bowtie2_index:
            RUNNING_MODE='ONLY_AMPLICONS'
            info('Only the Amplicon description file was provided. The analysis will be perfomed using only the provided amplicons sequences.')
    
        elif args.bowtie2_index and not args.amplicons_file:
            RUNNING_MODE='ONLY_GENOME'
            info('Only the bowtie2 reference genome index file was provided. The analysis will be perfomed using only genomic regions where enough reads align.')
        elif args.bowtie2_index and args.amplicons_file:
            RUNNING_MODE='AMPLICONS_AND_GENOME'
            info('Amplicon description file and bowtie2 reference genome index files provided. The analysis will be perfomed using the reads that are aligned ony to the amplicons provided and not to other genomic regions.')
        else:
            error('Please provide the amplicons description file (-f or --amplicons_file option) or the bowtie2 reference genome index file (-x or --bowtie2_index option) or both.')
            sys.exit(1)
    
    
    
        ####TRIMMING AND MERGING
        get_name_from_fasta=lambda  x: os.path.basename(x).replace('.fastq','').replace('.gz','')
    
        if not args.name:
                 if args.fastq_r2!='':
                         database_id='%s_%s' % (get_name_from_fasta(args.fastq_r1),get_name_from_fasta(args.fastq_r2))
                 else:
                         database_id='%s' % get_name_from_fasta(args.fastq_r1)
    
        else:
                 database_id=args.name
                
    
    
        OUTPUT_DIRECTORY='CRISPRessoPooled_on_%s' % database_id
    
        if args.output_folder:
                 OUTPUT_DIRECTORY=os.path.join(os.path.abspath(args.output_folder),OUTPUT_DIRECTORY)
    
        _jp=lambda filename: os.path.join(OUTPUT_DIRECTORY,filename) #handy function to put a file in the output directory
    
        try:
                 info('Creating Folder %s' % OUTPUT_DIRECTORY)
                 os.makedirs(OUTPUT_DIRECTORY)
                 info('Done!')
        except:
                 warn('Folder %s already exists.' % OUTPUT_DIRECTORY)
    
        log_filename=_jp('CRISPRessoPooled_RUNNING_LOG.txt')
        logging.getLogger().addHandler(logging.FileHandler(log_filename))
    
        with open(log_filename,'w+') as outfile:
                  outfile.write('[Command used]:\nCRISPRessoPooled %s\n\n[Execution log]:\n' % ' '.join(sys.argv))
    
        if args.fastq_r2=='': #single end reads
    
             #check if we need to trim
             if not args.trim_sequences:
                 #create a symbolic link
                 symlink_filename=_jp(os.path.basename(args.fastq_r1))
                 force_symlink(os.path.abspath(args.fastq_r1),symlink_filename)
                 output_forward_filename=symlink_filename
             else:
                 output_forward_filename=_jp('reads.trimmed.fq.gz')
                 #Trimming with trimmomatic
                 cmd='java -jar %s SE -phred33 %s  %s %s >>%s 2>&1'\
                 % (get_data('trimmomatic-0.33.jar'),args.fastq_r1,
                    output_forward_filename,
                    args.trimmomatic_options_string.replace('NexteraPE-PE.fa','TruSeq3-SE.fa'),
                    log_filename)
                 #print cmd
                 TRIMMOMATIC_STATUS=sb.call(cmd,shell=True)
    
                 if TRIMMOMATIC_STATUS:
                         raise TrimmomaticException('TRIMMOMATIC failed to run, please check the log file.')
    
    
             processed_output_filename=output_forward_filename
    
        else:#paired end reads case
    
             if not args.trim_sequences:
                 output_forward_paired_filename=args.fastq_r1
                 output_reverse_paired_filename=args.fastq_r2
             else:
                 info('Trimming sequences with Trimmomatic...')
                 output_forward_paired_filename=_jp('output_forward_paired.fq.gz')
                 output_forward_unpaired_filename=_jp('output_forward_unpaired.fq.gz')
                 output_reverse_paired_filename=_jp('output_reverse_paired.fq.gz')
                 output_reverse_unpaired_filename=_jp('output_reverse_unpaired.fq.gz')
    
                 #Trimming with trimmomatic
                 cmd='java -jar %s PE -phred33 %s  %s %s  %s  %s  %s %s >>%s 2>&1'\
                 % (get_data('trimmomatic-0.33.jar'),
                         args.fastq_r1,args.fastq_r2,output_forward_paired_filename,
                         output_forward_unpaired_filename,output_reverse_paired_filename,
                         output_reverse_unpaired_filename,args.trimmomatic_options_string,log_filename)
                 #print cmd
                 TRIMMOMATIC_STATUS=sb.call(cmd,shell=True)
                 if TRIMMOMATIC_STATUS:
                         raise TrimmomaticException('TRIMMOMATIC failed to run, please check the log file.')
    
                 info('Done!')
    
    
             #Merging with Flash
             info('Merging paired sequences with Flash...')
             cmd='flash %s %s --min-overlap %d --max-overlap %d -z -d %s >>%s 2>&1' %\
             (output_forward_paired_filename,
              output_reverse_paired_filename,
              args.min_paired_end_reads_overlap,
              args.max_paired_end_reads_overlap,
              OUTPUT_DIRECTORY,log_filename)
    
             FLASH_STATUS=sb.call(cmd,shell=True)
             if FLASH_STATUS:
                 raise FlashException('Flash failed to run, please check the log file.')
    
             info('Done!')
    
             flash_hist_filename=_jp('out.hist')
             flash_histogram_filename=_jp('out.histogram')
             flash_not_combined_1_filename=_jp('out.notCombined_1.fastq.gz')
             flash_not_combined_2_filename=_jp('out.notCombined_2.fastq.gz')
    
             processed_output_filename=_jp('out.extendedFrags.fastq.gz')
    
    
        #count reads 
        N_READS_INPUT=get_n_reads_fastq(args.fastq_r1)
        N_READS_AFTER_PREPROCESSING=get_n_reads_fastq(processed_output_filename)
    
            
        #load gene annotation
        if args.gene_annotations:
            info('Loading gene coordinates from annotation file: %s...' % args.gene_annotations)
            try:
                df_genes=pd.read_table(args.gene_annotations,compression='gzip')
                df_genes.txEnd=df_genes.txEnd.astype(int)
                df_genes.txStart=df_genes.txStart.astype(int)
                df_genes.head()
            except:
               info('Failed to load the gene annotations file.')
        
    
        if RUNNING_MODE=='ONLY_AMPLICONS' or  RUNNING_MODE=='AMPLICONS_AND_GENOME':
    
            #load and validate template file
            df_template=pd.read_csv(args.amplicons_file,names=[
                    'Name','Amplicon_Sequence','sgRNA',
                    'Expected_HDR','Coding_sequence'],comment='#',sep='\t',dtype={'Name':str})
    
    
            #remove empty amplicons/lines
            df_template.dropna(subset=['Amplicon_Sequence'],inplace=True)
            df_template.dropna(subset=['Name'],inplace=True)
    
            df_template.Amplicon_Sequence=df_template.Amplicon_Sequence.apply(capitalize_sequence)
            df_template.Expected_HDR=df_template.Expected_HDR.apply(capitalize_sequence)
            df_template.sgRNA=df_template.sgRNA.apply(capitalize_sequence)
            df_template.Coding_sequence=df_template.Coding_sequence.apply(capitalize_sequence)
    
            if not len(df_template.Amplicon_Sequence.unique())==df_template.shape[0]:
                raise Exception('The amplicons should be all distinct!')
    
            if not len(df_template.Name.unique())==df_template.shape[0]:
                raise Exception('The amplicon names should be all distinct!')
    
            df_template=df_template.set_index('Name')
            df_template.index=df_template.index.to_series().str.replace(' ','_')
    
            for idx,row in df_template.iterrows():
    
                wrong_nt=find_wrong_nt(row.Amplicon_Sequence)
                if wrong_nt:
                     raise NTException('The amplicon sequence %s contains wrong characters:%s' % (idx,' '.join(wrong_nt)))
    
                if not pd.isnull(row.sgRNA):
                    
                    cut_points=[]
    
                    for current_guide_seq in row.sgRNA.strip().upper().split(','):
                    
                        wrong_nt=find_wrong_nt(current_guide_seq)
                        if wrong_nt:
                            raise NTException('The sgRNA sequence %s contains wrong characters:%s'  % (current_guide_seq, ' '.join(wrong_nt)))
                    
                        offset_fw=args.cleavage_offset+len(current_guide_seq)-1
                        offset_rc=(-args.cleavage_offset)-1
                        cut_points+=[m.start() + offset_fw for \
                                    m in re.finditer(current_guide_seq,  row.Amplicon_Sequence)]+[m.start() + offset_rc for m in re.finditer(reverse_complement(current_guide_seq),  row.Amplicon_Sequence)]
                    
                    if not cut_points:
                        warn('\nThe guide sequence/s provided: %s is(are) not present in the amplicon sequence:%s! \nNOTE: The guide will be ignored for the analysis. Please check your input!' % (row.sgRNA,row.Amplicon_Sequence))
                        df_template.ix[idx,'sgRNA']=''
                        
                        
    
        if RUNNING_MODE=='ONLY_AMPLICONS':
            #create a fasta file with all the amplicons
            amplicon_fa_filename=_jp('AMPLICONS.fa')
            fastq_gz_amplicon_filenames=[]
            with open(amplicon_fa_filename,'w+') as outfile:
                for idx,row in df_template.iterrows():
                    if row['Amplicon_Sequence']:
                        outfile.write('>%s\n%s\n' %(clean_filename('AMPL_'+idx),row['Amplicon_Sequence']))
    
                        #create place-holder fastq files
                        fastq_gz_amplicon_filenames.append(_jp('%s.fastq.gz' % clean_filename('AMPL_'+idx)))
                        open(fastq_gz_amplicon_filenames[-1], 'w+').close()
    
            df_template['Demultiplexed_fastq.gz_filename']=fastq_gz_amplicon_filenames
            info('Creating a custom index file with all the amplicons...')
            custom_index_filename=_jp('CUSTOM_BOWTIE2_INDEX')
            sb.call('bowtie2-build %s %s >>%s 2>&1' %(amplicon_fa_filename,custom_index_filename,log_filename), shell=True)
    
    
            #align the file to the amplicons (MODE 1)
            info('Align reads to the amplicons...')
            bam_filename_amplicons= _jp('CRISPResso_AMPLICONS_ALIGNED.bam')
            aligner_command= 'bowtie2 -x %s -p %s -k 1 --end-to-end -N 0 --np 0 -U %s 2>>%s | samtools view -bS - > %s' %(custom_index_filename,args.n_processes,processed_output_filename,log_filename,bam_filename_amplicons)
    
            sb.call(aligner_command,shell=True)
    
            N_READS_ALIGNED=get_n_aligned_bam(bam_filename_amplicons)
            
            s1=r"samtools view -F 4 %s 2>>%s | grep -v ^'@'" % (bam_filename_amplicons,log_filename)
            s2=r'''|awk '{ gzip_filename=sprintf("gzip >> OUTPUTPATH%s.fastq.gz",$3);\
            print "@"$1"\n"$10"\n+\n"$11  | gzip_filename;}' '''
    
            cmd=s1+s2.replace('OUTPUTPATH',_jp(''))
            sb.call(cmd,shell=True)
            
            info('Demultiplex reads and run CRISPResso on each amplicon...')
            n_reads_aligned_amplicons=[]
            for idx,row in df_template.iterrows():
                info('\n Processing:%s' %idx)
                n_reads_aligned_amplicons.append(get_n_reads_fastq(row['Demultiplexed_fastq.gz_filename']))
                crispresso_cmd='CRISPResso -r1 %s -a %s -o %s --name %s' % (row['Demultiplexed_fastq.gz_filename'],row['Amplicon_Sequence'],OUTPUT_DIRECTORY,idx)
    
                if n_reads_aligned_amplicons[-1]>args.min_reads_to_use_region:
                    if row['sgRNA'] and not pd.isnull(row['sgRNA']):
                        crispresso_cmd+=' -g %s' % row['sgRNA']
    
                    if row['Expected_HDR'] and not pd.isnull(row['Expected_HDR']):
                        crispresso_cmd+=' -e %s' % row['Expected_HDR']
    
                    if row['Coding_sequence'] and not pd.isnull(row['Coding_sequence']):
                        crispresso_cmd+=' -c %s' % row['Coding_sequence']
                    
                    crispresso_cmd=propagate_options(crispresso_cmd,crispresso_options,args)
                    info('Running CRISPResso:%s' % crispresso_cmd)
                    sb.call(crispresso_cmd,shell=True)
                else:
                    warn('Skipping amplicon [%s] since no reads are aligning to it\n'% idx)
    
            df_template['n_reads']=n_reads_aligned_amplicons
            df_template['n_reads_aligned_%']=df_template['n_reads']/float(N_READS_ALIGNED)*100
            df_template.fillna('NA').to_csv(_jp('REPORT_READS_ALIGNED_TO_AMPLICONS.txt'),sep='\t')
    
    
    
        if RUNNING_MODE=='AMPLICONS_AND_GENOME':
            print 'Mapping amplicons to the reference genome...'
            #find the locations of the amplicons on the genome and their strand and check if there are mutations in the reference genome
            additional_columns=[]
            for idx,row in df_template.iterrows():
                fields_to_append=list(np.take(get_align_sequence(row.Amplicon_Sequence, args.bowtie2_index).split('\t'),[0,1,2,3,5]))
                if fields_to_append[0]=='*':
                    info('The amplicon [%s] is not mappable to the reference genome provided!' % idx )
                    additional_columns.append([idx,'NOT_ALIGNED',0,-1,'+',''])
                else:
                    additional_columns.append([idx]+fields_to_append)
                    info('The amplicon [%s] was mapped to: %s ' % (idx,' '.join(fields_to_append[:3]) ))
        
        
            df_template=df_template.join(pd.DataFrame(additional_columns,columns=['Name','chr_id','bpstart','bpend','strand','Reference_Sequence']).set_index('Name'))
            
            df_template.bpstart=df_template.bpstart.astype(int)
            df_template.bpend=df_template.bpend.astype(int)
            
            #Check reference is the same otherwise throw a warning
            for idx,row in df_template.iterrows():
                if row.Amplicon_Sequence != row.Reference_Sequence and row.Amplicon_Sequence != reverse_complement(row.Reference_Sequence):
                    warn('The amplicon sequence %s provided:\n%s\n\nis different from the reference sequence(both strand):\n\n%s\n\n%s\n' %(row.name,row.Amplicon_Sequence,row.Amplicon_Sequence,reverse_complement(row.Amplicon_Sequence)))
     
    
        if RUNNING_MODE=='ONLY_GENOME' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
    
            ###HERE we recreate the uncompressed genome file if not available###
    
            #check you have all the files for the genome and create a fa idx for samtools
            
            uncompressed_reference=args.bowtie2_index+'.fa'
            
            #if not os.path.exists(GENOME_LOCAL_FOLDER):
            #    os.mkdir(GENOME_LOCAL_FOLDER)
    
            if os.path.exists(uncompressed_reference):
                info('The uncompressed reference fasta file for %s is already present! Skipping generation.' % args.bowtie2_index)
            else:
                #uncompressed_reference=os.path.join(GENOME_LOCAL_FOLDER,'UNCOMPRESSED_REFERENCE_FROM_'+args.bowtie2_index.replace('/','_')+'.fa')
                info('Extracting uncompressed reference from the provided bowtie2 index since it is not available... Please be patient!')
    
                cmd_to_uncompress='bowtie2-inspect %s > %s 2>>%s' % (args.bowtie2_index,uncompressed_reference,log_filename)
                sb.call(cmd_to_uncompress,shell=True)
    
                info('Indexing fasta file with samtools...')
                #!samtools faidx {uncompressed_reference}
                sb.call('samtools faidx %s 2>>%s ' % (uncompressed_reference,log_filename),shell=True)
    
    
        #####CORRECT ONE####
        #align in unbiased way the reads to the genome
        if RUNNING_MODE=='ONLY_GENOME' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
            info('Aligning reads to the provided genome index...')
            bam_filename_genome = _jp('%s_GENOME_ALIGNED.bam' % database_id)
            aligner_command= 'bowtie2 -x %s -p %s -k 1 --end-to-end -N 0 --np 0 -U %s 2>>%s| samtools view -bS - > %s' %(args.bowtie2_index,args.n_processes,processed_output_filename,log_filename,bam_filename_genome)
            sb.call(aligner_command,shell=True)
            
            N_READS_ALIGNED=get_n_aligned_bam(bam_filename_genome)
            
            #REDISCOVER LOCATIONS and DEMULTIPLEX READS
            MAPPED_REGIONS=_jp('MAPPED_REGIONS/')
            if not os.path.exists(MAPPED_REGIONS):
                os.mkdir(MAPPED_REGIONS)
    
            s1=r'''samtools view -F 0x0004 %s 2>>%s |''' % (bam_filename_genome,log_filename)+\
            r'''awk '{OFS="\t"; bpstart=$4;  bpend=bpstart; split ($6,a,"[MIDNSHP]"); n=0;\
            for (i=1; i in a; i++){\
                n+=1+length(a[i]);\
                if (substr($6,n,1)=="S"){\
                    if (bpend==$4)\
                        bpstart-=a[i];\
                    else
                        bpend+=a[i];
                    }\
                else if( (substr($6,n,1)!="I")  && (substr($6,n,1)!="H") )\
                        bpend+=a[i];\
                }\
                if ( ($2 % 32)>=16)\
                    print $3,bpstart,bpend,"-",$1,$10,$11;\
                else\
                    print $3,bpstart,bpend,"+",$1,$10,$11;}' | ''' 
        
            s2=r'''  sort -k1,1 -k2,2n  | awk \
            'BEGIN{chr_id="NA";bpstart=-1;bpend=-1; fastq_filename="NA"}\
            { if ( (chr_id!=$1) || (bpstart!=$2) || (bpend!=$3) )\
                {\
                if (fastq_filename!="NA") {close(fastq_filename); system("gzip "fastq_filename)}\
                chr_id=$1; bpstart=$2; bpend=$3;\
                fastq_filename=sprintf("__OUTPUTPATH__REGION_%s_%s_%s.fastq",$1,$2,$3);\
                }\
            print "@"$5"\n"$6"\n+\n"$7 >> fastq_filename;\
            }' '''
            cmd=s1+s2.replace('__OUTPUTPATH__',MAPPED_REGIONS)
            
            info('Demultiplexing reads by location...')
            sb.call(cmd,shell=True)
            
            #gzip the missing ones 
            sb.call('gzip %s/*.fastq' % MAPPED_REGIONS,shell=True)
    
        '''
        The most common use case, where many different target sites are pooled into a single 
        high-throughput sequencing library for quantification, is not directly addressed by this implementation. 
        Potential users of CRISPResso would need to write their own code to generate separate input files for processing. 
        Importantly, this preprocessing code would need to remove any PCR amplification artifacts 
        (such as amplification of sequences from a gene and a highly similar pseudogene ) 
        which may confound the interpretation of results. 
        This can be done by mapping of input sequences to a reference genome and removing 
        those that do not map to the expected genomic location, but is non-trivial for an end-user to implement.
        '''
        
    
        
        if RUNNING_MODE=='AMPLICONS_AND_GENOME':
            files_to_match=glob.glob(os.path.join(MAPPED_REGIONS,'REGION*'))
            n_reads_aligned_genome=[]
            fastq_region_filenames=[]
        
            for idx,row in df_template.iterrows():
                
                info('Processing amplicon:%s' % idx )
        
                #check if we have reads
                fastq_filename_region=os.path.join(MAPPED_REGIONS,'REGION_%s_%s_%s.fastq.gz' % (row['chr_id'],row['bpstart'],row['bpend']))
        
                if os.path.exists(fastq_filename_region):
                    
                    N_READS=get_n_reads_fastq(fastq_filename_region)
                    n_reads_aligned_genome.append(N_READS)
                    fastq_region_filenames.append(fastq_filename_region)
                    files_to_match.remove(fastq_filename_region)
                    if N_READS>=args.min_reads_to_use_region:
                        info('\nThe amplicon [%s] has enough reads (%d) mapped to it! Running CRISPResso!\n' % (idx,N_READS))
        
                        crispresso_cmd='CRISPResso -r1 %s -a %s -o %s --name %s' % (fastq_filename_region,row['Amplicon_Sequence'],OUTPUT_DIRECTORY,idx)
        
                        if row['sgRNA'] and not pd.isnull(row['sgRNA']):
                            crispresso_cmd+=' -g %s' % row['sgRNA']
        
                        if row['Expected_HDR'] and not pd.isnull(row['Expected_HDR']):
                            crispresso_cmd+=' -e %s' % row['Expected_HDR']
        
                        if row['Coding_sequence'] and not pd.isnull(row['Coding_sequence']):
                            crispresso_cmd+=' -c %s' % row['Coding_sequence']
                        
                        crispresso_cmd=propagate_options(crispresso_cmd,crispresso_options,args)
                        info('Running CRISPResso:%s' % crispresso_cmd)
                        sb.call(crispresso_cmd,shell=True)
         
                    else:
                         warn('The amplicon [%s] has not enough reads (%d) mapped to it! Skipping the execution of CRISPResso!' % (idx,N_READS))
                else:
                    fastq_region_filenames.append('')
                    n_reads_aligned_genome.append(0)
                    warn("The amplicon %s doesn't have any read mapped to it!\n Please check your amplicon sequence." %  idx)
        
            df_template['Amplicon_Specific_fastq.gz_filename']=fastq_region_filenames
            df_template['n_reads']=n_reads_aligned_genome
            df_template['n_reads_aligned_%']=df_template['n_reads']/float(N_READS_ALIGNED)*100
            
            if args.gene_annotations:
                df_template=df_template.apply(lambda row: find_overlapping_genes(row, df_genes),axis=1)
            
            df_template.fillna('NA').to_csv(_jp('REPORT_READS_ALIGNED_TO_GENOME_AND_AMPLICONS.txt'),sep='\t')
            
            #write another file with the not amplicon regions
            
            info('Reporting problematic regions...')  
            coordinates=[]
            for region in files_to_match:
                coordinates.append(os.path.basename(region).replace('.fastq.gz','').replace('.fastq','').split('_')[1:4]+[region,get_n_reads_fastq(region)])
        
            df_regions=pd.DataFrame(coordinates,columns=['chr_id','bpstart','bpend','fastq_file','n_reads'])
    
            df_regions=df_regions.convert_objects(convert_numeric=True)
            df_regions.dropna(inplace=True) #remove regions in chrUn
            df_regions.bpstart=df_regions.bpstart.astype(int)
            df_regions.bpend=df_regions.bpend.astype(int)
    
            df_regions['n_reads_aligned_%']=df_regions['n_reads']/float(N_READS_ALIGNED)*100
    
            df_regions['Reference_sequence']=df_regions.apply(lambda row: get_region_from_fa(row.chr_id,row.bpstart,row.bpend,uncompressed_reference),axis=1)
    
            
            if args.gene_annotations:
                info('Checking overlapping genes...')   
                df_regions=df_regions.apply(lambda row: find_overlapping_genes(row, df_genes),axis=1)
            
            if np.sum(np.array(map(int,pd.__version__.split('.')))*(100,10,1))< 170:
                df_regions.sort('n_reads',ascending=False,inplace=True)
            else:
                df_regions.sort_values(by='n_reads',ascending=False,inplace=True)


            df_regions.fillna('NA').to_csv(_jp('REPORTS_READS_ALIGNED_TO_GENOME_NOT_MATCHING_AMPLICONS.txt'),sep='\t',index=None)
    
    
        if RUNNING_MODE=='ONLY_GENOME' :
            #Load regions and build REFERENCE TABLES 
            info('Parsing the demultiplexed files and extracting locations and reference sequences...')
            coordinates=[]
            for region in glob.glob(os.path.join(MAPPED_REGIONS,'REGION*.fastq.gz')):
                coordinates.append(os.path.basename(region).replace('.fastq.gz','').split('_')[1:4]+[region,get_n_reads_fastq(region)])
            
            print 'C:',coordinates
            df_regions=pd.DataFrame(coordinates,columns=['chr_id','bpstart','bpend','fastq_file','n_reads'])
            
            print 'D:', df_regions
            df_regions=df_regions.convert_objects(convert_numeric=True)
            df_regions.dropna(inplace=True) #remove regions in chrUn
            df_regions.bpstart=df_regions.bpstart.astype(int)
            df_regions.bpend=df_regions.bpend.astype(int)
            print df_regions
            df_regions['sequence']=df_regions.apply(lambda row: get_region_from_fa(row.chr_id,row.bpstart,row.bpend,uncompressed_reference),axis=1)
    
            df_regions['n_reads_aligned_%']=df_regions['n_reads']/float(N_READS_ALIGNED)*100
                 
            if args.gene_annotations:
                info('Checking overlapping genes...')   
                df_regions=df_regions.apply(lambda row: find_overlapping_genes(row, df_genes),axis=1)
            
            if np.sum(np.array(map(int,pd.__version__.split('.')))*(100,10,1))< 170:
                df_regions.sort('n_reads',ascending=False,inplace=True)
            else:
                df_regions.sort_values(by='n_reads',ascending=False,inplace=True)


            df_regions.fillna('NA').to_csv(_jp('REPORT_READS_ALIGNED_TO_GENOME_ONLY.txt'),sep='\t',index=None)
            
            
            #run CRISPResso
            #demultiplex reads in the amplicons and call crispresso!
            info('Running CRISPResso on the regions discovered...')
            for idx,row in df_regions.iterrows():
        
                if row.n_reads > args.min_reads_to_use_region:
                    info('\nRunning CRISPResso on: %s-%d-%d...'%(row.chr_id,row.bpstart,row.bpend ))
                    crispresso_cmd='CRISPResso -r1 %s -a %s -o %s' %(row.fastq_file,row.sequence,OUTPUT_DIRECTORY)  
                    crispresso_cmd=propagate_options(crispresso_cmd,crispresso_options,args)
                    info('Running CRISPResso:%s' % crispresso_cmd)
                    sb.call(crispresso_cmd,shell=True)
                else:
                    info('Skipping region: %s-%d-%d , not enough reads (%d)' %(row.chr_id,row.bpstart,row.bpend, row.n_reads))
    
    
        #write alignment statistics
        with open(_jp('MAPPING_STATISTICS.txt'),'w+') as outfile:
            outfile.write('READS IN INPUTS:%d\nREADS AFTER PREPROCESSING:%d\nREADS ALIGNED:%d' % (N_READS_INPUT,N_READS_AFTER_PREPROCESSING,N_READS_ALIGNED))
    
        #write a file with basic quantification info for each sample
        def check_output_folder(output_folder):
            quantification_file=os.path.join(output_folder,'Quantification_of_editing_frequency.txt')  
    
            if os.path.exists(quantification_file):
                return quantification_file
            else:
                return None
    
        def parse_quantification(quantification_file):
            with open(quantification_file) as infile:
                infile.readline()
                N_UNMODIFIED=float(re.findall("Unmodified:(\d+)",infile.readline())[0])
                N_MODIFIED=float(re.findall("NHEJ:(\d+)",infile.readline())[0])
                N_REPAIRED=float(re.findall("HDR:(\d+)", infile.readline())[0])
                N_MIXED_HDR_NHEJ=float(re.findall("Mixed HDR-NHEJ:(\d+)", infile.readline())[0])
                infile.readline()
                N_TOTAL=float(re.findall("Total Aligned:(\d+) reads",infile.readline())[0])
                return N_UNMODIFIED,N_MODIFIED,N_REPAIRED,N_MIXED_HDR_NHEJ,N_TOTAL
    
        quantification_summary=[]
    
        if RUNNING_MODE=='ONLY_AMPLICONS' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
            df_final_data=df_template
        else:
            df_final_data=df_regions
    
        for idx,row in df_final_data.iterrows():
    
                if RUNNING_MODE=='ONLY_AMPLICONS' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
                    folder_name='CRISPResso_on_%s' % idx
                else:
                    folder_name='CRISPResso_on_REGION_%s_%d_%d' %(row.chr_id,row.bpstart,row.bpend )
    
                quantification_file=check_output_folder(_jp(folder_name))
    
                if quantification_file:
                    N_UNMODIFIED,N_MODIFIED,N_REPAIRED,N_MIXED_HDR_NHEJ,N_TOTAL=parse_quantification(quantification_file)
                    quantification_summary.append([idx,N_UNMODIFIED/N_TOTAL*100,N_MODIFIED/N_TOTAL*100,N_REPAIRED/N_TOTAL*100,N_MIXED_HDR_NHEJ/N_TOTAL*100,N_TOTAL,row.n_reads])
                else:
                    quantification_summary.append([idx,np.nan,np.nan,np.nan,np.nan,np.nan,row.n_reads])
                    warn('Skipping the folder %s, not enough reads or empty folder.'% folder_name)
    
    
        df_summary_quantification=pd.DataFrame(quantification_summary,columns=['Name','Unmodified%','NHEJ%','HDR%', 'Mixed_HDR-NHEJ%','Reads_aligned','Reads_total'])        
        df_summary_quantification.fillna('NA').to_csv(_jp('SAMPLES_QUANTIFICATION_SUMMARY.txt'),sep='\t',index=None)        
                    
        #cleaning up
        if not args.keep_intermediate:
             info('Removing Intermediate files...')
        
             if args.fastq_r2!='':
                 files_to_remove=[processed_output_filename,flash_hist_filename,flash_histogram_filename,\
                              flash_not_combined_1_filename,flash_not_combined_2_filename] 
             else:
                 files_to_remove=[processed_output_filename] 
        
             if args.trim_sequences and args.fastq_r2!='':
                 files_to_remove+=[output_forward_paired_filename,output_reverse_paired_filename,\
                                                   output_forward_unpaired_filename,output_reverse_unpaired_filename]
        
             if RUNNING_MODE=='ONLY_GENOME' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
                     files_to_remove+=[bam_filename_genome]
                 
             if RUNNING_MODE=='ONLY_AMPLICONS':  
                files_to_remove+=[bam_filename_amplicons,amplicon_fa_filename]
                for bowtie2_file in glob.glob(_jp('CUSTOM_BOWTIE2_INDEX.*')):
                    files_to_remove.append(bowtie2_file)
        
             for file_to_remove in files_to_remove:
                 try:
                         if os.path.islink(file_to_remove):
                             #print 'LINK',file_to_remove
                             os.unlink(file_to_remove)
                         else:                             
                             os.remove(file_to_remove)
                 except:
                         warn('Skipping:%s' %file_to_remove)
    
    
           
        info('All Done!')
        print r'''
              )                                            )
             (           _______________________          (
            __)__       | __  __  __     __ __  |        __)__
         C\|     \      ||__)/  \/  \|  |_ |  \ |     C\|     \
           \     /      ||   \__/\__/|__|__|__/ |       \     /
            \___/       |_______________________|        \___/
        '''
        sys.exit(0)
    
    except Exception as e:
        error('\n\nERROR: %s' % e)
        sys.exit(-1)

Example 24

Project: CRISPResso
Source File: CRISPRessoPooledCORE.py
View license
def main():
    try:
        print '  \n~~~CRISPRessoPooled~~~'
        print '-Analysis of CRISPR/Cas9 outcomes from POOLED deep sequencing data-'
        print r'''
              )                                            )
             (           _______________________          (
            __)__       | __  __  __     __ __  |        __)__
         C\|     \      ||__)/  \/  \|  |_ |  \ |     C\|     \
           \     /      ||   \__/\__/|__|__|__/ |       \     /
            \___/       |_______________________|        \___/
        '''
    
    
        print'\n[Luca Pinello 2015, send bugs, suggestions or *green coffee* to lucapinello AT gmail DOT com]\n\n',
    
        __version__ = re.search(
            '^__version__\s*=\s*"(.*)"',
            open(os.path.join(_ROOT,'CRISPRessoCORE.py')).read(),
            re.M
            ).group(1)
        print 'Version %s\n' % __version__
    
        parser = argparse.ArgumentParser(description='CRISPRessoPooled Parameters',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('-r1','--fastq_r1', type=str,  help='First fastq file', required=True,default='Fastq filename' )
        parser.add_argument('-r2','--fastq_r2', type=str,  help='Second fastq file for paired end reads',default='')
        parser.add_argument('-f','--amplicons_file', type=str,  help='Amplicons description file. In particular, this file, is a tab delimited text file with up to 5 columns (2 required):\
        \nAMPLICON_NAME:  an identifier for the amplicon (must be unique)\nAMPLICON_SEQUENCE:  amplicon sequence used in the design of the experiment\n\
        \nsgRNA_SEQUENCE (OPTIONAL):  sgRNA sequence used for this amplicon without the PAM sequence. If more than one separate them by commas and not spaces. If not available enter NA.\
        \nEXPECTED_AMPLICON_AFTER_HDR (OPTIONAL): expected amplicon sequence in case of HDR. If not available enter NA.\
        \nCODING_SEQUENCE (OPTIONAL): Subsequence(s) of the amplicon corresponding to coding sequences. If more than one separate them by commas and not spaces. If not available enter NA.', default='')
        parser.add_argument('-x','--bowtie2_index', type=str, help='Basename of Bowtie2 index for the reference genome', default='')
    
        #tool specific optional
        parser.add_argument('--gene_annotations', type=str, help='Gene Annotation Table from UCSC Genome Browser Tables (http://genome.ucsc.edu/cgi-bin/hgTables?command=start), \
        please select as table "knowGene", as output format "all fields from selected table" and as file returned "gzip compressed"', default='')
        parser.add_argument('-p','--n_processes',type=int, help='Specify the number of processes to use for the quantification.\
        Please use with caution since increasing this parameter will increase significantly the memory required to run CRISPResso.',default=1)        
        parser.add_argument('--botwie2_options_string', type=str, help='Override options for the Bowtie2 alignment command',default=' -k 1 --end-to-end -N 0 --np 0 ')
        parser.add_argument('--min_reads_to_use_region',  type=float, help='Minimum number of reads that align to a region to perform the CRISPResso analysis', default=1000)
    
        #general CRISPResso optional
        parser.add_argument('-q','--min_average_read_quality', type=int, help='Minimum average quality score (phred33) to keep a read', default=0)
        parser.add_argument('-s','--min_single_bp_quality', type=int, help='Minimum single bp score (phred33) to keep a read', default=0)
        parser.add_argument('--min_identity_score', type=float, help='Min identity score for the alignment', default=60.0)
        parser.add_argument('-n','--name',  help='Output name', default='')
        parser.add_argument('-o','--output_folder',  help='', default='')
        parser.add_argument('--trim_sequences',help='Enable the trimming of Illumina adapters with Trimmomatic',action='store_true')
        parser.add_argument('--trimmomatic_options_string', type=str, help='Override options for Trimmomatic',default=' ILLUMINACLIP:%s:0:90:10:0:true MINLEN:40' % get_data('NexteraPE-PE.fa'))
        parser.add_argument('--min_paired_end_reads_overlap',  type=int, help='Minimum required overlap length between two reads to provide a confident overlap. ', default=4)
        parser.add_argument('--max_paired_end_reads_overlap',  type=int, help='parameter for the flash merging step, this parameter  is the maximum overlap length expected in approximately 90%% of read pairs. Please see the flash manual for more information.', default=100)    
        parser.add_argument('--hide_mutations_outside_window_NHEJ',help='This parameter allows to visualize only the mutations overlapping the cleavage site and used to classify a read as NHEJ. This parameter has no effect on the quanitification of the NHEJ. It  may be helpful to mask a pre-existing and known mutations or sequencing errors outside the window used for quantification of NHEJ events.',action='store_true')
        parser.add_argument('-w','--window_around_sgrna', type=int, help='Window(s) in bp around the cleavage position (half on on each side) as determined by the provide guide RNA sequence to quantify the indels. Any indels outside this window are excluded. A value of 0 disables this filter.', default=1)
        parser.add_argument('--cleavage_offset', type=int, help="Cleavage offset to use within respect to the 3' end of the provided sgRNA sequence. Remember that the sgRNA sequence must be entered without the PAM. The default is -3 and is suitable for the SpCas9 system. For alternate nucleases, other cleavage offsets may be appropriate, for example, if using Cpf1 this parameter would be set to 1.", default=-3)    
        parser.add_argument('--exclude_bp_from_left', type=int, help='Exclude bp from the left side of the amplicon sequence for the quantification of the indels', default=15)
        parser.add_argument('--exclude_bp_from_right', type=int, help='Exclude bp from the right side of the amplicon sequence for the quantification of the indels', default=15)
        parser.add_argument('--hdr_perfect_alignment_threshold',  type=float, help='Sequence homology %% for an HDR occurrence', default=98.0)
        parser.add_argument('--ignore_substitutions',help='Ignore substitutions events for the quantification and visualization',action='store_true')    
        parser.add_argument('--ignore_insertions',help='Ignore insertions events for the quantification and visualization',action='store_true')  
        parser.add_argument('--ignore_deletions',help='Ignore deletions events for the quantification and visualization',action='store_true')  
        parser.add_argument('--needle_options_string',type=str,help='Override options for the Needle aligner',default=' -gapopen=10 -gapextend=0.5  -awidth3=5000')
        parser.add_argument('--keep_intermediate',help='Keep all the  intermediate files',action='store_true')
        parser.add_argument('--dump',help='Dump numpy arrays and pandas dataframes to file for debugging purposes',action='store_true')
        parser.add_argument('--save_also_png',help='Save also .png images additionally to .pdf files',action='store_true')
        
         
    
        args = parser.parse_args()
        
     
    
        crispresso_options=['window_around_sgrna','cleavage_offset','min_average_read_quality','min_single_bp_quality','min_identity_score',
                                   'min_single_bp_quality','exclude_bp_from_left',
                                   'exclude_bp_from_right',
                                   'hdr_perfect_alignment_threshold','ignore_substitutions','ignore_insertions','ignore_deletions',
                                  'needle_options_string',
                                  'keep_intermediate',
                                  'dump',
                                  'save_also_png','hide_mutations_outside_window_NHEJ','n_processes',]
    
        
        def propagate_options(cmd,options,args):
        
            for option in options :
                if option:
                    val=eval('args.%s' % option )
      
                    if type(val)==str:
                        cmd+=' --%s "%s"' % (option,str(val)) # this is for options with space like needle...
                    elif type(val)==bool:
                        if val:
                            cmd+=' --%s' % option
                    else:
                        cmd+=' --%s %s' % (option,str(val))
                
            return cmd
        
        info('Checking dependencies...')
    
        if check_samtools() and check_bowtie2():
            info('\n All the required dependencies are present!')
        else:
            sys.exit(1)
    
        #check files
        check_file(args.fastq_r1)
        if args.fastq_r2:
            check_file(args.fastq_r2)
    
        if args.bowtie2_index:
            check_file(args.bowtie2_index+'.1.bt2')
    
        if args.amplicons_file:
            check_file(args.amplicons_file)
    
        if args.gene_annotations:
            check_file(args.gene_annotations)
    
        if args.amplicons_file and not args.bowtie2_index:
            RUNNING_MODE='ONLY_AMPLICONS'
            info('Only the Amplicon description file was provided. The analysis will be perfomed using only the provided amplicons sequences.')
    
        elif args.bowtie2_index and not args.amplicons_file:
            RUNNING_MODE='ONLY_GENOME'
            info('Only the bowtie2 reference genome index file was provided. The analysis will be perfomed using only genomic regions where enough reads align.')
        elif args.bowtie2_index and args.amplicons_file:
            RUNNING_MODE='AMPLICONS_AND_GENOME'
            info('Amplicon description file and bowtie2 reference genome index files provided. The analysis will be perfomed using the reads that are aligned ony to the amplicons provided and not to other genomic regions.')
        else:
            error('Please provide the amplicons description file (-f or --amplicons_file option) or the bowtie2 reference genome index file (-x or --bowtie2_index option) or both.')
            sys.exit(1)
    
    
    
        ####TRIMMING AND MERGING
        get_name_from_fasta=lambda  x: os.path.basename(x).replace('.fastq','').replace('.gz','')
    
        if not args.name:
                 if args.fastq_r2!='':
                         database_id='%s_%s' % (get_name_from_fasta(args.fastq_r1),get_name_from_fasta(args.fastq_r2))
                 else:
                         database_id='%s' % get_name_from_fasta(args.fastq_r1)
    
        else:
                 database_id=args.name
                
    
    
        OUTPUT_DIRECTORY='CRISPRessoPooled_on_%s' % database_id
    
        if args.output_folder:
                 OUTPUT_DIRECTORY=os.path.join(os.path.abspath(args.output_folder),OUTPUT_DIRECTORY)
    
        _jp=lambda filename: os.path.join(OUTPUT_DIRECTORY,filename) #handy function to put a file in the output directory
    
        try:
                 info('Creating Folder %s' % OUTPUT_DIRECTORY)
                 os.makedirs(OUTPUT_DIRECTORY)
                 info('Done!')
        except:
                 warn('Folder %s already exists.' % OUTPUT_DIRECTORY)
    
        log_filename=_jp('CRISPRessoPooled_RUNNING_LOG.txt')
        logging.getLogger().addHandler(logging.FileHandler(log_filename))
    
        with open(log_filename,'w+') as outfile:
                  outfile.write('[Command used]:\nCRISPRessoPooled %s\n\n[Execution log]:\n' % ' '.join(sys.argv))
    
        if args.fastq_r2=='': #single end reads
    
             #check if we need to trim
             if not args.trim_sequences:
                 #create a symbolic link
                 symlink_filename=_jp(os.path.basename(args.fastq_r1))
                 force_symlink(os.path.abspath(args.fastq_r1),symlink_filename)
                 output_forward_filename=symlink_filename
             else:
                 output_forward_filename=_jp('reads.trimmed.fq.gz')
                 #Trimming with trimmomatic
                 cmd='java -jar %s SE -phred33 %s  %s %s >>%s 2>&1'\
                 % (get_data('trimmomatic-0.33.jar'),args.fastq_r1,
                    output_forward_filename,
                    args.trimmomatic_options_string.replace('NexteraPE-PE.fa','TruSeq3-SE.fa'),
                    log_filename)
                 #print cmd
                 TRIMMOMATIC_STATUS=sb.call(cmd,shell=True)
    
                 if TRIMMOMATIC_STATUS:
                         raise TrimmomaticException('TRIMMOMATIC failed to run, please check the log file.')
    
    
             processed_output_filename=output_forward_filename
    
        else:#paired end reads case
    
             if not args.trim_sequences:
                 output_forward_paired_filename=args.fastq_r1
                 output_reverse_paired_filename=args.fastq_r2
             else:
                 info('Trimming sequences with Trimmomatic...')
                 output_forward_paired_filename=_jp('output_forward_paired.fq.gz')
                 output_forward_unpaired_filename=_jp('output_forward_unpaired.fq.gz')
                 output_reverse_paired_filename=_jp('output_reverse_paired.fq.gz')
                 output_reverse_unpaired_filename=_jp('output_reverse_unpaired.fq.gz')
    
                 #Trimming with trimmomatic
                 cmd='java -jar %s PE -phred33 %s  %s %s  %s  %s  %s %s >>%s 2>&1'\
                 % (get_data('trimmomatic-0.33.jar'),
                         args.fastq_r1,args.fastq_r2,output_forward_paired_filename,
                         output_forward_unpaired_filename,output_reverse_paired_filename,
                         output_reverse_unpaired_filename,args.trimmomatic_options_string,log_filename)
                 #print cmd
                 TRIMMOMATIC_STATUS=sb.call(cmd,shell=True)
                 if TRIMMOMATIC_STATUS:
                         raise TrimmomaticException('TRIMMOMATIC failed to run, please check the log file.')
    
                 info('Done!')
    
    
             #Merging with Flash
             info('Merging paired sequences with Flash...')
             cmd='flash %s %s --min-overlap %d --max-overlap %d -z -d %s >>%s 2>&1' %\
             (output_forward_paired_filename,
              output_reverse_paired_filename,
              args.min_paired_end_reads_overlap,
              args.max_paired_end_reads_overlap,
              OUTPUT_DIRECTORY,log_filename)
    
             FLASH_STATUS=sb.call(cmd,shell=True)
             if FLASH_STATUS:
                 raise FlashException('Flash failed to run, please check the log file.')
    
             info('Done!')
    
             flash_hist_filename=_jp('out.hist')
             flash_histogram_filename=_jp('out.histogram')
             flash_not_combined_1_filename=_jp('out.notCombined_1.fastq.gz')
             flash_not_combined_2_filename=_jp('out.notCombined_2.fastq.gz')
    
             processed_output_filename=_jp('out.extendedFrags.fastq.gz')
    
    
        #count reads 
        N_READS_INPUT=get_n_reads_fastq(args.fastq_r1)
        N_READS_AFTER_PREPROCESSING=get_n_reads_fastq(processed_output_filename)
    
            
        #load gene annotation
        if args.gene_annotations:
            info('Loading gene coordinates from annotation file: %s...' % args.gene_annotations)
            try:
                df_genes=pd.read_table(args.gene_annotations,compression='gzip')
                df_genes.txEnd=df_genes.txEnd.astype(int)
                df_genes.txStart=df_genes.txStart.astype(int)
                df_genes.head()
            except:
               info('Failed to load the gene annotations file.')
        
    
        if RUNNING_MODE=='ONLY_AMPLICONS' or  RUNNING_MODE=='AMPLICONS_AND_GENOME':
    
            #load and validate template file
            df_template=pd.read_csv(args.amplicons_file,names=[
                    'Name','Amplicon_Sequence','sgRNA',
                    'Expected_HDR','Coding_sequence'],comment='#',sep='\t',dtype={'Name':str})
    
    
            #remove empty amplicons/lines
            df_template.dropna(subset=['Amplicon_Sequence'],inplace=True)
            df_template.dropna(subset=['Name'],inplace=True)
    
            df_template.Amplicon_Sequence=df_template.Amplicon_Sequence.apply(capitalize_sequence)
            df_template.Expected_HDR=df_template.Expected_HDR.apply(capitalize_sequence)
            df_template.sgRNA=df_template.sgRNA.apply(capitalize_sequence)
            df_template.Coding_sequence=df_template.Coding_sequence.apply(capitalize_sequence)
    
            if not len(df_template.Amplicon_Sequence.unique())==df_template.shape[0]:
                raise Exception('The amplicons should be all distinct!')
    
            if not len(df_template.Name.unique())==df_template.shape[0]:
                raise Exception('The amplicon names should be all distinct!')
    
            df_template=df_template.set_index('Name')
            df_template.index=df_template.index.to_series().str.replace(' ','_')
    
            for idx,row in df_template.iterrows():
    
                wrong_nt=find_wrong_nt(row.Amplicon_Sequence)
                if wrong_nt:
                     raise NTException('The amplicon sequence %s contains wrong characters:%s' % (idx,' '.join(wrong_nt)))
    
                if not pd.isnull(row.sgRNA):
                    
                    cut_points=[]
    
                    for current_guide_seq in row.sgRNA.strip().upper().split(','):
                    
                        wrong_nt=find_wrong_nt(current_guide_seq)
                        if wrong_nt:
                            raise NTException('The sgRNA sequence %s contains wrong characters:%s'  % (current_guide_seq, ' '.join(wrong_nt)))
                    
                        offset_fw=args.cleavage_offset+len(current_guide_seq)-1
                        offset_rc=(-args.cleavage_offset)-1
                        cut_points+=[m.start() + offset_fw for \
                                    m in re.finditer(current_guide_seq,  row.Amplicon_Sequence)]+[m.start() + offset_rc for m in re.finditer(reverse_complement(current_guide_seq),  row.Amplicon_Sequence)]
                    
                    if not cut_points:
                        warn('\nThe guide sequence/s provided: %s is(are) not present in the amplicon sequence:%s! \nNOTE: The guide will be ignored for the analysis. Please check your input!' % (row.sgRNA,row.Amplicon_Sequence))
                        df_template.ix[idx,'sgRNA']=''
                        
                        
    
        if RUNNING_MODE=='ONLY_AMPLICONS':
            #create a fasta file with all the amplicons
            amplicon_fa_filename=_jp('AMPLICONS.fa')
            fastq_gz_amplicon_filenames=[]
            with open(amplicon_fa_filename,'w+') as outfile:
                for idx,row in df_template.iterrows():
                    if row['Amplicon_Sequence']:
                        outfile.write('>%s\n%s\n' %(clean_filename('AMPL_'+idx),row['Amplicon_Sequence']))
    
                        #create place-holder fastq files
                        fastq_gz_amplicon_filenames.append(_jp('%s.fastq.gz' % clean_filename('AMPL_'+idx)))
                        open(fastq_gz_amplicon_filenames[-1], 'w+').close()
    
            df_template['Demultiplexed_fastq.gz_filename']=fastq_gz_amplicon_filenames
            info('Creating a custom index file with all the amplicons...')
            custom_index_filename=_jp('CUSTOM_BOWTIE2_INDEX')
            sb.call('bowtie2-build %s %s >>%s 2>&1' %(amplicon_fa_filename,custom_index_filename,log_filename), shell=True)
    
    
            #align the file to the amplicons (MODE 1)
            info('Align reads to the amplicons...')
            bam_filename_amplicons= _jp('CRISPResso_AMPLICONS_ALIGNED.bam')
            aligner_command= 'bowtie2 -x %s -p %s -k 1 --end-to-end -N 0 --np 0 -U %s 2>>%s | samtools view -bS - > %s' %(custom_index_filename,args.n_processes,processed_output_filename,log_filename,bam_filename_amplicons)
    
            sb.call(aligner_command,shell=True)
    
            N_READS_ALIGNED=get_n_aligned_bam(bam_filename_amplicons)
            
            s1=r"samtools view -F 4 %s 2>>%s | grep -v ^'@'" % (bam_filename_amplicons,log_filename)
            s2=r'''|awk '{ gzip_filename=sprintf("gzip >> OUTPUTPATH%s.fastq.gz",$3);\
            print "@"$1"\n"$10"\n+\n"$11  | gzip_filename;}' '''
    
            cmd=s1+s2.replace('OUTPUTPATH',_jp(''))
            sb.call(cmd,shell=True)
            
            info('Demultiplex reads and run CRISPResso on each amplicon...')
            n_reads_aligned_amplicons=[]
            for idx,row in df_template.iterrows():
                info('\n Processing:%s' %idx)
                n_reads_aligned_amplicons.append(get_n_reads_fastq(row['Demultiplexed_fastq.gz_filename']))
                crispresso_cmd='CRISPResso -r1 %s -a %s -o %s --name %s' % (row['Demultiplexed_fastq.gz_filename'],row['Amplicon_Sequence'],OUTPUT_DIRECTORY,idx)
    
                if n_reads_aligned_amplicons[-1]>args.min_reads_to_use_region:
                    if row['sgRNA'] and not pd.isnull(row['sgRNA']):
                        crispresso_cmd+=' -g %s' % row['sgRNA']
    
                    if row['Expected_HDR'] and not pd.isnull(row['Expected_HDR']):
                        crispresso_cmd+=' -e %s' % row['Expected_HDR']
    
                    if row['Coding_sequence'] and not pd.isnull(row['Coding_sequence']):
                        crispresso_cmd+=' -c %s' % row['Coding_sequence']
                    
                    crispresso_cmd=propagate_options(crispresso_cmd,crispresso_options,args)
                    info('Running CRISPResso:%s' % crispresso_cmd)
                    sb.call(crispresso_cmd,shell=True)
                else:
                    warn('Skipping amplicon [%s] since no reads are aligning to it\n'% idx)
    
            df_template['n_reads']=n_reads_aligned_amplicons
            df_template['n_reads_aligned_%']=df_template['n_reads']/float(N_READS_ALIGNED)*100
            df_template.fillna('NA').to_csv(_jp('REPORT_READS_ALIGNED_TO_AMPLICONS.txt'),sep='\t')
    
    
    
        if RUNNING_MODE=='AMPLICONS_AND_GENOME':
            print 'Mapping amplicons to the reference genome...'
            #find the locations of the amplicons on the genome and their strand and check if there are mutations in the reference genome
            additional_columns=[]
            for idx,row in df_template.iterrows():
                fields_to_append=list(np.take(get_align_sequence(row.Amplicon_Sequence, args.bowtie2_index).split('\t'),[0,1,2,3,5]))
                if fields_to_append[0]=='*':
                    info('The amplicon [%s] is not mappable to the reference genome provided!' % idx )
                    additional_columns.append([idx,'NOT_ALIGNED',0,-1,'+',''])
                else:
                    additional_columns.append([idx]+fields_to_append)
                    info('The amplicon [%s] was mapped to: %s ' % (idx,' '.join(fields_to_append[:3]) ))
        
        
            df_template=df_template.join(pd.DataFrame(additional_columns,columns=['Name','chr_id','bpstart','bpend','strand','Reference_Sequence']).set_index('Name'))
            
            df_template.bpstart=df_template.bpstart.astype(int)
            df_template.bpend=df_template.bpend.astype(int)
            
            #Check reference is the same otherwise throw a warning
            for idx,row in df_template.iterrows():
                if row.Amplicon_Sequence != row.Reference_Sequence and row.Amplicon_Sequence != reverse_complement(row.Reference_Sequence):
                    warn('The amplicon sequence %s provided:\n%s\n\nis different from the reference sequence(both strand):\n\n%s\n\n%s\n' %(row.name,row.Amplicon_Sequence,row.Amplicon_Sequence,reverse_complement(row.Amplicon_Sequence)))
     
    
        if RUNNING_MODE=='ONLY_GENOME' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
    
            ###HERE we recreate the uncompressed genome file if not available###
    
            #check you have all the files for the genome and create a fa idx for samtools
            
            uncompressed_reference=args.bowtie2_index+'.fa'
            
            #if not os.path.exists(GENOME_LOCAL_FOLDER):
            #    os.mkdir(GENOME_LOCAL_FOLDER)
    
            if os.path.exists(uncompressed_reference):
                info('The uncompressed reference fasta file for %s is already present! Skipping generation.' % args.bowtie2_index)
            else:
                #uncompressed_reference=os.path.join(GENOME_LOCAL_FOLDER,'UNCOMPRESSED_REFERENCE_FROM_'+args.bowtie2_index.replace('/','_')+'.fa')
                info('Extracting uncompressed reference from the provided bowtie2 index since it is not available... Please be patient!')
    
                cmd_to_uncompress='bowtie2-inspect %s > %s 2>>%s' % (args.bowtie2_index,uncompressed_reference,log_filename)
                sb.call(cmd_to_uncompress,shell=True)
    
                info('Indexing fasta file with samtools...')
                #!samtools faidx {uncompressed_reference}
                sb.call('samtools faidx %s 2>>%s ' % (uncompressed_reference,log_filename),shell=True)
    
    
        #####CORRECT ONE####
        #align in unbiased way the reads to the genome
        if RUNNING_MODE=='ONLY_GENOME' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
            info('Aligning reads to the provided genome index...')
            bam_filename_genome = _jp('%s_GENOME_ALIGNED.bam' % database_id)
            aligner_command= 'bowtie2 -x %s -p %s -k 1 --end-to-end -N 0 --np 0 -U %s 2>>%s| samtools view -bS - > %s' %(args.bowtie2_index,args.n_processes,processed_output_filename,log_filename,bam_filename_genome)
            sb.call(aligner_command,shell=True)
            
            N_READS_ALIGNED=get_n_aligned_bam(bam_filename_genome)
            
            #REDISCOVER LOCATIONS and DEMULTIPLEX READS
            MAPPED_REGIONS=_jp('MAPPED_REGIONS/')
            if not os.path.exists(MAPPED_REGIONS):
                os.mkdir(MAPPED_REGIONS)
    
            s1=r'''samtools view -F 0x0004 %s 2>>%s |''' % (bam_filename_genome,log_filename)+\
            r'''awk '{OFS="\t"; bpstart=$4;  bpend=bpstart; split ($6,a,"[MIDNSHP]"); n=0;\
            for (i=1; i in a; i++){\
                n+=1+length(a[i]);\
                if (substr($6,n,1)=="S"){\
                    if (bpend==$4)\
                        bpstart-=a[i];\
                    else
                        bpend+=a[i];
                    }\
                else if( (substr($6,n,1)!="I")  && (substr($6,n,1)!="H") )\
                        bpend+=a[i];\
                }\
                if ( ($2 % 32)>=16)\
                    print $3,bpstart,bpend,"-",$1,$10,$11;\
                else\
                    print $3,bpstart,bpend,"+",$1,$10,$11;}' | ''' 
        
            s2=r'''  sort -k1,1 -k2,2n  | awk \
            'BEGIN{chr_id="NA";bpstart=-1;bpend=-1; fastq_filename="NA"}\
            { if ( (chr_id!=$1) || (bpstart!=$2) || (bpend!=$3) )\
                {\
                if (fastq_filename!="NA") {close(fastq_filename); system("gzip "fastq_filename)}\
                chr_id=$1; bpstart=$2; bpend=$3;\
                fastq_filename=sprintf("__OUTPUTPATH__REGION_%s_%s_%s.fastq",$1,$2,$3);\
                }\
            print "@"$5"\n"$6"\n+\n"$7 >> fastq_filename;\
            }' '''
            cmd=s1+s2.replace('__OUTPUTPATH__',MAPPED_REGIONS)
            
            info('Demultiplexing reads by location...')
            sb.call(cmd,shell=True)
            
            #gzip the missing ones 
            sb.call('gzip %s/*.fastq' % MAPPED_REGIONS,shell=True)
    
        '''
        The most common use case, where many different target sites are pooled into a single 
        high-throughput sequencing library for quantification, is not directly addressed by this implementation. 
        Potential users of CRISPResso would need to write their own code to generate separate input files for processing. 
        Importantly, this preprocessing code would need to remove any PCR amplification artifacts 
        (such as amplification of sequences from a gene and a highly similar pseudogene ) 
        which may confound the interpretation of results. 
        This can be done by mapping of input sequences to a reference genome and removing 
        those that do not map to the expected genomic location, but is non-trivial for an end-user to implement.
        '''
        
    
        
        if RUNNING_MODE=='AMPLICONS_AND_GENOME':
            files_to_match=glob.glob(os.path.join(MAPPED_REGIONS,'REGION*'))
            n_reads_aligned_genome=[]
            fastq_region_filenames=[]
        
            for idx,row in df_template.iterrows():
                
                info('Processing amplicon:%s' % idx )
        
                #check if we have reads
                fastq_filename_region=os.path.join(MAPPED_REGIONS,'REGION_%s_%s_%s.fastq.gz' % (row['chr_id'],row['bpstart'],row['bpend']))
        
                if os.path.exists(fastq_filename_region):
                    
                    N_READS=get_n_reads_fastq(fastq_filename_region)
                    n_reads_aligned_genome.append(N_READS)
                    fastq_region_filenames.append(fastq_filename_region)
                    files_to_match.remove(fastq_filename_region)
                    if N_READS>=args.min_reads_to_use_region:
                        info('\nThe amplicon [%s] has enough reads (%d) mapped to it! Running CRISPResso!\n' % (idx,N_READS))
        
                        crispresso_cmd='CRISPResso -r1 %s -a %s -o %s --name %s' % (fastq_filename_region,row['Amplicon_Sequence'],OUTPUT_DIRECTORY,idx)
        
                        if row['sgRNA'] and not pd.isnull(row['sgRNA']):
                            crispresso_cmd+=' -g %s' % row['sgRNA']
        
                        if row['Expected_HDR'] and not pd.isnull(row['Expected_HDR']):
                            crispresso_cmd+=' -e %s' % row['Expected_HDR']
        
                        if row['Coding_sequence'] and not pd.isnull(row['Coding_sequence']):
                            crispresso_cmd+=' -c %s' % row['Coding_sequence']
                        
                        crispresso_cmd=propagate_options(crispresso_cmd,crispresso_options,args)
                        info('Running CRISPResso:%s' % crispresso_cmd)
                        sb.call(crispresso_cmd,shell=True)
         
                    else:
                         warn('The amplicon [%s] has not enough reads (%d) mapped to it! Skipping the execution of CRISPResso!' % (idx,N_READS))
                else:
                    fastq_region_filenames.append('')
                    n_reads_aligned_genome.append(0)
                    warn("The amplicon %s doesn't have any read mapped to it!\n Please check your amplicon sequence." %  idx)
        
            df_template['Amplicon_Specific_fastq.gz_filename']=fastq_region_filenames
            df_template['n_reads']=n_reads_aligned_genome
            df_template['n_reads_aligned_%']=df_template['n_reads']/float(N_READS_ALIGNED)*100
            
            if args.gene_annotations:
                df_template=df_template.apply(lambda row: find_overlapping_genes(row, df_genes),axis=1)
            
            df_template.fillna('NA').to_csv(_jp('REPORT_READS_ALIGNED_TO_GENOME_AND_AMPLICONS.txt'),sep='\t')
            
            #write another file with the not amplicon regions
            
            info('Reporting problematic regions...')  
            coordinates=[]
            for region in files_to_match:
                coordinates.append(os.path.basename(region).replace('.fastq.gz','').replace('.fastq','').split('_')[1:4]+[region,get_n_reads_fastq(region)])
        
            df_regions=pd.DataFrame(coordinates,columns=['chr_id','bpstart','bpend','fastq_file','n_reads'])
    
            df_regions=df_regions.convert_objects(convert_numeric=True)
            df_regions.dropna(inplace=True) #remove regions in chrUn
            df_regions.bpstart=df_regions.bpstart.astype(int)
            df_regions.bpend=df_regions.bpend.astype(int)
    
            df_regions['n_reads_aligned_%']=df_regions['n_reads']/float(N_READS_ALIGNED)*100
    
            df_regions['Reference_sequence']=df_regions.apply(lambda row: get_region_from_fa(row.chr_id,row.bpstart,row.bpend,uncompressed_reference),axis=1)
    
            
            if args.gene_annotations:
                info('Checking overlapping genes...')   
                df_regions=df_regions.apply(lambda row: find_overlapping_genes(row, df_genes),axis=1)
            
            if np.sum(np.array(map(int,pd.__version__.split('.')))*(100,10,1))< 170:
                df_regions.sort('n_reads',ascending=False,inplace=True)
            else:
                df_regions.sort_values(by='n_reads',ascending=False,inplace=True)


            df_regions.fillna('NA').to_csv(_jp('REPORTS_READS_ALIGNED_TO_GENOME_NOT_MATCHING_AMPLICONS.txt'),sep='\t',index=None)
    
    
        if RUNNING_MODE=='ONLY_GENOME' :
            #Load regions and build REFERENCE TABLES 
            info('Parsing the demultiplexed files and extracting locations and reference sequences...')
            coordinates=[]
            for region in glob.glob(os.path.join(MAPPED_REGIONS,'REGION*.fastq.gz')):
                coordinates.append(os.path.basename(region).replace('.fastq.gz','').split('_')[1:4]+[region,get_n_reads_fastq(region)])
            
            print 'C:',coordinates
            df_regions=pd.DataFrame(coordinates,columns=['chr_id','bpstart','bpend','fastq_file','n_reads'])
            
            print 'D:', df_regions
            df_regions=df_regions.convert_objects(convert_numeric=True)
            df_regions.dropna(inplace=True) #remove regions in chrUn
            df_regions.bpstart=df_regions.bpstart.astype(int)
            df_regions.bpend=df_regions.bpend.astype(int)
            print df_regions
            df_regions['sequence']=df_regions.apply(lambda row: get_region_from_fa(row.chr_id,row.bpstart,row.bpend,uncompressed_reference),axis=1)
    
            df_regions['n_reads_aligned_%']=df_regions['n_reads']/float(N_READS_ALIGNED)*100
                 
            if args.gene_annotations:
                info('Checking overlapping genes...')   
                df_regions=df_regions.apply(lambda row: find_overlapping_genes(row, df_genes),axis=1)
            
            if np.sum(np.array(map(int,pd.__version__.split('.')))*(100,10,1))< 170:
                df_regions.sort('n_reads',ascending=False,inplace=True)
            else:
                df_regions.sort_values(by='n_reads',ascending=False,inplace=True)


            df_regions.fillna('NA').to_csv(_jp('REPORT_READS_ALIGNED_TO_GENOME_ONLY.txt'),sep='\t',index=None)
            
            
            #run CRISPResso
            #demultiplex reads in the amplicons and call crispresso!
            info('Running CRISPResso on the regions discovered...')
            for idx,row in df_regions.iterrows():
        
                if row.n_reads > args.min_reads_to_use_region:
                    info('\nRunning CRISPResso on: %s-%d-%d...'%(row.chr_id,row.bpstart,row.bpend ))
                    crispresso_cmd='CRISPResso -r1 %s -a %s -o %s' %(row.fastq_file,row.sequence,OUTPUT_DIRECTORY)  
                    crispresso_cmd=propagate_options(crispresso_cmd,crispresso_options,args)
                    info('Running CRISPResso:%s' % crispresso_cmd)
                    sb.call(crispresso_cmd,shell=True)
                else:
                    info('Skipping region: %s-%d-%d , not enough reads (%d)' %(row.chr_id,row.bpstart,row.bpend, row.n_reads))
    
    
        #write alignment statistics
        with open(_jp('MAPPING_STATISTICS.txt'),'w+') as outfile:
            outfile.write('READS IN INPUTS:%d\nREADS AFTER PREPROCESSING:%d\nREADS ALIGNED:%d' % (N_READS_INPUT,N_READS_AFTER_PREPROCESSING,N_READS_ALIGNED))
    
        #write a file with basic quantification info for each sample
        def check_output_folder(output_folder):
            quantification_file=os.path.join(output_folder,'Quantification_of_editing_frequency.txt')  
    
            if os.path.exists(quantification_file):
                return quantification_file
            else:
                return None
    
        def parse_quantification(quantification_file):
            with open(quantification_file) as infile:
                infile.readline()
                N_UNMODIFIED=float(re.findall("Unmodified:(\d+)",infile.readline())[0])
                N_MODIFIED=float(re.findall("NHEJ:(\d+)",infile.readline())[0])
                N_REPAIRED=float(re.findall("HDR:(\d+)", infile.readline())[0])
                N_MIXED_HDR_NHEJ=float(re.findall("Mixed HDR-NHEJ:(\d+)", infile.readline())[0])
                infile.readline()
                N_TOTAL=float(re.findall("Total Aligned:(\d+) reads",infile.readline())[0])
                return N_UNMODIFIED,N_MODIFIED,N_REPAIRED,N_MIXED_HDR_NHEJ,N_TOTAL
    
        quantification_summary=[]
    
        if RUNNING_MODE=='ONLY_AMPLICONS' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
            df_final_data=df_template
        else:
            df_final_data=df_regions
    
        for idx,row in df_final_data.iterrows():
    
                if RUNNING_MODE=='ONLY_AMPLICONS' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
                    folder_name='CRISPResso_on_%s' % idx
                else:
                    folder_name='CRISPResso_on_REGION_%s_%d_%d' %(row.chr_id,row.bpstart,row.bpend )
    
                quantification_file=check_output_folder(_jp(folder_name))
    
                if quantification_file:
                    N_UNMODIFIED,N_MODIFIED,N_REPAIRED,N_MIXED_HDR_NHEJ,N_TOTAL=parse_quantification(quantification_file)
                    quantification_summary.append([idx,N_UNMODIFIED/N_TOTAL*100,N_MODIFIED/N_TOTAL*100,N_REPAIRED/N_TOTAL*100,N_MIXED_HDR_NHEJ/N_TOTAL*100,N_TOTAL,row.n_reads])
                else:
                    quantification_summary.append([idx,np.nan,np.nan,np.nan,np.nan,np.nan,row.n_reads])
                    warn('Skipping the folder %s, not enough reads or empty folder.'% folder_name)
    
    
        df_summary_quantification=pd.DataFrame(quantification_summary,columns=['Name','Unmodified%','NHEJ%','HDR%', 'Mixed_HDR-NHEJ%','Reads_aligned','Reads_total'])        
        df_summary_quantification.fillna('NA').to_csv(_jp('SAMPLES_QUANTIFICATION_SUMMARY.txt'),sep='\t',index=None)        
                    
        #cleaning up
        if not args.keep_intermediate:
             info('Removing Intermediate files...')
        
             if args.fastq_r2!='':
                 files_to_remove=[processed_output_filename,flash_hist_filename,flash_histogram_filename,\
                              flash_not_combined_1_filename,flash_not_combined_2_filename] 
             else:
                 files_to_remove=[processed_output_filename] 
        
             if args.trim_sequences and args.fastq_r2!='':
                 files_to_remove+=[output_forward_paired_filename,output_reverse_paired_filename,\
                                                   output_forward_unpaired_filename,output_reverse_unpaired_filename]
        
             if RUNNING_MODE=='ONLY_GENOME' or RUNNING_MODE=='AMPLICONS_AND_GENOME':
                     files_to_remove+=[bam_filename_genome]
                 
             if RUNNING_MODE=='ONLY_AMPLICONS':  
                files_to_remove+=[bam_filename_amplicons,amplicon_fa_filename]
                for bowtie2_file in glob.glob(_jp('CUSTOM_BOWTIE2_INDEX.*')):
                    files_to_remove.append(bowtie2_file)
        
             for file_to_remove in files_to_remove:
                 try:
                         if os.path.islink(file_to_remove):
                             #print 'LINK',file_to_remove
                             os.unlink(file_to_remove)
                         else:                             
                             os.remove(file_to_remove)
                 except:
                         warn('Skipping:%s' %file_to_remove)
    
    
           
        info('All Done!')
        print r'''
              )                                            )
             (           _______________________          (
            __)__       | __  __  __     __ __  |        __)__
         C\|     \      ||__)/  \/  \|  |_ |  \ |     C\|     \
           \     /      ||   \__/\__/|__|__|__/ |       \     /
            \___/       |_______________________|        \___/
        '''
        sys.exit(0)
    
    except Exception as e:
        error('\n\nERROR: %s' % e)
        sys.exit(-1)

Example 25

Project: Haystack
Source File: haystack_hotspots_CORE.py
View license
def main():

    print '\n[H A Y S T A C K   H O T S P O T]'
    print('\n-SELECTION OF VARIABLE REGIONS- [Luca Pinello - [email protected]]\n')
    print 'Version %s\n' % HAYSTACK_VERSION
    
    
    if which('samtools') is None:
            error('Haystack requires samtools free available at: http://sourceforge.net/projects/samtools/files/samtools/0.1.19/')
            sys.exit(1)
    
    if which('bedtools') is None:
            error('Haystack requires bedtools free available at: https://github.com/arq5x/bedtools2/releases/tag/v2.20.1')
            sys.exit(1)
    
    if which('bedGraphToBigWig') is None:
            info('To generate the bigwig files Haystack requires bedGraphToBigWig please download from here: http://hgdownload.cse.ucsc.edu/admin/exe/ and add to your PATH')
    
    #mandatory
    parser = argparse.ArgumentParser(description='HAYSTACK Parameters')
    parser.add_argument('samples_filename_or_bam_folder', type=str,  help='A tab delimeted file with in each row (1) a sample name, (2) the path to the corresponding bam filename. Alternatively it is possible to specify a folder containing some .bam files to analyze.' )
    parser.add_argument('genome_name', type=str,  help='Genome assembly to use from UCSC (for example hg19, mm9, etc.)')
    
    #optional
    parser.add_argument('--bin_size', type=int,help='bin size to use(default: 500bp)',default=500)
    parser.add_argument('--disable_quantile_normalization',help='Disable quantile normalization (default: False)',action='store_true')
    parser.add_argument('--th_rpm',type=float,help='Percentile on the signal intensity to consider for the hotspots (default: 99)', default=99)
    parser.add_argument('--transformation',type=str,help='Variance stabilizing transformation among: none, log2, angle (default: angle)',default='angle',choices=['angle', 'log2', 'none'])
    parser.add_argument('--recompute_all',help='Ignore any file previously precalculated',action='store_true')
    parser.add_argument('--z_score_high', type=float,help='z-score value to select the specific regions(default: 1.5)',default=1.5)
    parser.add_argument('--z_score_low', type=float,help='z-score value to select the not specific regions(default: 0.25)',default=0.25)
    parser.add_argument('--name',  help='Define a custom output filename for the report', default='')
    parser.add_argument('--output_directory',type=str, help='Output directory (default: current directory)',default='')
    parser.add_argument('--use_X_Y', help='Force to process the X and Y chromosomes (default: not processed)',action='store_true')
    parser.add_argument('--max_regions_percentage', type=float , help='Upper bound on the %% of the regions selected  (deafult: 0.1, 0.0=0%% 1.0=100%%)' , default=0.1)
    parser.add_argument('--depleted', help='Look for cell type specific regions with depletion of signal instead of enrichment',action='store_true')
    parser.add_argument('--input_is_bigwig', help='Use the bigwig format instead of the bam format for the input. Note: The files must have extension .bw',action='store_true')
    parser.add_argument('--version',help='Print version and exit.',action='version', version='Version %s' % HAYSTACK_VERSION)
    args = parser.parse_args()
    
    
    args_dict=vars(args)
    for key,value in args_dict.items():
            exec('%s=%s' %(key,repr(value)))
    
    
    if input_is_bigwig:
            extension_to_check='.bw'
            info('Input is set BigWig (.bw)')
    else:
            extension_to_check='.bam'
            info('Input is set compressed SAM (.bam)')
            
    #check folder or sample filename
    if os.path.isfile(samples_filename_or_bam_folder):
            BAM_FOLDER=False
            bam_filenames=[]
            sample_names=[]
            with open(samples_filename_or_bam_folder) as infile:
                for line in infile:
    
                    if not line.strip():
                            continue
                    
                    if line.startswith('#'): #skip optional header line or empty lines
                            info('Skipping header/comment line:%s' % line)
                            continue
    
                    fields=line.strip().split()
                    n_fields=len(fields)
                    
                    if n_fields==2: 
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
                    else:
                        error('The samples file format is wrong!')
                        sys.exit(1)
    
            
    else:
            if os.path.exists(samples_filename_or_bam_folder):
                    BAM_FOLDER=True
                    bam_filenames=glob.glob(os.path.join(samples_filename_or_bam_folder,'*'+extension_to_check))
    
                    if not bam_filenames:
                        error('No bam/bigwig  files to analyze in %s. Exiting.' % samples_filename_or_bam_folder)
                        sys.exit(1)
                    
                    sample_names=[os.path.basename(bam_filename).replace(extension_to_check,'') for bam_filename in bam_filenames]
            else:
                    error("The file or folder %s doesn't exist. Exiting." % samples_filename_or_bam_folder)
                    sys.exit(1)
                    
            
    #check all the files before starting
    info('Checking samples files location...')
    for bam_filename in bam_filenames:
            check_file(bam_filename)
    
    info('Initializing Genome:%s' %genome_name)
    
    genome_directory=determine_path('genomes')
    genome_2bit=os.path.join(genome_directory,genome_name+'.2bit')
    
    if os.path.exists(genome_2bit):
            genome=Genome_2bit(genome_2bit)
    else:
            info("\nIt seems you don't have the required genome file.")
            if query_yes_no('Should I download it for you?'):
                    sb.call('haystack_download_genome %s' %genome_name,shell=True,env=system_env)
                    if os.path.exists(genome_2bit):
                            info('Genome correctly downloaded!')
                            genome=Genome_2bit(genome_2bit)
                    else:
                            error('Sorry I cannot download the required file for you. Check your Internet connection.')
                            sys.exit(1)
            else:
                    error('Sorry I need the genome file to perform the analysis. Exiting...')
                    sys.exit(1)
    
    chr_len_filename=os.path.join(genome_directory, "%s_chr_lengths.txt" % genome_name)
    check_file(chr_len_filename)
    
    
    if name:
            directory_name='HAYSTACK_HOTSPOTS_on_%s' % name
    
    else:
            directory_name='HAYSTACK_HOTSPOTS'
    
    
    if output_directory:
            output_directory=os.path.join(output_directory,directory_name)
    else:
            output_directory=directory_name
            
    
    if not os.path.exists(output_directory):
    	os.makedirs(output_directory)    
    
                    
    genome_sorted_bins_file=os.path.join(output_directory,'%s.%dbp.bins.sorted.bed' %(os.path.basename(genome_name),bin_size))
    
    
    tracks_directory=os.path.join(output_directory,'TRACKS')
    if not os.path.exists(tracks_directory):
            os.makedirs(tracks_directory)   
    
    
    intermediate_directory=os.path.join(output_directory,'INTERMEDIATE')
    if not os.path.exists(intermediate_directory):
            os.makedirs(intermediate_directory)   
    
    if not os.path.exists(genome_sorted_bins_file) or recompute_all:
            info('Creating bins of %dbp for %s in %s' %(bin_size,chr_len_filename,genome_sorted_bins_file))
            sb.call('bedtools makewindows -g %s -w %s |  bedtools sort -i stdin |' %  (chr_len_filename, bin_size)+ "perl -nle 'print "+'"$_\t$.";'+"' /dev/stdin> %s" % genome_sorted_bins_file,shell=True,env=system_env)
            
    
    #convert bam files to genome-wide rpm tracks
    for base_name,bam_filename in zip(sample_names,bam_filenames):
    
        info('Processing:%s' %bam_filename)
        
        rpm_filename=os.path.join(tracks_directory,'%s.bedgraph' % base_name)
        sorted_rpm_filename=os.path.join(tracks_directory,'%s_sorted.bedgraph' % base_name)
        mapped_sorted_rpm_filename=os.path.join(tracks_directory,'%s_mapped_sorted.bedgraph' % base_name)
        binned_rpm_filename=os.path.join(intermediate_directory,'%s.%dbp.rpm' % (base_name,bin_size))
        bigwig_filename=os.path.join(tracks_directory,'%s.bw' %base_name)
    
        if  input_is_bigwig and which('bigWigAverageOverBed'):
                        if not os.path.exists(binned_rpm_filename) or recompute_all:
                                cmd='bigWigAverageOverBed %s %s  /dev/stdout | sort -s -n -k 1,1 | cut -f5 > %s' % (bam_filename,genome_sorted_bins_file,binned_rpm_filename)
                                sb.call(cmd,shell=True,env=system_env)
                                shutil.copy2(bam_filename,bigwig_filename)
    
        else:    
                if not os.path.exists(binned_rpm_filename) or recompute_all:
                        info('Computing Scaling Factor...')
                        cmd='samtools view -c -F 512 %s' % bam_filename
                        #print cmd
                        proc=sb.Popen(cmd, stdout=sb.PIPE,shell=True,env=system_env)
                        (stdout, stderr) = proc.communicate()
                        #print stdout,stderr
                        scaling_factor=(1.0/float(stdout.strip()))*1000000
    
                        info('Scaling Factor: %e' %scaling_factor)
    
                        info('Building BedGraph RPM track...')
                        cmd='samtools view -b -F 512 %s | bamToBed | slopBed  -r %s -l 0 -s -i stdin -g %s | genomeCoverageBed -g  %s -i stdin -bg -scale %.32f > %s'  %(bam_filename,bin_size,chr_len_filename,chr_len_filename,scaling_factor,rpm_filename)
                        #print cmd
    
                
                        proc=sb.call(cmd,shell=True,env=system_env)
    
                if which('bedGraphToBigWig'):
                    if not os.path.exists(bigwig_filename) or recompute_all:
                            info('Converting BedGraph to BigWig')
                            cmd='bedGraphToBigWig %s %s %s' %(rpm_filename,chr_len_filename,bigwig_filename)
                            proc=sb.call(cmd,shell=True,env=system_env)
    
                else:
                    info('Sorry I cannot create the bigwig file.\nPlease download and install bedGraphToBigWig from here: http://hgdownload.cse.ucsc.edu/admin/exe/ and add to your PATH')
    
                if not os.path.exists(binned_rpm_filename) or recompute_all:      
                        info('Make constant binned (%dbp) rpm values file' % bin_size)
                        #cmd='bedtools sort -i %s |  bedtools map -a %s -b stdin -c 4 -o mean -null 0.0 | cut -f5 > %s'   %(rpm_filename,genome_sorted_bins_file,binned_rpm_filename)
                        #proc=sb.call(cmd,shell=True,env=system_env)
                        
                        cmd='sort -k1,1 -k2,2n  %s  > %s'   %(rpm_filename,sorted_rpm_filename)
                        proc=sb.call(cmd,shell=True,env=system_env)

                        cmd='bedtools map -a %s -b %s -c 4 -o mean -null 0.0  > %s'   % (genome_sorted_bins_file,sorted_rpm_filename,mapped_sorted_rpm_filename)
                        proc=sb.call(cmd,shell=True,env=system_env)
                        
                        cmd='cut -f5 %s  > %s'   %(mapped_sorted_rpm_filename,binned_rpm_filename)
                        proc=sb.call(cmd,shell=True,env=system_env)

                
                try:    
                        os.remove(rpm_filename)
                        os.remove(sorted_rpm_filename)
                        os.remove(mapped_sorted_rpm_filename)
                except:
                        pass
    
    
    #load coordinates of bins
    coordinates_bin=pd.read_csv(genome_sorted_bins_file,names=['chr_id','bpstart','bpend'],sep='\t',header=None,usecols=[0,1,2])
    N_BINS=coordinates_bin.shape[0]
    if not use_X_Y:
        coordinates_bin=coordinates_bin.ix[(coordinates_bin['chr_id']!='chrX') & (coordinates_bin['chr_id']!='chrY')]  
    
    #load all the tracks
    info('Loading the processed tracks') 
    df_chip={}
    for state_file in  glob.glob(os.path.join(intermediate_directory,'*.rpm')):
            col_name=os.path.basename(state_file).replace('.rpm','')
            df_chip[col_name]=pd.read_csv(state_file,squeeze=True,header=None)
            info('Loading:%s' % col_name)
    
    df_chip=pd.DataFrame(df_chip)
    
    if disable_quantile_normalization:
            info('Skipping quantile normalization...')
    else:
            info('Normalizing the data...')
            df_chip=pd.DataFrame(quantile_normalization(df_chip.values),columns=df_chip.columns,index=df_chip.index)
    
    
    if which('bedGraphToBigWig'):        
            #write quantile normalized tracks
            coord_quantile=coordinates_bin.copy()
            for col in df_chip:
    
                if disable_quantile_normalization:
                        normalized_output_filename=os.path.join(tracks_directory,'%s.bedgraph' % os.path.basename(col))
                else:
                        normalized_output_filename=os.path.join(tracks_directory,'%s_quantile_normalized.bedgraph' % os.path.basename(col))
                        
                normalized_output_filename_bigwig=normalized_output_filename.replace('.bedgraph','.bw')
      
                if not os.path.exists(normalized_output_filename_bigwig) or recompute_all:         
                        info('Writing binned track: %s' % normalized_output_filename_bigwig )    
                        coord_quantile['rpm_normalized']=df_chip.ix[:,col]
                        coord_quantile.dropna().to_csv(normalized_output_filename,sep='\t',header=False,index=False)
                
                        cmd='bedGraphToBigWig %s %s %s' %(normalized_output_filename,chr_len_filename,normalized_output_filename_bigwig)
                        proc=sb.call(cmd,shell=True,env=system_env)
                        try:
                                os.remove(normalized_output_filename)
                        except:
                                pass
    else:
            info('Sorry I cannot creat the bigwig file.\nPlease download and install bedGraphToBigWig from here: http://hgdownload.cse.ucsc.edu/admin/exe/ and add to your PATH')
         
            
    #th_rpm=np.min(df_chip.apply(lambda x: np.percentile(x,th_rpm)))
    th_rpm=find_th_rpm(df_chip,th_rpm)
    info('Estimated th_rpm:%s' % th_rpm)
    
    df_chip_not_empty=df_chip.ix[(df_chip>th_rpm).any(1),:]
    

    
    if transformation=='log2':
            df_chip_not_empty=df_chip_not_empty.applymap(log2_transform)
            info('Using log2 transformation')
    
    elif transformation =='angle':     
            df_chip_not_empty=df_chip_not_empty.applymap(angle_transform )
            info('Using angle transformation')
    
    else:
            info('Using no transformation')
            
    iod_values=df_chip_not_empty.var(1)/df_chip_not_empty.mean(1)
    
    ####calculate the inflation point a la superenhancers
    scores=iod_values
    min_s=np.min(scores)
    max_s=np.max(scores)
    
    N_POINTS=len(scores)
    x=np.linspace(0,1,N_POINTS)
    y=sorted((scores-min_s)/(max_s-min_s))
    m=smooth((np.diff(y)/np.diff(x)),50)
    m=m-1
    m[m<=0]=np.inf
    m[:int(len(m)*(1-max_regions_percentage))]=np.inf
    idx_th=np.argmin(m)+1
    
    #print idx_th,
    th_iod=sorted(iod_values)[idx_th]
    #print th_iod
    
    
    hpr_idxs=iod_values>th_iod
    #print len(iod_values),len(hpr_idxs),sum(hpr_idxs), sum(hpr_idxs)/float(len(hpr_idxs)),
    
    info('Selected %f%% regions (%d)' %( sum(hpr_idxs)/float(len(hpr_idxs))*100, sum(hpr_idxs)))
    coordinates_bin['iod']=iod_values
    
    #we remove the regions "without" signal in any of the cell types
    coordinates_bin.dropna(inplace=True)
    
    
    #create a track for IGV
    bedgraph_iod_track_filename=os.path.join(tracks_directory,'VARIABILITY.bedgraph')
    bw_iod_track_filename=os.path.join(tracks_directory,'VARIABILITY.bw')
    
    if not os.path.exists(bw_iod_track_filename) or recompute_all:   
    
            info('Generating variability track in bigwig format in:%s' % bw_iod_track_filename)
    
            coordinates_bin.to_csv(bedgraph_iod_track_filename,sep='\t',header=False,index=False)
            sb.call('bedGraphToBigWig %s %s %s' % (bedgraph_iod_track_filename,chr_len_filename,bw_iod_track_filename ),shell=True,env=system_env)
            try:
                    os.remove(bedgraph_iod_track_filename)
            except:
                    pass
    
    
    #Write the HPRs
    bedgraph_hpr_filename=os.path.join(tracks_directory,'SELECTED_VARIABILITY_HOTSPOT.bedgraph')
    
    to_write=coordinates_bin.ix[hpr_idxs[hpr_idxs].index]
    to_write.dropna(inplace=True)
    to_write['bpstart']=to_write['bpstart'].astype(int)
    to_write['bpend']=to_write['bpend'].astype(int)
    
    to_write.to_csv(bedgraph_hpr_filename,sep='\t',header=False,index=False)
    
    bed_hpr_fileaname=os.path.join(output_directory,'SELECTED_VARIABILITY_HOTSPOT.bed')
    
    if not os.path.exists(bed_hpr_fileaname) or recompute_all:  
            info('Writing the HPRs in: %s' % bed_hpr_fileaname)
            sb.call('sort -k1,1 -k2,2n %s | bedtools merge -i stdin >  %s' %(bedgraph_hpr_filename,bed_hpr_fileaname),shell=True,env=system_env)
    
    #os.remove(bedgraph_hpr_filename)
    
    df_chip_hpr=df_chip_not_empty.ix[hpr_idxs,:]
    df_chip_hpr_zscore=df_chip_hpr.apply(zscore,axis=1)
    
    
    specific_regions_directory=os.path.join(output_directory,'SPECIFIC_REGIONS')
    if not os.path.exists(specific_regions_directory):
            os.makedirs(specific_regions_directory)   
    
    if depleted:
            z_score_high=-z_score_high
            z_score_low=-z_score_low
    
    
    #write target
    info('Writing Specific Regions for each cell line...')
    coord_zscore=coordinates_bin.copy()
    for col in df_chip_hpr_zscore:
    
            regions_specific_filename='Regions_specific_for_%s_z_%.2f.bedgraph' % (os.path.basename(col).replace('.rpm',''),z_score_high)
            specific_output_filename=os.path.join(specific_regions_directory,regions_specific_filename)
            specific_output_bed_filename=specific_output_filename.replace('.bedgraph','.bed')
    
            if not os.path.exists(specific_output_bed_filename) or recompute_all:  
                    if depleted:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]<z_score_high,col]
                    else:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]>z_score_high,col]
                    coord_zscore.dropna().to_csv(specific_output_filename,sep='\t',header=False,index=False)
    
                    info('Writing:%s' % specific_output_bed_filename )
                    sb.call('sort -k1,1 -k2,2n %s | bedtools merge -i stdin >  %s' %(specific_output_filename,specific_output_bed_filename),shell=True,env=system_env)
    
    
    #write background
    info('Writing Background Regions for each cell line...')
    coord_zscore=coordinates_bin.copy()
    for col in df_chip_hpr_zscore:
    
            regions_bg_filename='Background_for_%s_z_%.2f.bedgraph' % (os.path.basename(col).replace('.rpm',''),z_score_low)
            bg_output_filename=os.path.join(specific_regions_directory,'Background_for_%s_z_%.2f.bedgraph' % (os.path.basename(col).replace('.rpm',''),z_score_low))
            bg_output_bed_filename=bg_output_filename.replace('.bedgraph','.bed')
    
            if not os.path.exists(bg_output_bed_filename) or recompute_all:
    
                    if depleted:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]>z_score_low,col]
                    else:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]<z_score_low,col]
                    coord_zscore.dropna().to_csv(bg_output_filename,sep='\t',header=False,index=False)
    
                    info('Writing:%s' % bg_output_bed_filename )
                    sb.call('sort -k1,1 -k2,2n -i %s | bedtools merge -i stdin >  %s' %(bg_output_filename,bg_output_bed_filename),shell=True,env=system_env)    
    
    
    ###plot selection
    pl.figure()
    pl.title('Selection of the HPRs')
    pl.plot(x,y,'r',lw=3)
    pl.plot(x[idx_th],y[idx_th],'*',markersize=20)
    pl.hold(True)
    x_ext=np.linspace(-0.1,1.2,N_POINTS)
    y_line=(m[idx_th]+1.0)*(x_ext -x[idx_th])+ y[idx_th];
    pl.plot(x_ext,y_line,'--k',lw=3)
    pl.xlim(0,1.1)
    pl.ylim(0,1)
    pl.xlabel('Fraction of bins')
    pl.ylabel('Score normalized')
    pl.savefig(os.path.join(output_directory,'SELECTION_OF_VARIABILITY_HOTSPOT.pdf'))
    pl.close()
    
    
    
    igv_session_filename=os.path.join(output_directory,'OPEN_ME_WITH_IGV.xml')
    info('Creating an IGV session file (.xml) in: %s' %igv_session_filename)
    
    session = ET.Element("Session")
    session.set("genome",genome_name)
    session.set("hasGeneTrack","true")
    session.set("version","7")
    resources = ET.SubElement(session, "Resources")
    panel= ET.SubElement(session, "Panel")
    
    resource_items=[]
    track_items=[]
    
    hpr_iod_scores=scores[scores>th_iod]
    min_h=np.mean(hpr_iod_scores)-2*np.std(hpr_iod_scores)
    max_h=np.mean(hpr_iod_scores)+2*np.std(hpr_iod_scores)
    mid_h=np.mean(hpr_iod_scores)
    #write the tracks
    for sample_name in sample_names:
        if disable_quantile_normalization:
                track_full_path=os.path.join(output_directory,'TRACKS','%s.%dbp.bw' % (sample_name,bin_size))
        else:
                track_full_path=os.path.join(output_directory,'TRACKS','%s.%dbp_quantile_normalized.bw' % (sample_name,bin_size))
    
        track_filename=rem_base_path(track_full_path,output_directory)        
    
        if os.path.exists(track_full_path):    
                resource_items.append( ET.SubElement(resources, "Resource"))
                resource_items[-1].set("path",track_filename)
                track_items.append(ET.SubElement(panel, "Track" ))
                track_items[-1].set('color',"0,0,178")
                track_items[-1].set('id',track_filename)
                track_items[-1].set("name",sample_name)
    
    resource_items.append(ET.SubElement(resources, "Resource"))
    resource_items[-1].set("path",rem_base_path(bw_iod_track_filename,output_directory))
    
    track_items.append(ET.SubElement(panel, "Track" ))
    track_items[-1].set('color',"178,0,0")
    track_items[-1].set('id',rem_base_path(bw_iod_track_filename,output_directory))
    track_items[-1].set('renderer',"HEATMAP")
    track_items[-1].set("colorScale","ContinuousColorScale;%e;%e;%e;%e;0,153,255;255,255,51;204,0,0" % (mid_h,min_h,mid_h,max_h))
    track_items[-1].set("name",'VARIABILITY')
    
    resource_items.append(ET.SubElement(resources, "Resource"))
    resource_items[-1].set("path",rem_base_path(bed_hpr_fileaname,output_directory))
    track_items.append(ET.SubElement(panel, "Track" ))
    track_items[-1].set('color',"178,0,0")
    track_items[-1].set('id',rem_base_path(bed_hpr_fileaname,output_directory))
    track_items[-1].set('renderer',"HEATMAP")
    track_items[-1].set("colorScale","ContinuousColorScale;%e;%e;%e;%e;0,153,255;255,255,51;204,0,0" % (mid_h,min_h,mid_h,max_h))
    track_items[-1].set("name",'HOTSPOTS')
    
    for sample_name in sample_names:
        track_full_path=glob.glob(os.path.join(output_directory,'SPECIFIC_REGIONS','Regions_specific_for_%s*.bedgraph' %sample_name))[0]    
        specific_track_filename=rem_base_path(track_full_path,output_directory)
        if os.path.exists(track_full_path):
                resource_items.append( ET.SubElement(resources, "Resource"))
                resource_items[-1].set("path",specific_track_filename)
    
                track_items.append(ET.SubElement(panel, "Track" ))
                track_items[-1].set('color',"178,0,0")
                track_items[-1].set('id',specific_track_filename)
                track_items[-1].set('renderer',"HEATMAP")
                track_items[-1].set("colorScale","ContinuousColorScale;%e;%e;%e;%e;0,153,255;255,255,51;204,0,0" % (mid_h,min_h,mid_h,max_h))
                track_items[-1].set("name",'REGION SPECIFIC FOR %s' % sample_name)
    
    tree = ET.ElementTree(session)
    tree.write(igv_session_filename,xml_declaration=True)
    
    info('All done! Ciao!')
    sys.exit(0)

Example 26

Project: Haystack
Source File: haystack_hotspots_CORE.py
View license
def main():

    print '\n[H A Y S T A C K   H O T S P O T]'
    print('\n-SELECTION OF VARIABLE REGIONS- [Luca Pinello - [email protected]]\n')
    print 'Version %s\n' % HAYSTACK_VERSION
    
    
    if which('samtools') is None:
            error('Haystack requires samtools free available at: http://sourceforge.net/projects/samtools/files/samtools/0.1.19/')
            sys.exit(1)
    
    if which('bedtools') is None:
            error('Haystack requires bedtools free available at: https://github.com/arq5x/bedtools2/releases/tag/v2.20.1')
            sys.exit(1)
    
    if which('bedGraphToBigWig') is None:
            info('To generate the bigwig files Haystack requires bedGraphToBigWig please download from here: http://hgdownload.cse.ucsc.edu/admin/exe/ and add to your PATH')
    
    #mandatory
    parser = argparse.ArgumentParser(description='HAYSTACK Parameters')
    parser.add_argument('samples_filename_or_bam_folder', type=str,  help='A tab delimeted file with in each row (1) a sample name, (2) the path to the corresponding bam filename. Alternatively it is possible to specify a folder containing some .bam files to analyze.' )
    parser.add_argument('genome_name', type=str,  help='Genome assembly to use from UCSC (for example hg19, mm9, etc.)')
    
    #optional
    parser.add_argument('--bin_size', type=int,help='bin size to use(default: 500bp)',default=500)
    parser.add_argument('--disable_quantile_normalization',help='Disable quantile normalization (default: False)',action='store_true')
    parser.add_argument('--th_rpm',type=float,help='Percentile on the signal intensity to consider for the hotspots (default: 99)', default=99)
    parser.add_argument('--transformation',type=str,help='Variance stabilizing transformation among: none, log2, angle (default: angle)',default='angle',choices=['angle', 'log2', 'none'])
    parser.add_argument('--recompute_all',help='Ignore any file previously precalculated',action='store_true')
    parser.add_argument('--z_score_high', type=float,help='z-score value to select the specific regions(default: 1.5)',default=1.5)
    parser.add_argument('--z_score_low', type=float,help='z-score value to select the not specific regions(default: 0.25)',default=0.25)
    parser.add_argument('--name',  help='Define a custom output filename for the report', default='')
    parser.add_argument('--output_directory',type=str, help='Output directory (default: current directory)',default='')
    parser.add_argument('--use_X_Y', help='Force to process the X and Y chromosomes (default: not processed)',action='store_true')
    parser.add_argument('--max_regions_percentage', type=float , help='Upper bound on the %% of the regions selected  (deafult: 0.1, 0.0=0%% 1.0=100%%)' , default=0.1)
    parser.add_argument('--depleted', help='Look for cell type specific regions with depletion of signal instead of enrichment',action='store_true')
    parser.add_argument('--input_is_bigwig', help='Use the bigwig format instead of the bam format for the input. Note: The files must have extension .bw',action='store_true')
    parser.add_argument('--version',help='Print version and exit.',action='version', version='Version %s' % HAYSTACK_VERSION)
    args = parser.parse_args()
    
    
    args_dict=vars(args)
    for key,value in args_dict.items():
            exec('%s=%s' %(key,repr(value)))
    
    
    if input_is_bigwig:
            extension_to_check='.bw'
            info('Input is set BigWig (.bw)')
    else:
            extension_to_check='.bam'
            info('Input is set compressed SAM (.bam)')
            
    #check folder or sample filename
    if os.path.isfile(samples_filename_or_bam_folder):
            BAM_FOLDER=False
            bam_filenames=[]
            sample_names=[]
            with open(samples_filename_or_bam_folder) as infile:
                for line in infile:
    
                    if not line.strip():
                            continue
                    
                    if line.startswith('#'): #skip optional header line or empty lines
                            info('Skipping header/comment line:%s' % line)
                            continue
    
                    fields=line.strip().split()
                    n_fields=len(fields)
                    
                    if n_fields==2: 
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
                    else:
                        error('The samples file format is wrong!')
                        sys.exit(1)
    
            
    else:
            if os.path.exists(samples_filename_or_bam_folder):
                    BAM_FOLDER=True
                    bam_filenames=glob.glob(os.path.join(samples_filename_or_bam_folder,'*'+extension_to_check))
    
                    if not bam_filenames:
                        error('No bam/bigwig  files to analyze in %s. Exiting.' % samples_filename_or_bam_folder)
                        sys.exit(1)
                    
                    sample_names=[os.path.basename(bam_filename).replace(extension_to_check,'') for bam_filename in bam_filenames]
            else:
                    error("The file or folder %s doesn't exist. Exiting." % samples_filename_or_bam_folder)
                    sys.exit(1)
                    
            
    #check all the files before starting
    info('Checking samples files location...')
    for bam_filename in bam_filenames:
            check_file(bam_filename)
    
    info('Initializing Genome:%s' %genome_name)
    
    genome_directory=determine_path('genomes')
    genome_2bit=os.path.join(genome_directory,genome_name+'.2bit')
    
    if os.path.exists(genome_2bit):
            genome=Genome_2bit(genome_2bit)
    else:
            info("\nIt seems you don't have the required genome file.")
            if query_yes_no('Should I download it for you?'):
                    sb.call('haystack_download_genome %s' %genome_name,shell=True,env=system_env)
                    if os.path.exists(genome_2bit):
                            info('Genome correctly downloaded!')
                            genome=Genome_2bit(genome_2bit)
                    else:
                            error('Sorry I cannot download the required file for you. Check your Internet connection.')
                            sys.exit(1)
            else:
                    error('Sorry I need the genome file to perform the analysis. Exiting...')
                    sys.exit(1)
    
    chr_len_filename=os.path.join(genome_directory, "%s_chr_lengths.txt" % genome_name)
    check_file(chr_len_filename)
    
    
    if name:
            directory_name='HAYSTACK_HOTSPOTS_on_%s' % name
    
    else:
            directory_name='HAYSTACK_HOTSPOTS'
    
    
    if output_directory:
            output_directory=os.path.join(output_directory,directory_name)
    else:
            output_directory=directory_name
            
    
    if not os.path.exists(output_directory):
    	os.makedirs(output_directory)    
    
                    
    genome_sorted_bins_file=os.path.join(output_directory,'%s.%dbp.bins.sorted.bed' %(os.path.basename(genome_name),bin_size))
    
    
    tracks_directory=os.path.join(output_directory,'TRACKS')
    if not os.path.exists(tracks_directory):
            os.makedirs(tracks_directory)   
    
    
    intermediate_directory=os.path.join(output_directory,'INTERMEDIATE')
    if not os.path.exists(intermediate_directory):
            os.makedirs(intermediate_directory)   
    
    if not os.path.exists(genome_sorted_bins_file) or recompute_all:
            info('Creating bins of %dbp for %s in %s' %(bin_size,chr_len_filename,genome_sorted_bins_file))
            sb.call('bedtools makewindows -g %s -w %s |  bedtools sort -i stdin |' %  (chr_len_filename, bin_size)+ "perl -nle 'print "+'"$_\t$.";'+"' /dev/stdin> %s" % genome_sorted_bins_file,shell=True,env=system_env)
            
    
    #convert bam files to genome-wide rpm tracks
    for base_name,bam_filename in zip(sample_names,bam_filenames):
    
        info('Processing:%s' %bam_filename)
        
        rpm_filename=os.path.join(tracks_directory,'%s.bedgraph' % base_name)
        sorted_rpm_filename=os.path.join(tracks_directory,'%s_sorted.bedgraph' % base_name)
        mapped_sorted_rpm_filename=os.path.join(tracks_directory,'%s_mapped_sorted.bedgraph' % base_name)
        binned_rpm_filename=os.path.join(intermediate_directory,'%s.%dbp.rpm' % (base_name,bin_size))
        bigwig_filename=os.path.join(tracks_directory,'%s.bw' %base_name)
    
        if  input_is_bigwig and which('bigWigAverageOverBed'):
                        if not os.path.exists(binned_rpm_filename) or recompute_all:
                                cmd='bigWigAverageOverBed %s %s  /dev/stdout | sort -s -n -k 1,1 | cut -f5 > %s' % (bam_filename,genome_sorted_bins_file,binned_rpm_filename)
                                sb.call(cmd,shell=True,env=system_env)
                                shutil.copy2(bam_filename,bigwig_filename)
    
        else:    
                if not os.path.exists(binned_rpm_filename) or recompute_all:
                        info('Computing Scaling Factor...')
                        cmd='samtools view -c -F 512 %s' % bam_filename
                        #print cmd
                        proc=sb.Popen(cmd, stdout=sb.PIPE,shell=True,env=system_env)
                        (stdout, stderr) = proc.communicate()
                        #print stdout,stderr
                        scaling_factor=(1.0/float(stdout.strip()))*1000000
    
                        info('Scaling Factor: %e' %scaling_factor)
    
                        info('Building BedGraph RPM track...')
                        cmd='samtools view -b -F 512 %s | bamToBed | slopBed  -r %s -l 0 -s -i stdin -g %s | genomeCoverageBed -g  %s -i stdin -bg -scale %.32f > %s'  %(bam_filename,bin_size,chr_len_filename,chr_len_filename,scaling_factor,rpm_filename)
                        #print cmd
    
                
                        proc=sb.call(cmd,shell=True,env=system_env)
    
                if which('bedGraphToBigWig'):
                    if not os.path.exists(bigwig_filename) or recompute_all:
                            info('Converting BedGraph to BigWig')
                            cmd='bedGraphToBigWig %s %s %s' %(rpm_filename,chr_len_filename,bigwig_filename)
                            proc=sb.call(cmd,shell=True,env=system_env)
    
                else:
                    info('Sorry I cannot create the bigwig file.\nPlease download and install bedGraphToBigWig from here: http://hgdownload.cse.ucsc.edu/admin/exe/ and add to your PATH')
    
                if not os.path.exists(binned_rpm_filename) or recompute_all:      
                        info('Make constant binned (%dbp) rpm values file' % bin_size)
                        #cmd='bedtools sort -i %s |  bedtools map -a %s -b stdin -c 4 -o mean -null 0.0 | cut -f5 > %s'   %(rpm_filename,genome_sorted_bins_file,binned_rpm_filename)
                        #proc=sb.call(cmd,shell=True,env=system_env)
                        
                        cmd='sort -k1,1 -k2,2n  %s  > %s'   %(rpm_filename,sorted_rpm_filename)
                        proc=sb.call(cmd,shell=True,env=system_env)

                        cmd='bedtools map -a %s -b %s -c 4 -o mean -null 0.0  > %s'   % (genome_sorted_bins_file,sorted_rpm_filename,mapped_sorted_rpm_filename)
                        proc=sb.call(cmd,shell=True,env=system_env)
                        
                        cmd='cut -f5 %s  > %s'   %(mapped_sorted_rpm_filename,binned_rpm_filename)
                        proc=sb.call(cmd,shell=True,env=system_env)

                
                try:    
                        os.remove(rpm_filename)
                        os.remove(sorted_rpm_filename)
                        os.remove(mapped_sorted_rpm_filename)
                except:
                        pass
    
    
    #load coordinates of bins
    coordinates_bin=pd.read_csv(genome_sorted_bins_file,names=['chr_id','bpstart','bpend'],sep='\t',header=None,usecols=[0,1,2])
    N_BINS=coordinates_bin.shape[0]
    if not use_X_Y:
        coordinates_bin=coordinates_bin.ix[(coordinates_bin['chr_id']!='chrX') & (coordinates_bin['chr_id']!='chrY')]  
    
    #load all the tracks
    info('Loading the processed tracks') 
    df_chip={}
    for state_file in  glob.glob(os.path.join(intermediate_directory,'*.rpm')):
            col_name=os.path.basename(state_file).replace('.rpm','')
            df_chip[col_name]=pd.read_csv(state_file,squeeze=True,header=None)
            info('Loading:%s' % col_name)
    
    df_chip=pd.DataFrame(df_chip)
    
    if disable_quantile_normalization:
            info('Skipping quantile normalization...')
    else:
            info('Normalizing the data...')
            df_chip=pd.DataFrame(quantile_normalization(df_chip.values),columns=df_chip.columns,index=df_chip.index)
    
    
    if which('bedGraphToBigWig'):        
            #write quantile normalized tracks
            coord_quantile=coordinates_bin.copy()
            for col in df_chip:
    
                if disable_quantile_normalization:
                        normalized_output_filename=os.path.join(tracks_directory,'%s.bedgraph' % os.path.basename(col))
                else:
                        normalized_output_filename=os.path.join(tracks_directory,'%s_quantile_normalized.bedgraph' % os.path.basename(col))
                        
                normalized_output_filename_bigwig=normalized_output_filename.replace('.bedgraph','.bw')
      
                if not os.path.exists(normalized_output_filename_bigwig) or recompute_all:         
                        info('Writing binned track: %s' % normalized_output_filename_bigwig )    
                        coord_quantile['rpm_normalized']=df_chip.ix[:,col]
                        coord_quantile.dropna().to_csv(normalized_output_filename,sep='\t',header=False,index=False)
                
                        cmd='bedGraphToBigWig %s %s %s' %(normalized_output_filename,chr_len_filename,normalized_output_filename_bigwig)
                        proc=sb.call(cmd,shell=True,env=system_env)
                        try:
                                os.remove(normalized_output_filename)
                        except:
                                pass
    else:
            info('Sorry I cannot creat the bigwig file.\nPlease download and install bedGraphToBigWig from here: http://hgdownload.cse.ucsc.edu/admin/exe/ and add to your PATH')
         
            
    #th_rpm=np.min(df_chip.apply(lambda x: np.percentile(x,th_rpm)))
    th_rpm=find_th_rpm(df_chip,th_rpm)
    info('Estimated th_rpm:%s' % th_rpm)
    
    df_chip_not_empty=df_chip.ix[(df_chip>th_rpm).any(1),:]
    

    
    if transformation=='log2':
            df_chip_not_empty=df_chip_not_empty.applymap(log2_transform)
            info('Using log2 transformation')
    
    elif transformation =='angle':     
            df_chip_not_empty=df_chip_not_empty.applymap(angle_transform )
            info('Using angle transformation')
    
    else:
            info('Using no transformation')
            
    iod_values=df_chip_not_empty.var(1)/df_chip_not_empty.mean(1)
    
    ####calculate the inflation point a la superenhancers
    scores=iod_values
    min_s=np.min(scores)
    max_s=np.max(scores)
    
    N_POINTS=len(scores)
    x=np.linspace(0,1,N_POINTS)
    y=sorted((scores-min_s)/(max_s-min_s))
    m=smooth((np.diff(y)/np.diff(x)),50)
    m=m-1
    m[m<=0]=np.inf
    m[:int(len(m)*(1-max_regions_percentage))]=np.inf
    idx_th=np.argmin(m)+1
    
    #print idx_th,
    th_iod=sorted(iod_values)[idx_th]
    #print th_iod
    
    
    hpr_idxs=iod_values>th_iod
    #print len(iod_values),len(hpr_idxs),sum(hpr_idxs), sum(hpr_idxs)/float(len(hpr_idxs)),
    
    info('Selected %f%% regions (%d)' %( sum(hpr_idxs)/float(len(hpr_idxs))*100, sum(hpr_idxs)))
    coordinates_bin['iod']=iod_values
    
    #we remove the regions "without" signal in any of the cell types
    coordinates_bin.dropna(inplace=True)
    
    
    #create a track for IGV
    bedgraph_iod_track_filename=os.path.join(tracks_directory,'VARIABILITY.bedgraph')
    bw_iod_track_filename=os.path.join(tracks_directory,'VARIABILITY.bw')
    
    if not os.path.exists(bw_iod_track_filename) or recompute_all:   
    
            info('Generating variability track in bigwig format in:%s' % bw_iod_track_filename)
    
            coordinates_bin.to_csv(bedgraph_iod_track_filename,sep='\t',header=False,index=False)
            sb.call('bedGraphToBigWig %s %s %s' % (bedgraph_iod_track_filename,chr_len_filename,bw_iod_track_filename ),shell=True,env=system_env)
            try:
                    os.remove(bedgraph_iod_track_filename)
            except:
                    pass
    
    
    #Write the HPRs
    bedgraph_hpr_filename=os.path.join(tracks_directory,'SELECTED_VARIABILITY_HOTSPOT.bedgraph')
    
    to_write=coordinates_bin.ix[hpr_idxs[hpr_idxs].index]
    to_write.dropna(inplace=True)
    to_write['bpstart']=to_write['bpstart'].astype(int)
    to_write['bpend']=to_write['bpend'].astype(int)
    
    to_write.to_csv(bedgraph_hpr_filename,sep='\t',header=False,index=False)
    
    bed_hpr_fileaname=os.path.join(output_directory,'SELECTED_VARIABILITY_HOTSPOT.bed')
    
    if not os.path.exists(bed_hpr_fileaname) or recompute_all:  
            info('Writing the HPRs in: %s' % bed_hpr_fileaname)
            sb.call('sort -k1,1 -k2,2n %s | bedtools merge -i stdin >  %s' %(bedgraph_hpr_filename,bed_hpr_fileaname),shell=True,env=system_env)
    
    #os.remove(bedgraph_hpr_filename)
    
    df_chip_hpr=df_chip_not_empty.ix[hpr_idxs,:]
    df_chip_hpr_zscore=df_chip_hpr.apply(zscore,axis=1)
    
    
    specific_regions_directory=os.path.join(output_directory,'SPECIFIC_REGIONS')
    if not os.path.exists(specific_regions_directory):
            os.makedirs(specific_regions_directory)   
    
    if depleted:
            z_score_high=-z_score_high
            z_score_low=-z_score_low
    
    
    #write target
    info('Writing Specific Regions for each cell line...')
    coord_zscore=coordinates_bin.copy()
    for col in df_chip_hpr_zscore:
    
            regions_specific_filename='Regions_specific_for_%s_z_%.2f.bedgraph' % (os.path.basename(col).replace('.rpm',''),z_score_high)
            specific_output_filename=os.path.join(specific_regions_directory,regions_specific_filename)
            specific_output_bed_filename=specific_output_filename.replace('.bedgraph','.bed')
    
            if not os.path.exists(specific_output_bed_filename) or recompute_all:  
                    if depleted:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]<z_score_high,col]
                    else:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]>z_score_high,col]
                    coord_zscore.dropna().to_csv(specific_output_filename,sep='\t',header=False,index=False)
    
                    info('Writing:%s' % specific_output_bed_filename )
                    sb.call('sort -k1,1 -k2,2n %s | bedtools merge -i stdin >  %s' %(specific_output_filename,specific_output_bed_filename),shell=True,env=system_env)
    
    
    #write background
    info('Writing Background Regions for each cell line...')
    coord_zscore=coordinates_bin.copy()
    for col in df_chip_hpr_zscore:
    
            regions_bg_filename='Background_for_%s_z_%.2f.bedgraph' % (os.path.basename(col).replace('.rpm',''),z_score_low)
            bg_output_filename=os.path.join(specific_regions_directory,'Background_for_%s_z_%.2f.bedgraph' % (os.path.basename(col).replace('.rpm',''),z_score_low))
            bg_output_bed_filename=bg_output_filename.replace('.bedgraph','.bed')
    
            if not os.path.exists(bg_output_bed_filename) or recompute_all:
    
                    if depleted:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]>z_score_low,col]
                    else:
                            coord_zscore['z-score']=df_chip_hpr_zscore.ix[df_chip_hpr_zscore.ix[:,col]<z_score_low,col]
                    coord_zscore.dropna().to_csv(bg_output_filename,sep='\t',header=False,index=False)
    
                    info('Writing:%s' % bg_output_bed_filename )
                    sb.call('sort -k1,1 -k2,2n -i %s | bedtools merge -i stdin >  %s' %(bg_output_filename,bg_output_bed_filename),shell=True,env=system_env)    
    
    
    ###plot selection
    pl.figure()
    pl.title('Selection of the HPRs')
    pl.plot(x,y,'r',lw=3)
    pl.plot(x[idx_th],y[idx_th],'*',markersize=20)
    pl.hold(True)
    x_ext=np.linspace(-0.1,1.2,N_POINTS)
    y_line=(m[idx_th]+1.0)*(x_ext -x[idx_th])+ y[idx_th];
    pl.plot(x_ext,y_line,'--k',lw=3)
    pl.xlim(0,1.1)
    pl.ylim(0,1)
    pl.xlabel('Fraction of bins')
    pl.ylabel('Score normalized')
    pl.savefig(os.path.join(output_directory,'SELECTION_OF_VARIABILITY_HOTSPOT.pdf'))
    pl.close()
    
    
    
    igv_session_filename=os.path.join(output_directory,'OPEN_ME_WITH_IGV.xml')
    info('Creating an IGV session file (.xml) in: %s' %igv_session_filename)
    
    session = ET.Element("Session")
    session.set("genome",genome_name)
    session.set("hasGeneTrack","true")
    session.set("version","7")
    resources = ET.SubElement(session, "Resources")
    panel= ET.SubElement(session, "Panel")
    
    resource_items=[]
    track_items=[]
    
    hpr_iod_scores=scores[scores>th_iod]
    min_h=np.mean(hpr_iod_scores)-2*np.std(hpr_iod_scores)
    max_h=np.mean(hpr_iod_scores)+2*np.std(hpr_iod_scores)
    mid_h=np.mean(hpr_iod_scores)
    #write the tracks
    for sample_name in sample_names:
        if disable_quantile_normalization:
                track_full_path=os.path.join(output_directory,'TRACKS','%s.%dbp.bw' % (sample_name,bin_size))
        else:
                track_full_path=os.path.join(output_directory,'TRACKS','%s.%dbp_quantile_normalized.bw' % (sample_name,bin_size))
    
        track_filename=rem_base_path(track_full_path,output_directory)        
    
        if os.path.exists(track_full_path):    
                resource_items.append( ET.SubElement(resources, "Resource"))
                resource_items[-1].set("path",track_filename)
                track_items.append(ET.SubElement(panel, "Track" ))
                track_items[-1].set('color',"0,0,178")
                track_items[-1].set('id',track_filename)
                track_items[-1].set("name",sample_name)
    
    resource_items.append(ET.SubElement(resources, "Resource"))
    resource_items[-1].set("path",rem_base_path(bw_iod_track_filename,output_directory))
    
    track_items.append(ET.SubElement(panel, "Track" ))
    track_items[-1].set('color',"178,0,0")
    track_items[-1].set('id',rem_base_path(bw_iod_track_filename,output_directory))
    track_items[-1].set('renderer',"HEATMAP")
    track_items[-1].set("colorScale","ContinuousColorScale;%e;%e;%e;%e;0,153,255;255,255,51;204,0,0" % (mid_h,min_h,mid_h,max_h))
    track_items[-1].set("name",'VARIABILITY')
    
    resource_items.append(ET.SubElement(resources, "Resource"))
    resource_items[-1].set("path",rem_base_path(bed_hpr_fileaname,output_directory))
    track_items.append(ET.SubElement(panel, "Track" ))
    track_items[-1].set('color',"178,0,0")
    track_items[-1].set('id',rem_base_path(bed_hpr_fileaname,output_directory))
    track_items[-1].set('renderer',"HEATMAP")
    track_items[-1].set("colorScale","ContinuousColorScale;%e;%e;%e;%e;0,153,255;255,255,51;204,0,0" % (mid_h,min_h,mid_h,max_h))
    track_items[-1].set("name",'HOTSPOTS')
    
    for sample_name in sample_names:
        track_full_path=glob.glob(os.path.join(output_directory,'SPECIFIC_REGIONS','Regions_specific_for_%s*.bedgraph' %sample_name))[0]    
        specific_track_filename=rem_base_path(track_full_path,output_directory)
        if os.path.exists(track_full_path):
                resource_items.append( ET.SubElement(resources, "Resource"))
                resource_items[-1].set("path",specific_track_filename)
    
                track_items.append(ET.SubElement(panel, "Track" ))
                track_items[-1].set('color',"178,0,0")
                track_items[-1].set('id',specific_track_filename)
                track_items[-1].set('renderer',"HEATMAP")
                track_items[-1].set("colorScale","ContinuousColorScale;%e;%e;%e;%e;0,153,255;255,255,51;204,0,0" % (mid_h,min_h,mid_h,max_h))
                track_items[-1].set("name",'REGION SPECIFIC FOR %s' % sample_name)
    
    tree = ET.ElementTree(session)
    tree.write(igv_session_filename,xml_declaration=True)
    
    info('All done! Ciao!')
    sys.exit(0)

Example 27

Project: Haystack
Source File: haystack_pipeline_CORE.py
View license
def main():

    print '\n[H A Y S T A C K   P I P E L I N E]'
    print('\n-SELECTION OF HOTSPOTS OF VARIABILITY AND ENRICHED MOTIFS- [Luca Pinello - [email protected]]\n')
    print 'Version %s\n' % HAYSTACK_VERSION
    
    #mandatory
    parser = argparse.ArgumentParser(description='HAYSTACK Parameters')
    parser.add_argument('samples_filename_or_bam_folder', type=str,  help='A tab delimeted file with in each row (1) a sample name, (2) the path to the corresponding bam filename, (3 optional) the path to the corresponding gene expression filaneme. Alternatively it is possible to specify a folder containing some .bam files to analyze.')
    parser.add_argument('genome_name', type=str,  help='Genome assembly to use from UCSC (for example hg19, mm9, etc.)')
    
    #optional
    parser.add_argument('--name',  help='Define a custom output filename for the report', default='')
    parser.add_argument('--output_directory',type=str, help='Output directory (default: current directory)',default='')
    parser.add_argument('--bin_size', type=int,help='bin size to use(default: 500bp)',default=500)
    parser.add_argument('--recompute_all',help='Ignore any file previously precalculated fot the command haystack_hotstpot',action='store_true')
    parser.add_argument('--depleted', help='Look for cell type specific regions with depletion of signal instead of enrichment',action='store_true')
    parser.add_argument('--input_is_bigwig', help='Use the bigwig format instead of the bam format for the input. Note: The files must have extension .bw',action='store_true')
    parser.add_argument('--disable_quantile_normalization',help='Disable quantile normalization (default: False)',action='store_true')
    parser.add_argument('--transformation',type=str,help='Variance stabilizing transformation among: none, log2, angle (default: angle)',default='angle',choices=['angle', 'log2', 'none'])
    parser.add_argument('--z_score_high', type=float,help='z-score value to select the specific regions(default: 1.5)',default=1.5)
    parser.add_argument('--z_score_low', type=float,help='z-score value to select the not specific regions(default: 0.25)',default=0.25)
    parser.add_argument('--th_rpm',type=float,help='Percentile on the signal intensity to consider for the hotspots (default: 99)', default=99)
    parser.add_argument('--meme_motifs_filename', type=str, help='Motifs database in MEME format (default JASPAR CORE 2016)')
    parser.add_argument('--motif_mapping_filename', type=str, help='Custom motif to gene mapping file (the default is for JASPAR CORE 2016 database)')
    parser.add_argument('--plot_all',  help='Disable the filter on the TF activity and correlation (default z-score TF>0 and rho>0.3)',action='store_true')
    parser.add_argument('--n_processes',type=int, help='Specify the number of processes to use. The default is #cores available.',default=multiprocessing.cpu_count())
    parser.add_argument('--temp_directory',  help='Directory to store temporary files  (default: /tmp)', default='/tmp')
    parser.add_argument('--version',help='Print version and exit.',action='version', version='Version %s' % HAYSTACK_VERSION)
    
    args = parser.parse_args()
    args_dict=vars(args)
    for key,value in args_dict.items():
            exec('%s=%s' %(key,repr(value)))
            
            
            
    if meme_motifs_filename:
        check_file(meme_motifs_filename)
    
    if motif_mapping_filename:
        check_file(motif_mapping_filename)
        
    if not os.path.exists(temp_directory):
        error('The folder specified with --temp_directory: %s does not exist!' % temp_directory)
        sys.exit(1)
    
    if input_is_bigwig:
            extension_to_check='.bw'
            info('Input is set BigWig (.bw)')
    else:
            extension_to_check='.bam'
            info('Input is set compressed SAM (.bam)')
    
    if name:
            directory_name='HAYSTACK_PIPELINE_RESULTS_on_%s' % name
    
    else:
            directory_name='HAYSTACK_PIPELINE_RESULTS'
    
    if output_directory:
            output_directory=os.path.join(output_directory,directory_name)
    else:
            output_directory=directory_name
    
    #check folder or sample filename
    
    USE_GENE_EXPRESSION=True
    
    if os.path.isfile(samples_filename_or_bam_folder):
            BAM_FOLDER=False
            bam_filenames=[]
            gene_expression_filenames=[]
            sample_names=[]
    
            with open(samples_filename_or_bam_folder) as infile:
                for line in infile:
    
                    if not line.strip():
                            continue
                    
                    if line.startswith('#'): #skip optional header line or empty lines
                            info('Skipping header/comment line:%s' % line)
                            continue
    
                    fields=line.strip().split()
                    n_fields=len(fields)
    
                    if n_fields==2:
    
                        USE_GENE_EXPRESSION=False
                        
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
    
                    elif n_fields==3:
    
                        USE_GENE_EXPRESSION=USE_GENE_EXPRESSION and True
    
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
                        gene_expression_filenames.append(fields[2])
                    else:
                        error('The samples file format is wrong!')
                        sys.exit(1)
            
    else:
            if os.path.exists(samples_filename_or_bam_folder):
                    BAM_FOLDER=True
                    USE_GENE_EXPRESSION=False
                    bam_filenames=glob.glob(os.path.join(samples_filename_or_bam_folder,'*'+extension_to_check))
    
                    if not bam_filenames:
                        error('No bam/bigwig  files to analyze in %s. Exiting.' % samples_filename_or_bam_folder)
                        sys.exit(1)
                    
                    sample_names=[os.path.basename(bam_filename).replace(extension_to_check,'') for bam_filename in bam_filenames]
            else:
                    error("The file or folder %s doesn't exist. Exiting." % samples_filename_or_bam_folder)
                    sys.exit(1)
    
    
    #check all the files before starting
    info('Checking samples files location...')
    for bam_filename in bam_filenames:
            check_file(bam_filename)
    
    if USE_GENE_EXPRESSION:
        for gene_expression_filename in gene_expression_filenames:
                check_file(gene_expression_filename)
    
    if not os.path.exists(output_directory):
            os.makedirs(output_directory)
    
    #copy back the file used
    if not BAM_FOLDER:
            shutil.copy2(samples_filename_or_bam_folder,output_directory)
    
    #write hotspots conf files
    sample_names_hotspots_filename=os.path.join(output_directory,'sample_names_hotspots.txt')
    
    with open(sample_names_hotspots_filename,'w+') as outfile:
        for sample_name,bam_filename in zip(sample_names,bam_filenames):
            outfile.write('%s\t%s\n' % (sample_name, bam_filename))
    
    #write tf activity  conf files
    if USE_GENE_EXPRESSION:
            sample_names_tf_activity_filename=os.path.join(output_directory,'sample_names_tf_activity.txt')
    
            with open(sample_names_tf_activity_filename,'w+') as outfile:
                    for sample_name,gene_expression_filename in zip(sample_names,gene_expression_filenames):
                            outfile.write('%s\t%s\n' % (sample_name, gene_expression_filename))
    
            tf_activity_directory=os.path.join(output_directory,'HAYSTACK_TFs_ACTIVITY_PLANES')
    
    
    #CALL HAYSTACK HOTSPOTS
    cmd_to_run='haystack_hotspots %s %s --output_directory %s --bin_size %d %s %s %s %s %s %s %s %s' % \
                (sample_names_hotspots_filename, genome_name,output_directory,bin_size,
                 ('--recompute_all' if recompute_all else ''),
                 ('--depleted' if depleted else ''),
                 ('--input_is_bigwig' if input_is_bigwig else ''),
                 ('--disable_quantile_normalization' if disable_quantile_normalization else ''),
                 '--transformation %s' % transformation,
                 '--z_score_high %f' % z_score_high,
                 '--z_score_low %f' % z_score_low,
                 '--th_rpm %f' % th_rpm)
    print cmd_to_run
    sb.call(cmd_to_run ,shell=True,env=system_env)        
    
    #CALL HAYSTACK MOTIFS
    motif_directory=os.path.join(output_directory,'HAYSTACK_MOTIFS')
    for sample_name in sample_names:
        specific_regions_filename=os.path.join(output_directory,'HAYSTACK_HOTSPOTS','SPECIFIC_REGIONS','Regions_specific_for_%s*.bed' %sample_name)
        bg_regions_filename=glob.glob(os.path.join(output_directory,'HAYSTACK_HOTSPOTS','SPECIFIC_REGIONS','Background_for_%s*.bed' %sample_name))[0]
        #bg_regions_filename=glob.glob(specific_regions_filename.replace('Regions_specific','Background')[:-11]+'*.bed')[0] #lo zscore e' diverso...
        #print specific_regions_filename,bg_regions_filename
        cmd_to_run='haystack_motifs %s %s --bed_bg_filename %s --output_directory %s --name %s' % (specific_regions_filename,genome_name, bg_regions_filename,motif_directory, sample_name)
        
        if meme_motifs_filename:
             cmd_to_run+=' --meme_motifs_filename %s' % meme_motifs_filename
             
             
        if n_processes:
            cmd_to_run+=' --n_processes %d' % n_processes
            
        if temp_directory:
            cmd_to_run+=' --temp_directory %s' % temp_directory
            
            
        
        print cmd_to_run
        sb.call(cmd_to_run,shell=True,env=system_env)
    
        if USE_GENE_EXPRESSION:
                #CALL HAYSTACK TF ACTIVITY 
                motifs_output_folder=os.path.join(motif_directory,'HAYSTACK_MOTIFS_on_%s' % sample_name) 
                if os.path.exists(motifs_output_folder):
                    cmd_to_run='haystack_tf_activity_plane %s %s %s --output_directory %s'  %(motifs_output_folder,sample_names_tf_activity_filename,sample_name,tf_activity_directory)
                    
                    if motif_mapping_filename:
                        cmd_to_run+=' --motif_mapping_filename %s' %  motif_mapping_filename       
                    
                    if plot_all:
                        cmd_to_run+=' --plot_all'
                        
                    
                    print cmd_to_run
                    sb.call(cmd_to_run,shell=True,env=system_env) 

Example 28

Project: Haystack
Source File: haystack_pipeline_CORE.py
View license
def main():

    print '\n[H A Y S T A C K   P I P E L I N E]'
    print('\n-SELECTION OF HOTSPOTS OF VARIABILITY AND ENRICHED MOTIFS- [Luca Pinello - [email protected]]\n')
    print 'Version %s\n' % HAYSTACK_VERSION
    
    #mandatory
    parser = argparse.ArgumentParser(description='HAYSTACK Parameters')
    parser.add_argument('samples_filename_or_bam_folder', type=str,  help='A tab delimeted file with in each row (1) a sample name, (2) the path to the corresponding bam filename, (3 optional) the path to the corresponding gene expression filaneme. Alternatively it is possible to specify a folder containing some .bam files to analyze.')
    parser.add_argument('genome_name', type=str,  help='Genome assembly to use from UCSC (for example hg19, mm9, etc.)')
    
    #optional
    parser.add_argument('--name',  help='Define a custom output filename for the report', default='')
    parser.add_argument('--output_directory',type=str, help='Output directory (default: current directory)',default='')
    parser.add_argument('--bin_size', type=int,help='bin size to use(default: 500bp)',default=500)
    parser.add_argument('--recompute_all',help='Ignore any file previously precalculated fot the command haystack_hotstpot',action='store_true')
    parser.add_argument('--depleted', help='Look for cell type specific regions with depletion of signal instead of enrichment',action='store_true')
    parser.add_argument('--input_is_bigwig', help='Use the bigwig format instead of the bam format for the input. Note: The files must have extension .bw',action='store_true')
    parser.add_argument('--disable_quantile_normalization',help='Disable quantile normalization (default: False)',action='store_true')
    parser.add_argument('--transformation',type=str,help='Variance stabilizing transformation among: none, log2, angle (default: angle)',default='angle',choices=['angle', 'log2', 'none'])
    parser.add_argument('--z_score_high', type=float,help='z-score value to select the specific regions(default: 1.5)',default=1.5)
    parser.add_argument('--z_score_low', type=float,help='z-score value to select the not specific regions(default: 0.25)',default=0.25)
    parser.add_argument('--th_rpm',type=float,help='Percentile on the signal intensity to consider for the hotspots (default: 99)', default=99)
    parser.add_argument('--meme_motifs_filename', type=str, help='Motifs database in MEME format (default JASPAR CORE 2016)')
    parser.add_argument('--motif_mapping_filename', type=str, help='Custom motif to gene mapping file (the default is for JASPAR CORE 2016 database)')
    parser.add_argument('--plot_all',  help='Disable the filter on the TF activity and correlation (default z-score TF>0 and rho>0.3)',action='store_true')
    parser.add_argument('--n_processes',type=int, help='Specify the number of processes to use. The default is #cores available.',default=multiprocessing.cpu_count())
    parser.add_argument('--temp_directory',  help='Directory to store temporary files  (default: /tmp)', default='/tmp')
    parser.add_argument('--version',help='Print version and exit.',action='version', version='Version %s' % HAYSTACK_VERSION)
    
    args = parser.parse_args()
    args_dict=vars(args)
    for key,value in args_dict.items():
            exec('%s=%s' %(key,repr(value)))
            
            
            
    if meme_motifs_filename:
        check_file(meme_motifs_filename)
    
    if motif_mapping_filename:
        check_file(motif_mapping_filename)
        
    if not os.path.exists(temp_directory):
        error('The folder specified with --temp_directory: %s does not exist!' % temp_directory)
        sys.exit(1)
    
    if input_is_bigwig:
            extension_to_check='.bw'
            info('Input is set BigWig (.bw)')
    else:
            extension_to_check='.bam'
            info('Input is set compressed SAM (.bam)')
    
    if name:
            directory_name='HAYSTACK_PIPELINE_RESULTS_on_%s' % name
    
    else:
            directory_name='HAYSTACK_PIPELINE_RESULTS'
    
    if output_directory:
            output_directory=os.path.join(output_directory,directory_name)
    else:
            output_directory=directory_name
    
    #check folder or sample filename
    
    USE_GENE_EXPRESSION=True
    
    if os.path.isfile(samples_filename_or_bam_folder):
            BAM_FOLDER=False
            bam_filenames=[]
            gene_expression_filenames=[]
            sample_names=[]
    
            with open(samples_filename_or_bam_folder) as infile:
                for line in infile:
    
                    if not line.strip():
                            continue
                    
                    if line.startswith('#'): #skip optional header line or empty lines
                            info('Skipping header/comment line:%s' % line)
                            continue
    
                    fields=line.strip().split()
                    n_fields=len(fields)
    
                    if n_fields==2:
    
                        USE_GENE_EXPRESSION=False
                        
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
    
                    elif n_fields==3:
    
                        USE_GENE_EXPRESSION=USE_GENE_EXPRESSION and True
    
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
                        gene_expression_filenames.append(fields[2])
                    else:
                        error('The samples file format is wrong!')
                        sys.exit(1)
            
    else:
            if os.path.exists(samples_filename_or_bam_folder):
                    BAM_FOLDER=True
                    USE_GENE_EXPRESSION=False
                    bam_filenames=glob.glob(os.path.join(samples_filename_or_bam_folder,'*'+extension_to_check))
    
                    if not bam_filenames:
                        error('No bam/bigwig  files to analyze in %s. Exiting.' % samples_filename_or_bam_folder)
                        sys.exit(1)
                    
                    sample_names=[os.path.basename(bam_filename).replace(extension_to_check,'') for bam_filename in bam_filenames]
            else:
                    error("The file or folder %s doesn't exist. Exiting." % samples_filename_or_bam_folder)
                    sys.exit(1)
    
    
    #check all the files before starting
    info('Checking samples files location...')
    for bam_filename in bam_filenames:
            check_file(bam_filename)
    
    if USE_GENE_EXPRESSION:
        for gene_expression_filename in gene_expression_filenames:
                check_file(gene_expression_filename)
    
    if not os.path.exists(output_directory):
            os.makedirs(output_directory)
    
    #copy back the file used
    if not BAM_FOLDER:
            shutil.copy2(samples_filename_or_bam_folder,output_directory)
    
    #write hotspots conf files
    sample_names_hotspots_filename=os.path.join(output_directory,'sample_names_hotspots.txt')
    
    with open(sample_names_hotspots_filename,'w+') as outfile:
        for sample_name,bam_filename in zip(sample_names,bam_filenames):
            outfile.write('%s\t%s\n' % (sample_name, bam_filename))
    
    #write tf activity  conf files
    if USE_GENE_EXPRESSION:
            sample_names_tf_activity_filename=os.path.join(output_directory,'sample_names_tf_activity.txt')
    
            with open(sample_names_tf_activity_filename,'w+') as outfile:
                    for sample_name,gene_expression_filename in zip(sample_names,gene_expression_filenames):
                            outfile.write('%s\t%s\n' % (sample_name, gene_expression_filename))
    
            tf_activity_directory=os.path.join(output_directory,'HAYSTACK_TFs_ACTIVITY_PLANES')
    
    
    #CALL HAYSTACK HOTSPOTS
    cmd_to_run='haystack_hotspots %s %s --output_directory %s --bin_size %d %s %s %s %s %s %s %s %s' % \
                (sample_names_hotspots_filename, genome_name,output_directory,bin_size,
                 ('--recompute_all' if recompute_all else ''),
                 ('--depleted' if depleted else ''),
                 ('--input_is_bigwig' if input_is_bigwig else ''),
                 ('--disable_quantile_normalization' if disable_quantile_normalization else ''),
                 '--transformation %s' % transformation,
                 '--z_score_high %f' % z_score_high,
                 '--z_score_low %f' % z_score_low,
                 '--th_rpm %f' % th_rpm)
    print cmd_to_run
    sb.call(cmd_to_run ,shell=True,env=system_env)        
    
    #CALL HAYSTACK MOTIFS
    motif_directory=os.path.join(output_directory,'HAYSTACK_MOTIFS')
    for sample_name in sample_names:
        specific_regions_filename=os.path.join(output_directory,'HAYSTACK_HOTSPOTS','SPECIFIC_REGIONS','Regions_specific_for_%s*.bed' %sample_name)
        bg_regions_filename=glob.glob(os.path.join(output_directory,'HAYSTACK_HOTSPOTS','SPECIFIC_REGIONS','Background_for_%s*.bed' %sample_name))[0]
        #bg_regions_filename=glob.glob(specific_regions_filename.replace('Regions_specific','Background')[:-11]+'*.bed')[0] #lo zscore e' diverso...
        #print specific_regions_filename,bg_regions_filename
        cmd_to_run='haystack_motifs %s %s --bed_bg_filename %s --output_directory %s --name %s' % (specific_regions_filename,genome_name, bg_regions_filename,motif_directory, sample_name)
        
        if meme_motifs_filename:
             cmd_to_run+=' --meme_motifs_filename %s' % meme_motifs_filename
             
             
        if n_processes:
            cmd_to_run+=' --n_processes %d' % n_processes
            
        if temp_directory:
            cmd_to_run+=' --temp_directory %s' % temp_directory
            
            
        
        print cmd_to_run
        sb.call(cmd_to_run,shell=True,env=system_env)
    
        if USE_GENE_EXPRESSION:
                #CALL HAYSTACK TF ACTIVITY 
                motifs_output_folder=os.path.join(motif_directory,'HAYSTACK_MOTIFS_on_%s' % sample_name) 
                if os.path.exists(motifs_output_folder):
                    cmd_to_run='haystack_tf_activity_plane %s %s %s --output_directory %s'  %(motifs_output_folder,sample_names_tf_activity_filename,sample_name,tf_activity_directory)
                    
                    if motif_mapping_filename:
                        cmd_to_run+=' --motif_mapping_filename %s' %  motif_mapping_filename       
                    
                    if plot_all:
                        cmd_to_run+=' --plot_all'
                        
                    
                    print cmd_to_run
                    sb.call(cmd_to_run,shell=True,env=system_env) 

Example 29

Project: edx2bigquery
Source File: edx2course_axis.py
View license
def make_axis(dir):
    '''
    return dict of {course_id : { policy, xbundle, axis (as list of Axel elements) }}
    '''
    
    courses = []
    log_msg = []

    def logit(msg, nolog=False):
        if not nolog:
            log_msg.append(msg)
        print msg

    dir = path(dir)

    if os.path.exists(dir / 'roots'):	# if roots directory exists, use that for different course versions
        # get roots
        roots = glob.glob(dir / 'roots/*.xml')
        courses = [ CourseInfo(fn, '', dir) for fn in roots ]

    else:	# single course.xml file - use differnt policy files in policy directory, though

        fn = dir / 'course.xml'
    
        # get semesters
        policies = glob.glob(dir/'policies/*.json')
        assetsfn = dir / 'policies/assets.json'
        if str(assetsfn) in policies:
            policies.remove(assetsfn)
        if not policies:
            policies = glob.glob(dir/'policies/*/policy.json')
        if not policies:
            logit("Error: no policy files found!")
        
        courses = [ CourseInfo(fn, pfn) for pfn in policies ]


    logit("%d course runs found: %s" % (len(courses), [c.url_name for c in courses]))
    
    ret = {}

    # construct axis for each policy
    for cinfo in courses:
        policy = cinfo.policy
        semester = policy.semester
        org = cinfo.org
        course = cinfo.course
        cid = '%s/%s/%s' % (org, course, semester)
        logit('course_id=%s' %  cid)
    
        cfn = dir / ('course/%s.xml' % semester)
        
        # generate XBundle for course
        xml = etree.parse(cfn).getroot()
        xb = xbundle.XBundle(keep_urls=True, skip_hidden=True, keep_studio_urls=True)
        xb.policy = policy.policy
        cxml = xb.import_xml_removing_descriptor(dir, xml)

        # append metadata
        metadata = etree.Element('metadata')
        cxml.append(metadata)
        policy_xml = etree.Element('policy')
        metadata.append(policy_xml)
        policy_xml.text = json.dumps(policy.policy)
        grading_policy_xml = etree.Element('grading_policy')
        metadata.append(grading_policy_xml)
        grading_policy_xml.text = json.dumps(policy.grading_policy)
    
        bundle = etree.tostring(cxml, pretty_print=True)
        #print bundle[:500]
        index = [1]
        caxis = []
    
        def walk(x, seq_num=1, path=[], seq_type=None, parent_start=None, parent=None, chapter=None,
                 parent_url_name=None, split_url_name=None):
            '''
            Recursively traverse course tree.  
            
            x        = current etree element
            seq_num  = sequence of current element in its parent, starting from 1
            path     = list of url_name's to current element, following edX's hierarchy conventions
            seq_type = problemset, sequential, or videosequence
            parent_start = start date of parent of current etree element
            parent   = parent module
            chapter  = the last chapter module_id seen while walking through the tree
            parent_url_name = url_name of parent
            split_url_name   = url_name of split_test element if this subtree is in a split_test, otherwise None
            '''
            url_name = x.get('url_name',x.get('url_name_orig',''))
            if not url_name:
                dn = x.get('display_name')
                if dn is not None:
                    url_name = dn.strip().replace(' ','_')     # 2012 convention for converting display_name to url_name
                    url_name = url_name.replace(':','_')
                    url_name = url_name.replace('.','_')
                    url_name = url_name.replace('(','_').replace(')','_').replace('__','_')
            
            data = None
            start = None

            if not FORCE_NO_HIDE:
                hide = policy.get_metadata(x, 'hide_from_toc')
                if hide is not None and not hide=="false":
                    logit('[edx2course_axis] Skipping %s (%s), it has hide_from_toc=%s' % (x.tag, x.get('display_name','<noname>'), hide))
                    return

            if x.tag=='video':	# special: for video, let data = youtube ID(s)
                data = x.get('youtube','')
                if data:
                    # old ytid format - extract just the 1.0 part of this 
                    # 0.75:JdL1Vo0Hru0,1.0:lbaG3uiQ6IY,1.25:Lrj0G8RWHKw,1.50:54fs3-WxqLs
                    ytid = data.replace(' ','').split(',')
                    ytid = [z[1] for z in [y.split(':') for y in ytid] if z[0]=='1.0']
                    # print "   ytid: %s -> %s" % (x.get('youtube',''), ytid)
                    if ytid:
                        data = ytid
                if not data:
                    data = x.get('youtube_id_1_0', '')
                if data:
                    data = '{"ytid": "%s"}' % data

            if x.tag=="split_test":
                data = {}
                to_copy = ['group_id_to_child', 'user_partition_id']
                for tc in to_copy:
                    data[tc] = x.get(tc, None)

            if x.tag=='problem' and x.get('weight') is not None and x.get('weight'):
                try:
                    # Changed from string to dict. In next code block.
                    data = {"weight": "%f" % float(x.get('weight'))}
                except Exception as err:
                    logit("    Error converting weight %s" % x.get('weight'))

            ### Had a hard time making my code work within the try/except for weight. Happy to improve
            ### Also note, weight is typically missing in problems. So I find it weird that we throw an exception.
            if x.tag=='problem':
                # Initialize data if no weight
                if not data:
                    data = {}

                # meta will store all problem related metadata, then be used to update data
                meta = {}
                # Items is meant to help debug - an ordered list of encountered problem types with url names
                # Likely should not be pulled to Big Query 
                meta['items'] = []
                # Known Problem Types
                known_problem_types = ['multiplechoiceresponse','numericalresponse','choiceresponse',
                                       'optionresponse','stringresponse','formularesponse',
                                       'customresponse','fieldset']

                # Loop through all child nodes in a problem. If encountering a known problem type, add metadata.
                for a in x:
                    if a.tag in known_problem_types:
                        meta['items'].append({'itype':a.tag,'url_name':a.get('url_name')})

                ### Check for accompanying image
                images = x.findall('.//img')
                # meta['has_image'] = False
                
                if images and len(images)>0:
                    meta['has_image'] = True #Note, one can use a.get('src'), but needs to account for multiple images
                    # print meta['img'],len(images)

                ### Search for all solution tags in a problem
                solutions = x.findall('.//solution')
                # meta['has_solution'] = False

                if solutions and len(solutions)>0:
                    text = ''
                    for sol in solutions:
                        text = text.join(html.tostring(e, pretty_print=False) for e in sol)
                        # This if statment checks each solution. Note, many MITx problems have multiple solution tags.
                        # In 8.05x, common to put image in one solution tag, and the text in a second. So we are checking each tag.
                        # If we find one solution with > 65 char, or one solution with an image, we set meta['solution'] = True
                        if len(text) > 65 or 'img src' in text:
                            meta['has_solution'] = True

                ### If meta is empty, log all tags for debugging later. 
                if len(meta)==0:
                    logit('item type not found - here is the list of tags:['+','.join(a.tag if a else ' ' for a in x)+']')
                    # print 'problem type not found - here is the list of tags:['+','.join(a.tag for a in x)+']'

                ### Add easily accessible metadata for problems
                # num_items: number of items
                # itype: problem type - note, mixed is used when items are not of same type
                if len(meta['items']) > 0:
                    # Number of Items
                    meta['num_items'] = len(meta['items'])

                    # Problem Type
                    if all(meta['items'][0]['itype'] == item['itype'] for item in meta['items']):
                        meta['itype'] = meta['items'][0]['itype']
                        # print meta['items'][0]['itype']
                    else:
                        meta['itype'] = 'mixed'

                # Update data field
                ### ! For now, removing the items field. 
                del meta["items"]               

                data.update(meta)
                data = json.dumps(data)

            if x.tag=='html':
                iframe = x.find('.//iframe')
                if iframe is not None:
                    logit("   found iframe in html %s" % url_name)
                    src = iframe.get('src','')
                    if 'https://www.youtube.com/embed/' in src:
                        m = re.search('embed/([^"/?]+)', src)
                        if m:
                            data = '{"ytid": "%s"}' % m.group(1)
                            logit("    data=%s" % data)
                
            if url_name:              # url_name is mandatory if we are to do anything with this element
                # url_name = url_name.replace(':','_')
                dn = x.get('display_name', url_name)
                try:
                    #dn = dn.decode('utf-8')
                    dn = unicode(dn)
                    dn = fix_bad_unicode(dn)
                except Exception as err:
                    logit('unicode error, type(dn)=%s'  % type(dn))
                    raise
                pdn = policy.get_metadata(x, 'display_name')      # policy display_name - if given, let that override default
                if pdn is not None:
                    dn = pdn

                #start = date_parse(x.get('start', policy.get_metadata(x, 'start', '')))
                start = date_parse(policy.get_metadata(x, 'start', '', parent=True))
                
                if parent_start is not None and start < parent_start:
                    if VERBOSE_WARNINGS:
                        logit("    Warning: start of %s element %s happens before start %s of parent: using parent start" % (start, x.tag, parent_start), nolog=True)
                    start = parent_start
                #print "start for %s = %s" % (x, start)
                
                # drop bad due date strings
                if date_parse(x.get('due',None), retbad=True)=='Bad':
                    x.set('due', '')

                due = date_parse(policy.get_metadata(x, 'due', '', parent=True))
                if x.tag=="problem":
                    logit("    setting problem due date: for %s due=%s" % (url_name, due), nolog=True)

                gformat = x.get('format', policy.get_metadata(x, 'format', ''))
                if url_name=='hw0':
                    logit( "gformat for hw0 = %s" % gformat)

                graded = x.get('graded', policy.get_metadata(x, 'graded', ''))
                if not (type(graded) in [unicode, str]):
                    graded = str(graded)

                # compute path
                # The hierarchy goes: `course > chapter > (problemset | sequential | videosequence)`
                if x.tag=='chapter':
                    path = [url_name]
                elif x.tag in ['problemset', 'sequential', 'videosequence', 'proctor', 'randomize']:
                    seq_type = x.tag
                    path = [path[0], url_name]
                else:
                    path = path[:] + [str(seq_num)]      # note arrays are passed by reference, so copy, don't modify
                    
                # compute module_id
                if x.tag=='html':
                    module_id = '%s/%s/%s/%s' % (org, course, seq_type, '/'.join(path[1:3]))  # module_id which appears in tracking log
                else:
                    module_id = '%s/%s/%s/%s' % (org, course, x.tag, url_name)
                
                # debugging
                # print "     module %s gformat=%s" % (module_id, gformat)

                # done with getting all info for this axis element; save it
                path_str = '/' + '/'.join(path)
                ae = Axel(cid, index[0], url_name, x.tag, gformat, start, due, dn, path_str, module_id, data, chapter, graded,
                          parent_url_name,
                          not split_url_name==None,
                          split_url_name)
                caxis.append(ae)
                index[0] += 1
            else:
                if VERBOSE_WARNINGS:
                    if x.tag in ['transcript', 'wiki', 'metadata']:
                        pass
                    else:
                        logit("Missing url_name for element %s (attrib=%s, parent_tag=%s)" % (x, x.attrib, (parent.tag if parent is not None else '')))

            # chapter?
            if x.tag=='chapter':
                the_chapter = module_id
            else:
                the_chapter = chapter

            # done processing this element, now process all its children
            if (not x.tag in ['html', 'problem', 'discussion', 'customtag', 'poll_question', 'combinedopenended', 'metadata']):
                inherit_seq_num = (x.tag=='vertical' and not url_name)    # if <vertical> with no url_name then keep seq_num for children
                if not inherit_seq_num:
                    seq_num = 1
                for y in x:
                    if (not str(y).startswith('<!--')) and (not y.tag in ['discussion', 'source']):
                        if not split_url_name and x.tag=="split_test":
                            split_url_name = url_name
                                
                        walk(y, seq_num, path, seq_type, parent_start=start, parent=x, chapter=the_chapter,
                             parent_url_name=url_name,
                             split_url_name=split_url_name,
                        )
                        if not inherit_seq_num:
                            seq_num += 1
                
        walk(cxml)
        ret[cid] = dict(policy=policy.policy, 
                        bundle=bundle, 
                        axis=caxis, 
                        grading_policy=policy.grading_policy,
                        log_msg=log_msg,
                        )
    
    return ret

Example 30

Project: edx2bigquery
Source File: edx2course_axis.py
View license
def make_axis(dir):
    '''
    return dict of {course_id : { policy, xbundle, axis (as list of Axel elements) }}
    '''
    
    courses = []
    log_msg = []

    def logit(msg, nolog=False):
        if not nolog:
            log_msg.append(msg)
        print msg

    dir = path(dir)

    if os.path.exists(dir / 'roots'):	# if roots directory exists, use that for different course versions
        # get roots
        roots = glob.glob(dir / 'roots/*.xml')
        courses = [ CourseInfo(fn, '', dir) for fn in roots ]

    else:	# single course.xml file - use differnt policy files in policy directory, though

        fn = dir / 'course.xml'
    
        # get semesters
        policies = glob.glob(dir/'policies/*.json')
        assetsfn = dir / 'policies/assets.json'
        if str(assetsfn) in policies:
            policies.remove(assetsfn)
        if not policies:
            policies = glob.glob(dir/'policies/*/policy.json')
        if not policies:
            logit("Error: no policy files found!")
        
        courses = [ CourseInfo(fn, pfn) for pfn in policies ]


    logit("%d course runs found: %s" % (len(courses), [c.url_name for c in courses]))
    
    ret = {}

    # construct axis for each policy
    for cinfo in courses:
        policy = cinfo.policy
        semester = policy.semester
        org = cinfo.org
        course = cinfo.course
        cid = '%s/%s/%s' % (org, course, semester)
        logit('course_id=%s' %  cid)
    
        cfn = dir / ('course/%s.xml' % semester)
        
        # generate XBundle for course
        xml = etree.parse(cfn).getroot()
        xb = xbundle.XBundle(keep_urls=True, skip_hidden=True, keep_studio_urls=True)
        xb.policy = policy.policy
        cxml = xb.import_xml_removing_descriptor(dir, xml)

        # append metadata
        metadata = etree.Element('metadata')
        cxml.append(metadata)
        policy_xml = etree.Element('policy')
        metadata.append(policy_xml)
        policy_xml.text = json.dumps(policy.policy)
        grading_policy_xml = etree.Element('grading_policy')
        metadata.append(grading_policy_xml)
        grading_policy_xml.text = json.dumps(policy.grading_policy)
    
        bundle = etree.tostring(cxml, pretty_print=True)
        #print bundle[:500]
        index = [1]
        caxis = []
    
        def walk(x, seq_num=1, path=[], seq_type=None, parent_start=None, parent=None, chapter=None,
                 parent_url_name=None, split_url_name=None):
            '''
            Recursively traverse course tree.  
            
            x        = current etree element
            seq_num  = sequence of current element in its parent, starting from 1
            path     = list of url_name's to current element, following edX's hierarchy conventions
            seq_type = problemset, sequential, or videosequence
            parent_start = start date of parent of current etree element
            parent   = parent module
            chapter  = the last chapter module_id seen while walking through the tree
            parent_url_name = url_name of parent
            split_url_name   = url_name of split_test element if this subtree is in a split_test, otherwise None
            '''
            url_name = x.get('url_name',x.get('url_name_orig',''))
            if not url_name:
                dn = x.get('display_name')
                if dn is not None:
                    url_name = dn.strip().replace(' ','_')     # 2012 convention for converting display_name to url_name
                    url_name = url_name.replace(':','_')
                    url_name = url_name.replace('.','_')
                    url_name = url_name.replace('(','_').replace(')','_').replace('__','_')
            
            data = None
            start = None

            if not FORCE_NO_HIDE:
                hide = policy.get_metadata(x, 'hide_from_toc')
                if hide is not None and not hide=="false":
                    logit('[edx2course_axis] Skipping %s (%s), it has hide_from_toc=%s' % (x.tag, x.get('display_name','<noname>'), hide))
                    return

            if x.tag=='video':	# special: for video, let data = youtube ID(s)
                data = x.get('youtube','')
                if data:
                    # old ytid format - extract just the 1.0 part of this 
                    # 0.75:JdL1Vo0Hru0,1.0:lbaG3uiQ6IY,1.25:Lrj0G8RWHKw,1.50:54fs3-WxqLs
                    ytid = data.replace(' ','').split(',')
                    ytid = [z[1] for z in [y.split(':') for y in ytid] if z[0]=='1.0']
                    # print "   ytid: %s -> %s" % (x.get('youtube',''), ytid)
                    if ytid:
                        data = ytid
                if not data:
                    data = x.get('youtube_id_1_0', '')
                if data:
                    data = '{"ytid": "%s"}' % data

            if x.tag=="split_test":
                data = {}
                to_copy = ['group_id_to_child', 'user_partition_id']
                for tc in to_copy:
                    data[tc] = x.get(tc, None)

            if x.tag=='problem' and x.get('weight') is not None and x.get('weight'):
                try:
                    # Changed from string to dict. In next code block.
                    data = {"weight": "%f" % float(x.get('weight'))}
                except Exception as err:
                    logit("    Error converting weight %s" % x.get('weight'))

            ### Had a hard time making my code work within the try/except for weight. Happy to improve
            ### Also note, weight is typically missing in problems. So I find it weird that we throw an exception.
            if x.tag=='problem':
                # Initialize data if no weight
                if not data:
                    data = {}

                # meta will store all problem related metadata, then be used to update data
                meta = {}
                # Items is meant to help debug - an ordered list of encountered problem types with url names
                # Likely should not be pulled to Big Query 
                meta['items'] = []
                # Known Problem Types
                known_problem_types = ['multiplechoiceresponse','numericalresponse','choiceresponse',
                                       'optionresponse','stringresponse','formularesponse',
                                       'customresponse','fieldset']

                # Loop through all child nodes in a problem. If encountering a known problem type, add metadata.
                for a in x:
                    if a.tag in known_problem_types:
                        meta['items'].append({'itype':a.tag,'url_name':a.get('url_name')})

                ### Check for accompanying image
                images = x.findall('.//img')
                # meta['has_image'] = False
                
                if images and len(images)>0:
                    meta['has_image'] = True #Note, one can use a.get('src'), but needs to account for multiple images
                    # print meta['img'],len(images)

                ### Search for all solution tags in a problem
                solutions = x.findall('.//solution')
                # meta['has_solution'] = False

                if solutions and len(solutions)>0:
                    text = ''
                    for sol in solutions:
                        text = text.join(html.tostring(e, pretty_print=False) for e in sol)
                        # This if statment checks each solution. Note, many MITx problems have multiple solution tags.
                        # In 8.05x, common to put image in one solution tag, and the text in a second. So we are checking each tag.
                        # If we find one solution with > 65 char, or one solution with an image, we set meta['solution'] = True
                        if len(text) > 65 or 'img src' in text:
                            meta['has_solution'] = True

                ### If meta is empty, log all tags for debugging later. 
                if len(meta)==0:
                    logit('item type not found - here is the list of tags:['+','.join(a.tag if a else ' ' for a in x)+']')
                    # print 'problem type not found - here is the list of tags:['+','.join(a.tag for a in x)+']'

                ### Add easily accessible metadata for problems
                # num_items: number of items
                # itype: problem type - note, mixed is used when items are not of same type
                if len(meta['items']) > 0:
                    # Number of Items
                    meta['num_items'] = len(meta['items'])

                    # Problem Type
                    if all(meta['items'][0]['itype'] == item['itype'] for item in meta['items']):
                        meta['itype'] = meta['items'][0]['itype']
                        # print meta['items'][0]['itype']
                    else:
                        meta['itype'] = 'mixed'

                # Update data field
                ### ! For now, removing the items field. 
                del meta["items"]               

                data.update(meta)
                data = json.dumps(data)

            if x.tag=='html':
                iframe = x.find('.//iframe')
                if iframe is not None:
                    logit("   found iframe in html %s" % url_name)
                    src = iframe.get('src','')
                    if 'https://www.youtube.com/embed/' in src:
                        m = re.search('embed/([^"/?]+)', src)
                        if m:
                            data = '{"ytid": "%s"}' % m.group(1)
                            logit("    data=%s" % data)
                
            if url_name:              # url_name is mandatory if we are to do anything with this element
                # url_name = url_name.replace(':','_')
                dn = x.get('display_name', url_name)
                try:
                    #dn = dn.decode('utf-8')
                    dn = unicode(dn)
                    dn = fix_bad_unicode(dn)
                except Exception as err:
                    logit('unicode error, type(dn)=%s'  % type(dn))
                    raise
                pdn = policy.get_metadata(x, 'display_name')      # policy display_name - if given, let that override default
                if pdn is not None:
                    dn = pdn

                #start = date_parse(x.get('start', policy.get_metadata(x, 'start', '')))
                start = date_parse(policy.get_metadata(x, 'start', '', parent=True))
                
                if parent_start is not None and start < parent_start:
                    if VERBOSE_WARNINGS:
                        logit("    Warning: start of %s element %s happens before start %s of parent: using parent start" % (start, x.tag, parent_start), nolog=True)
                    start = parent_start
                #print "start for %s = %s" % (x, start)
                
                # drop bad due date strings
                if date_parse(x.get('due',None), retbad=True)=='Bad':
                    x.set('due', '')

                due = date_parse(policy.get_metadata(x, 'due', '', parent=True))
                if x.tag=="problem":
                    logit("    setting problem due date: for %s due=%s" % (url_name, due), nolog=True)

                gformat = x.get('format', policy.get_metadata(x, 'format', ''))
                if url_name=='hw0':
                    logit( "gformat for hw0 = %s" % gformat)

                graded = x.get('graded', policy.get_metadata(x, 'graded', ''))
                if not (type(graded) in [unicode, str]):
                    graded = str(graded)

                # compute path
                # The hierarchy goes: `course > chapter > (problemset | sequential | videosequence)`
                if x.tag=='chapter':
                    path = [url_name]
                elif x.tag in ['problemset', 'sequential', 'videosequence', 'proctor', 'randomize']:
                    seq_type = x.tag
                    path = [path[0], url_name]
                else:
                    path = path[:] + [str(seq_num)]      # note arrays are passed by reference, so copy, don't modify
                    
                # compute module_id
                if x.tag=='html':
                    module_id = '%s/%s/%s/%s' % (org, course, seq_type, '/'.join(path[1:3]))  # module_id which appears in tracking log
                else:
                    module_id = '%s/%s/%s/%s' % (org, course, x.tag, url_name)
                
                # debugging
                # print "     module %s gformat=%s" % (module_id, gformat)

                # done with getting all info for this axis element; save it
                path_str = '/' + '/'.join(path)
                ae = Axel(cid, index[0], url_name, x.tag, gformat, start, due, dn, path_str, module_id, data, chapter, graded,
                          parent_url_name,
                          not split_url_name==None,
                          split_url_name)
                caxis.append(ae)
                index[0] += 1
            else:
                if VERBOSE_WARNINGS:
                    if x.tag in ['transcript', 'wiki', 'metadata']:
                        pass
                    else:
                        logit("Missing url_name for element %s (attrib=%s, parent_tag=%s)" % (x, x.attrib, (parent.tag if parent is not None else '')))

            # chapter?
            if x.tag=='chapter':
                the_chapter = module_id
            else:
                the_chapter = chapter

            # done processing this element, now process all its children
            if (not x.tag in ['html', 'problem', 'discussion', 'customtag', 'poll_question', 'combinedopenended', 'metadata']):
                inherit_seq_num = (x.tag=='vertical' and not url_name)    # if <vertical> with no url_name then keep seq_num for children
                if not inherit_seq_num:
                    seq_num = 1
                for y in x:
                    if (not str(y).startswith('<!--')) and (not y.tag in ['discussion', 'source']):
                        if not split_url_name and x.tag=="split_test":
                            split_url_name = url_name
                                
                        walk(y, seq_num, path, seq_type, parent_start=start, parent=x, chapter=the_chapter,
                             parent_url_name=url_name,
                             split_url_name=split_url_name,
                        )
                        if not inherit_seq_num:
                            seq_num += 1
                
        walk(cxml)
        ret[cid] = dict(policy=policy.policy, 
                        bundle=bundle, 
                        axis=caxis, 
                        grading_policy=policy.grading_policy,
                        log_msg=log_msg,
                        )
    
    return ret

Example 31

Project: nansat
Source File: mapper_sentinel1_l1.py
View license
    def __init__(self, fileName, gdalDataset, gdalMetadata,
                 manifestonly=False, **kwargs):

        if zipfile.is_zipfile(fileName):
            zz = zipfile.PyZipFile(fileName)
            # Assuming the file names are consistent, the polarization
            # dependent data should be sorted equally such that we can use the
            # same indices consistently for all the following lists
            # THIS IS NOT THE CASE...
            mdsFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist() if 'measurement/s1' in fn]
            calFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist()
                        if 'annotation/calibration/calibration-s1' in fn]
            noiseFiles = ['/vsizip/%s/%s' % (fileName, fn)
                          for fn in zz.namelist()
                          if 'annotation/calibration/noise-s1' in fn]
            annotationFiles = ['/vsizip/%s/%s' % (fileName, fn)
                               for fn in zz.namelist()
                               if 'annotation/s1' in fn]
            manifestFile = ['/vsizip/%s/%s' % (fileName, fn)
                            for fn in zz.namelist()
                            if 'manifest.safe' in fn]
            zz.close()
        else:
            mdsFiles = glob.glob('%s/measurement/s1*' % fileName)
            calFiles = glob.glob('%s/annotation/calibration/calibration-s1*'
                                 % fileName)
            noiseFiles = glob.glob('%s/annotation/calibration/noise-s1*'
                                   % fileName)
            annotationFiles = glob.glob('%s/annotation/s1*'
                                        % fileName)
            manifestFile = glob.glob('%s/manifest.safe' % fileName)

        if (not mdsFiles or not calFiles or not noiseFiles or
                not annotationFiles or not manifestFile):
            raise WrongMapperError

        mdsDict = {}
        for ff in mdsFiles:
            mdsDict[
                os.path.splitext(os.path.basename(ff))[0].split('-')[3]] = ff

        self.calXMLDict = {}
        for ff in calFiles:
            self.calXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.noiseXMLDict = {}
        for ff in noiseFiles:
            self.noiseXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.annotationXMLDict = {}
        for ff in annotationFiles:
            self.annotationXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[3]] = self.read_xml(ff)

        self.manifestXML = self.read_xml(manifestFile[0])

        if not os.path.split(fileName)[1][:3] in ['S1A', 'S1B']:
            raise WrongMapperError('Not Sentinel 1A or 1B')

        missionName = {'S1A': 'SENTINEL-1A', 'S1B': 'SENTINEL-1B'}[
            os.path.split(fileName)[1][:3]]

        # very fast constructor without any bands
        if manifestonly:
            self.init_from_manifest_only(self.manifestXML,
                                         self.annotationXMLDict[
                                         self.annotationXMLDict.keys()[0]],
                                         missionName)
            return

        gdalDatasets = {}
        for key in mdsDict.keys():
            # Open data files
            gdalDatasets[key] = gdal.Open(mdsDict[key])

        if not gdalDatasets:
            raise WrongMapperError('No Sentinel-1 datasets found')

        # Check metadata to confirm it is Sentinel-1 L1
        metadata = gdalDatasets[mdsDict.keys()[0]].GetMetadata()
        
        if not 'TIFFTAG_IMAGEDESCRIPTION' in metadata.keys():
            raise WrongMapperError
        if (not 'Sentinel-1' in metadata['TIFFTAG_IMAGEDESCRIPTION']
                and not 'L1' in metadata['TIFFTAG_IMAGEDESCRIPTION']):
            raise WrongMapperError

        warnings.warn('Sentinel-1 level-1 mapper is not yet adapted to '
                      'complex data. In addition, the band names should be '
                      'updated for multi-swath data - '
                      'and there might be other issues.')

        # create empty VRT dataset with geolocation only
        for key in gdalDatasets:
            VRT.__init__(self, gdalDatasets[key])
            break

        # Read annotation, noise and calibration xml-files
        pol = {}
        it = 0
        for key in self.annotationXMLDict:
            xml = Node.create(self.annotationXMLDict[key])
            pol[key] = (xml.node('product').
                        node('adsHeader')['polarisation'].upper())
            it += 1
            if it == 1:
                # Get incidence angle
                pi = xml.node('generalAnnotation').node('productInformation')

                self.dataset.SetMetadataItem('ORBIT_DIRECTION',
                                              str(pi['pass']))
                (X, Y, lon, lat, inc, ele, numberOfSamples,
                numberOfLines) = self.read_geolocation_lut(
                                                self.annotationXMLDict[key])

                X = np.unique(X)
                Y = np.unique(Y)

                lon = np.array(lon).reshape(len(Y), len(X))
                lat = np.array(lat).reshape(len(Y), len(X))
                inc = np.array(inc).reshape(len(Y), len(X))
                ele = np.array(ele).reshape(len(Y), len(X))

                incVRT = VRT(array=inc, lat=lat, lon=lon)
                eleVRT = VRT(array=ele, lat=lat, lon=lon)
                incVRT = incVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                eleVRT = eleVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                self.bandVRTs['incVRT'] = incVRT
                self.bandVRTs['eleVRT'] = eleVRT

        for key in self.calXMLDict:
            calibration_LUT_VRTs, longitude, latitude = (
                self.get_LUT_VRTs(self.calXMLDict[key],
                                  'calibrationVectorList',
                                  ['sigmaNought', 'betaNought',
                                   'gamma', 'dn']
                                  ))
            self.bandVRTs['LUT_sigmaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['sigmaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_betaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['betaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_gamma_VRT'] = calibration_LUT_VRTs['gamma']
            self.bandVRTs['LUT_dn_VRT'] = calibration_LUT_VRTs['dn']

        for key in self.noiseXMLDict:
            noise_LUT_VRT = self.get_LUT_VRTs(self.noiseXMLDict[key],
                                              'noiseVectorList',
                                              ['noiseLut'])[0]
            self.bandVRTs['LUT_noise_VRT_'+pol[key]] = (
                noise_LUT_VRT['noiseLut'].get_resized_vrt(
                    self.dataset.RasterXSize,
                    self.dataset.RasterYSize,
                    eResampleAlg=1))

        metaDict = []
        bandNumberDict = {}
        bnmax = 0
        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'DN_%s' % pol[key]
            # A dictionary of band numbers is needed for the pixel function
            # bands further down. This is not the best solution. It would be
            # better to have a function in VRT that returns the number given a
            # band name. This function exists in Nansat but could perhaps be
            # moved to VRT? The existing nansat function could just call the
            # VRT one...
            bandNumberDict[name] = bnmax + 1
            bnmax = bandNumberDict[name]
            band = gdalDatasets[key].GetRasterBand(1)
            dtype = band.DataType
            metaDict.append({
                'src': {
                    'SourceFilename': mdsDict[key],
                    'SourceBand': 1,
                    'DataType': dtype,
                },
                'dst': {
                    'name': name,
                    #'SourceTransferType': gdal.GetDataTypeName(dtype),
                    #'dataType': 6,
                },
            })
        # add bands with metadata and corresponding values to the empty VRT
        self._create_bands(metaDict)

        '''
        Calibration should be performed as

        s0 = DN^2/sigmaNought^2,

        where sigmaNought is from e.g.
        annotation/calibration/calibration-s1a-iw-grd-hh-20140811t151231-20140811t151301-001894-001cc7-001.xml,
        and DN is the Digital Numbers in the tiff files.

        Also the noise should be subtracted.

        See
        https://sentinel.esa.int/web/sentinel/sentinel-1-sar-wiki/-/wiki/Sentinel%20One/Application+of+Radiometric+Calibration+LUT
        '''
        # Get look direction
        sat_heading = initial_bearing(longitude[:-1, :],
                                      latitude[:-1, :],
                                      longitude[1:, :],
                                      latitude[1:, :])
        look_direction = scipy.ndimage.interpolation.zoom(
            np.mod(sat_heading + 90, 360),
            (np.shape(longitude)[0] / (np.shape(longitude)[0]-1.), 1))

        # Decompose, to avoid interpolation errors around 0 <-> 360
        look_direction_u = np.sin(np.deg2rad(look_direction))
        look_direction_v = np.cos(np.deg2rad(look_direction))
        look_u_VRT = VRT(array=look_direction_u,
                         lat=latitude, lon=longitude)
        look_v_VRT = VRT(array=look_direction_v,
                         lat=latitude, lon=longitude)
        lookVRT = VRT(lat=latitude, lon=longitude)
        lookVRT._create_band([{'SourceFilename': look_u_VRT.fileName,
                               'SourceBand': 1},
                              {'SourceFilename': look_v_VRT.fileName,
                               'SourceBand': 1}],
                             {'PixelFunctionType': 'UVToDirectionTo'}
                             )

        # Blow up to full size
        lookVRT = lookVRT.get_resized_vrt(self.dataset.RasterXSize,
                                          self.dataset.RasterYSize,
                                          eResampleAlg=1)

        # Store VRTs so that they are accessible later
        self.bandVRTs['look_u_VRT'] = look_u_VRT
        self.bandVRTs['look_v_VRT'] = look_v_VRT
        self.bandVRTs['lookVRT'] = lookVRT

        metaDict = []
        # Add bands to full size VRT
        for key in pol:
            name = 'LUT_sigmaNought_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': {'SourceFilename':
                         (self.bandVRTs['LUT_sigmaNought_VRT_' +
                          pol[key]].fileName),
                         'SourceBand': 1
                         },
                 'dst': {'name': name
                         }
                 })
            name = 'LUT_noise_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append({
                'src': {
                    'SourceFilename': self.bandVRTs['LUT_noise_VRT_' +
                                                   pol[key]].fileName,
                    'SourceBand': 1
                },
                'dst': {
                    'name': name
                }
            })

        name = 'look_direction'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        metaDict.append({
            'src': {
                'SourceFilename': self.bandVRTs['lookVRT'].fileName,
                'SourceBand': 1
            },
            'dst': {
                'wkv': 'sensor_azimuth_angle',
                'name': name
            }
        })

        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'sigma0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]],
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_sigmaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })
            name = 'beta0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]]
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_betaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_brightness_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })

        self._create_bands(metaDict)

        # Add incidence angle as band
        name = 'incidence_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['incVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_incidence',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add elevation angle as band
        name = 'elevation_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['eleVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_elevation',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add sigma0_VV
        pp = [pol[key] for key in pol]
        if 'VV' not in pp and 'HH' in pp:
            name = 'sigma0_VV'
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            src = [{'SourceFilename': self.fileName,
                    'SourceBand': bandNumberDict['DN_HH'],
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_noise_VRT_HH'].
                                       fileName),
                    'SourceBand': 1
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_sigmaNought_VRT_HH'].
                                       fileName),
                    'SourceBand': 1,
                    },
                   {'SourceFilename': self.bandVRTs['incVRT'].fileName,
                    'SourceBand': 1}
                   ]
            dst = {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                   'PixelFunctionType': 'Sentinel1Sigma0HHToSigma0VV',
                   'polarization': 'VV',
                   'suffix': 'VV'}
            self._create_band(src, dst)
            self.dataset.FlushCache()

        # set time as acquisition start time
        n = Node.create(self.manifestXML)
        meta = n.node('metadataSection')
        for nn in meta.children:
            if nn.getAttribute('ID') == u'acquisitionPeriod':
                # set valid time
                self.dataset.SetMetadataItem(
                    'time_coverage_start',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:startTime'])
                          ).isoformat())
                self.dataset.SetMetadataItem(
                    'time_coverage_end',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:stopTime'])
                          ).isoformat())

        # Get dictionary describing the instrument and platform according to
        # the GCMD keywords
        mm = pti.get_gcmd_instrument('sar')
        ee = pti.get_gcmd_platform(missionName)

        # TODO: Validate that the found instrument and platform are indeed what we
        # want....

        self.dataset.SetMetadataItem('instrument', json.dumps(mm))
        self.dataset.SetMetadataItem('platform', json.dumps(ee))

Example 32

Project: nansat
Source File: mapper_sentinel1_l1.py
View license
    def __init__(self, fileName, gdalDataset, gdalMetadata,
                 manifestonly=False, **kwargs):

        if zipfile.is_zipfile(fileName):
            zz = zipfile.PyZipFile(fileName)
            # Assuming the file names are consistent, the polarization
            # dependent data should be sorted equally such that we can use the
            # same indices consistently for all the following lists
            # THIS IS NOT THE CASE...
            mdsFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist() if 'measurement/s1' in fn]
            calFiles = ['/vsizip/%s/%s' % (fileName, fn)
                        for fn in zz.namelist()
                        if 'annotation/calibration/calibration-s1' in fn]
            noiseFiles = ['/vsizip/%s/%s' % (fileName, fn)
                          for fn in zz.namelist()
                          if 'annotation/calibration/noise-s1' in fn]
            annotationFiles = ['/vsizip/%s/%s' % (fileName, fn)
                               for fn in zz.namelist()
                               if 'annotation/s1' in fn]
            manifestFile = ['/vsizip/%s/%s' % (fileName, fn)
                            for fn in zz.namelist()
                            if 'manifest.safe' in fn]
            zz.close()
        else:
            mdsFiles = glob.glob('%s/measurement/s1*' % fileName)
            calFiles = glob.glob('%s/annotation/calibration/calibration-s1*'
                                 % fileName)
            noiseFiles = glob.glob('%s/annotation/calibration/noise-s1*'
                                   % fileName)
            annotationFiles = glob.glob('%s/annotation/s1*'
                                        % fileName)
            manifestFile = glob.glob('%s/manifest.safe' % fileName)

        if (not mdsFiles or not calFiles or not noiseFiles or
                not annotationFiles or not manifestFile):
            raise WrongMapperError

        mdsDict = {}
        for ff in mdsFiles:
            mdsDict[
                os.path.splitext(os.path.basename(ff))[0].split('-')[3]] = ff

        self.calXMLDict = {}
        for ff in calFiles:
            self.calXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.noiseXMLDict = {}
        for ff in noiseFiles:
            self.noiseXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[4]] = self.read_xml(ff)

        self.annotationXMLDict = {}
        for ff in annotationFiles:
            self.annotationXMLDict[
                os.path.splitext(
                os.path.basename(ff))[0].split('-')[3]] = self.read_xml(ff)

        self.manifestXML = self.read_xml(manifestFile[0])

        if not os.path.split(fileName)[1][:3] in ['S1A', 'S1B']:
            raise WrongMapperError('Not Sentinel 1A or 1B')

        missionName = {'S1A': 'SENTINEL-1A', 'S1B': 'SENTINEL-1B'}[
            os.path.split(fileName)[1][:3]]

        # very fast constructor without any bands
        if manifestonly:
            self.init_from_manifest_only(self.manifestXML,
                                         self.annotationXMLDict[
                                         self.annotationXMLDict.keys()[0]],
                                         missionName)
            return

        gdalDatasets = {}
        for key in mdsDict.keys():
            # Open data files
            gdalDatasets[key] = gdal.Open(mdsDict[key])

        if not gdalDatasets:
            raise WrongMapperError('No Sentinel-1 datasets found')

        # Check metadata to confirm it is Sentinel-1 L1
        metadata = gdalDatasets[mdsDict.keys()[0]].GetMetadata()
        
        if not 'TIFFTAG_IMAGEDESCRIPTION' in metadata.keys():
            raise WrongMapperError
        if (not 'Sentinel-1' in metadata['TIFFTAG_IMAGEDESCRIPTION']
                and not 'L1' in metadata['TIFFTAG_IMAGEDESCRIPTION']):
            raise WrongMapperError

        warnings.warn('Sentinel-1 level-1 mapper is not yet adapted to '
                      'complex data. In addition, the band names should be '
                      'updated for multi-swath data - '
                      'and there might be other issues.')

        # create empty VRT dataset with geolocation only
        for key in gdalDatasets:
            VRT.__init__(self, gdalDatasets[key])
            break

        # Read annotation, noise and calibration xml-files
        pol = {}
        it = 0
        for key in self.annotationXMLDict:
            xml = Node.create(self.annotationXMLDict[key])
            pol[key] = (xml.node('product').
                        node('adsHeader')['polarisation'].upper())
            it += 1
            if it == 1:
                # Get incidence angle
                pi = xml.node('generalAnnotation').node('productInformation')

                self.dataset.SetMetadataItem('ORBIT_DIRECTION',
                                              str(pi['pass']))
                (X, Y, lon, lat, inc, ele, numberOfSamples,
                numberOfLines) = self.read_geolocation_lut(
                                                self.annotationXMLDict[key])

                X = np.unique(X)
                Y = np.unique(Y)

                lon = np.array(lon).reshape(len(Y), len(X))
                lat = np.array(lat).reshape(len(Y), len(X))
                inc = np.array(inc).reshape(len(Y), len(X))
                ele = np.array(ele).reshape(len(Y), len(X))

                incVRT = VRT(array=inc, lat=lat, lon=lon)
                eleVRT = VRT(array=ele, lat=lat, lon=lon)
                incVRT = incVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                eleVRT = eleVRT.get_resized_vrt(self.dataset.RasterXSize,
                                                self.dataset.RasterYSize,
                                                eResampleAlg=2)
                self.bandVRTs['incVRT'] = incVRT
                self.bandVRTs['eleVRT'] = eleVRT

        for key in self.calXMLDict:
            calibration_LUT_VRTs, longitude, latitude = (
                self.get_LUT_VRTs(self.calXMLDict[key],
                                  'calibrationVectorList',
                                  ['sigmaNought', 'betaNought',
                                   'gamma', 'dn']
                                  ))
            self.bandVRTs['LUT_sigmaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['sigmaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_betaNought_VRT_'+pol[key]] = (
                calibration_LUT_VRTs['betaNought'].
                get_resized_vrt(self.dataset.RasterXSize,
                                self.dataset.RasterYSize,
                                eResampleAlg=1))
            self.bandVRTs['LUT_gamma_VRT'] = calibration_LUT_VRTs['gamma']
            self.bandVRTs['LUT_dn_VRT'] = calibration_LUT_VRTs['dn']

        for key in self.noiseXMLDict:
            noise_LUT_VRT = self.get_LUT_VRTs(self.noiseXMLDict[key],
                                              'noiseVectorList',
                                              ['noiseLut'])[0]
            self.bandVRTs['LUT_noise_VRT_'+pol[key]] = (
                noise_LUT_VRT['noiseLut'].get_resized_vrt(
                    self.dataset.RasterXSize,
                    self.dataset.RasterYSize,
                    eResampleAlg=1))

        metaDict = []
        bandNumberDict = {}
        bnmax = 0
        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'DN_%s' % pol[key]
            # A dictionary of band numbers is needed for the pixel function
            # bands further down. This is not the best solution. It would be
            # better to have a function in VRT that returns the number given a
            # band name. This function exists in Nansat but could perhaps be
            # moved to VRT? The existing nansat function could just call the
            # VRT one...
            bandNumberDict[name] = bnmax + 1
            bnmax = bandNumberDict[name]
            band = gdalDatasets[key].GetRasterBand(1)
            dtype = band.DataType
            metaDict.append({
                'src': {
                    'SourceFilename': mdsDict[key],
                    'SourceBand': 1,
                    'DataType': dtype,
                },
                'dst': {
                    'name': name,
                    #'SourceTransferType': gdal.GetDataTypeName(dtype),
                    #'dataType': 6,
                },
            })
        # add bands with metadata and corresponding values to the empty VRT
        self._create_bands(metaDict)

        '''
        Calibration should be performed as

        s0 = DN^2/sigmaNought^2,

        where sigmaNought is from e.g.
        annotation/calibration/calibration-s1a-iw-grd-hh-20140811t151231-20140811t151301-001894-001cc7-001.xml,
        and DN is the Digital Numbers in the tiff files.

        Also the noise should be subtracted.

        See
        https://sentinel.esa.int/web/sentinel/sentinel-1-sar-wiki/-/wiki/Sentinel%20One/Application+of+Radiometric+Calibration+LUT
        '''
        # Get look direction
        sat_heading = initial_bearing(longitude[:-1, :],
                                      latitude[:-1, :],
                                      longitude[1:, :],
                                      latitude[1:, :])
        look_direction = scipy.ndimage.interpolation.zoom(
            np.mod(sat_heading + 90, 360),
            (np.shape(longitude)[0] / (np.shape(longitude)[0]-1.), 1))

        # Decompose, to avoid interpolation errors around 0 <-> 360
        look_direction_u = np.sin(np.deg2rad(look_direction))
        look_direction_v = np.cos(np.deg2rad(look_direction))
        look_u_VRT = VRT(array=look_direction_u,
                         lat=latitude, lon=longitude)
        look_v_VRT = VRT(array=look_direction_v,
                         lat=latitude, lon=longitude)
        lookVRT = VRT(lat=latitude, lon=longitude)
        lookVRT._create_band([{'SourceFilename': look_u_VRT.fileName,
                               'SourceBand': 1},
                              {'SourceFilename': look_v_VRT.fileName,
                               'SourceBand': 1}],
                             {'PixelFunctionType': 'UVToDirectionTo'}
                             )

        # Blow up to full size
        lookVRT = lookVRT.get_resized_vrt(self.dataset.RasterXSize,
                                          self.dataset.RasterYSize,
                                          eResampleAlg=1)

        # Store VRTs so that they are accessible later
        self.bandVRTs['look_u_VRT'] = look_u_VRT
        self.bandVRTs['look_v_VRT'] = look_v_VRT
        self.bandVRTs['lookVRT'] = lookVRT

        metaDict = []
        # Add bands to full size VRT
        for key in pol:
            name = 'LUT_sigmaNought_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': {'SourceFilename':
                         (self.bandVRTs['LUT_sigmaNought_VRT_' +
                          pol[key]].fileName),
                         'SourceBand': 1
                         },
                 'dst': {'name': name
                         }
                 })
            name = 'LUT_noise_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append({
                'src': {
                    'SourceFilename': self.bandVRTs['LUT_noise_VRT_' +
                                                   pol[key]].fileName,
                    'SourceBand': 1
                },
                'dst': {
                    'name': name
                }
            })

        name = 'look_direction'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        metaDict.append({
            'src': {
                'SourceFilename': self.bandVRTs['lookVRT'].fileName,
                'SourceBand': 1
            },
            'dst': {
                'wkv': 'sensor_azimuth_angle',
                'name': name
            }
        })

        for key in gdalDatasets.keys():
            dsPath, dsName = os.path.split(mdsDict[key])
            name = 'sigma0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]],
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_sigmaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })
            name = 'beta0_%s' % pol[key]
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            metaDict.append(
                {'src': [{'SourceFilename': self.fileName,
                          'SourceBand': bandNumberDict['DN_%s' % pol[key]]
                          },
                         {'SourceFilename':
                          (self.bandVRTs['LUT_betaNought_VRT_%s'
                           % pol[key]].fileName),
                          'SourceBand': 1
                          }
                         ],
                 'dst': {'wkv': 'surface_backwards_brightness_coefficient_of_radar_wave',
                         'PixelFunctionType': 'Sentinel1Calibration',
                         'polarization': pol[key],
                         'suffix': pol[key],
                         },
                 })

        self._create_bands(metaDict)

        # Add incidence angle as band
        name = 'incidence_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['incVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_incidence',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add elevation angle as band
        name = 'elevation_angle'
        bandNumberDict[name] = bnmax+1
        bnmax = bandNumberDict[name]
        src = {'SourceFilename': self.bandVRTs['eleVRT'].fileName,
               'SourceBand': 1}
        dst = {'wkv': 'angle_of_elevation',
               'name': name}
        self._create_band(src, dst)
        self.dataset.FlushCache()

        # Add sigma0_VV
        pp = [pol[key] for key in pol]
        if 'VV' not in pp and 'HH' in pp:
            name = 'sigma0_VV'
            bandNumberDict[name] = bnmax+1
            bnmax = bandNumberDict[name]
            src = [{'SourceFilename': self.fileName,
                    'SourceBand': bandNumberDict['DN_HH'],
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_noise_VRT_HH'].
                                       fileName),
                    'SourceBand': 1
                    },
                   {'SourceFilename': (self.bandVRTs['LUT_sigmaNought_VRT_HH'].
                                       fileName),
                    'SourceBand': 1,
                    },
                   {'SourceFilename': self.bandVRTs['incVRT'].fileName,
                    'SourceBand': 1}
                   ]
            dst = {'wkv': 'surface_backwards_scattering_coefficient_of_radar_wave',
                   'PixelFunctionType': 'Sentinel1Sigma0HHToSigma0VV',
                   'polarization': 'VV',
                   'suffix': 'VV'}
            self._create_band(src, dst)
            self.dataset.FlushCache()

        # set time as acquisition start time
        n = Node.create(self.manifestXML)
        meta = n.node('metadataSection')
        for nn in meta.children:
            if nn.getAttribute('ID') == u'acquisitionPeriod':
                # set valid time
                self.dataset.SetMetadataItem(
                    'time_coverage_start',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:startTime'])
                          ).isoformat())
                self.dataset.SetMetadataItem(
                    'time_coverage_end',
                    parse((nn.node('metadataWrap').
                           node('xmlData').
                           node('safe:acquisitionPeriod')['safe:stopTime'])
                          ).isoformat())

        # Get dictionary describing the instrument and platform according to
        # the GCMD keywords
        mm = pti.get_gcmd_instrument('sar')
        ee = pti.get_gcmd_platform(missionName)

        # TODO: Validate that the found instrument and platform are indeed what we
        # want....

        self.dataset.SetMetadataItem('instrument', json.dumps(mm))
        self.dataset.SetMetadataItem('platform', json.dumps(ee))

Example 33

View license
def main():
    results_folder = param_default.results_folder
    methods_to_display = param_default.methods_to_display
    noise_std_to_display = param_default.noise_std_to_display
    tracts_std_to_display = param_default.tracts_std_to_display
    csf_value_to_display = param_default.csf_value_to_display
    nb_RL_labels = param_default.nb_RL_labels

    # Parameters for debug mode
    if param_default.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        results_folder = "/Users/slevy_local/spinalcordtoolbox/dev/atlas/validate_atlas/results_20150210_200iter"#"C:/cygwin64/home/Simon_2/data_methods_comparison"
        path_sct = '/Users/slevy_local/spinalcordtoolbox' #'C:/cygwin64/home/Simon_2/spinalcordtoolbox'
    else:
        status, path_sct = commands.getstatusoutput('echo $SCT_DIR')

        # Check input parameters
        try:
            opts, args = getopt.getopt(sys.argv[1:], 'i:m:')  # define flags
        except getopt.GetoptError as err:  # check if the arguments are defined
            print str(err)  # error
            # usage() # display usage
        # if not opts:
        #     print 'Please enter the path to the result folder. Exit program.'
        #     sys.exit(1)
        #     # usage()
        for opt, arg in opts:  # explore flags
            if opt in '-i':
                results_folder = arg
            if opt in '-m':
                methods_to_display = arg

    # Append path that contains scripts, to be able to load modules
    sys.path.append(path_sct + '/scripts')
    import sct_utils as sct
    import isct_get_fractional_volume

    sct.printv("Working directory: " + os.getcwd())

    results_folder_noise = results_folder + '/noise'
    results_folder_tracts = results_folder + '/tracts'
    results_folder_csf = results_folder + '/csf'

    sct.printv('\n\nData will be extracted from folder ' + results_folder_noise + ' , ' + results_folder_tracts + ' and ' + results_folder_csf + '.', 'warning')
    sct.printv('\t\tCheck existence...')
    sct.check_folder_exist(results_folder_noise)
    sct.check_folder_exist(results_folder_tracts)
    sct.check_folder_exist(results_folder_csf)

    # Extract methods to display
    methods_to_display = methods_to_display.strip().split(',')

    # Extract file names of the results files
    fname_results_noise = glob.glob(results_folder_noise + '/*.txt')
    fname_results_tracts = glob.glob(results_folder_tracts + '/*.txt')
    fname_results_csf = glob.glob(results_folder_csf + '/*.txt')
    fname_results = fname_results_noise + fname_results_tracts + fname_results_csf
    # Remove doublons (due to the two folders)
    # for i_fname in range(0, len(fname_results)):
    #     for j_fname in range(0, len(fname_results)):
    #         if (i_fname != j_fname) & (os.path.basename(fname_results[i_fname]) == os.path.basename(fname_results[j_fname])):
    #             fname_results.remove(fname_results[j_fname])
    file_results = []
    for fname in fname_results:
        file_results.append(os.path.basename(fname))
    for file in file_results:
        if file_results.count(file) > 1:
            ind = file_results.index(file)
            fname_results.remove(fname_results[ind])
            file_results.remove(file)

    nb_results_file = len(fname_results)

    # 1st dim: SNR, 2nd dim: tract std, 3rd dim: mean abs error, 4th dim: std abs error
    # result_array = numpy.empty((nb_results_file, nb_results_file, 3), dtype=object)
    # SNR
    snr = numpy.zeros((nb_results_file))
    # Tracts std
    tracts_std = numpy.zeros((nb_results_file))
    # CSF value
    csf_values = numpy.zeros((nb_results_file))
    # methods' name
    methods_name = []  #numpy.empty((nb_results_file, nb_method), dtype=object)
    # labels
    error_per_label = []
    std_per_label = []
    labels_id = []
    # median
    median_results = numpy.zeros((nb_results_file, 5))
    # median std across bootstraps
    median_std = numpy.zeros((nb_results_file, 5))
    # min
    min_results = numpy.zeros((nb_results_file, 5))
    # max
    max_results = numpy.zeros((nb_results_file, 5))

    #
    for i_file in range(0, nb_results_file):

        # Open file
        f = open(fname_results[i_file])  # open file
        # Extract all lines in .txt file
        lines = [line for line in f.readlines() if line.strip()]

        # extract SNR
        # find all index of lines containing the string "sigma noise"
        ind_line_noise = [lines.index(line_noise) for line_noise in lines if "sigma noise" in line_noise]
        if len(ind_line_noise) != 1:
            sct.printv("ERROR: number of lines including \"sigma noise\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # result_array[:, i_file, i_file] = int(''.join(c for c in lines[ind_line_noise[0]] if c.isdigit()))
            snr[i_file] = int(''.join(c for c in lines[ind_line_noise[0]] if c.isdigit()))

        # extract tract std
        ind_line_tract_std = [lines.index(line_tract_std) for line_tract_std in lines if
                              "range tracts" in line_tract_std]
        if len(ind_line_tract_std) != 1:
            sct.printv("ERROR: number of lines including \"range tracts\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # result_array[i_file, i_file, :] = int(''.join(c for c in lines[ind_line_tract_std[0]].split(':')[1] if c.isdigit()))
            # regex = re.compile(''('(.*)':)  # re.I permet d'ignorer la case (majuscule/minuscule)
            # match = regex.search(lines[ind_line_tract_std[0]])
            # result_array[:, i_file, :, :] = match.group(1)  # le groupe 1 correspond a '.*'
            tracts_std[i_file] = int(''.join(c for c in lines[ind_line_tract_std[0]].split(':')[1] if c.isdigit()))

        # extract CSF value
        ind_line_csf_value = [lines.index(line_csf_value) for line_csf_value in lines if
                              "# value CSF" in line_csf_value]
        if len(ind_line_csf_value) != 1:
            sct.printv("ERROR: number of lines including \"range tracts\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # result_array[i_file, i_file, :] = int(''.join(c for c in lines[ind_line_tract_std[0]].split(':')[1] if c.isdigit()))
            # regex = re.compile(''('(.*)':)  # re.I permet d'ignorer la case (majuscule/minuscule)
            # match = regex.search(lines[ind_line_tract_std[0]])
            # result_array[:, i_file, :, :] = match.group(1)  # le groupe 1 correspond a '.*'
            csf_values[i_file] = int(''.join(c for c in lines[ind_line_csf_value[0]].split(':')[1] if c.isdigit()))


        # extract method name
        ind_line_label = [lines.index(line_label) for line_label in lines if "Label" in line_label]
        if len(ind_line_label) != 1:
            sct.printv("ERROR: number of lines including \"Label\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # methods_name[i_file, :] = numpy.array(lines[ind_line_label[0]].strip().split(',')[1:])
            methods_name.append(lines[ind_line_label[0]].strip().replace(' ', '').split(',')[1:])

        # extract median
        ind_line_median = [lines.index(line_median) for line_median in lines if "median" in line_median]
        if len(ind_line_median) != 1:
            sct.printv("WARNING: number of lines including \"median\" is different from 1. Exit program.", 'warning')
            # sys.exit(1)
        else:
            median = lines[ind_line_median[0]].strip().split(',')[1:]
            # result_array[i_file, i_file, 0] = [float(m.split('(')[0]) for m in median]
            median_results[i_file, :] = numpy.array([float(m.split('(')[0]) for m in median])
            median_std[i_file, :] = numpy.array([float(m.split('(')[1][:-1]) for m in median])

        # extract min
        ind_line_min = [lines.index(line_min) for line_min in lines if "min," in line_min]
        if len(ind_line_min) != 1:
            sct.printv("WARNING: number of lines including \"min\" is different from 1. Exit program.", 'warning')
            # sys.exit(1)
        else:
            min = lines[ind_line_min[0]].strip().split(',')[1:]
            # result_array[i_file, i_file, 1] = [float(m.split('(')[0]) for m in min]
            min_results[i_file, :] = numpy.array([float(m.split('(')[0]) for m in min])

        # extract max
        ind_line_max = [lines.index(line_max) for line_max in lines if "max" in line_max]
        if len(ind_line_max) != 1:
            sct.printv("WARNING: number of lines including \"max\" is different from 1. Exit program.", 'warning')
            # sys.exit(1)
        else:
            max = lines[ind_line_max[0]].strip().split(',')[1:]
            # result_array[i_file, i_file, 1] = [float(m.split('(')[0]) for m in max]
            max_results[i_file, :] = numpy.array([float(m.split('(')[0]) for m in max])

        # extract error for each label
        error_per_label_for_file_i = []
        std_per_label_for_file_i = []
        labels_id_for_file_i = []
        # Due to 2 different kind of file structure, the number of the last label line must be adapted
        if not ind_line_median:
            ind_line_median = [len(lines) + 1]
        for i_line in range(ind_line_label[0] + 1, ind_line_median[0] - 1):
            line_label_i = lines[i_line].strip().split(',')
            error_per_label_for_file_i.append([float(error.strip().split('(')[0]) for error in line_label_i[1:]])
            std_per_label_for_file_i.append([float(error.strip().split('(')[1][:-1]) for error in line_label_i[1:]])
            labels_id_for_file_i.append(int(line_label_i[0]))
        error_per_label.append(error_per_label_for_file_i)
        std_per_label.append(std_per_label_for_file_i)
        labels_id.append(labels_id_for_file_i)

        # close file
        f.close()

    # check if all the files in the result folder were generated with the same number of methods
    if not all(x == methods_name[0] for x in methods_name):
        sct.printv(
            'ERROR: All the generated files in folder ' + results_folder + ' have not been generated with the same number of methods. Exit program.',
            'error')
        sys.exit(1)
    # check if all the files in the result folder were generated with the same labels
    if not all(x == labels_id[0] for x in labels_id):
        sct.printv(
            'ERROR: All the generated files in folder ' + results_folder + ' have not been generated with the same labels. Exit program.',
            'error')
        sys.exit(1)

    # convert the list "error_per_label" into a numpy array to ease further manipulations
    error_per_label = numpy.array(error_per_label)
    std_per_label = numpy.array(std_per_label)
    # compute different stats
    abs_error_per_labels = numpy.absolute(error_per_label)
    max_abs_error_per_meth = numpy.amax(abs_error_per_labels, axis=1)
    min_abs_error_per_meth = numpy.amin(abs_error_per_labels, axis=1)
    mean_abs_error_per_meth = numpy.mean(abs_error_per_labels, axis=1)
    std_abs_error_per_meth = numpy.std(abs_error_per_labels, axis=1)

    # average error and std across sides
    meanRL_abs_error_per_labels = numpy.zeros((error_per_label.shape[0], nb_RL_labels, error_per_label.shape[2]))
    meanRL_std_abs_error_per_labels = numpy.zeros((std_per_label.shape[0], nb_RL_labels, std_per_label.shape[2]))
    for i_file in range(0, nb_results_file):
        for i_meth in range(0, len(methods_name[i_file])):
            for i_label in range(0, nb_RL_labels):
                # find indexes of corresponding labels
                ind_ID_first_side = labels_id[i_file].index(i_label)
                ind_ID_other_side = labels_id[i_file].index(i_label + nb_RL_labels)
                # compute mean across 2 sides
                meanRL_abs_error_per_labels[i_file, i_label, i_meth] = float(error_per_label[i_file, ind_ID_first_side, i_meth] + error_per_label[i_file, ind_ID_other_side, i_meth]) / 2
                meanRL_std_abs_error_per_labels[i_file, i_label, i_meth] = float(std_per_label[i_file, ind_ID_first_side, i_meth] + std_per_label[i_file, ind_ID_other_side, i_meth]) / 2

    nb_method = len(methods_to_display)

    sct.printv('Noise std of the ' + str(nb_results_file) + ' generated files:')
    print snr
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Tracts std of the ' + str(nb_results_file) + ' generated files:')
    print tracts_std
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('CSF value of the ' + str(nb_results_file) + ' generated files:')
    print csf_values
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Methods used to generate results for the ' + str(nb_results_file) + ' generated files:')
    print methods_name
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Median obtained with each method (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print median_results
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Minimum obtained with each method (in colons) for the ' + str(
        nb_results_file) + ' generated files (in lines):')
    print min_results
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Maximum obtained with each method (in colons) for the ' + str(
        nb_results_file) + ' generated files (in lines):')
    print max_results
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Labels\' ID (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print labels_id
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Errors obtained with each method (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print error_per_label
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Mean errors across both sides obtained with each method (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print meanRL_abs_error_per_labels


    # Compute fractional volume per label
    labels_id_FV, labels_name_FV, fract_vol_per_lab, labels_name_FV_RL_gathered, fract_vol_per_lab_RL_gathered = isct_get_fractional_volume.get_fractional_volume_per_label('./cropped_atlas/', 'info_label.txt')
    # # Get the number of voxels including at least one tract
    # nb_voxels_in_WM = isct_get_fractional_volume.get_nb_voxel_in_WM('./cropped_atlas/', 'info_label.txt')
    # normalize by the number of voxels in WM and express it as a percentage
    fract_vol_norm = numpy.divide(fract_vol_per_lab_RL_gathered, numpy.sum(fract_vol_per_lab_RL_gathered)/100)

    # NOT NECESSARY NOW WE AVERAGE ACROSS BOTH SIDES (which orders the labels)
    # # check if the order of the labels returned by the function computing the fractional volumes is the same (which should be the case)
    # if labels_id_FV != labels_id[0]:
    #     sct.printv('\n\nERROR: the labels IDs returned by the function \'i_sct_get_fractional_volume\' are different from the labels IDs of the results files\n\n', 'error')

    # # Remove labels #30 and #31
    # labels_id_FV_29, labels_name_FV_29, fract_vol_per_lab_29 = labels_id_FV[:-2], labels_name_FV[:-2], fract_vol_per_lab[:-2]

    # indexes of labels sort according to the fractional volume
    ind_labels_sort = numpy.argsort(fract_vol_norm)

    # Find index of the file generated with noise variance = 10 and tracts std = 10
    ind_file_to_display = numpy.where((snr == noise_std_to_display) & (tracts_std == tracts_std_to_display) & (csf_values == csf_value_to_display))

    # sort arrays in this order
    meanRL_abs_error_per_labels_sort = meanRL_abs_error_per_labels[ind_file_to_display[0], ind_labels_sort, :]
    meanRL_std_abs_error_per_labels_sort = meanRL_std_abs_error_per_labels[ind_file_to_display[0], ind_labels_sort, :]
    labels_name_sort = numpy.array(labels_name_FV_RL_gathered)[ind_labels_sort]

    # *********************************************** START PLOTTING HERE **********************************************

    # stringColor = Color()
    matplotlib.rcParams.update({'font.size': 50, 'font.family': 'trebuchet'})
    # plt.rcParams['xtick.major.pad'] = '11'
    plt.rcParams['ytick.major.pad'] = '15'

    fig = plt.figure(figsize=(60, 37))
    width = 1.0 / (nb_method + 1)
    ind_fig = numpy.arange(len(labels_name_sort)) * (1.0 + width)
    plt.ylabel('Absolute error (%)\n', fontsize=65)
    plt.xlabel('Fractional volume (% of the total number of voxels in WM)', fontsize=65)
    plt.title('Absolute error per tract as a function of their fractional volume\n\n', fontsize=30)
    plt.suptitle('(Noise std='+str(snr[ind_file_to_display[0]][0])+', Tracts std='+str(tracts_std[ind_file_to_display[0]][0])+', CSF value='+str(csf_values[ind_file_to_display[0]][0])+')', fontsize=30)

    # colors = plt.get_cmap('jet')(np.linspace(0, 1.0, nb_method))
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
    markers = ['o', 's', '^', 'D']
    errorbar_plots = []
    for meth, color, marker in zip(methods_to_display, colors, markers):
        i_meth = methods_name[0].index(meth)
        i_meth_to_display = methods_to_display.index(meth)

        plot_i = plt.errorbar(ind_fig + i_meth_to_display * width, meanRL_abs_error_per_labels_sort[:, i_meth], meanRL_std_abs_error_per_labels_sort[:, i_meth], color=color, marker=marker, markersize=35, lw=7, elinewidth=1, capthick=5, capsize=10)
        # plot_i = plt.boxplot(numpy.transpose(abs_error_per_labels[ind_files_csf_sort, :, i_meth]), positions=ind_fig + i_meth_to_display * width + (float(i_meth_to_display) * width) / (nb_method + 1), widths=width, boxprops=boxprops, medianprops=medianprops, flierprops=flierprops, whiskerprops=whiskerprops, capprops=capprops)
        errorbar_plots.append(plot_i)

    # add alternated vertical background colored bars
    for i_xtick in range(0, len(ind_fig), 2):
        plt.axvspan(ind_fig[i_xtick] - width - width / 2, ind_fig[i_xtick] + (nb_method + 1) * width - width / 2, facecolor='grey', alpha=0.1)

    # concatenate value of fractional volume to labels'name
    xtick_labels = [labels_name_sort[i_lab]+'\n'+r'$\bf{['+str(round(fract_vol_norm[ind_labels_sort][i_lab], 2))+']}$' for i_lab in range(0, len(labels_name_sort))]
    ind_lemniscus = numpy.where(labels_name_sort == 'spinal lemniscus (spinothalamic and spinoreticular tracts)')[0][0]
    xtick_labels[ind_lemniscus] = 'spinal lemniscus\n'+r'$\bf{['+str(round(fract_vol_norm[ind_labels_sort][ind_lemniscus], 2))+']}$'

    # plt.legend(box_plots, methods_to_display, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
    plt.legend(errorbar_plots, methods_to_display, loc=1, fontsize=50, numpoints=1)
    plt.xticks(ind_fig + (numpy.floor(float(nb_method-1)/2)) * width, xtick_labels, fontsize=45)
    # Tweak spacing to prevent clipping of tick-labels
    plt.subplots_adjust(bottom=0, top=0.95, right=0.96)
    plt.gca().set_xlim([-width, numpy.max(ind_fig) + (nb_method + 0.5) * width])
    plt.gca().set_ylim([0, 17])
    plt.gca().yaxis.set_major_locator(plt.MultipleLocator(1.0))
    plt.gca().yaxis.set_minor_locator(plt.MultipleLocator(0.5))
    plt.grid(b=True, axis='y', which='both')
    fig.autofmt_xdate()

    plt.savefig(param_default.fname_folder_to_save_fig+'/absolute_error_vs_fractional_volume.pdf', format='PDF')

    plt.show(block=False)

Example 34

View license
def main():

    # Initialization
    fname_anat = ''
    fname_point = ''
    slice_gap = param.gap
    remove_tmp_files = param.remove_tmp_files
    gaussian_kernel = param.gaussian_kernel
    start_time = time.time()
    verbose = 1

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    path_sct = sct.slash_at_the_end(path_sct, 1)

    # Parameters for debug mode
    if param.debug == 1:
        sct.printv('\n*** WARNING: DEBUG MODE ON ***\n\t\t\tCurrent working directory: '+os.getcwd(), 'warning')
        status, path_sct_testing_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_testing_data+'/t2/t2.nii.gz'
        fname_point = path_sct_testing_data+'/t2/t2_centerline_init.nii.gz'
        slice_gap = 5

    else:
        # Check input param
        try:
            opts, args = getopt.getopt(sys.argv[1:],'hi:p:g:r:k:')
        except getopt.GetoptError as err:
            print str(err)
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_anat = arg
            elif opt in ('-p'):
                fname_point = arg
            elif opt in ('-g'):
                slice_gap = int(arg)
            elif opt in ('-r'):
                remove_tmp_files = int(arg)
            elif opt in ('-k'):
                gaussian_kernel = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_point == '':
        usage()

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_point)

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_point, file_point, ext_point = sct.extract_fname(fname_point)

    # extract path of schedule file
    # TODO: include schedule file in sct
    # TODO: check existence of schedule file
    file_schedule = path_sct + param.schedule_file

    # Get input image orientation
    input_image_orientation = get_orientation(fname_anat)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Anatomical image:     '+fname_anat
    print '  Orientation:          '+input_image_orientation
    print '  Point in spinal cord: '+fname_point
    print '  Slice gap:            '+str(slice_gap)
    print '  Gaussian kernel:      '+str(gaussian_kernel)
    print '  Degree of polynomial: '+str(param.deg_poly)

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp '+fname_anat+ ' '+path_tmp+'/tmp.anat'+ext_anat)
    sct.run('cp '+fname_point+ ' '+path_tmp+'/tmp.point'+ext_point)

    # go to temporary folder
    os.chdir(path_tmp)

    # convert to nii
    convert('tmp.anat'+ext_anat, 'tmp.anat.nii')
    convert('tmp.point'+ext_point, 'tmp.point.nii')

    # Reorient input anatomical volume into RL PA IS orientation
    print '\nReorient input volume to RL PA IS orientation...'
    #sct.run(sct.fsloutput + 'fslswapdim tmp.anat RL PA IS tmp.anat_orient')
    set_orientation('tmp.anat.nii', 'RPI', 'tmp.anat_orient.nii')
    # Reorient binary point into RL PA IS orientation
    print '\nReorient binary point into RL PA IS orientation...'
    # sct.run(sct.fsloutput + 'fslswapdim tmp.point RL PA IS tmp.point_orient')
    set_orientation('tmp.point.nii', 'RPI', 'tmp.point_orient.nii')

    # Get image dimensions
    print '\nGet image dimensions...'
    nx, ny, nz, nt, px, py, pz, pt = Image('tmp.anat_orient.nii').dim
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'

    # Split input volume
    print '\nSplit input volume...'
    split_data('tmp.anat_orient.nii', 2, '_z')
    file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    split_data('tmp.point_orient.nii', 2, '_z')
    file_point_split = ['tmp.point_orient_z'+str(z).zfill(4) for z in range(0, nz, 1)]

    # Extract coordinates of input point
    # sct.printv('\nExtract the slice corresponding to z='+str(z_init)+'...', verbose)
    #
    data_point = Image('tmp.point_orient.nii').data
    x_init, y_init, z_init = unravel_index(data_point.argmax(), data_point.shape)
    sct.printv('Coordinates of input point: ('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')', verbose)

    # Create 2D gaussian mask
    sct.printv('\nCreate gaussian mask from point...', verbose)
    xx, yy = mgrid[:nx, :ny]
    mask2d = zeros((nx, ny))
    radius = round(float(gaussian_kernel+1)/2)  # add 1 because the radius includes the center.
    sigma = float(radius)
    mask2d = exp(-(((xx-x_init)**2)/(2*(sigma**2)) + ((yy-y_init)**2)/(2*(sigma**2))))

    # Save mask to 2d file
    file_mask_split = ['tmp.mask_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    nii_mask2d = Image('tmp.anat_orient_z0000.nii')
    nii_mask2d.data = mask2d
    nii_mask2d.setFileName(file_mask_split[z_init]+'.nii')
    nii_mask2d.save()
    #
    # # Get the coordinates of the input point
    # print '\nGet the coordinates of the input point...'
    # data_point = Image('tmp.point_orient.nii').data
    # x_init, y_init, z_init = unravel_index(data_point.argmax(), data_point.shape)
    # print '('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')'

    # x_init, y_init, z_init = (data > 0).nonzero()
    # x_init = x_init[0]
    # y_init = y_init[0]
    # z_init = z_init[0]
    # print '('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')'
    #
    # numpy.unravel_index(a.argmax(), a.shape)
    #
    # file = nibabel.load('tmp.point_orient.nii')
    # data = file.get_data()
    # x_init, y_init, z_init = (data > 0).nonzero()
    # x_init = x_init[0]
    # y_init = y_init[0]
    # z_init = z_init[0]
    # print '('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')'
    #
    # # Extract the slice corresponding to z=z_init
    # print '\nExtract the slice corresponding to z='+str(z_init)+'...'
    # file_point_split = ['tmp.point_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    # nii = Image('tmp.point_orient.nii')
    # data_crop = nii.data[:, :, z_init:z_init+1]
    # nii.data = data_crop
    # nii.setFileName(file_point_split[z_init]+'.nii')
    # nii.save()
    #
    # # Create gaussian mask from point
    # print '\nCreate gaussian mask from point...'
    # file_mask_split = ['tmp.mask_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    # sct.run(sct.fsloutput+'fslmaths '+file_point_split[z_init]+' -s '+str(gaussian_kernel)+' '+file_mask_split[z_init])
    #
    # # Obtain max value from mask
    # print '\nFind maximum value from mask...'
    # file = nibabel.load(file_mask_split[z_init]+'.nii')
    # data = file.get_data()
    # max_value_mask = numpy.max(data)
    # print '..'+str(max_value_mask)
    #
    # # Normalize mask beween 0 and 1
    # print '\nNormalize mask beween 0 and 1...'
    # sct.run(sct.fsloutput+'fslmaths '+file_mask_split[z_init]+' -div '+str(max_value_mask)+' '+file_mask_split[z_init])

    ## Take the square of the mask
    #print '\nCalculate the square of the mask...'
    #sct.run(sct.fsloutput+'fslmaths '+file_mask_split[z_init]+' -mul '+file_mask_split[z_init]+' '+file_mask_split[z_init])

    # initialize variables
    file_mat = ['tmp.mat_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_inv = ['tmp.mat_inv_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_inv_cumul = ['tmp.mat_inv_cumul_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # create identity matrix for initial transformation matrix
    fid = open(file_mat_inv_cumul[z_init], 'w')
    fid.write('%i %i %i %i\n' %(1, 0, 0, 0) )
    fid.write('%i %i %i %i\n' %(0, 1, 0, 0) )
    fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
    fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
    fid.close()

    # initialize centerline: give value corresponding to initial point
    x_centerline = [x_init]
    y_centerline = [y_init]
    z_centerline = [z_init]
    warning_count = 0

    # go up (1), then down (2) in reference to the binary point
    for iUpDown in range(1, 3):

        if iUpDown == 1:
            # z increases
            slice_gap_signed = slice_gap
        elif iUpDown == 2:
            # z decreases
            slice_gap_signed = -slice_gap
            # reverse centerline (because values will be appended at the end)
            x_centerline.reverse()
            y_centerline.reverse()
            z_centerline.reverse()

        # initialization before looping
        z_dest = z_init # point given by user
        z_src = z_dest + slice_gap_signed

        # continue looping if 0 < z < nz
        while 0 <= z_src and z_src <= nz-1:

            # print current z:
            print 'z='+str(z_src)+':'

            # estimate transformation
            sct.run(fsloutput+'flirt -in '+file_anat_split[z_src]+' -ref '+file_anat_split[z_dest]+' -schedule '+file_schedule+ ' -verbose 0 -omat '+file_mat[z_src]+' -cost normcorr -forcescaling -inweight '+file_mask_split[z_dest]+' -refweight '+file_mask_split[z_dest])

            # display transfo
            status, output = sct.run('cat '+file_mat[z_src])
            print output

            # check if transformation is bigger than 1.5x slice_gap
            tx = float(output.split()[3])
            ty = float(output.split()[7])
            norm_txy = linalg.norm([tx, ty],ord=2)
            if norm_txy > 1.5*slice_gap:
                print 'WARNING: Transformation is too large --> using previous one.'
                warning_count = warning_count + 1
                # if previous transformation exists, replace current one with previous one
                if os.path.isfile(file_mat[z_dest]):
                    sct.run('cp '+file_mat[z_dest]+' '+file_mat[z_src])

            # estimate inverse transformation matrix
            sct.run('convert_xfm -omat '+file_mat_inv[z_src]+' -inverse '+file_mat[z_src])

            # compute cumulative transformation
            sct.run('convert_xfm -omat '+file_mat_inv_cumul[z_src]+' -concat '+file_mat_inv[z_src]+' '+file_mat_inv_cumul[z_dest])

            # apply inverse cumulative transformation to initial gaussian mask (to put it in src space)
            sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul[z_src]+' -out '+file_mask_split[z_src])

            # open inverse cumulative transformation file and generate centerline
            fid = open(file_mat_inv_cumul[z_src])
            mat = fid.read().split()
            x_centerline.append(x_init + float(mat[3]))
            y_centerline.append(y_init + float(mat[7]))
            z_centerline.append(z_src)
            #z_index = z_index+1

            # define new z_dest (target slice) and new z_src (moving slice)
            z_dest = z_dest + slice_gap_signed
            z_src = z_src + slice_gap_signed


    # Reconstruct centerline
    # ====================================================================================================

    # reverse back centerline (because it's been reversed once, so now all values are in the right order)
    x_centerline.reverse()
    y_centerline.reverse()
    z_centerline.reverse()

    # fit centerline in the Z-X plane using polynomial function
    print '\nFit centerline in the Z-X plane using polynomial function...'
    coeffsx = polyfit(z_centerline, x_centerline, deg=param.deg_poly)
    polyx = poly1d(coeffsx)
    x_centerline_fit = polyval(polyx, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(x_centerline_fit-x_centerline)/sqrt( len(x_centerline) )
    # calculate max absolute error
    max_abs = max( abs(x_centerline_fit-x_centerline) )
    print '.. RMSE (in mm): '+str(rmse*px)
    print '.. Maximum absolute error (in mm): '+str(max_abs*px)

    # fit centerline in the Z-Y plane using polynomial function
    print '\nFit centerline in the Z-Y plane using polynomial function...'
    coeffsy = polyfit(z_centerline, y_centerline, deg=param.deg_poly)
    polyy = poly1d(coeffsy)
    y_centerline_fit = polyval(polyy, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(y_centerline_fit-y_centerline)/sqrt( len(y_centerline) )
    # calculate max absolute error
    max_abs = max( abs(y_centerline_fit-y_centerline) )
    print '.. RMSE (in mm): '+str(rmse*py)
    print '.. Maximum absolute error (in mm): '+str(max_abs*py)

    # display
    if param.debug == 1:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(z_centerline,x_centerline,'.',z_centerline,x_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-X plane polynomial interpolation')
        plt.show()

        plt.figure()
        plt.plot(z_centerline,y_centerline,'.',z_centerline,y_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-Y plane polynomial interpolation')
        plt.show()

    # generate full range z-values for centerline
    z_centerline_full = [iz for iz in range(0, nz, 1)]

    # calculate X and Y values for the full centerline
    x_centerline_fit_full = polyval(polyx, z_centerline_full)
    y_centerline_fit_full = polyval(polyy, z_centerline_full)

    # Generate fitted transformation matrices and write centerline coordinates in text file
    print '\nGenerate fitted transformation matrices and write centerline coordinates in text file...'
    file_mat_inv_cumul_fit = ['tmp.mat_inv_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_cumul_fit = ['tmp.mat_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    fid_centerline = open('tmp.centerline_coordinates.txt', 'w')
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, x_centerline_fit_full[iz]-x_init) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, y_centerline_fit_full[iz]-y_init) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()
        # compute forward cumulative fitted transformation matrix
        sct.run('convert_xfm -omat '+file_mat_cumul_fit[iz]+' -inverse '+file_mat_inv_cumul_fit[iz])
        # write centerline coordinates in x, y, z format
        fid_centerline.write('%f %f %f\n' %(x_centerline_fit_full[iz], y_centerline_fit_full[iz], z_centerline_full[iz]) )
    fid_centerline.close()


    # Prepare output data
    # ====================================================================================================

    # write centerline as text file
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, x_centerline_fit_full[iz]-x_init) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, y_centerline_fit_full[iz]-y_init) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()

    # write polynomial coefficients
    savetxt('tmp.centerline_polycoeffs_x.txt',coeffsx)
    savetxt('tmp.centerline_polycoeffs_y.txt',coeffsy)

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = ['tmp.anat_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mask_split_fit = ['tmp.mask_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_point_split_fit = ['tmp.point_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput+'flirt -in '+file_anat_split[iz]+' -ref '+file_anat_split[iz]+' -applyxfm -init '+file_mat_cumul_fit[iz]+' -out '+file_anat_split_fit[iz])
        # inverse cumulative transformation to mask
        sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_mask_split_fit[iz])
        # inverse cumulative transformation to point
        sct.run(fsloutput+'flirt -in '+file_point_split[z_init]+' -ref '+file_point_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_point_split_fit[iz]+' -interp nearestneighbour')

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    # sct.run(fsloutput+'fslmerge -z tmp.anat_orient_fit tmp.anat_orient_fit_z*')
    # sct.run(fsloutput+'fslmerge -z tmp.mask_orient_fit tmp.mask_orient_fit_z*')
    # sct.run(fsloutput+'fslmerge -z tmp.point_orient_fit tmp.point_orient_fit_z*')
    concat_data(glob.glob('tmp.anat_orient_fit_z*.nii'), 'tmp.anat_orient_fit.nii', dim=2)
    concat_data(glob.glob('tmp.mask_orient_fit_z*.nii'), 'tmp.mask_orient_fit.nii', dim=2)
    concat_data(glob.glob('tmp.point_orient_fit_z*.nii'), 'tmp.point_orient_fit.nii', dim=2)

    # Copy header geometry from input data
    print '\nCopy header geometry from input data...'
    copy_header('tmp.anat_orient.nii', 'tmp.anat_orient_fit.nii')
    copy_header('tmp.anat_orient.nii', 'tmp.mask_orient_fit.nii')
    copy_header('tmp.anat_orient.nii', 'tmp.point_orient_fit.nii')

    # Reorient outputs into the initial orientation of the input image
    print '\nReorient the centerline into the initial orientation of the input image...'
    set_orientation('tmp.point_orient_fit.nii', input_image_orientation, 'tmp.point_orient_fit.nii')
    set_orientation('tmp.mask_orient_fit.nii', input_image_orientation, 'tmp.mask_orient_fit.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    os.chdir('..')  # come back to parent folder
    #sct.generate_output_file('tmp.centerline_polycoeffs_x.txt','./','centerline_polycoeffs_x','.txt')
    #sct.generate_output_file('tmp.centerline_polycoeffs_y.txt','./','centerline_polycoeffs_y','.txt')
    #sct.generate_output_file('tmp.centerline_coordinates.txt','./','centerline_coordinates','.txt')
    #sct.generate_output_file('tmp.anat_orient.nii','./',file_anat+'_rpi',ext_anat)
    #sct.generate_output_file('tmp.anat_orient_fit.nii', file_anat+'_rpi_align'+ext_anat)
    #sct.generate_output_file('tmp.mask_orient_fit.nii', file_anat+'_mask'+ext_anat)
    fname_output_centerline = sct.generate_output_file(path_tmp+'/tmp.point_orient_fit.nii', file_anat+'_centerline'+ext_anat)

    # Delete temporary files
    if remove_tmp_files == 1:
        print '\nRemove temporary files...'
        sct.run('rm -rf '+path_tmp)

    # print number of warnings
    print '\nNumber of warnings: '+str(warning_count)+' (if >10, you should probably reduce the gap and/or increase the kernel size'

    # display elapsed time
    elapsed_time = time.time() - start_time
    print '\nFinished! \n\tGenerated file: '+fname_output_centerline+'\n\tElapsed time: '+str(int(round(elapsed_time)))+'s\n'

Example 35

View license
def main():
    results_folder = param_default.results_folder
    methods_to_display = param_default.methods_to_display
    noise_std_to_display = param_default.noise_std_to_display
    tracts_std_to_display = param_default.tracts_std_to_display
    csf_value_to_display = param_default.csf_value_to_display
    nb_RL_labels = param_default.nb_RL_labels

    # Parameters for debug mode
    if param_default.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        results_folder = "/Users/slevy_local/spinalcordtoolbox/dev/atlas/validate_atlas/results_20150210_200iter"#"C:/cygwin64/home/Simon_2/data_methods_comparison"
        path_sct = '/Users/slevy_local/spinalcordtoolbox' #'C:/cygwin64/home/Simon_2/spinalcordtoolbox'
    else:
        status, path_sct = commands.getstatusoutput('echo $SCT_DIR')

        # Check input parameters
        try:
            opts, args = getopt.getopt(sys.argv[1:], 'i:m:')  # define flags
        except getopt.GetoptError as err:  # check if the arguments are defined
            print str(err)  # error
            # usage() # display usage
        # if not opts:
        #     print 'Please enter the path to the result folder. Exit program.'
        #     sys.exit(1)
        #     # usage()
        for opt, arg in opts:  # explore flags
            if opt in '-i':
                results_folder = arg
            if opt in '-m':
                methods_to_display = arg

    # Append path that contains scripts, to be able to load modules
    sys.path.append(path_sct + '/scripts')
    import sct_utils as sct
    import isct_get_fractional_volume

    sct.printv("Working directory: " + os.getcwd())

    results_folder_noise = results_folder + '/noise'
    results_folder_tracts = results_folder + '/tracts'
    results_folder_csf = results_folder + '/csf'

    sct.printv('\n\nData will be extracted from folder ' + results_folder_noise + ' , ' + results_folder_tracts + ' and ' + results_folder_csf + '.', 'warning')
    sct.printv('\t\tCheck existence...')
    sct.check_folder_exist(results_folder_noise)
    sct.check_folder_exist(results_folder_tracts)
    sct.check_folder_exist(results_folder_csf)

    # Extract methods to display
    methods_to_display = methods_to_display.strip().split(',')

    # Extract file names of the results files
    fname_results_noise = glob.glob(results_folder_noise + '/*.txt')
    fname_results_tracts = glob.glob(results_folder_tracts + '/*.txt')
    fname_results_csf = glob.glob(results_folder_csf + '/*.txt')
    fname_results = fname_results_noise + fname_results_tracts + fname_results_csf
    # Remove doublons (due to the two folders)
    # for i_fname in range(0, len(fname_results)):
    #     for j_fname in range(0, len(fname_results)):
    #         if (i_fname != j_fname) & (os.path.basename(fname_results[i_fname]) == os.path.basename(fname_results[j_fname])):
    #             fname_results.remove(fname_results[j_fname])
    file_results = []
    for fname in fname_results:
        file_results.append(os.path.basename(fname))
    for file in file_results:
        if file_results.count(file) > 1:
            ind = file_results.index(file)
            fname_results.remove(fname_results[ind])
            file_results.remove(file)

    nb_results_file = len(fname_results)

    # 1st dim: SNR, 2nd dim: tract std, 3rd dim: mean abs error, 4th dim: std abs error
    # result_array = numpy.empty((nb_results_file, nb_results_file, 3), dtype=object)
    # SNR
    snr = numpy.zeros((nb_results_file))
    # Tracts std
    tracts_std = numpy.zeros((nb_results_file))
    # CSF value
    csf_values = numpy.zeros((nb_results_file))
    # methods' name
    methods_name = []  #numpy.empty((nb_results_file, nb_method), dtype=object)
    # labels
    error_per_label = []
    std_per_label = []
    labels_id = []
    # median
    median_results = numpy.zeros((nb_results_file, 5))
    # median std across bootstraps
    median_std = numpy.zeros((nb_results_file, 5))
    # min
    min_results = numpy.zeros((nb_results_file, 5))
    # max
    max_results = numpy.zeros((nb_results_file, 5))

    #
    for i_file in range(0, nb_results_file):

        # Open file
        f = open(fname_results[i_file])  # open file
        # Extract all lines in .txt file
        lines = [line for line in f.readlines() if line.strip()]

        # extract SNR
        # find all index of lines containing the string "sigma noise"
        ind_line_noise = [lines.index(line_noise) for line_noise in lines if "sigma noise" in line_noise]
        if len(ind_line_noise) != 1:
            sct.printv("ERROR: number of lines including \"sigma noise\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # result_array[:, i_file, i_file] = int(''.join(c for c in lines[ind_line_noise[0]] if c.isdigit()))
            snr[i_file] = int(''.join(c for c in lines[ind_line_noise[0]] if c.isdigit()))

        # extract tract std
        ind_line_tract_std = [lines.index(line_tract_std) for line_tract_std in lines if
                              "range tracts" in line_tract_std]
        if len(ind_line_tract_std) != 1:
            sct.printv("ERROR: number of lines including \"range tracts\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # result_array[i_file, i_file, :] = int(''.join(c for c in lines[ind_line_tract_std[0]].split(':')[1] if c.isdigit()))
            # regex = re.compile(''('(.*)':)  # re.I permet d'ignorer la case (majuscule/minuscule)
            # match = regex.search(lines[ind_line_tract_std[0]])
            # result_array[:, i_file, :, :] = match.group(1)  # le groupe 1 correspond a '.*'
            tracts_std[i_file] = int(''.join(c for c in lines[ind_line_tract_std[0]].split(':')[1] if c.isdigit()))

        # extract CSF value
        ind_line_csf_value = [lines.index(line_csf_value) for line_csf_value in lines if
                              "# value CSF" in line_csf_value]
        if len(ind_line_csf_value) != 1:
            sct.printv("ERROR: number of lines including \"range tracts\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # result_array[i_file, i_file, :] = int(''.join(c for c in lines[ind_line_tract_std[0]].split(':')[1] if c.isdigit()))
            # regex = re.compile(''('(.*)':)  # re.I permet d'ignorer la case (majuscule/minuscule)
            # match = regex.search(lines[ind_line_tract_std[0]])
            # result_array[:, i_file, :, :] = match.group(1)  # le groupe 1 correspond a '.*'
            csf_values[i_file] = int(''.join(c for c in lines[ind_line_csf_value[0]].split(':')[1] if c.isdigit()))


        # extract method name
        ind_line_label = [lines.index(line_label) for line_label in lines if "Label" in line_label]
        if len(ind_line_label) != 1:
            sct.printv("ERROR: number of lines including \"Label\" is different from 1. Exit program.", 'error')
            sys.exit(1)
        else:
            # methods_name[i_file, :] = numpy.array(lines[ind_line_label[0]].strip().split(',')[1:])
            methods_name.append(lines[ind_line_label[0]].strip().replace(' ', '').split(',')[1:])

        # extract median
        ind_line_median = [lines.index(line_median) for line_median in lines if "median" in line_median]
        if len(ind_line_median) != 1:
            sct.printv("WARNING: number of lines including \"median\" is different from 1. Exit program.", 'warning')
            # sys.exit(1)
        else:
            median = lines[ind_line_median[0]].strip().split(',')[1:]
            # result_array[i_file, i_file, 0] = [float(m.split('(')[0]) for m in median]
            median_results[i_file, :] = numpy.array([float(m.split('(')[0]) for m in median])
            median_std[i_file, :] = numpy.array([float(m.split('(')[1][:-1]) for m in median])

        # extract min
        ind_line_min = [lines.index(line_min) for line_min in lines if "min," in line_min]
        if len(ind_line_min) != 1:
            sct.printv("WARNING: number of lines including \"min\" is different from 1. Exit program.", 'warning')
            # sys.exit(1)
        else:
            min = lines[ind_line_min[0]].strip().split(',')[1:]
            # result_array[i_file, i_file, 1] = [float(m.split('(')[0]) for m in min]
            min_results[i_file, :] = numpy.array([float(m.split('(')[0]) for m in min])

        # extract max
        ind_line_max = [lines.index(line_max) for line_max in lines if "max" in line_max]
        if len(ind_line_max) != 1:
            sct.printv("WARNING: number of lines including \"max\" is different from 1. Exit program.", 'warning')
            # sys.exit(1)
        else:
            max = lines[ind_line_max[0]].strip().split(',')[1:]
            # result_array[i_file, i_file, 1] = [float(m.split('(')[0]) for m in max]
            max_results[i_file, :] = numpy.array([float(m.split('(')[0]) for m in max])

        # extract error for each label
        error_per_label_for_file_i = []
        std_per_label_for_file_i = []
        labels_id_for_file_i = []
        # Due to 2 different kind of file structure, the number of the last label line must be adapted
        if not ind_line_median:
            ind_line_median = [len(lines) + 1]
        for i_line in range(ind_line_label[0] + 1, ind_line_median[0] - 1):
            line_label_i = lines[i_line].strip().split(',')
            error_per_label_for_file_i.append([float(error.strip().split('(')[0]) for error in line_label_i[1:]])
            std_per_label_for_file_i.append([float(error.strip().split('(')[1][:-1]) for error in line_label_i[1:]])
            labels_id_for_file_i.append(int(line_label_i[0]))
        error_per_label.append(error_per_label_for_file_i)
        std_per_label.append(std_per_label_for_file_i)
        labels_id.append(labels_id_for_file_i)

        # close file
        f.close()

    # check if all the files in the result folder were generated with the same number of methods
    if not all(x == methods_name[0] for x in methods_name):
        sct.printv(
            'ERROR: All the generated files in folder ' + results_folder + ' have not been generated with the same number of methods. Exit program.',
            'error')
        sys.exit(1)
    # check if all the files in the result folder were generated with the same labels
    if not all(x == labels_id[0] for x in labels_id):
        sct.printv(
            'ERROR: All the generated files in folder ' + results_folder + ' have not been generated with the same labels. Exit program.',
            'error')
        sys.exit(1)

    # convert the list "error_per_label" into a numpy array to ease further manipulations
    error_per_label = numpy.array(error_per_label)
    std_per_label = numpy.array(std_per_label)
    # compute different stats
    abs_error_per_labels = numpy.absolute(error_per_label)
    max_abs_error_per_meth = numpy.amax(abs_error_per_labels, axis=1)
    min_abs_error_per_meth = numpy.amin(abs_error_per_labels, axis=1)
    mean_abs_error_per_meth = numpy.mean(abs_error_per_labels, axis=1)
    std_abs_error_per_meth = numpy.std(abs_error_per_labels, axis=1)

    # average error and std across sides
    meanRL_abs_error_per_labels = numpy.zeros((error_per_label.shape[0], nb_RL_labels, error_per_label.shape[2]))
    meanRL_std_abs_error_per_labels = numpy.zeros((std_per_label.shape[0], nb_RL_labels, std_per_label.shape[2]))
    for i_file in range(0, nb_results_file):
        for i_meth in range(0, len(methods_name[i_file])):
            for i_label in range(0, nb_RL_labels):
                # find indexes of corresponding labels
                ind_ID_first_side = labels_id[i_file].index(i_label)
                ind_ID_other_side = labels_id[i_file].index(i_label + nb_RL_labels)
                # compute mean across 2 sides
                meanRL_abs_error_per_labels[i_file, i_label, i_meth] = float(error_per_label[i_file, ind_ID_first_side, i_meth] + error_per_label[i_file, ind_ID_other_side, i_meth]) / 2
                meanRL_std_abs_error_per_labels[i_file, i_label, i_meth] = float(std_per_label[i_file, ind_ID_first_side, i_meth] + std_per_label[i_file, ind_ID_other_side, i_meth]) / 2

    nb_method = len(methods_to_display)

    sct.printv('Noise std of the ' + str(nb_results_file) + ' generated files:')
    print snr
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Tracts std of the ' + str(nb_results_file) + ' generated files:')
    print tracts_std
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('CSF value of the ' + str(nb_results_file) + ' generated files:')
    print csf_values
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Methods used to generate results for the ' + str(nb_results_file) + ' generated files:')
    print methods_name
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Median obtained with each method (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print median_results
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Minimum obtained with each method (in colons) for the ' + str(
        nb_results_file) + ' generated files (in lines):')
    print min_results
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Maximum obtained with each method (in colons) for the ' + str(
        nb_results_file) + ' generated files (in lines):')
    print max_results
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Labels\' ID (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print labels_id
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Errors obtained with each method (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print error_per_label
    print '----------------------------------------------------------------------------------------------------------------'
    sct.printv('Mean errors across both sides obtained with each method (in colons) for the ' + str(nb_results_file) + ' generated files (in lines):')
    print meanRL_abs_error_per_labels


    # Compute fractional volume per label
    labels_id_FV, labels_name_FV, fract_vol_per_lab, labels_name_FV_RL_gathered, fract_vol_per_lab_RL_gathered = isct_get_fractional_volume.get_fractional_volume_per_label('./cropped_atlas/', 'info_label.txt')
    # # Get the number of voxels including at least one tract
    # nb_voxels_in_WM = isct_get_fractional_volume.get_nb_voxel_in_WM('./cropped_atlas/', 'info_label.txt')
    # normalize by the number of voxels in WM and express it as a percentage
    fract_vol_norm = numpy.divide(fract_vol_per_lab_RL_gathered, numpy.sum(fract_vol_per_lab_RL_gathered)/100)

    # NOT NECESSARY NOW WE AVERAGE ACROSS BOTH SIDES (which orders the labels)
    # # check if the order of the labels returned by the function computing the fractional volumes is the same (which should be the case)
    # if labels_id_FV != labels_id[0]:
    #     sct.printv('\n\nERROR: the labels IDs returned by the function \'i_sct_get_fractional_volume\' are different from the labels IDs of the results files\n\n', 'error')

    # # Remove labels #30 and #31
    # labels_id_FV_29, labels_name_FV_29, fract_vol_per_lab_29 = labels_id_FV[:-2], labels_name_FV[:-2], fract_vol_per_lab[:-2]

    # indexes of labels sort according to the fractional volume
    ind_labels_sort = numpy.argsort(fract_vol_norm)

    # Find index of the file generated with noise variance = 10 and tracts std = 10
    ind_file_to_display = numpy.where((snr == noise_std_to_display) & (tracts_std == tracts_std_to_display) & (csf_values == csf_value_to_display))

    # sort arrays in this order
    meanRL_abs_error_per_labels_sort = meanRL_abs_error_per_labels[ind_file_to_display[0], ind_labels_sort, :]
    meanRL_std_abs_error_per_labels_sort = meanRL_std_abs_error_per_labels[ind_file_to_display[0], ind_labels_sort, :]
    labels_name_sort = numpy.array(labels_name_FV_RL_gathered)[ind_labels_sort]

    # *********************************************** START PLOTTING HERE **********************************************

    # stringColor = Color()
    matplotlib.rcParams.update({'font.size': 50, 'font.family': 'trebuchet'})
    # plt.rcParams['xtick.major.pad'] = '11'
    plt.rcParams['ytick.major.pad'] = '15'

    fig = plt.figure(figsize=(60, 37))
    width = 1.0 / (nb_method + 1)
    ind_fig = numpy.arange(len(labels_name_sort)) * (1.0 + width)
    plt.ylabel('Absolute error (%)\n', fontsize=65)
    plt.xlabel('Fractional volume (% of the total number of voxels in WM)', fontsize=65)
    plt.title('Absolute error per tract as a function of their fractional volume\n\n', fontsize=30)
    plt.suptitle('(Noise std='+str(snr[ind_file_to_display[0]][0])+', Tracts std='+str(tracts_std[ind_file_to_display[0]][0])+', CSF value='+str(csf_values[ind_file_to_display[0]][0])+')', fontsize=30)

    # colors = plt.get_cmap('jet')(np.linspace(0, 1.0, nb_method))
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
    markers = ['o', 's', '^', 'D']
    errorbar_plots = []
    for meth, color, marker in zip(methods_to_display, colors, markers):
        i_meth = methods_name[0].index(meth)
        i_meth_to_display = methods_to_display.index(meth)

        plot_i = plt.errorbar(ind_fig + i_meth_to_display * width, meanRL_abs_error_per_labels_sort[:, i_meth], meanRL_std_abs_error_per_labels_sort[:, i_meth], color=color, marker=marker, markersize=35, lw=7, elinewidth=1, capthick=5, capsize=10)
        # plot_i = plt.boxplot(numpy.transpose(abs_error_per_labels[ind_files_csf_sort, :, i_meth]), positions=ind_fig + i_meth_to_display * width + (float(i_meth_to_display) * width) / (nb_method + 1), widths=width, boxprops=boxprops, medianprops=medianprops, flierprops=flierprops, whiskerprops=whiskerprops, capprops=capprops)
        errorbar_plots.append(plot_i)

    # add alternated vertical background colored bars
    for i_xtick in range(0, len(ind_fig), 2):
        plt.axvspan(ind_fig[i_xtick] - width - width / 2, ind_fig[i_xtick] + (nb_method + 1) * width - width / 2, facecolor='grey', alpha=0.1)

    # concatenate value of fractional volume to labels'name
    xtick_labels = [labels_name_sort[i_lab]+'\n'+r'$\bf{['+str(round(fract_vol_norm[ind_labels_sort][i_lab], 2))+']}$' for i_lab in range(0, len(labels_name_sort))]
    ind_lemniscus = numpy.where(labels_name_sort == 'spinal lemniscus (spinothalamic and spinoreticular tracts)')[0][0]
    xtick_labels[ind_lemniscus] = 'spinal lemniscus\n'+r'$\bf{['+str(round(fract_vol_norm[ind_labels_sort][ind_lemniscus], 2))+']}$'

    # plt.legend(box_plots, methods_to_display, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
    plt.legend(errorbar_plots, methods_to_display, loc=1, fontsize=50, numpoints=1)
    plt.xticks(ind_fig + (numpy.floor(float(nb_method-1)/2)) * width, xtick_labels, fontsize=45)
    # Tweak spacing to prevent clipping of tick-labels
    plt.subplots_adjust(bottom=0, top=0.95, right=0.96)
    plt.gca().set_xlim([-width, numpy.max(ind_fig) + (nb_method + 0.5) * width])
    plt.gca().set_ylim([0, 17])
    plt.gca().yaxis.set_major_locator(plt.MultipleLocator(1.0))
    plt.gca().yaxis.set_minor_locator(plt.MultipleLocator(0.5))
    plt.grid(b=True, axis='y', which='both')
    fig.autofmt_xdate()

    plt.savefig(param_default.fname_folder_to_save_fig+'/absolute_error_vs_fractional_volume.pdf', format='PDF')

    plt.show(block=False)

Example 36

View license
def main():

    # Initialization
    fname_anat = ''
    fname_point = ''
    slice_gap = param.gap
    remove_tmp_files = param.remove_tmp_files
    gaussian_kernel = param.gaussian_kernel
    start_time = time.time()
    verbose = 1

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    path_sct = sct.slash_at_the_end(path_sct, 1)

    # Parameters for debug mode
    if param.debug == 1:
        sct.printv('\n*** WARNING: DEBUG MODE ON ***\n\t\t\tCurrent working directory: '+os.getcwd(), 'warning')
        status, path_sct_testing_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_testing_data+'/t2/t2.nii.gz'
        fname_point = path_sct_testing_data+'/t2/t2_centerline_init.nii.gz'
        slice_gap = 5

    else:
        # Check input param
        try:
            opts, args = getopt.getopt(sys.argv[1:],'hi:p:g:r:k:')
        except getopt.GetoptError as err:
            print str(err)
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_anat = arg
            elif opt in ('-p'):
                fname_point = arg
            elif opt in ('-g'):
                slice_gap = int(arg)
            elif opt in ('-r'):
                remove_tmp_files = int(arg)
            elif opt in ('-k'):
                gaussian_kernel = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_point == '':
        usage()

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_point)

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_point, file_point, ext_point = sct.extract_fname(fname_point)

    # extract path of schedule file
    # TODO: include schedule file in sct
    # TODO: check existence of schedule file
    file_schedule = path_sct + param.schedule_file

    # Get input image orientation
    input_image_orientation = get_orientation(fname_anat)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Anatomical image:     '+fname_anat
    print '  Orientation:          '+input_image_orientation
    print '  Point in spinal cord: '+fname_point
    print '  Slice gap:            '+str(slice_gap)
    print '  Gaussian kernel:      '+str(gaussian_kernel)
    print '  Degree of polynomial: '+str(param.deg_poly)

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp '+fname_anat+ ' '+path_tmp+'/tmp.anat'+ext_anat)
    sct.run('cp '+fname_point+ ' '+path_tmp+'/tmp.point'+ext_point)

    # go to temporary folder
    os.chdir(path_tmp)

    # convert to nii
    convert('tmp.anat'+ext_anat, 'tmp.anat.nii')
    convert('tmp.point'+ext_point, 'tmp.point.nii')

    # Reorient input anatomical volume into RL PA IS orientation
    print '\nReorient input volume to RL PA IS orientation...'
    #sct.run(sct.fsloutput + 'fslswapdim tmp.anat RL PA IS tmp.anat_orient')
    set_orientation('tmp.anat.nii', 'RPI', 'tmp.anat_orient.nii')
    # Reorient binary point into RL PA IS orientation
    print '\nReorient binary point into RL PA IS orientation...'
    # sct.run(sct.fsloutput + 'fslswapdim tmp.point RL PA IS tmp.point_orient')
    set_orientation('tmp.point.nii', 'RPI', 'tmp.point_orient.nii')

    # Get image dimensions
    print '\nGet image dimensions...'
    nx, ny, nz, nt, px, py, pz, pt = Image('tmp.anat_orient.nii').dim
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'

    # Split input volume
    print '\nSplit input volume...'
    split_data('tmp.anat_orient.nii', 2, '_z')
    file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    split_data('tmp.point_orient.nii', 2, '_z')
    file_point_split = ['tmp.point_orient_z'+str(z).zfill(4) for z in range(0, nz, 1)]

    # Extract coordinates of input point
    # sct.printv('\nExtract the slice corresponding to z='+str(z_init)+'...', verbose)
    #
    data_point = Image('tmp.point_orient.nii').data
    x_init, y_init, z_init = unravel_index(data_point.argmax(), data_point.shape)
    sct.printv('Coordinates of input point: ('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')', verbose)

    # Create 2D gaussian mask
    sct.printv('\nCreate gaussian mask from point...', verbose)
    xx, yy = mgrid[:nx, :ny]
    mask2d = zeros((nx, ny))
    radius = round(float(gaussian_kernel+1)/2)  # add 1 because the radius includes the center.
    sigma = float(radius)
    mask2d = exp(-(((xx-x_init)**2)/(2*(sigma**2)) + ((yy-y_init)**2)/(2*(sigma**2))))

    # Save mask to 2d file
    file_mask_split = ['tmp.mask_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    nii_mask2d = Image('tmp.anat_orient_z0000.nii')
    nii_mask2d.data = mask2d
    nii_mask2d.setFileName(file_mask_split[z_init]+'.nii')
    nii_mask2d.save()
    #
    # # Get the coordinates of the input point
    # print '\nGet the coordinates of the input point...'
    # data_point = Image('tmp.point_orient.nii').data
    # x_init, y_init, z_init = unravel_index(data_point.argmax(), data_point.shape)
    # print '('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')'

    # x_init, y_init, z_init = (data > 0).nonzero()
    # x_init = x_init[0]
    # y_init = y_init[0]
    # z_init = z_init[0]
    # print '('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')'
    #
    # numpy.unravel_index(a.argmax(), a.shape)
    #
    # file = nibabel.load('tmp.point_orient.nii')
    # data = file.get_data()
    # x_init, y_init, z_init = (data > 0).nonzero()
    # x_init = x_init[0]
    # y_init = y_init[0]
    # z_init = z_init[0]
    # print '('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')'
    #
    # # Extract the slice corresponding to z=z_init
    # print '\nExtract the slice corresponding to z='+str(z_init)+'...'
    # file_point_split = ['tmp.point_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    # nii = Image('tmp.point_orient.nii')
    # data_crop = nii.data[:, :, z_init:z_init+1]
    # nii.data = data_crop
    # nii.setFileName(file_point_split[z_init]+'.nii')
    # nii.save()
    #
    # # Create gaussian mask from point
    # print '\nCreate gaussian mask from point...'
    # file_mask_split = ['tmp.mask_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    # sct.run(sct.fsloutput+'fslmaths '+file_point_split[z_init]+' -s '+str(gaussian_kernel)+' '+file_mask_split[z_init])
    #
    # # Obtain max value from mask
    # print '\nFind maximum value from mask...'
    # file = nibabel.load(file_mask_split[z_init]+'.nii')
    # data = file.get_data()
    # max_value_mask = numpy.max(data)
    # print '..'+str(max_value_mask)
    #
    # # Normalize mask beween 0 and 1
    # print '\nNormalize mask beween 0 and 1...'
    # sct.run(sct.fsloutput+'fslmaths '+file_mask_split[z_init]+' -div '+str(max_value_mask)+' '+file_mask_split[z_init])

    ## Take the square of the mask
    #print '\nCalculate the square of the mask...'
    #sct.run(sct.fsloutput+'fslmaths '+file_mask_split[z_init]+' -mul '+file_mask_split[z_init]+' '+file_mask_split[z_init])

    # initialize variables
    file_mat = ['tmp.mat_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_inv = ['tmp.mat_inv_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_inv_cumul = ['tmp.mat_inv_cumul_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # create identity matrix for initial transformation matrix
    fid = open(file_mat_inv_cumul[z_init], 'w')
    fid.write('%i %i %i %i\n' %(1, 0, 0, 0) )
    fid.write('%i %i %i %i\n' %(0, 1, 0, 0) )
    fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
    fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
    fid.close()

    # initialize centerline: give value corresponding to initial point
    x_centerline = [x_init]
    y_centerline = [y_init]
    z_centerline = [z_init]
    warning_count = 0

    # go up (1), then down (2) in reference to the binary point
    for iUpDown in range(1, 3):

        if iUpDown == 1:
            # z increases
            slice_gap_signed = slice_gap
        elif iUpDown == 2:
            # z decreases
            slice_gap_signed = -slice_gap
            # reverse centerline (because values will be appended at the end)
            x_centerline.reverse()
            y_centerline.reverse()
            z_centerline.reverse()

        # initialization before looping
        z_dest = z_init # point given by user
        z_src = z_dest + slice_gap_signed

        # continue looping if 0 < z < nz
        while 0 <= z_src and z_src <= nz-1:

            # print current z:
            print 'z='+str(z_src)+':'

            # estimate transformation
            sct.run(fsloutput+'flirt -in '+file_anat_split[z_src]+' -ref '+file_anat_split[z_dest]+' -schedule '+file_schedule+ ' -verbose 0 -omat '+file_mat[z_src]+' -cost normcorr -forcescaling -inweight '+file_mask_split[z_dest]+' -refweight '+file_mask_split[z_dest])

            # display transfo
            status, output = sct.run('cat '+file_mat[z_src])
            print output

            # check if transformation is bigger than 1.5x slice_gap
            tx = float(output.split()[3])
            ty = float(output.split()[7])
            norm_txy = linalg.norm([tx, ty],ord=2)
            if norm_txy > 1.5*slice_gap:
                print 'WARNING: Transformation is too large --> using previous one.'
                warning_count = warning_count + 1
                # if previous transformation exists, replace current one with previous one
                if os.path.isfile(file_mat[z_dest]):
                    sct.run('cp '+file_mat[z_dest]+' '+file_mat[z_src])

            # estimate inverse transformation matrix
            sct.run('convert_xfm -omat '+file_mat_inv[z_src]+' -inverse '+file_mat[z_src])

            # compute cumulative transformation
            sct.run('convert_xfm -omat '+file_mat_inv_cumul[z_src]+' -concat '+file_mat_inv[z_src]+' '+file_mat_inv_cumul[z_dest])

            # apply inverse cumulative transformation to initial gaussian mask (to put it in src space)
            sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul[z_src]+' -out '+file_mask_split[z_src])

            # open inverse cumulative transformation file and generate centerline
            fid = open(file_mat_inv_cumul[z_src])
            mat = fid.read().split()
            x_centerline.append(x_init + float(mat[3]))
            y_centerline.append(y_init + float(mat[7]))
            z_centerline.append(z_src)
            #z_index = z_index+1

            # define new z_dest (target slice) and new z_src (moving slice)
            z_dest = z_dest + slice_gap_signed
            z_src = z_src + slice_gap_signed


    # Reconstruct centerline
    # ====================================================================================================

    # reverse back centerline (because it's been reversed once, so now all values are in the right order)
    x_centerline.reverse()
    y_centerline.reverse()
    z_centerline.reverse()

    # fit centerline in the Z-X plane using polynomial function
    print '\nFit centerline in the Z-X plane using polynomial function...'
    coeffsx = polyfit(z_centerline, x_centerline, deg=param.deg_poly)
    polyx = poly1d(coeffsx)
    x_centerline_fit = polyval(polyx, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(x_centerline_fit-x_centerline)/sqrt( len(x_centerline) )
    # calculate max absolute error
    max_abs = max( abs(x_centerline_fit-x_centerline) )
    print '.. RMSE (in mm): '+str(rmse*px)
    print '.. Maximum absolute error (in mm): '+str(max_abs*px)

    # fit centerline in the Z-Y plane using polynomial function
    print '\nFit centerline in the Z-Y plane using polynomial function...'
    coeffsy = polyfit(z_centerline, y_centerline, deg=param.deg_poly)
    polyy = poly1d(coeffsy)
    y_centerline_fit = polyval(polyy, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(y_centerline_fit-y_centerline)/sqrt( len(y_centerline) )
    # calculate max absolute error
    max_abs = max( abs(y_centerline_fit-y_centerline) )
    print '.. RMSE (in mm): '+str(rmse*py)
    print '.. Maximum absolute error (in mm): '+str(max_abs*py)

    # display
    if param.debug == 1:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(z_centerline,x_centerline,'.',z_centerline,x_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-X plane polynomial interpolation')
        plt.show()

        plt.figure()
        plt.plot(z_centerline,y_centerline,'.',z_centerline,y_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-Y plane polynomial interpolation')
        plt.show()

    # generate full range z-values for centerline
    z_centerline_full = [iz for iz in range(0, nz, 1)]

    # calculate X and Y values for the full centerline
    x_centerline_fit_full = polyval(polyx, z_centerline_full)
    y_centerline_fit_full = polyval(polyy, z_centerline_full)

    # Generate fitted transformation matrices and write centerline coordinates in text file
    print '\nGenerate fitted transformation matrices and write centerline coordinates in text file...'
    file_mat_inv_cumul_fit = ['tmp.mat_inv_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_cumul_fit = ['tmp.mat_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    fid_centerline = open('tmp.centerline_coordinates.txt', 'w')
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, x_centerline_fit_full[iz]-x_init) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, y_centerline_fit_full[iz]-y_init) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()
        # compute forward cumulative fitted transformation matrix
        sct.run('convert_xfm -omat '+file_mat_cumul_fit[iz]+' -inverse '+file_mat_inv_cumul_fit[iz])
        # write centerline coordinates in x, y, z format
        fid_centerline.write('%f %f %f\n' %(x_centerline_fit_full[iz], y_centerline_fit_full[iz], z_centerline_full[iz]) )
    fid_centerline.close()


    # Prepare output data
    # ====================================================================================================

    # write centerline as text file
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, x_centerline_fit_full[iz]-x_init) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, y_centerline_fit_full[iz]-y_init) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()

    # write polynomial coefficients
    savetxt('tmp.centerline_polycoeffs_x.txt',coeffsx)
    savetxt('tmp.centerline_polycoeffs_y.txt',coeffsy)

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = ['tmp.anat_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mask_split_fit = ['tmp.mask_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_point_split_fit = ['tmp.point_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput+'flirt -in '+file_anat_split[iz]+' -ref '+file_anat_split[iz]+' -applyxfm -init '+file_mat_cumul_fit[iz]+' -out '+file_anat_split_fit[iz])
        # inverse cumulative transformation to mask
        sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_mask_split_fit[iz])
        # inverse cumulative transformation to point
        sct.run(fsloutput+'flirt -in '+file_point_split[z_init]+' -ref '+file_point_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_point_split_fit[iz]+' -interp nearestneighbour')

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    # sct.run(fsloutput+'fslmerge -z tmp.anat_orient_fit tmp.anat_orient_fit_z*')
    # sct.run(fsloutput+'fslmerge -z tmp.mask_orient_fit tmp.mask_orient_fit_z*')
    # sct.run(fsloutput+'fslmerge -z tmp.point_orient_fit tmp.point_orient_fit_z*')
    concat_data(glob.glob('tmp.anat_orient_fit_z*.nii'), 'tmp.anat_orient_fit.nii', dim=2)
    concat_data(glob.glob('tmp.mask_orient_fit_z*.nii'), 'tmp.mask_orient_fit.nii', dim=2)
    concat_data(glob.glob('tmp.point_orient_fit_z*.nii'), 'tmp.point_orient_fit.nii', dim=2)

    # Copy header geometry from input data
    print '\nCopy header geometry from input data...'
    copy_header('tmp.anat_orient.nii', 'tmp.anat_orient_fit.nii')
    copy_header('tmp.anat_orient.nii', 'tmp.mask_orient_fit.nii')
    copy_header('tmp.anat_orient.nii', 'tmp.point_orient_fit.nii')

    # Reorient outputs into the initial orientation of the input image
    print '\nReorient the centerline into the initial orientation of the input image...'
    set_orientation('tmp.point_orient_fit.nii', input_image_orientation, 'tmp.point_orient_fit.nii')
    set_orientation('tmp.mask_orient_fit.nii', input_image_orientation, 'tmp.mask_orient_fit.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    os.chdir('..')  # come back to parent folder
    #sct.generate_output_file('tmp.centerline_polycoeffs_x.txt','./','centerline_polycoeffs_x','.txt')
    #sct.generate_output_file('tmp.centerline_polycoeffs_y.txt','./','centerline_polycoeffs_y','.txt')
    #sct.generate_output_file('tmp.centerline_coordinates.txt','./','centerline_coordinates','.txt')
    #sct.generate_output_file('tmp.anat_orient.nii','./',file_anat+'_rpi',ext_anat)
    #sct.generate_output_file('tmp.anat_orient_fit.nii', file_anat+'_rpi_align'+ext_anat)
    #sct.generate_output_file('tmp.mask_orient_fit.nii', file_anat+'_mask'+ext_anat)
    fname_output_centerline = sct.generate_output_file(path_tmp+'/tmp.point_orient_fit.nii', file_anat+'_centerline'+ext_anat)

    # Delete temporary files
    if remove_tmp_files == 1:
        print '\nRemove temporary files...'
        sct.run('rm -rf '+path_tmp)

    # print number of warnings
    print '\nNumber of warnings: '+str(warning_count)+' (if >10, you should probably reduce the gap and/or increase the kernel size'

    # display elapsed time
    elapsed_time = time.time() - start_time
    print '\nFinished! \n\tGenerated file: '+fname_output_centerline+'\n\tElapsed time: '+str(int(round(elapsed_time)))+'s\n'

Example 37

View license
def get_centerline_from_point(input_image, point_file, gap=4, gaussian_kernel=4, remove_tmp_files=1):

    # Initialization
    fname_anat = input_image
    fname_point = point_file
    slice_gap = gap
    remove_tmp_files = remove_tmp_files
    gaussian_kernel = gaussian_kernel
    start_time = time()
    verbose = 1

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    path_sct = sct.slash_at_the_end(path_sct, 1)

    # Parameters for debug mode
    if param.debug == 1:
        sct.printv('\n*** WARNING: DEBUG MODE ON ***\n\t\t\tCurrent working directory: '+os.getcwd(), 'warning')
        status, path_sct_testing_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_testing_data+'/t2/t2.nii.gz'
        fname_point = path_sct_testing_data+'/t2/t2_centerline_init.nii.gz'
        slice_gap = 5

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_point)

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_point, file_point, ext_point = sct.extract_fname(fname_point)

    # extract path of schedule file
    # TODO: include schedule file in sct
    # TODO: check existence of schedule file
    file_schedule = path_sct + param.schedule_file

    # Get input image orientation
    input_image_orientation = get_orientation_3d(fname_anat, filename=True)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Anatomical image:     '+fname_anat
    print '  Orientation:          '+input_image_orientation
    print '  Point in spinal cord: '+fname_point
    print '  Slice gap:            '+str(slice_gap)
    print '  Gaussian kernel:      '+str(gaussian_kernel)
    print '  Degree of polynomial: '+str(param.deg_poly)

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+strftime('%y%m%d%H%M%S')
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp '+fname_anat+ ' '+path_tmp+'/tmp.anat'+ext_anat)
    sct.run('cp '+fname_point+ ' '+path_tmp+'/tmp.point'+ext_point)

    # go to temporary folder
    os.chdir(path_tmp)

    # convert to nii
    im_anat = convert('tmp.anat'+ext_anat, 'tmp.anat.nii')
    im_point = convert('tmp.point'+ext_point, 'tmp.point.nii')

    # Reorient input anatomical volume into RL PA IS orientation
    print '\nReorient input volume to RL PA IS orientation...'
    set_orientation(im_anat, 'RPI')
    im_anat.setFileName('tmp.anat_orient.nii')
    # Reorient binary point into RL PA IS orientation
    print '\nReorient binary point into RL PA IS orientation...'
    # sct.run(sct.fsloutput + 'fslswapdim tmp.point RL PA IS tmp.point_orient')
    set_orientation(im_point, 'RPI')
    im_point.setFileName('tmp.point_orient.nii')

    # Get image dimensions
    print '\nGet image dimensions...'
    nx, ny, nz, nt, px, py, pz, pt = Image('tmp.anat_orient.nii').dim
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'

    # Split input volume
    print '\nSplit input volume...'
    im_anat_split_list = split_data(im_anat, 2)
    file_anat_split = []
    for im in im_anat_split_list:
        file_anat_split.append(im.absolutepath)
        im.save()

    im_point_split_list = split_data(im_point, 2)
    file_point_split = []
    for im in im_point_split_list:
        file_point_split.append(im.absolutepath)
        im.save()

    # Extract coordinates of input point
    data_point = Image('tmp.point_orient.nii').data
    x_init, y_init, z_init = unravel_index(data_point.argmax(), data_point.shape)
    sct.printv('Coordinates of input point: ('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')', verbose)

    # Create 2D gaussian mask
    sct.printv('\nCreate gaussian mask from point...', verbose)
    xx, yy = mgrid[:nx, :ny]
    mask2d = zeros((nx, ny))
    radius = round(float(gaussian_kernel+1)/2)  # add 1 because the radius includes the center.
    sigma = float(radius)
    mask2d = exp(-(((xx-x_init)**2)/(2*(sigma**2)) + ((yy-y_init)**2)/(2*(sigma**2))))

    # Save mask to 2d file
    file_mask_split = ['tmp.mask_orient_Z'+str(z).zfill(4) for z in range(0, nz, 1)]
    nii_mask2d = Image('tmp.anat_orient_Z0000.nii')
    nii_mask2d.data = mask2d
    nii_mask2d.setFileName(file_mask_split[z_init]+'.nii')
    nii_mask2d.save()

    # initialize variables
    file_mat = ['tmp.mat_Z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_mat_inv = ['tmp.mat_inv_Z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_mat_inv_cumul = ['tmp.mat_inv_cumul_Z'+str(z).zfill(4) for z in range(0, nz, 1)]

    # create identity matrix for initial transformation matrix
    fid = open(file_mat_inv_cumul[z_init], 'w')
    fid.write('%i %i %i %i\n' % (1, 0, 0, 0))
    fid.write('%i %i %i %i\n' % (0, 1, 0, 0))
    fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
    fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
    fid.close()

    # initialize centerline: give value corresponding to initial point
    x_centerline = [x_init]
    y_centerline = [y_init]
    z_centerline = [z_init]
    warning_count = 0

    # go up (1), then down (2) in reference to the binary point
    for iUpDown in range(1, 3):

        if iUpDown == 1:
            # z increases
            slice_gap_signed = slice_gap
        elif iUpDown == 2:
            # z decreases
            slice_gap_signed = -slice_gap
            # reverse centerline (because values will be appended at the end)
            x_centerline.reverse()
            y_centerline.reverse()
            z_centerline.reverse()

        # initialization before looping
        z_dest = z_init  # point given by user
        z_src = z_dest + slice_gap_signed

        # continue looping if 0 <= z < nz
        while 0 <= z_src < nz:

            # print current z:
            print 'z='+str(z_src)+':'

            # estimate transformation
            sct.run(fsloutput+'flirt -in '+file_anat_split[z_src]+' -ref '+file_anat_split[z_dest]+' -schedule ' +
                    file_schedule + ' -verbose 0 -omat ' + file_mat[z_src] +
                    ' -cost normcorr -forcescaling -inweight ' + file_mask_split[z_dest] +
                    ' -refweight '+file_mask_split[z_dest])

            # display transfo
            status, output = sct.run('cat '+file_mat[z_src])
            print output

            # check if transformation is bigger than 1.5x slice_gap
            tx = float(output.split()[3])
            ty = float(output.split()[7])
            norm_txy = linalg.norm([tx, ty], ord=2)
            if norm_txy > 1.5*slice_gap:
                print 'WARNING: Transformation is too large --> using previous one.'
                warning_count = warning_count + 1
                # if previous transformation exists, replace current one with previous one
                if os.path.isfile(file_mat[z_dest]):
                    sct.run('cp '+file_mat[z_dest]+' '+file_mat[z_src])

            # estimate inverse transformation matrix
            sct.run('convert_xfm -omat '+file_mat_inv[z_src]+' -inverse '+file_mat[z_src])

            # compute cumulative transformation
            sct.run('convert_xfm -omat '+file_mat_inv_cumul[z_src]+' -concat '+file_mat_inv[z_src]+' '+file_mat_inv_cumul[z_dest])

            # apply inverse cumulative transformation to initial gaussian mask (to put it in src space)
            sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul[z_src]+' -out '+file_mask_split[z_src])

            # open inverse cumulative transformation file and generate centerline
            fid = open(file_mat_inv_cumul[z_src])
            mat = fid.read().split()
            x_centerline.append(x_init + float(mat[3]))
            y_centerline.append(y_init + float(mat[7]))
            z_centerline.append(z_src)
            #z_index = z_index+1

            # define new z_dest (target slice) and new z_src (moving slice)
            z_dest = z_dest + slice_gap_signed
            z_src = z_src + slice_gap_signed


    # Reconstruct centerline
    # ====================================================================================================

    # reverse back centerline (because it's been reversed once, so now all values are in the right order)
    x_centerline.reverse()
    y_centerline.reverse()
    z_centerline.reverse()

    # fit centerline in the Z-X plane using polynomial function
    print '\nFit centerline in the Z-X plane using polynomial function...'
    coeffsx = polyfit(z_centerline, x_centerline, deg=param.deg_poly)
    polyx = poly1d(coeffsx)
    x_centerline_fit = polyval(polyx, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(x_centerline_fit-x_centerline)/sqrt( len(x_centerline) )
    # calculate max absolute error
    max_abs = max(abs(x_centerline_fit-x_centerline))
    print '.. RMSE (in mm): '+str(rmse*px)
    print '.. Maximum absolute error (in mm): '+str(max_abs*px)

    # fit centerline in the Z-Y plane using polynomial function
    print '\nFit centerline in the Z-Y plane using polynomial function...'
    coeffsy = polyfit(z_centerline, y_centerline, deg=param.deg_poly)
    polyy = poly1d(coeffsy)
    y_centerline_fit = polyval(polyy, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(y_centerline_fit-y_centerline)/sqrt( len(y_centerline) )
    # calculate max absolute error
    max_abs = max( abs(y_centerline_fit-y_centerline) )
    print '.. RMSE (in mm): '+str(rmse*py)
    print '.. Maximum absolute error (in mm): '+str(max_abs*py)

    # display
    if param.debug == 1:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(z_centerline,x_centerline,'.',z_centerline,x_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-X plane polynomial interpolation')
        plt.show()

        plt.figure()
        plt.plot(z_centerline,y_centerline,'.',z_centerline,y_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-Y plane polynomial interpolation')
        plt.show()

    # generate full range z-values for centerline
    z_centerline_full = [iz for iz in range(0, nz, 1)]

    # calculate X and Y values for the full centerline
    x_centerline_fit_full = polyval(polyx, z_centerline_full)
    y_centerline_fit_full = polyval(polyy, z_centerline_full)

    # Generate fitted transformation matrices and write centerline coordinates in text file
    print '\nGenerate fitted transformation matrices and write centerline coordinates in text file...'
    file_mat_inv_cumul_fit = ['tmp.mat_inv_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_cumul_fit = ['tmp.mat_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    fid_centerline = open('tmp.centerline_coordinates.txt', 'w')
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, x_centerline_fit_full[iz]-x_init))
        fid.write('%i %i %i %f\n' % (0, 1, 0, y_centerline_fit_full[iz]-y_init))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
        fid.close()
        # compute forward cumulative fitted transformation matrix
        sct.run('convert_xfm -omat '+file_mat_cumul_fit[iz]+' -inverse '+file_mat_inv_cumul_fit[iz])
        # write centerline coordinates in x, y, z format
        fid_centerline.write('%f %f %f\n' %(x_centerline_fit_full[iz], y_centerline_fit_full[iz], z_centerline_full[iz]) )
    fid_centerline.close()


    # Prepare output data
    # ====================================================================================================

    # write centerline as text file
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, x_centerline_fit_full[iz]-x_init))
        fid.write('%i %i %i %f\n' % (0, 1, 0, y_centerline_fit_full[iz]-y_init))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
        fid.close()

    # write polynomial coefficients
    savetxt('tmp.centerline_polycoeffs_x.txt',coeffsx)
    savetxt('tmp.centerline_polycoeffs_y.txt',coeffsy)

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = ['tmp.anat_orient_fit_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_mask_split_fit = ['tmp.mask_orient_fit_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_point_split_fit = ['tmp.point_orient_fit_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput+'flirt -in '+file_anat_split[iz]+' -ref '+file_anat_split[iz]+' -applyxfm -init '+file_mat_cumul_fit[iz]+' -out '+file_anat_split_fit[iz])
        # inverse cumulative transformation to mask
        sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_mask_split_fit[iz])
        # inverse cumulative transformation to point
        sct.run(fsloutput+'flirt -in '+file_point_split[z_init]+' -ref '+file_point_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_point_split_fit[iz]+' -interp nearestneighbour')

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    # im_anat_list = [Image(fname) for fname in glob.glob('tmp.anat_orient_fit_z*.nii')]
    fname_anat_list = glob.glob('tmp.anat_orient_fit_z*.nii')
    im_anat_concat = concat_data(fname_anat_list, 2)
    im_anat_concat.setFileName('tmp.anat_orient_fit.nii')
    im_anat_concat.save()

    # im_mask_list = [Image(fname) for fname in glob.glob('tmp.mask_orient_fit_z*.nii')]
    fname_mask_list = glob.glob('tmp.mask_orient_fit_z*.nii')
    im_mask_concat = concat_data(fname_mask_list, 2)
    im_mask_concat.setFileName('tmp.mask_orient_fit.nii')
    im_mask_concat.save()

    # im_point_list = [Image(fname) for fname in 	glob.glob('tmp.point_orient_fit_z*.nii')]
    fname_point_list = glob.glob('tmp.point_orient_fit_z*.nii')
    im_point_concat = concat_data(fname_point_list, 2)
    im_point_concat.setFileName('tmp.point_orient_fit.nii')
    im_point_concat.save()

    # Copy header geometry from input data
    print '\nCopy header geometry from input data...'
    im_anat = Image('tmp.anat_orient.nii')
    im_anat_orient_fit = Image('tmp.anat_orient_fit.nii')
    im_mask_orient_fit = Image('tmp.mask_orient_fit.nii')
    im_point_orient_fit = Image('tmp.point_orient_fit.nii')
    im_anat_orient_fit = copy_header(im_anat, im_anat_orient_fit)
    im_mask_orient_fit = copy_header(im_anat, im_mask_orient_fit)
    im_point_orient_fit = copy_header(im_anat, im_point_orient_fit)
    for im in [im_anat_orient_fit, im_mask_orient_fit, im_point_orient_fit]:
        im.save()

    # Reorient outputs into the initial orientation of the input image
    print '\nReorient the centerline into the initial orientation of the input image...'
    set_orientation('tmp.point_orient_fit.nii', input_image_orientation, 'tmp.point_orient_fit.nii')
    set_orientation('tmp.mask_orient_fit.nii', input_image_orientation, 'tmp.mask_orient_fit.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    os.chdir('..')  # come back to parent folder
    fname_output_centerline = sct.generate_output_file(path_tmp+'/tmp.point_orient_fit.nii', file_anat+'_centerline'+ext_anat)

    # Delete temporary files
    if remove_tmp_files == 1:
        print '\nRemove temporary files...'
        sct.run('rm -rf '+path_tmp, error_exit='warning')

    # print number of warnings
    print '\nNumber of warnings: '+str(warning_count)+' (if >10, you should probably reduce the gap and/or increase the kernel size'

    # display elapsed time
    elapsed_time = time() - start_time
    print '\nFinished! \n\tGenerated file: '+fname_output_centerline+'\n\tElapsed time: '+str(int(round(elapsed_time)))+'s\n'

Example 38

View license
def get_centerline_from_point(input_image, point_file, gap=4, gaussian_kernel=4, remove_tmp_files=1):

    # Initialization
    fname_anat = input_image
    fname_point = point_file
    slice_gap = gap
    remove_tmp_files = remove_tmp_files
    gaussian_kernel = gaussian_kernel
    start_time = time()
    verbose = 1

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    path_sct = sct.slash_at_the_end(path_sct, 1)

    # Parameters for debug mode
    if param.debug == 1:
        sct.printv('\n*** WARNING: DEBUG MODE ON ***\n\t\t\tCurrent working directory: '+os.getcwd(), 'warning')
        status, path_sct_testing_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_testing_data+'/t2/t2.nii.gz'
        fname_point = path_sct_testing_data+'/t2/t2_centerline_init.nii.gz'
        slice_gap = 5

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_point)

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_point, file_point, ext_point = sct.extract_fname(fname_point)

    # extract path of schedule file
    # TODO: include schedule file in sct
    # TODO: check existence of schedule file
    file_schedule = path_sct + param.schedule_file

    # Get input image orientation
    input_image_orientation = get_orientation_3d(fname_anat, filename=True)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Anatomical image:     '+fname_anat
    print '  Orientation:          '+input_image_orientation
    print '  Point in spinal cord: '+fname_point
    print '  Slice gap:            '+str(slice_gap)
    print '  Gaussian kernel:      '+str(gaussian_kernel)
    print '  Degree of polynomial: '+str(param.deg_poly)

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+strftime('%y%m%d%H%M%S')
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp '+fname_anat+ ' '+path_tmp+'/tmp.anat'+ext_anat)
    sct.run('cp '+fname_point+ ' '+path_tmp+'/tmp.point'+ext_point)

    # go to temporary folder
    os.chdir(path_tmp)

    # convert to nii
    im_anat = convert('tmp.anat'+ext_anat, 'tmp.anat.nii')
    im_point = convert('tmp.point'+ext_point, 'tmp.point.nii')

    # Reorient input anatomical volume into RL PA IS orientation
    print '\nReorient input volume to RL PA IS orientation...'
    set_orientation(im_anat, 'RPI')
    im_anat.setFileName('tmp.anat_orient.nii')
    # Reorient binary point into RL PA IS orientation
    print '\nReorient binary point into RL PA IS orientation...'
    # sct.run(sct.fsloutput + 'fslswapdim tmp.point RL PA IS tmp.point_orient')
    set_orientation(im_point, 'RPI')
    im_point.setFileName('tmp.point_orient.nii')

    # Get image dimensions
    print '\nGet image dimensions...'
    nx, ny, nz, nt, px, py, pz, pt = Image('tmp.anat_orient.nii').dim
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'

    # Split input volume
    print '\nSplit input volume...'
    im_anat_split_list = split_data(im_anat, 2)
    file_anat_split = []
    for im in im_anat_split_list:
        file_anat_split.append(im.absolutepath)
        im.save()

    im_point_split_list = split_data(im_point, 2)
    file_point_split = []
    for im in im_point_split_list:
        file_point_split.append(im.absolutepath)
        im.save()

    # Extract coordinates of input point
    data_point = Image('tmp.point_orient.nii').data
    x_init, y_init, z_init = unravel_index(data_point.argmax(), data_point.shape)
    sct.printv('Coordinates of input point: ('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')', verbose)

    # Create 2D gaussian mask
    sct.printv('\nCreate gaussian mask from point...', verbose)
    xx, yy = mgrid[:nx, :ny]
    mask2d = zeros((nx, ny))
    radius = round(float(gaussian_kernel+1)/2)  # add 1 because the radius includes the center.
    sigma = float(radius)
    mask2d = exp(-(((xx-x_init)**2)/(2*(sigma**2)) + ((yy-y_init)**2)/(2*(sigma**2))))

    # Save mask to 2d file
    file_mask_split = ['tmp.mask_orient_Z'+str(z).zfill(4) for z in range(0, nz, 1)]
    nii_mask2d = Image('tmp.anat_orient_Z0000.nii')
    nii_mask2d.data = mask2d
    nii_mask2d.setFileName(file_mask_split[z_init]+'.nii')
    nii_mask2d.save()

    # initialize variables
    file_mat = ['tmp.mat_Z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_mat_inv = ['tmp.mat_inv_Z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_mat_inv_cumul = ['tmp.mat_inv_cumul_Z'+str(z).zfill(4) for z in range(0, nz, 1)]

    # create identity matrix for initial transformation matrix
    fid = open(file_mat_inv_cumul[z_init], 'w')
    fid.write('%i %i %i %i\n' % (1, 0, 0, 0))
    fid.write('%i %i %i %i\n' % (0, 1, 0, 0))
    fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
    fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
    fid.close()

    # initialize centerline: give value corresponding to initial point
    x_centerline = [x_init]
    y_centerline = [y_init]
    z_centerline = [z_init]
    warning_count = 0

    # go up (1), then down (2) in reference to the binary point
    for iUpDown in range(1, 3):

        if iUpDown == 1:
            # z increases
            slice_gap_signed = slice_gap
        elif iUpDown == 2:
            # z decreases
            slice_gap_signed = -slice_gap
            # reverse centerline (because values will be appended at the end)
            x_centerline.reverse()
            y_centerline.reverse()
            z_centerline.reverse()

        # initialization before looping
        z_dest = z_init  # point given by user
        z_src = z_dest + slice_gap_signed

        # continue looping if 0 <= z < nz
        while 0 <= z_src < nz:

            # print current z:
            print 'z='+str(z_src)+':'

            # estimate transformation
            sct.run(fsloutput+'flirt -in '+file_anat_split[z_src]+' -ref '+file_anat_split[z_dest]+' -schedule ' +
                    file_schedule + ' -verbose 0 -omat ' + file_mat[z_src] +
                    ' -cost normcorr -forcescaling -inweight ' + file_mask_split[z_dest] +
                    ' -refweight '+file_mask_split[z_dest])

            # display transfo
            status, output = sct.run('cat '+file_mat[z_src])
            print output

            # check if transformation is bigger than 1.5x slice_gap
            tx = float(output.split()[3])
            ty = float(output.split()[7])
            norm_txy = linalg.norm([tx, ty], ord=2)
            if norm_txy > 1.5*slice_gap:
                print 'WARNING: Transformation is too large --> using previous one.'
                warning_count = warning_count + 1
                # if previous transformation exists, replace current one with previous one
                if os.path.isfile(file_mat[z_dest]):
                    sct.run('cp '+file_mat[z_dest]+' '+file_mat[z_src])

            # estimate inverse transformation matrix
            sct.run('convert_xfm -omat '+file_mat_inv[z_src]+' -inverse '+file_mat[z_src])

            # compute cumulative transformation
            sct.run('convert_xfm -omat '+file_mat_inv_cumul[z_src]+' -concat '+file_mat_inv[z_src]+' '+file_mat_inv_cumul[z_dest])

            # apply inverse cumulative transformation to initial gaussian mask (to put it in src space)
            sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul[z_src]+' -out '+file_mask_split[z_src])

            # open inverse cumulative transformation file and generate centerline
            fid = open(file_mat_inv_cumul[z_src])
            mat = fid.read().split()
            x_centerline.append(x_init + float(mat[3]))
            y_centerline.append(y_init + float(mat[7]))
            z_centerline.append(z_src)
            #z_index = z_index+1

            # define new z_dest (target slice) and new z_src (moving slice)
            z_dest = z_dest + slice_gap_signed
            z_src = z_src + slice_gap_signed


    # Reconstruct centerline
    # ====================================================================================================

    # reverse back centerline (because it's been reversed once, so now all values are in the right order)
    x_centerline.reverse()
    y_centerline.reverse()
    z_centerline.reverse()

    # fit centerline in the Z-X plane using polynomial function
    print '\nFit centerline in the Z-X plane using polynomial function...'
    coeffsx = polyfit(z_centerline, x_centerline, deg=param.deg_poly)
    polyx = poly1d(coeffsx)
    x_centerline_fit = polyval(polyx, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(x_centerline_fit-x_centerline)/sqrt( len(x_centerline) )
    # calculate max absolute error
    max_abs = max(abs(x_centerline_fit-x_centerline))
    print '.. RMSE (in mm): '+str(rmse*px)
    print '.. Maximum absolute error (in mm): '+str(max_abs*px)

    # fit centerline in the Z-Y plane using polynomial function
    print '\nFit centerline in the Z-Y plane using polynomial function...'
    coeffsy = polyfit(z_centerline, y_centerline, deg=param.deg_poly)
    polyy = poly1d(coeffsy)
    y_centerline_fit = polyval(polyy, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(y_centerline_fit-y_centerline)/sqrt( len(y_centerline) )
    # calculate max absolute error
    max_abs = max( abs(y_centerline_fit-y_centerline) )
    print '.. RMSE (in mm): '+str(rmse*py)
    print '.. Maximum absolute error (in mm): '+str(max_abs*py)

    # display
    if param.debug == 1:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(z_centerline,x_centerline,'.',z_centerline,x_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-X plane polynomial interpolation')
        plt.show()

        plt.figure()
        plt.plot(z_centerline,y_centerline,'.',z_centerline,y_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-Y plane polynomial interpolation')
        plt.show()

    # generate full range z-values for centerline
    z_centerline_full = [iz for iz in range(0, nz, 1)]

    # calculate X and Y values for the full centerline
    x_centerline_fit_full = polyval(polyx, z_centerline_full)
    y_centerline_fit_full = polyval(polyy, z_centerline_full)

    # Generate fitted transformation matrices and write centerline coordinates in text file
    print '\nGenerate fitted transformation matrices and write centerline coordinates in text file...'
    file_mat_inv_cumul_fit = ['tmp.mat_inv_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_cumul_fit = ['tmp.mat_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    fid_centerline = open('tmp.centerline_coordinates.txt', 'w')
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, x_centerline_fit_full[iz]-x_init))
        fid.write('%i %i %i %f\n' % (0, 1, 0, y_centerline_fit_full[iz]-y_init))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
        fid.close()
        # compute forward cumulative fitted transformation matrix
        sct.run('convert_xfm -omat '+file_mat_cumul_fit[iz]+' -inverse '+file_mat_inv_cumul_fit[iz])
        # write centerline coordinates in x, y, z format
        fid_centerline.write('%f %f %f\n' %(x_centerline_fit_full[iz], y_centerline_fit_full[iz], z_centerline_full[iz]) )
    fid_centerline.close()


    # Prepare output data
    # ====================================================================================================

    # write centerline as text file
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, x_centerline_fit_full[iz]-x_init))
        fid.write('%i %i %i %f\n' % (0, 1, 0, y_centerline_fit_full[iz]-y_init))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
        fid.close()

    # write polynomial coefficients
    savetxt('tmp.centerline_polycoeffs_x.txt',coeffsx)
    savetxt('tmp.centerline_polycoeffs_y.txt',coeffsy)

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = ['tmp.anat_orient_fit_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_mask_split_fit = ['tmp.mask_orient_fit_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    file_point_split_fit = ['tmp.point_orient_fit_z'+str(z).zfill(4) for z in range(0, nz, 1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput+'flirt -in '+file_anat_split[iz]+' -ref '+file_anat_split[iz]+' -applyxfm -init '+file_mat_cumul_fit[iz]+' -out '+file_anat_split_fit[iz])
        # inverse cumulative transformation to mask
        sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_mask_split_fit[iz])
        # inverse cumulative transformation to point
        sct.run(fsloutput+'flirt -in '+file_point_split[z_init]+' -ref '+file_point_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_point_split_fit[iz]+' -interp nearestneighbour')

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    # im_anat_list = [Image(fname) for fname in glob.glob('tmp.anat_orient_fit_z*.nii')]
    fname_anat_list = glob.glob('tmp.anat_orient_fit_z*.nii')
    im_anat_concat = concat_data(fname_anat_list, 2)
    im_anat_concat.setFileName('tmp.anat_orient_fit.nii')
    im_anat_concat.save()

    # im_mask_list = [Image(fname) for fname in glob.glob('tmp.mask_orient_fit_z*.nii')]
    fname_mask_list = glob.glob('tmp.mask_orient_fit_z*.nii')
    im_mask_concat = concat_data(fname_mask_list, 2)
    im_mask_concat.setFileName('tmp.mask_orient_fit.nii')
    im_mask_concat.save()

    # im_point_list = [Image(fname) for fname in 	glob.glob('tmp.point_orient_fit_z*.nii')]
    fname_point_list = glob.glob('tmp.point_orient_fit_z*.nii')
    im_point_concat = concat_data(fname_point_list, 2)
    im_point_concat.setFileName('tmp.point_orient_fit.nii')
    im_point_concat.save()

    # Copy header geometry from input data
    print '\nCopy header geometry from input data...'
    im_anat = Image('tmp.anat_orient.nii')
    im_anat_orient_fit = Image('tmp.anat_orient_fit.nii')
    im_mask_orient_fit = Image('tmp.mask_orient_fit.nii')
    im_point_orient_fit = Image('tmp.point_orient_fit.nii')
    im_anat_orient_fit = copy_header(im_anat, im_anat_orient_fit)
    im_mask_orient_fit = copy_header(im_anat, im_mask_orient_fit)
    im_point_orient_fit = copy_header(im_anat, im_point_orient_fit)
    for im in [im_anat_orient_fit, im_mask_orient_fit, im_point_orient_fit]:
        im.save()

    # Reorient outputs into the initial orientation of the input image
    print '\nReorient the centerline into the initial orientation of the input image...'
    set_orientation('tmp.point_orient_fit.nii', input_image_orientation, 'tmp.point_orient_fit.nii')
    set_orientation('tmp.mask_orient_fit.nii', input_image_orientation, 'tmp.mask_orient_fit.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    os.chdir('..')  # come back to parent folder
    fname_output_centerline = sct.generate_output_file(path_tmp+'/tmp.point_orient_fit.nii', file_anat+'_centerline'+ext_anat)

    # Delete temporary files
    if remove_tmp_files == 1:
        print '\nRemove temporary files...'
        sct.run('rm -rf '+path_tmp, error_exit='warning')

    # print number of warnings
    print '\nNumber of warnings: '+str(warning_count)+' (if >10, you should probably reduce the gap and/or increase the kernel size'

    # display elapsed time
    elapsed_time = time() - start_time
    print '\nFinished! \n\tGenerated file: '+fname_output_centerline+'\n\tElapsed time: '+str(int(round(elapsed_time)))+'s\n'

Example 39

Project: PHEnix
Source File: vcf2fasta.py
View license
def main(args):
    """
    Process VCF files and merge them into a single fasta file.
    """

    contigs = list()

    empty_tree = FastRBTree()

    exclude = {}
    include = {}

    if args["tmp"]:
        out_dir = os.path.join(args["tmp"])
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
    else:
        out_dir = tempfile.gettempdir()

    if args["reference"]:
        ref_seq = OrderedDict()
        with open(args["reference"]) as fp:
            for record in SeqIO.parse(fp, "fasta"):
                ref_seq[record.id] = list(record.seq)

        args["reference"] = ref_seq

    if args["exclude"] or args["include"]:
        pos = {}
        chr_pos = []
        bed_file = args["include"] if args["include"] is not None else args["exclude"]

        with open(bed_file) as fp:
            for line in fp:
                data = line.strip().split("\t")

                chr_pos += [ (i, False,) for i in xrange(int(data[1]), int(data[2]) + 1)]

                if data[0] not in pos:
                    pos[data[0]] = []

                pos[data[0]] += chr_pos

        pos = {chrom: FastRBTree(l) for chrom, l in pos.items()}

        if args["include"]:
            include = pos
        else:
            exclude = pos


    if args["directory"] is not None and args["input"] is None:
        regexp = args["regexp"] if args["regexp"] else "*.vcf"
        args["input"] = glob.glob(os.path.join(args["directory"], regexp))

    if not args["input"]:
        logging.warn("No VCFs found.")
        return 0


    # If we can stats and asked to stats, then output the data
    if args["with_stats"] is not None:
        args["with_stats"] = open(args["with_stats"], "wb")
        args["with_stats"].write("contig,position,mutations,n_frac,n_gaps\n")


    parallel_reader = ParallelVCFReader(args["input"])

    sample_seqs = { sample_name: tempfile.NamedTemporaryFile(prefix=sample_name, dir=out_dir) for sample_name in parallel_reader.get_samples() }
    sample_seqs["reference"] = tempfile.NamedTemporaryFile(prefix="reference", dir=out_dir)

    samples = parallel_reader.get_samples() + ["reference"]
    sample_stats = {sample: BaseStats() for sample in samples }
    last_base = 0

    total_records = 0
    guesstimate_records = guess_total_records(args["input"])

    for chrom, pos, records in parallel_reader:
        total_records += 1

        log_progress(total_records, guesstimate_records)

        final_records = pick_best_records(records)
        reference = [ record.REF for record in final_records.itervalues() if record.REF != "N"]
        valid = not reference or reference.count(reference[0]) == len(reference)

        # Make sure reference is the same across all samples.
        assert valid, "Position %s is not valid as multiple references found: %s" % (pos, reference)

        if not reference:
            continue
        else:
            reference = reference[0]

        # SKIP (or include) any pre-specified regions.
        if include and pos not in include.get(chrom, empty_tree) or exclude and pos in exclude.get(chrom, empty_tree):
            continue

        position_data = {"reference": str(reference), "stats": BaseStats()}

        for sample_name, record in final_records.iteritems():

            position_data["stats"].total += 1

            # IF this is uncallable genotype, add gap "-"
            if record.is_uncallable:
                # TODO: Mentioned in issue: #7(gitlab)
                position_data[sample_name] = "-"

                # Update stats
                position_data["stats"].gap += 1


            elif not record.FILTER:
                # If filter PASSED!
                # Make sure the reference base is the same. Maybe a vcf from different species snuck in here?!
                assert str(record.REF) == position_data["reference"] or str(record.REF) == 'N' or position_data["reference"] == 'N', "SOMETHING IS REALLY WRONG because reference for the same position is DIFFERENT! %s in %s (%s, %s)" % (record.POS, sample_name, str(record.REF), position_data["reference"])
                # update position_data['reference'] to a real base if possible
                if position_data['reference'] == 'N' and str(record.REF) != 'N':
                    position_data['reference'] = str(record.REF)
                if record.is_snp:
                    if len(record.ALT) > 1:
                        logging.info("POS %s passed filters but has multiple alleles REF: %s, ALT: %s. Inserting N", record.POS, str(record.REF), str(record.ALT))
                        position_data[sample_name] = "N"
                        position_data["stats"].N += 1

                    else:
                        position_data[sample_name] = str(record.ALT[0])

                        position_data["stats"].mut += 1

            # Filter(s) failed
            elif record.is_snp and is_above_min_depth(record):
                if args["with_mixtures"]:
                    extended_code = get_mixture(record, args["with_mixtures"])
                else:
                    extended_code = "N"

                if extended_code == "N":
                    position_data["stats"].N += 1
                elif extended_code in ["A", "C", "G", "T"]:
                    position_data["stats"].mut += 1
                else:
                    position_data["stats"].mix += 1

                position_data[sample_name] = extended_code

            else:
                # filter fail; code as N for consistency
                position_data[sample_name] = "N"
                position_data["stats"].N += 1

            # Filter columns when threashold reaches user specified value.
            if isinstance(args["column_Ns"], float) and float(position_data["stats"].N) / len(args["input"]) > args["column_Ns"]:
                break
#                 del position_data[sample_name]

            if isinstance(args["column_gaps"], float) and float(position_data["stats"].gap) / len(args["input"]) > args["column_gaps"]:
                break
#                 del position_data[sample_name]

        # this is not an if-else it's a for-else, it really is!
        else:
            if args["reference"]:
                seq = _make_ref_insert(last_base, pos, args["reference"][chrom], exclude.get(chrom, empty_tree))
                for sample in samples:
#                     sample_seqs[sample] += seq
                    sample_seqs[sample].write(''.join(seq))

            for i, sample_name in enumerate(samples):
                sample_base = position_data.get(sample_name, reference)

#                 sample_seqs[sample_name] += [sample_base]
                sample_seqs[sample_name].write(sample_base)
                sample_stats[sample_name].update(position_data, sample_name, reference)

            if args["with_stats"] is not None:
                args["with_stats"].write("%s,%i,%0.5f,%0.5f,%0.5f\n" % (chrom,
                                             pos,
                                             float(position_data["stats"].mut) / len(args["input"]),
                                             float(position_data["stats"].N) / len(args["input"]),
                                             float(position_data["stats"].gap) / len(args["input"]))
                         )

            last_base = pos

    # Fill from last snp to the end of reference.
    # FIXME: A little naughty to use chrom outside the loop!
    if args["reference"]:
        seq = _make_ref_insert(last_base, None, args["reference"][chrom], exclude.get(chrom, empty_tree))
        for sample in samples:
#             sample_seqs[sample] += seq
            sample_seqs[sample].write(''.join(seq))

    sample_seqs["reference"].seek(0)
    reference = sample_seqs["reference"].next()
    sample_seqs["reference"].close()
    del sample_seqs["reference"]

    bSamplesExcluded = False

    # Exclude any samples with high Ns or gaps
    if isinstance(args["sample_Ns"], float):
        for sample_name in samples:
            if sample_name == "reference":
                continue
            n_fraction = float(sample_stats[sample_name].N) / sample_stats[sample_name].total
            if n_fraction > args["sample_Ns"]:
                logging.info("Removing %s due to high sample Ns fraction %s", sample_name, n_fraction)

                sample_seqs[sample_name].close()
                del sample_seqs[sample_name]
                del sample_stats[sample_name]
                bSamplesExcluded = True

    # Exclude any samples with high gap fraction.
    if isinstance(args["sample_gaps"], float):
        for sample_name in samples:
            if sample_name == "reference" or sample_name not in sample_stats:
                continue

            gap_fractoin = float(sample_stats[sample_name].gap) / sample_stats[sample_name].total
            if gap_fractoin > args["sample_gaps"]:
                logging.info("Removing %s due to high sample gaps fraction %s", sample_name, gap_fractoin)

                sample_seqs[sample_name].close()
                del sample_seqs[sample_name]
                del sample_stats[sample_name]
                bSamplesExcluded = True

    try:
        assert len(sample_seqs) > 0, "All samples have been filtered out."

        reference_length = len(reference)

        dAlign = {}
        dAlign['reference'] = reference
        for sample_name, tmp_iter in sample_seqs.iteritems():
            tmp_iter.seek(0)
            # These are dumped as single long string of data. Calling next() should read it all.
            snp_sequence = tmp_iter.next()
            assert len(snp_sequence) == reference_length, "Sample %s has length %s, but should be %s (reference)" % (sample_name, len(snp_sequence), reference_length)

            dAlign[sample_name] = snp_sequence

        # if samples were excluded we need to filter the alignment for all equal positions,
        # because we might just have removed the sequence with the difference
        while bSamplesExcluded:
            dFinalAlign = {} #  this is for the new alignment
            # initialise thoes as empty
            for sample_name in dAlign.keys():
                dFinalAlign[sample_name] = ''
                sample_stats[sample_name] = BaseStats()
            # for all positions in the current alignment
            for i in range(len(dAlign['reference'])):
                # initialise empty stats for this position
                pos_stats = BaseStats()
                # get list of all nucs at this position
                ith_nucs = [seq[i] for seq in dAlign.values()]
                # check if all elements in the list are the same
                if ith_nucs.count(ith_nucs[0]) != len(ith_nucs):
                    # they are not all the same
                    # for all samples and seqs update position stats
                    for sample_name, seq in dAlign.iteritems():
                        if seq[i] == 'N':
                            pos_stats.N +=1
                        elif seq[i] == '-':
                            pos_stats.gap +=1
                        elif seq[i] != dAlign['reference'][i]:
                            pos_stats.mut +=1
                        else:
                            pass
                        pos_stats.total += 1

                    # check if we need to remove this column
                    bRmCol = False
                    if isinstance(args["column_gaps"], float):
                        gap_fractoin = float(pos_stats.gap) / pos_stats.total
                        if gap_fractoin > args["column_gaps"]:
                            bRmCol = True
                    if isinstance(args["column_Ns"], float):
                        n_fraction = float(pos_stats.N) / pos_stats.total
                        if n_fraction > args["column_Ns"]:
                            bRmCol = True

                    # remove col if necessary
                    if bRmCol == False:
                        # we don't remove it
                        for sample_name, seq in dAlign.iteritems():
                            dFinalAlign[sample_name] += seq[i]
                            # only update sample stats now that we have decided to keep the column
                            sample_stats[sample_name].total += 1
                            if seq[i] == 'N':
                                sample_stats[sample_name].N += 1
                            elif seq[i] == '-':
                                sample_stats[sample_name].gap += 1
                            elif seq[i] != dAlign['reference'][i]:
                                sample_stats[sample_name].mut += 1
                            else:
                                pass
                    else:
                        # we are removing it
                        logging.info("Removing column %i due to high Ns or gaps fraction, gaps: %s, Ns: %s", i, gap_fractoin, n_fraction)
                else:
                    # all positions they're all the same
                    pass

            # check all seqs are of the same lengths still
            seq_lens = [len(seq) for seq in dFinalAlign.values()]
            assert seq_lens.count(seq_lens[0]) == len(seq_lens), "ERROR: Not all samples in final alignment are equally long!"

            # check if additional samples need to be removed
            bSamplesExcluded = False
            for sample_name in dFinalAlign.keys():
                n_fraction = float(sample_stats[sample_name].N) / seq_lens[0]
                if n_fraction > args["sample_Ns"]:
                    logging.info("Removing %s due to high sample Ns fraction %s", sample_name, n_fraction)
                    bSamplesExcluded = True
                    del dFinalAlign[sample_name]
                    del sample_stats[sample_name]

            for sample_name in dFinalAlign.keys():
                gap_fractoin = float(sample_stats[sample_name].gap) / seq_lens[0]
                if gap_fractoin > args["sample_gaps"]:
                    logging.info("Removing %s due to high sample gaps fraction %s", sample_name, gap_fractoin)
                    bSamplesExcluded = True
                    del dFinalAlign[sample_name]
                    del sample_stats[sample_name]

            # in case we need to go again ...
            dAlign = dFinalAlign

        with open(args["out"], "w") as fp:
            # write seqs to file
            for name, seq in dAlign.iteritems():
                fp.write(">%s\n%s\n" % (name, seq))

    except AssertionError as e:
        logging.error(e.message)

        # Need to delete the malformed file.
        os.unlink(args["out"])

    finally:
        # Close all the tmp handles.
        for tmp_iter in sample_seqs.itervalues():
            tmp_iter.close()

        # Only remove tmp is it was specified.
        if args["tmp"]:
            shutil.rmtree(out_dir)

        if args["with_stats"] is not None:
            args["with_stats"].close()

    # Compute the stats.
    for sample in sample_stats:
        if sample != "reference":
            print "%s\t%s" % (sample, str(sample_stats[sample]))

#         if CAN_STATS:
#             plot_stats(avail_pos, len(samples) - 1, plots_dir=os.path.abspath(args["plots_dir"]))

    return 0

Example 40

Project: quality-assessment-protocol
Source File: cli.py
View license
def _run_workflow(args):

    # build pipeline for each subject, individually
    # ~ 5 min 20 sec per subject
    # (roughly 320 seconds)

    import os
    import os.path as op
    import sys

    import nipype.interfaces.io as nio
    import nipype.pipeline.engine as pe

    import nipype.interfaces.utility as util
    import nipype.interfaces.fsl.maths as fsl

    import glob

    import time
    from time import strftime
    from nipype import config as nyconfig

    resource_pool, config, subject_info, run_name, site_name = args
    sub_id = str(subject_info[0])

    qap_type = config['qap_type']

    if subject_info[1]:
        session_id = subject_info[1]
    else:
        session_id = "session_0"

    if subject_info[2]:
        scan_id = subject_info[2]
    else:
        scan_id = "scan_0"

    # Read and apply general settings in config
    keep_outputs = config.get('write_all_outputs', False)
    output_dir = op.join(config["output_directory"], run_name,
                         sub_id, session_id, scan_id)

    try:
        os.makedirs(output_dir)
    except:
        if not op.isdir(output_dir):
            err = "[!] Output directory unable to be created.\n" \
                  "Path: %s\n\n" % output_dir
            raise Exception(err)
        else:
            pass

    log_dir = output_dir

    # set up logging
    nyconfig.update_config(
        {'logging': {'log_directory': log_dir, 'log_to_file': True}})
    logging.update_logging(nyconfig)

    # take date+time stamp for run identification purposes
    unique_pipeline_id = strftime("%Y%m%d%H%M%S")
    pipeline_start_stamp = strftime("%Y-%m-%d_%H:%M:%S")

    pipeline_start_time = time.time()

    logger.info("Pipeline start time: %s" % pipeline_start_stamp)
    logger.info("Contents of resource pool:\n" + str(resource_pool))
    logger.info("Configuration settings:\n" + str(config))

    # for QAP spreadsheet generation only
    config.update({"subject_id": sub_id, "session_id": session_id,
                   "scan_id": scan_id, "run_name": run_name})

    if site_name:
        config["site_name"] = site_name

    workflow = pe.Workflow(name=scan_id)
    workflow.base_dir = op.join(config["working_directory"], sub_id,
                                session_id)

    # set up crash directory
    workflow.config['execution'] = \
        {'crashdump_dir': config["output_directory"]}

    # update that resource pool with what's already in the output directory
    for resource in os.listdir(output_dir):
        if (op.isdir(op.join(output_dir, resource)) and
                resource not in resource_pool.keys()):
            resource_pool[resource] = glob.glob(op.join(output_dir,
                                                        resource, "*"))[0]

    # resource pool check
    invalid_paths = []

    for resource in resource_pool.keys():
        if not op.isfile(resource_pool[resource]):
            invalid_paths.append((resource, resource_pool[resource]))

    if len(invalid_paths) > 0:
        err = "\n\n[!] The paths provided in the subject list to the " \
              "following resources are not valid:\n"

        for path_tuple in invalid_paths:
            err = err + path_tuple[0] + ": " + path_tuple[1] + "\n"

        err = err + "\n\n"
        raise Exception(err)

    # start connecting the pipeline
    if 'qap_' + qap_type not in resource_pool.keys():
        from qap import qap_workflows as qw
        wf_builder = getattr(qw, 'qap_' + qap_type + '_workflow')
        workflow, resource_pool = wf_builder(workflow, resource_pool, config)

    # set up the datasinks
    new_outputs = 0

    out_list = set(['qap_' + qap_type])

    # Save reports to out_dir if necessary
    if config.get('write_report', False):
        out_list.add('qap_mosaic')
        # The functional temporal also has an FD plot
        if 'functional_temporal' in qap_type:
            out_list.add('qap_fd')

    if keep_outputs:
        for k in resource_pool.keys():
            out_list.add(k)

    for output in list(out_list):
        # we use a check for len()==2 here to select those items in the
        # resource pool which are tuples of (node, node_output), instead
        # of the items which are straight paths to files

        # resource pool items which are in the tuple format are the
        # outputs that have been created in this workflow because they
        # were not present in the subject list YML (the starting resource
        # pool) and had to be generated
        if len(resource_pool[output]) == 2:
            ds = pe.Node(nio.DataSink(), name='datasink_%s' % output)
            ds.inputs.base_directory = output_dir
            node, out_file = resource_pool[output]
            workflow.connect(node, out_file, ds, output)
            new_outputs += 1

    rt = {'id': sub_id, 'session': session_id, 'scan': scan_id,
          'status': 'started'}
    # run the pipeline (if there is anything to do)
    if new_outputs > 0:
        if config.get('write_graph', False):
            workflow.write_graph(
                dotfilename=op.join(output_dir, run_name + ".dot"),
                simple_form=False)

        nc_per_subject = config.get('num_cores_per_subject', 1)
        runargs = {'plugin': 'Linear', 'plugin_args': {}}
        if nc_per_subject > 1:
            runargs['plugin'] = 'MultiProc'
            runargs['plugin_args'] = {'n_procs': nc_per_subject}

        try:
            workflow.run(**runargs)
            rt['status'] = 'finished'
        except Exception as e:
            # ... however this is run inside a pool.map: do not raise Exception
            etype, evalue, etrace = sys.exc_info()
            tb = format_exception(etype, evalue, etrace)
            rt.update({'status': 'failed', 'msg': '%s' % e, 'traceback': tb})
            logger.error('An error occurred processing subject %s. '
                         'Runtime dict: %s\n%s' %
                         (rt['id'], rt, '\n'.join(rt['traceback'])))
    else:
        rt['status'] = 'cached'
        logger.info("\nEverything is already done for subject %s." % sub_id)

    # Remove working directory when done
    if not keep_outputs:
        try:
            work_dir = op.join(workflow.base_dir, scan_id)

            if op.exists(work_dir):
                import shutil
                shutil.rmtree(work_dir)
        except:
            logger.warn("Couldn\'t remove the working directory!")
            pass

    pipeline_end_stamp = strftime("%Y-%m-%d_%H:%M:%S")
    pipeline_end_time = time.time()
    logger.info("Elapsed time (minutes) since last start: %s"
                % ((pipeline_end_time - pipeline_start_time) / 60))
    logger.info("Pipeline end time: %s" % pipeline_end_stamp)
    return rt

Example 41

Project: pupil
Source File: main.py
View license
def session(rec_dir):

    system_plugins = [Log_Display,Seek_Bar,Trim_Marks]
    vis_plugins = sorted([Vis_Circle,Vis_Polyline,Vis_Light_Points,Vis_Cross,Vis_Watermark,Eye_Video_Overlay,Scan_Path], key=lambda x: x.__name__)
    analysis_plugins = sorted([Gaze_Position_2D_Fixation_Detector,Pupil_Angle_3D_Fixation_Detector,Pupil_Angle_3D_Fixation_Detector,Manual_Gaze_Correction,Video_Export_Launcher,Offline_Surface_Tracker,Raw_Data_Exporter,Batch_Exporter,Annotation_Player], key=lambda x: x.__name__)
    other_plugins = sorted([Show_Calibration,Log_History], key=lambda x: x.__name__)
    user_plugins = sorted(import_runtime_plugins(os.path.join(user_dir,'plugins')), key=lambda x: x.__name__)
    user_launchable_plugins = vis_plugins + analysis_plugins + other_plugins + user_plugins
    available_plugins = system_plugins + user_launchable_plugins
    name_by_index = [p.__name__ for p in available_plugins]
    index_by_name = dict(zip(name_by_index,range(len(name_by_index))))
    plugin_by_name = dict(zip(name_by_index,available_plugins))


    # Callback functions
    def on_resize(window,w, h):
        g_pool.gui.update_window(w,h)
        g_pool.gui.collect_menus()
        graph.adjust_size(w,h)
        adjust_gl_view(w,h)
        for p in g_pool.plugins:
            p.on_window_resize(window,w,h)

    def on_key(window, key, scancode, action, mods):
        g_pool.gui.update_key(key,scancode,action,mods)

    def on_char(window,char):
        g_pool.gui.update_char(char)

    def on_button(window,button, action, mods):
        g_pool.gui.update_button(button,action,mods)
        pos = glfwGetCursorPos(window)
        pos = normalize(pos,glfwGetWindowSize(window))
        pos = denormalize(pos,(frame.img.shape[1],frame.img.shape[0]) ) # Position in img pixels
        for p in g_pool.plugins:
            p.on_click(pos,button,action)

    def on_pos(window,x, y):
        hdpi_factor = float(glfwGetFramebufferSize(window)[0]/glfwGetWindowSize(window)[0])
        g_pool.gui.update_mouse(x*hdpi_factor,y*hdpi_factor)

    def on_scroll(window,x,y):
        g_pool.gui.update_scroll(x,y*y_scroll_factor)


    def on_drop(window,count,paths):
        for x in range(count):
            new_rec_dir =  paths[x]
            if is_pupil_rec_dir(new_rec_dir):
                logger.debug("Starting new session with '%s'"%new_rec_dir)
                global rec_dir
                rec_dir = new_rec_dir
                glfwSetWindowShouldClose(window,True)
            else:
                logger.error("'%s' is not a valid pupil recording"%new_rec_dir)




    tick = delta_t()
    def get_dt():
        return next(tick)

    update_recording_to_recent(rec_dir)

    video_path = [f for f in glob(os.path.join(rec_dir,"world.*")) if f[-3:] in ('mp4','mkv','avi')][0]
    timestamps_path = os.path.join(rec_dir, "world_timestamps.npy")
    pupil_data_path = os.path.join(rec_dir, "pupil_data")

    meta_info = load_meta_info(rec_dir)
    rec_version = read_rec_version(meta_info)
    app_version = get_version(version_file)

    # log info about Pupil Platform and Platform in player.log
    logger.info('Application Version: %s'%app_version)
    logger.info('System Info: %s'%get_system_info())

    timestamps = np.load(timestamps_path)

    # create container for globally scoped vars
    g_pool = Global_Container()
    g_pool.app = 'player'

    # Initialize capture
    cap = File_Source(g_pool,video_path,timestamps=list(timestamps))

    # load session persistent settings
    session_settings = Persistent_Dict(os.path.join(user_dir,"user_settings"))
    if session_settings.get("version",VersionFormat('0.0')) < get_version(version_file):
        logger.info("Session setting are from older version of this app. I will not use those.")
        session_settings.clear()

    width,height = session_settings.get('window_size',cap.frame_size)
    window_pos = session_settings.get('window_position',(0,0))
    main_window = glfwCreateWindow(width, height, "Pupil Player: "+meta_info["Recording Name"]+" - "+ rec_dir.split(os.path.sep)[-1], None, None)
    glfwSetWindowPos(main_window,window_pos[0],window_pos[1])
    glfwMakeContextCurrent(main_window)
    cygl.utils.init()

    # load pupil_positions, gaze_positions
    pupil_data = load_object(pupil_data_path)
    pupil_list = pupil_data['pupil_positions']
    gaze_list = pupil_data['gaze_positions']

    g_pool.binocular = meta_info.get('Eye Mode','monocular') == 'binocular'
    g_pool.version = app_version
    g_pool.capture = cap
    g_pool.timestamps = timestamps
    g_pool.play = False
    g_pool.new_seek = True
    g_pool.user_dir = user_dir
    g_pool.rec_dir = rec_dir
    g_pool.rec_version = rec_version
    g_pool.meta_info = meta_info
    g_pool.min_data_confidence = session_settings.get('min_data_confidence',0.6)
    g_pool.pupil_positions_by_frame = correlate_data(pupil_list,g_pool.timestamps)
    g_pool.gaze_positions_by_frame = correlate_data(gaze_list,g_pool.timestamps)
    g_pool.fixations_by_frame = [[] for x in g_pool.timestamps] #populated by the fixation detector plugin

    def next_frame(_):
        try:
            cap.seek_to_frame(cap.get_frame_index())
        except FileSeekError:
            logger.warning("Could not seek to next frame.")
        else:
            g_pool.new_seek = True

    def prev_frame(_):
        try:
            cap.seek_to_frame(cap.get_frame_index()-2)
        except FileSeekError:
            logger.warning("Could not seek to previous frame.")
        else:
            g_pool.new_seek = True

    def toggle_play(new_state):
        if cap.get_frame_index() >= cap.get_frame_count()-5:
            cap.seek_to_frame(1) #avoid pause set by hitting trimmark pause.
            logger.warning("End of video - restart at beginning.")
        g_pool.play = new_state

    def set_scale(new_scale):
        g_pool.gui.scale = new_scale
        g_pool.gui.collect_menus()

    def set_data_confidence(new_confidence):
        g_pool.min_data_confidence = new_confidence
        notification = {'subject':'min_data_confidence_changed'}
        notification['_notify_time_'] = time()+.8
        g_pool.delayed_notifications[notification['subject']] = notification

    def open_plugin(plugin):
        if plugin ==  "Select to load":
            return
        g_pool.plugins.add(plugin)

    def purge_plugins():
        for p in g_pool.plugins:
            if p.__class__ in user_launchable_plugins:
                p.alive = False
        g_pool.plugins.clean()

    def do_export(_):
        export_range = slice(g_pool.trim_marks.in_mark,g_pool.trim_marks.out_mark)
        export_dir = os.path.join(g_pool.rec_dir,'exports','%s-%s'%(export_range.start,export_range.stop))
        try:
            os.makedirs(export_dir)
        except OSError as e:
            if e.errno != errno.EEXIST:
                logger.error("Could not create export dir")
                raise e
            else:
                logger.warning("Previous export for range [%s-%s] already exsits - overwriting."%(export_range.start,export_range.stop))
        else:
            logger.info('Created export dir at "%s"'%export_dir)

        notification = {'subject':'should_export','range':export_range,'export_dir':export_dir}
        g_pool.notifications.append(notification)

    g_pool.gui = ui.UI()
    g_pool.gui.scale = session_settings.get('gui_scale',1)
    g_pool.main_menu = ui.Scrolling_Menu("Settings",pos=(-350,20),size=(300,500))
    g_pool.main_menu.append(ui.Button("Close Pupil Player",lambda:glfwSetWindowShouldClose(main_window,True)))
    g_pool.main_menu.append(ui.Slider('scale',g_pool.gui, setter=set_scale,step = .05,min=0.75,max=2.5,label='Interface Size'))
    g_pool.main_menu.append(ui.Info_Text('Player Version: %s'%g_pool.version))
    g_pool.main_menu.append(ui.Info_Text('Recording Version: %s'%rec_version))
    g_pool.main_menu.append(ui.Slider('min_data_confidence',g_pool, setter=set_data_confidence,step=.05 ,min=0.0,max=1.0,label='Confidence threshold'))

    selector_label = "Select to load"

    vis_labels = ["   " + p.__name__.replace('_',' ') for p in vis_plugins]
    analysis_labels = ["   " + p.__name__.replace('_',' ') for p in analysis_plugins]
    other_labels = ["   " + p.__name__.replace('_',' ') for p in other_plugins]
    user_labels = ["   " + p.__name__.replace('_',' ') for p in user_plugins]

    plugins = [selector_label, selector_label] + vis_plugins + [selector_label] + analysis_plugins + [selector_label] + other_plugins + [selector_label] + user_plugins
    labels = [selector_label, "Visualization"] + vis_labels + ["Analysis"] + analysis_labels + ["Other"] + other_labels + ["User added"] + user_labels

    g_pool.main_menu.append(ui.Selector('Open plugin:',
                                        selection = plugins,
                                        labels    = labels,
                                        setter    = open_plugin,
                                        getter    = lambda: selector_label))

    g_pool.main_menu.append(ui.Button('Close all plugins',purge_plugins))
    g_pool.main_menu.append(ui.Button('Reset window size',lambda: glfwSetWindowSize(main_window,cap.frame_size[0],cap.frame_size[1])) )
    g_pool.quickbar = ui.Stretching_Menu('Quick Bar',(0,100),(120,-100))
    g_pool.play_button = ui.Thumb('play',g_pool,label=unichr(0xf04b).encode('utf-8'),setter=toggle_play,hotkey=GLFW_KEY_SPACE,label_font='fontawesome',label_offset_x=5,label_offset_y=0,label_offset_size=-24)
    g_pool.play_button.on_color[:] = (0,1.,.0,.8)
    g_pool.forward_button = ui.Thumb('forward',label=unichr(0xf04e).encode('utf-8'),getter = lambda: False,setter= next_frame, hotkey=GLFW_KEY_RIGHT,label_font='fontawesome',label_offset_x=5,label_offset_y=0,label_offset_size=-24)
    g_pool.backward_button = ui.Thumb('backward',label=unichr(0xf04a).encode('utf-8'),getter = lambda: False, setter = prev_frame, hotkey=GLFW_KEY_LEFT,label_font='fontawesome',label_offset_x=-5,label_offset_y=0,label_offset_size=-24)
    g_pool.export_button = ui.Thumb('export',label=unichr(0xf063).encode('utf-8'),getter = lambda: False, setter = do_export, hotkey='e',label_font='fontawesome',label_offset_x=0,label_offset_y=2,label_offset_size=-24)
    g_pool.quickbar.extend([g_pool.play_button,g_pool.forward_button,g_pool.backward_button,g_pool.export_button])
    g_pool.gui.append(g_pool.quickbar)
    g_pool.gui.append(g_pool.main_menu)


    #we always load these plugins
    system_plugins = [('Trim_Marks',{}),('Seek_Bar',{})]
    default_plugins = [('Log_Display',{}),('Scan_Path',{}),('Vis_Polyline',{}),('Vis_Circle',{}),('Video_Export_Launcher',{})]
    previous_plugins = session_settings.get('loaded_plugins',default_plugins)
    g_pool.notifications = []
    g_pool.delayed_notifications = {}
    g_pool.plugins = Plugin_List(g_pool,plugin_by_name,system_plugins+previous_plugins)


    # Register callbacks main_window
    glfwSetFramebufferSizeCallback(main_window,on_resize)
    glfwSetKeyCallback(main_window,on_key)
    glfwSetCharCallback(main_window,on_char)
    glfwSetMouseButtonCallback(main_window,on_button)
    glfwSetCursorPosCallback(main_window,on_pos)
    glfwSetScrollCallback(main_window,on_scroll)
    glfwSetDropCallback(main_window,on_drop)
    #trigger on_resize
    on_resize(main_window, *glfwGetFramebufferSize(main_window))

    g_pool.gui.configuration = session_settings.get('ui_config',{})

    # gl_state settings
    basic_gl_setup()
    g_pool.image_tex = Named_Texture()

    #set up performace graphs:
    pid = os.getpid()
    ps = psutil.Process(pid)
    ts = None

    cpu_graph = graph.Bar_Graph()
    cpu_graph.pos = (20,110)
    cpu_graph.update_fn = ps.cpu_percent
    cpu_graph.update_rate = 5
    cpu_graph.label = 'CPU %0.1f'

    fps_graph = graph.Bar_Graph()
    fps_graph.pos = (140,110)
    fps_graph.update_rate = 5
    fps_graph.label = "%0.0f REC FPS"

    pupil_graph = graph.Bar_Graph(max_val=1.0)
    pupil_graph.pos = (260,110)
    pupil_graph.update_rate = 5
    pupil_graph.label = "Confidence: %0.2f"

    while not glfwWindowShouldClose(main_window):


        #grab new frame
        if g_pool.play or g_pool.new_seek:
            g_pool.new_seek = False
            try:
                new_frame = cap.get_frame_nowait()
            except EndofVideoFileError:
                #end of video logic: pause at last frame.
                g_pool.play=False
                logger.warning("end of video")
            update_graph = True
        else:
            update_graph = False


        frame = new_frame.copy()
        events = {}
        #report time between now and the last loop interation
        events['dt'] = get_dt()
        #new positons we make a deepcopy just like the image is a copy.
        events['gaze_positions'] = deepcopy(g_pool.gaze_positions_by_frame[frame.index])
        events['pupil_positions'] = deepcopy(g_pool.pupil_positions_by_frame[frame.index])

        if update_graph:
            #update performace graphs
            for p in  events['pupil_positions']:
                pupil_graph.add(p['confidence'])

            t = new_frame.timestamp
            if ts and ts != t:
                dt,ts = t-ts,t
                fps_graph.add(1./dt)

            g_pool.play_button.status_text = str(frame.index)
        #always update the CPU graph
        cpu_graph.update()


        # publish delayed notifiactions when their time has come.
        for n in g_pool.delayed_notifications.values():
            if n['_notify_time_'] < time():
                del n['_notify_time_']
                del g_pool.delayed_notifications[n['subject']]
                g_pool.notifications.append(n)

        # notify each plugin if there are new notifactions:
        while g_pool.notifications:
            n = g_pool.notifications.pop(0)
            for p in g_pool.plugins:
                p.on_notify(n)

        # allow each Plugin to do its work.
        for p in g_pool.plugins:
            p.update(frame,events)

        #check if a plugin need to be destroyed
        g_pool.plugins.clean()

        # render camera image
        glfwMakeContextCurrent(main_window)
        make_coord_system_norm_based()
        g_pool.image_tex.update_from_frame(frame)
        g_pool.image_tex.draw()
        make_coord_system_pixel_based(frame.img.shape)
        # render visual feedback from loaded plugins
        for p in g_pool.plugins:
            p.gl_display()

        graph.push_view()
        fps_graph.draw()
        cpu_graph.draw()
        pupil_graph.draw()
        graph.pop_view()
        g_pool.gui.update()

        #present frames at appropriate speed
        cap.wait(frame)

        glfwSwapBuffers(main_window)
        glfwPollEvents()

    session_settings['loaded_plugins'] = g_pool.plugins.get_initializers()
    session_settings['min_data_confidence'] = g_pool.min_data_confidence
    session_settings['gui_scale'] = g_pool.gui.scale
    session_settings['ui_config'] = g_pool.gui.configuration
    session_settings['window_size'] = glfwGetWindowSize(main_window)
    session_settings['window_position'] = glfwGetWindowPos(main_window)
    session_settings['version'] = g_pool.version
    session_settings.close()

    # de-init all running plugins
    for p in g_pool.plugins:
        p.alive = False
    g_pool.plugins.clean()

    cap.cleanup()
    g_pool.gui.terminate()
    glfwDestroyWindow(main_window)

Example 42

Project: autotest
Source File: fence_apc_snmp.py
View license
def main():
    apc_base = "enterprises.apc.products.hardware."
    apc_outletctl = "masterswitch.sPDUOutletControl.sPDUOutletControlTable.sPDUOutletControlEntry.sPDUOutletCtl."
    apc_outletstatus = "masterswitch.sPDUOutletStatus.sPDUOutletStatusMSPTable.sPDUOutletStatusMSPEntry.sPDUOutletStatusMSP."

    address = ""
    output = ""
    port = ""
    action = "outletReboot"
    status_check = False
    verbose = False

    if not glob('/usr/share/snmp/mibs/powernet*.mib'):
        sys.stderr.write('This APC Fence script uses snmp to control the APC power switch. This script requires that net-snmp-utils be installed on all nodes in the cluster, and that the powernet369.mib file be located in /usr/share/snmp/mibs/\n')
        sys.exit(1)

    if len(sys.argv) > 1:
        try:
            opts, args = getopt.getopt(sys.argv[1:], "a:hl:p:n:o:vV", ["help", "output="])
        except getopt.GetoptError:
            # print help info and quit
            usage()
            sys.exit(2)

        for o, a in opts:
            if o == "-v":
                verbose = True
            if o == "-V":
                print "%s\n" % FENCE_RELEASE_NAME
                print "%s\n" % REDHAT_COPYRIGHT
                print "%s\n" % BUILD_DATE
                sys.exit(0)
            if o in ("-h", "--help"):
                usage()
                sys.exit(0)
            if o == "-n":
                port = a
            if o == "-o":
                lcase = a.lower()  # Lower case string
                if lcase == "off":
                    action = "outletOff"
                elif lcase == "on":
                    action = "outletOn"
                elif lcase == "reboot":
                    action = "outletReboot"
                elif lcase == "status":
                    #action = "sPDUOutletStatusMSPOutletState"
                    action = ""
                    status_check = True
                else:
                    usage()
                    sys.exit()
            if o == "-a":
                address = a

        if address == "":
            usage()
            sys.exit(1)

        if port == "":
            usage()
            sys.exit(1)

    else:  # Get opts from stdin
        params = {}
        # place params in dict
        for line in sys.stdin:
            val = line.split("=")
            if len(val) == 2:
                params[val[0].strip()] = val[1].strip()

        try:
            address = params["ipaddr"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing ipaddr param for fence_apc...exiting")
            sys.exit(1)
        try:
            login = params["login"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing login param for fence_apc...exiting")
            sys.exit(1)

        try:
            passwd = params["passwd"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing passwd param for fence_apc...exiting")
            sys.exit(1)

        try:
            port = params["port"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing port param for fence_apc...exiting")
            sys.exit(1)

        try:
            a = params["option"]
            if a == "Off" or a == "OFF" or a == "off":
                action = POWER_OFF
            elif a == "On" or a == "ON" or a == "on":
                action = POWER_ON
            elif a == "Reboot" or a == "REBOOT" or a == "reboot":
                action = POWER_REBOOT
        except KeyError, e:
            action = POWER_REBOOT

        # End of stdin section

    apc_command = apc_base + apc_outletctl + port

    args_status = list()
    args_off = list()
    args_on = list()

    args_status.append("/usr/bin/snmpget")
    args_status.append("-Oqu")  # sets printing options
    args_status.append("-v")
    args_status.append("1")
    args_status.append("-c")
    args_status.append("private")
    args_status.append("-m")
    args_status.append("ALL")
    args_status.append(address)
    args_status.append(apc_command)

    args_off.append("/usr/bin/snmpset")
    args_off.append("-Oqu")  # sets printing options
    args_off.append("-v")
    args_off.append("1")
    args_off.append("-c")
    args_off.append("private")
    args_off.append("-m")
    args_off.append("ALL")
    args_off.append(address)
    args_off.append(apc_command)
    args_off.append("i")
    args_off.append("outletOff")

    args_on.append("/usr/bin/snmpset")
    args_on.append("-Oqu")  # sets printing options
    args_on.append("-v")
    args_on.append("1")
    args_on.append("-c")
    args_on.append("private")
    args_on.append("-m")
    args_on.append("ALL")
    args_on.append(address)
    args_on.append(apc_command)
    args_on.append("i")
    args_on.append("outletOn")

    cmdstr_status = ' '.join(args_status)
    cmdstr_off = ' '.join(args_off)
    cmdstr_on = ' '.join(args_on)

# This section issues the actual commands. Reboot is split into
# Off, then On to make certain both actions work as planned.
#
# The status command just dumps the outlet status to stdout.
# The status checks that are made when turning an outlet on or off, though,
# use the execWithCaptureStatus so that the stdout from snmpget can be
# examined and the desired operation confirmed.

    if status_check:
        if verbose:
            fd = open("/tmp/apclog", "w")
            fd.write("Attempting the following command: %s\n" % cmdstr_status)
        strr = os.system(cmdstr_status)
        print strr
        if verbose:
            fd.write("Result: %s\n" % strr)
            fd.close()

    else:
        if action == POWER_OFF:
            if verbose:
                fd = open("/tmp/apclog", "w")
                fd.write("Attempting the following command: %s\n" % cmdstr_off)
            strr = os.system(cmdstr_off)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
                fd.close()
            if strr.find(POWER_OFF) >= 0:
                print "Success. Outlet off"
                sys.exit(0)
            else:
                if verbose:
                    fd.write("Unable to power off apc outlet")
                    fd.close()
                sys.exit(1)

        elif action == POWER_ON:
            if verbose:
                fd = open("/tmp/apclog", "w")
                fd.write("Attempting the following command: %s\n" % cmdstr_on)
            strr = os.system(cmdstr_on)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            #strr = os.system(cmdstr_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
            if strr.find(POWER_ON) >= 0:
                if verbose:
                    fd.close()
                print "Success. Outlet On."
                sys.exit(0)
            else:
                print "Unable to power on apc outlet"
                if verbose:
                    fd.write("Unable to power on apc outlet")
                    fd.close()
                sys.exit(1)

        elif action == POWER_REBOOT:
            if verbose:
                fd = open("/tmp/apclog", "w")
                fd.write("Attempting the following command: %s\n" % cmdstr_off)
            strr = os.system(cmdstr_off)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            #strr = os.system(cmdstr_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
            if strr.find(POWER_OFF) < 0:
                print "Unable to power off apc outlet"
                if verbose:
                    fd.write("Unable to power off apc outlet")
                    fd.close()
                sys.exit(1)

            if verbose:
                fd.write("Attempting the following command: %s\n" % cmdstr_on)
            strr = os.system(cmdstr_on)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            #strr = os.system(cmdstr_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
            if strr.find(POWER_ON) >= 0:
                if verbose:
                    fd.close()
                print "Success: Outlet Rebooted."
                sys.exit(0)
            else:
                print "Unable to power on apc outlet"
                if verbose:
                    fd.write("Unable to power on apc outlet")
                    fd.close()
                sys.exit(1)

Example 43

Project: tp-qemu
Source File: ntttcp.py
View license
def run(test, params, env):
    """
    Run NTttcp on Windows guest

    1) Install NTttcp in server/client side by Autoit
    2) Start NTttcp in server/client side
    3) Get test results

    :param test: kvm test object
    :param params: Dictionary with the test parameters
    :param env: Dictionary with test environment.
    """
    login_timeout = int(params.get("login_timeout", 360))
    timeout = int(params.get("timeout"))
    results_path = os.path.join(test.resultsdir,
                                'raw_output_%s' % test.iteration)
    platform = "x86"
    if "64" in params["vm_arch_name"]:
        platform = "x64"
    buffers = params.get("buffers").split()
    buf_num = params.get("buf_num", 200000)
    session_num = params.get("session_num")

    vm_sender = env.get_vm(params["main_vm"])
    vm_sender.verify_alive()
    vm_receiver = None
    receiver_addr = params.get("receiver_address")

    logging.debug(utils.system("numactl --hardware", ignore_status=True))
    logging.debug(utils.system("numactl --show", ignore_status=True))
    # pin guest vcpus/memory/vhost threads to last numa node of host by default
    if params.get('numa_node'):
        numa_node = int(params.get('numa_node'))
        node = utils_misc.NumaNode(numa_node)
        utils_test.qemu.pin_vm_threads(vm_sender, node)

    if not receiver_addr:
        vm_receiver = env.get_vm("vm2")
        vm_receiver.verify_alive()
        try:
            sess = None
            sess = vm_receiver.wait_for_login(timeout=login_timeout)
            receiver_addr = vm_receiver.get_address()
            if not receiver_addr:
                raise error.TestError("Can't get receiver(%s) ip address" %
                                      vm_sender.name)
            if params.get('numa_node'):
                utils_test.qemu.pin_vm_threads(vm_receiver, node)
        finally:
            if sess:
                sess.close()

    @error.context_aware
    def install_ntttcp(session):
        """ Install ntttcp through a remote session """
        logging.info("Installing NTttcp ...")
        try:
            # Don't install ntttcp if it's already installed
            error.context("NTttcp directory already exists")
            session.cmd(params.get("check_ntttcp_cmd"))
        except aexpect.ShellCmdError:
            ntttcp_install_cmd = params.get("ntttcp_install_cmd")
            error.context("Installing NTttcp on guest")
            session.cmd(ntttcp_install_cmd % (platform, platform), timeout=200)

    def receiver():
        """ Receive side """
        logging.info("Starting receiver process on %s", receiver_addr)
        if vm_receiver:
            session = vm_receiver.wait_for_login(timeout=login_timeout)
        else:
            username = params.get("username", "")
            password = params.get("password", "")
            prompt = params.get("shell_prompt", "[\#\$]")
            linesep = eval("'%s'" % params.get("shell_linesep", r"\n"))
            client = params.get("shell_client")
            port = int(params.get("shell_port"))
            log_filename = ("session-%s-%s.log" % (receiver_addr,
                                                   utils_misc.generate_random_string(4)))
            session = remote.remote_login(client, receiver_addr, port,
                                          username, password, prompt,
                                          linesep, log_filename, timeout)
            session.set_status_test_command("echo %errorlevel%")
        install_ntttcp(session)
        ntttcp_receiver_cmd = params.get("ntttcp_receiver_cmd")
        global _receiver_ready
        f = open(results_path + ".receiver", 'a')
        for b in buffers:
            utils_misc.wait_for(lambda: not _wait(), timeout)
            _receiver_ready = True
            rbuf = params.get("fixed_rbuf", b)
            cmd = ntttcp_receiver_cmd % (
                session_num, receiver_addr, rbuf, buf_num)
            r = session.cmd_output(cmd, timeout=timeout,
                                   print_func=logging.debug)
            f.write("Send buffer size: %s\n%s\n%s" % (b, cmd, r))
        f.close()
        session.close()

    def _wait():
        """ Check if receiver is ready """
        global _receiver_ready
        if _receiver_ready:
            return _receiver_ready
        return False

    def sender():
        """ Send side """
        logging.info("Sarting sender process ...")
        session = vm_sender.wait_for_login(timeout=login_timeout)
        install_ntttcp(session)
        ntttcp_sender_cmd = params.get("ntttcp_sender_cmd")
        f = open(results_path + ".sender", 'a')
        try:
            global _receiver_ready
            for b in buffers:
                cmd = ntttcp_sender_cmd % (
                    session_num, receiver_addr, b, buf_num)
                # Wait until receiver ready
                utils_misc.wait_for(_wait, timeout)
                r = session.cmd_output(cmd, timeout=timeout,
                                       print_func=logging.debug)
                _receiver_ready = False
                f.write("Send buffer size: %s\n%s\n%s" % (b, cmd, r))
        finally:
            f.close()
            session.close()

    def parse_file(resultfile):
        """ Parse raw result files and generate files with standard format """
        fileobj = open(resultfile, "r")
        lst = []
        found = False
        for line in fileobj.readlines():
            o = re.findall("Send buffer size: (\d+)", line)
            if o:
                bfr = o[0]
            if "Total Throughput(Mbit/s)" in line:
                found = True
            if found:
                fields = line.split()
                if len(fields) == 0:
                    continue
                try:
                    [float(i) for i in fields]
                    lst.append([bfr, fields[-1]])
                except ValueError:
                    continue
                found = False
        return lst

    try:
        bg = utils.InterruptedThread(receiver, ())
        bg.start()
        if bg.isAlive():
            sender()
            bg.join(suppress_exception=True)
        else:
            raise error.TestError("Can't start backgroud receiver thread")
    finally:
        for i in glob.glob("%s.receiver" % results_path):
            f = open("%s.RHS" % results_path, "w")
            raw = "  buf(k)| throughput(Mbit/s)"
            logging.info(raw)
            f.write("#ver# %s\n#ver# host kernel: %s\n" %
                    (commands.getoutput("rpm -q qemu-kvm"), os.uname()[2]))
            desc = """#desc# The tests are sessions of "NTttcp", send buf number is %s. 'throughput' was taken from ntttcp's report.
#desc# How to read the results:
#desc# - The Throughput is measured in Mbit/sec.
#desc#
""" % (buf_num)
            f.write(desc)
            f.write(raw + "\n")
            for j in parse_file(i):
                raw = "%8s| %8s" % (j[0], j[1])
                logging.info(raw)
                f.write(raw + "\n")
            f.close()

Example 44

Project: tp-qemu
Source File: ovs_mirror.py
View license
@error.context_aware
def run(test, params, env):
    """
    Test port mirror between guests in one ovs backend

    1) Boot the three vms.
    2) Set tap device of vm1 to mirror (input, output, input & output)
       of tap device of vm2 in openvswith.
    3) Start two tcpdump threads to dump icmp packet from vm2 and vm3.
    4) Ping host from vm2 and vm3.
    5) Stop ping in vm2 and vm3
    6) Check tcmpdump result in vm1.

    :param test: Kvm test object
    :param params: Dictionary with the test parameters
    :param env: Dictionary with test environment.
    """

    def make_mirror_cmd(
            mirror_port, target_port, direction="all", ovs="ovs0"):
        """
        Generate create ovs port mirror command.

        :parm mirror_port: port name in ovs in mirror status.
        :parm target_port: port name in ovs be mirroring.
        :parm direction: mirror direction, all, only input or output.
        :parm ovs: ovs port name.

        :return: string of ovs port mirror command.
        """
        cmd = ["ovs-vsctl set Bridge %s [email protected] " % ovs]
        for port in [mirror_port, target_port]:
            cmd.append("-- [email protected]%s get Port %s " % (port, port))
        if direction == "input":
            cmd.append(
                "-- [email protected] create Mirror name=input_of_%s" %
                target_port)
            cmd.append("[email protected]%s" % target_port)
        elif direction == "output":
            cmd.append(
                "-- [email protected] create Mirror name=output_of_%s" % target_port)
            cmd.append("[email protected]%s" % target_port)
        else:
            cmd.append(
                "-- [email protected] create Mirror name=mirror_%s" % target_port)
            cmd.append("[email protected]%s" % target_port)
            cmd.append("[email protected]%s" % target_port)
        cmd.append("[email protected]%s" % mirror_port)
        return " ".join(cmd)

    def create_mirror_port(mirror_port, target_port, direction, ovs):
        """
        Execute ovs port mirror command and check port really in mirror status.

        :parm mirror_port: port name in ovs in mirror status.
        :parm target_port: port name in ovs be mirroring.
        :parm direction: mirror direction, all, only input or output.
        :parm ovs: ovs port name.
        """
        mirror_cmd = make_mirror_cmd(mirror_port, target_port, direction, ovs)
        uuid = utils.system_output(mirror_cmd)
        output = utils.system_output("ovs-vsctl list mirror")
        if uuid not in output:
            logging.debug("Create OVS Mirror CMD: %s " % mirror_cmd)
            logging.debug("Ovs Info: %s " % output)
            raise error.TestFail("Setup mirorr port failed")

    def check_tcpdump(output, target_ip, host_ip, direction):
        """
        Check tcpdump result file and report unexpect packet to debug log.

        :parm output: string of tcpdump output.
        :parm target_p: ip of port in ovs be mirroring.
        :parm host_ip: ip of ovs port.
        :parm direction: mirror direction, all, only input or output.

        :return: bool type result.
        """
        rex = r".*IP (%s|%s) > " % (host_ip, target_ip)
        rex += "(%s|%s).*ICMP echo.*" % (target_ip, host_ip)
        if direction == "input":
            rex = r".*IP %s > %s.*ICMP echo reply.*" % (host_ip, target_ip)
        if direction == "output":
            rex = r".*IP %s > %s.*ICMP echo request.*" % (target_ip, host_ip)
        for idx, _ in enumerate(output.splitlines()):
            if not re.match(rex, _):
                logging.debug("Unexpect packet in line %d: %s" % (idx, _))
                return False
        return True

    os_dep.command("ovs-vsctl")
    ovs_name = params.get("ovs_name", "ovs0")
    direction = params.get("direction", "all")
    mirror_vm = params.get("mirror_vm", "vm1")
    target_vm = params.get("target_vm", "vm2")
    refer_vm = params.get("refer_vm", "vm3")
    net_mask = params.get("net_mask", "24")
    host_ip = params.get("ip_ovs", "192.168.1.1")
    pre_guest_cmd = params.get("pre_guest_cmd")
    ovs_create_cmd = params.get("ovs_create_cmd")
    ovs_remove_cmd = params.get("ovs_remove_cmd")
    login_timeout = int(params.get("login_timeout", "600"))

    error.context("Create private ovs switch", logging.info)
    utils.system(ovs_create_cmd)
    params["start_vm"] = "yes"
    params["netdst"] = ovs_name
    vms_info = {}
    try:
        for p_vm in params.get("vms").split():
            env_process.preprocess_vm(test, params, env, p_vm)
            o_vm = env.get_vm(p_vm)
            o_vm.verify_alive()
            ip = params["ip_%s" % p_vm]
            mac = o_vm.get_mac_address()
            ses = o_vm.wait_for_serial_login(timeout=login_timeout)
            ses.cmd(pre_guest_cmd)
            nic_name = utils_net.get_linux_ifname(ses, mac)
            ifname = o_vm.get_ifname()
            vms_info[p_vm] = [o_vm, ifname, ip, ses, nic_name]

        mirror_ifname = vms_info[mirror_vm][1]
        mirror_ip = vms_info[mirror_vm][2]
        mirror_nic = vms_info[mirror_vm][4]
        target_ifname = vms_info[target_vm][1]
        target_ip = vms_info[target_vm][2]
        refer_ip = vms_info[refer_vm][2]
        session = vms_info[mirror_vm][3]

        error.context("Create mirror port in ovs", logging.info)
        create_mirror_port(mirror_ifname, target_ifname, direction, ovs_name)
        ping_cmd = "ping -c 10 %s" % host_ip
        status, output = session.cmd_status_output(ping_cmd, timeout=60)
        if status == 0:
            ifcfg = session.cmd_output_safe("ifconfig")
            logging.debug("Guest network info: %s" % ifcfg)
            logging.debug("Ping results: %s" % output)
            raise error.TestFail("All packets from %s to host should lost"
                                 % mirror_vm)

        error.context("Start tcpdump threads in %s" % mirror_vm, logging.info)
        ifup_cmd = "ifconfig %s 0 up" % mirror_nic
        session.cmd(ifup_cmd, timeout=60)
        for vm, ip in [(target_vm, target_ip), (refer_vm, refer_ip)]:
            tcpdump_cmd = "tcpdump -l -n host %s and icmp >" % ip
            tcpdump_cmd += "/tmp/tcpdump-%s.txt &" % vm
            logging.info("tcpdump command: %s" % tcpdump_cmd)
            session.sendline(tcpdump_cmd)

        error.context("Start ping threads in %s %s" % (target_vm, refer_vm),
                      logging.info)
        for vm in [target_vm, refer_vm]:
            ses = vms_info[vm][3]
            nic_name = vms_info[vm][4]
            ip = vms_info[vm][2]
            ifup_cmd = "ifconfig %s %s/%s up" % (nic_name, ip, net_mask)
            ses.cmd(ifup_cmd)
            time.sleep(0.5)
            logging.info("Ping host from %s" % vm)
            ses.cmd("ping %s -c 100" % host_ip, timeout=150)

        error.context("Check tcpdump results", logging.info)
        session.cmd_output_safe("pkill tcpdump")
        utils.system("ovs-vsctl clear bridge %s mirrors" % ovs_name)
        ifup_cmd = "ifconfig %s %s/%s up" % (mirror_nic, mirror_ip, net_mask)
        session.cmd(ifup_cmd, timeout=60)
        time.sleep(0.5)
        for vm in [target_vm, refer_vm]:
            src_file = "/tmp/tcpdump-%s.txt" % vm
            dst_file = os.path.join(test.resultsdir, "tcpdump-%s.txt" % vm)
            vms_info[mirror_vm][0].copy_files_from(src_file, dst_file)
            fd = open(dst_file, "r")
            content = fd.read().strip()
            fd.close()
            if vm == refer_vm and content:
                raise error.TestFail(
                    "should not packet from %s dumped in %s" %
                    (refer_vm, mirror_vm))
            elif not check_tcpdump(content, target_ip, host_ip, direction):
                raise error.TestFail(
                    "Unexpect packages from %s dumped in %s" % (vm, mirror_vm))
    finally:
        for vm in vms_info:
            vms_info[vm][0].destroy(gracefully=False)
        for f in glob.glob("/var/log/openvswith/*.log"):
            dst = os.path.join(test.resultsdir, os.path.basename(f))
            shutil.copy(f, dst)
        utils.system(ovs_remove_cmd, ignore_status=False)

Example 45

Project: tp-qemu
Source File: ovs_qos.py
View license
@error.context_aware
def run(test, params, env):
    """
     Test Qos between guests in one ovs backend

    1) Boot the vms
    2) Apply QoS limitation to 1Mbps on the tap of a guest.
    3) Start netperf server on another guest.
    4) Start netperf client on guest in step 1 with option -l 60.
    5) Stop netperf client and set QoS to 10Mbps.
    6) Run step 4 again.
    7) Verify vm through out.

    :param test: Kvm test object
    :param params: Dictionary with the test parameters.
    :param env: Dictionary with test environment.
    """

    def set_ovs_port_attr(iface, attribute, value):
        """
        Set OVS port attribute.
        """
        cmd = "ovs-vsctl set interface %s %s=%s" % (iface, attribute, value)
        logging.info("execute host command: %s" % cmd)
        status = utils.system(cmd, ignore_status=True)
        if status != 0:
            err_msg = "set %s to %s for interface '%s' " % (
                attribute, value, iface)
            err_msg += "exited with nozero statu '%d'" % status
            error.TestError(err_msg)

    def set_port_qos(vm, rate, burst):
        """
        Set ingress_policing_rate and ingress_policing_burst for tap device
        used by vm.

        :param vm: netperf client vm object
        :param rate: value of ingress_policing_rate
        :param brust: value of ingress_policing_brust
        """
        iface = vm.get_ifname()
        error.context("Set QoS for tap '%s' use by vm '%s'" % (iface, vm.name),
                      logging.info)
        attributes = zip(['ingress_policing_rate',
                          'ingress_policing_burst'],
                         [rate, burst])
        for k, v in attributes:
            set_ovs_port_attr(iface, k, v)
            time.sleep(0.1)

    def get_throughout(netperf_server, server_vm, netperf_client,
                       client_vm, client_options=" -l 60"):
        """
        Get network throughout by netperf.

        :param netperf_server: utils_netperf.NetperfServer instance.
        :param server_ip: ip address of netperf server.
        :param netperf_client: utils_netperf.NetperfClient instance.
        :param client_options: netperf client start options.

        :return: float type throughout Kbps.
        """
        error.context("Set '%s' as netperf server" % server_vm.name,
                      logging.info)
        if not netperf_server.is_server_running():
            netperf_server.start()

        error.context("Set '%s' as netperf client" % client_vm.name,
                      logging.info)
        server_ip = server_vm.get_address()
        output = netperf_client.start(server_ip, client_options)
        logging.debug("netperf client output: %s" % output)
        regex = r"\d+\s+\d+\s+\d+\s+[\d.]+\s+([\d.]+)"
        try:
            throughout = float(re.search(regex, output, re.M).groups()[0])
            return throughout * 1000
        except Exception:
            raise error.TestError("Invaild output format of netperf client!")
        finally:
            netperf_client.stop()

    def is_test_pass(data):
        """
        Check throughout near gress_policing_rate set for tap device.
        """
        return data[1] <= data[2] + data[3]

    def report_test_results(datas):
        """
        Report failed test scenarios.
        """
        error.context("Analyze guest throughout", logging.info)
        fails = [_ for _ in datas if not is_test_pass(_)]
        if fails:
            msg = "OVS Qos test failed, "
            for tap, throughout, rate, burst in fails:
                msg += "netperf throughout(%s) on '%s' " % (throughout, tap)
                msg += "should be near ingress_policing_rate(%s), " % rate
                msg += "ingress_policing_burst is %s;\n" % burst
            raise error.TestFail(msg)

    def clear_qos_setting(iface):
        error.context("Clear qos setting for ovs port '%s'" % iface,
                      logging.info)
        clear_cmd = "ovs-vsctl clear Port %s qos" % iface
        utils.system(clear_cmd)
        logging.info("Clear ovs command: %s" % clear_cmd)

    def setup_netperf_env():
        """
        Setup netperf envrioments in vms
        """
        def __get_vminfo():
            """
            Get vms information;
            """
            login_timeout = float(params.get("login_timeout", 360))
            clear_iptables_cmd = "service iptables stop; iptables -F"
            guest_info = ["username", "password", "shell_client",
                          "shell_port", "os_type"]
            vms_info = []
            for _ in params.get("vms").split():
                info = map(
                    lambda x: params.object_params(_).get(x),
                    guest_info)
                vm = env.get_vm(_)
                vm.verify_alive()
                session = vm.wait_for_login(timeout=login_timeout)
                session.cmd(clear_iptables_cmd, ignore_all_errors=True)
                vms_info.append((vm, info))
            return vms_info

        netperf_link = params.get("netperf_link")
        netperf_link = os.path.join(
            data_dir.get_deps_dir("netperf"),
            netperf_link)
        md5sum = params.get("pkg_md5sum")
        netperf_server_link = params.get(
            "netperf_server_link_win",
            netperf_link)
        netperf_server_link = os.path.join(data_dir.get_deps_dir("netperf"),
                                           netperf_server_link)
        netperf_client_link = params.get(
            "netperf_client_link_win",
            netperf_link)
        netperf_client_link = os.path.join(data_dir.get_deps_dir("netperf"),
                                           netperf_client_link)

        server_path_linux = params.get("server_path", "/var/tmp")
        client_path_linux = params.get("client_path", "/var/tmp")
        server_path_win = params.get("server_path_win", "c:\\")
        client_path_win = params.get("client_path_win", "c:\\")
        compile_option_client = params.get("compile_option_client", "")
        compile_option_server = params.get("compile_option_server", "")

        netperf_servers, netperf_clients = [], []
        for idx, (vm, info) in enumerate(__get_vminfo()):
            if idx % 2 == 0:
                if info[-1] == "windows":
                    netperf_link = netperf_server_link
                    server_path = server_path_win
                else:
                    netperf_link = netperf_link
                    server_path = server_path_linux
                server = utils_netperf.NetperfServer(
                    vm.get_address(),
                    server_path,
                    md5sum,
                    netperf_link,
                    port=info[-2],
                    client=info[-3],
                    password=info[-4],
                    username=info[-5],
                    compile_option=compile_option_server)
                netperf_servers.append((server, vm))
                continue
            else:
                if info[-1] == "windows":
                    netperf_link = netperf_client_link
                    client_path = client_path_win
                else:
                    netperf_link = netperf_link
                    client_path = client_path_linux
                client = utils_netperf.NetperfClient(
                    vm.get_address(),
                    client_path,
                    md5sum,
                    netperf_link,
                    port=info[-2],
                    client=info[-3],
                    password=info[-4],
                    username=info[-5],
                    compile_option=compile_option_client)
                netperf_clients.append((client, vm))
                continue
        return netperf_clients, netperf_servers

    os_dep.command("ovs-vsctl")
    if params.get("netdst") not in utils.system_output("ovs-vsctl show"):
        raise error.TestError("This is a openvswitch only test")
    extra_options = params.get("netperf_client_options", " -l 60")
    rate_brust_pairs = params.get("rate_brust_pairs").split()
    rate_brust_pairs = map(lambda x: map(int, x.split(',')), rate_brust_pairs)
    results = []
    try:
        netperf_clients, netperf_servers = setup_netperf_env()
        for idx in range(len(netperf_clients)):
            netperf_client, client_vm = netperf_clients[idx]
            idx = (idx < len(netperf_servers) and [idx] or [0])[0]
            netperf_server, server_vm = netperf_servers[idx]
            for rate, burst in rate_brust_pairs:
                set_port_qos(client_vm, rate, burst)
                time.sleep(3)
                throughout = get_throughout(netperf_server,
                                            server_vm,
                                            netperf_client,
                                            client_vm,
                                            extra_options)
                iface = client_vm.get_ifname()
                clear_qos_setting(iface)
                results.append([iface, throughout, rate, burst])
        report_test_results(results)
    finally:
        for f in glob.glob("/var/log/openvswith/*.log"):
            dst = os.path.join(test.resultsdir, os.path.basename(f))
            shutil.copy(f, dst)

Example 46

Project: pyspace
Source File: trainer.py
View license
    def prepare_training(self, training_files, potentials, operation, nullmarker_stride_ms = None):
        """ Prepares pyspace live for training.

        Prepares everything for training of pyspace live,
        i.e. creates flows based on the dataflow specs
        and configures them.
        """
        online_logger.info( "Preparing Training")
        self.potentials = potentials
        self.operation = operation
        self.nullmarker_stride_ms = nullmarker_stride_ms
        if self.nullmarker_stride_ms == None:
            online_logger.warn( 'Nullmarker stride interval is %s. You can specify it in your parameter file.' % self.nullmarker_stride_ms)
        else:
            online_logger.info( 'Nullmarker stride interval is set to %s ms ' % self.nullmarker_stride_ms)

        online_logger.info( "Creating flows..")
        for key in self.potentials.keys():
            spec_base = self.potentials[key]["configuration"].spec_dir
            if self.operation == "train":
                self.potentials[key]["node_chain"] = os.path.join(spec_base, self.potentials[key]["node_chain"])
                online_logger.info( "node_chain_spec:" + self.potentials[key]["node_chain"])

            elif self.operation in ("prewindowing", "prewindowing_offline"):
                self.potentials[key]["prewindowing_flow"] = os.path.join(spec_base, self.potentials[key]["prewindowing_flow"])
                online_logger.info( "prewindowing_dataflow_spec: " + self.potentials[key]["prewindowing_flow"])

            elif self.operation == "prewindowed_train":
                self.potentials[key]["postprocess_flow"] = os.path.join(spec_base, self.potentials[key]["postprocess_flow"])
                online_logger.info( "postprocessing_dataflow_spec: " + self.potentials[key]["postprocess_flow"])

            self.training_active_potential[key] = multiprocessing.Value("b",False)

        online_logger.info("Path variables set for NodeChains")

        # check if multiple potentials are given for training
        if isinstance(training_files, list):
            self.training_data = training_files
        else:
            self.training_data = [training_files]

        # Training is done in separate processes, we send the time series
        # windows to these threads via two queues
        online_logger.info( "Initializing Queues")
        for key in self.potentials.keys():
            self.queue[key] = multiprocessing.Queue()


        def flow_generator(key):
            """create a generator to yield all the abri flow windows"""
            # Yield all windows until a None item is found in the queue
            while True:
                window = self.queue[key].get(block = True, timeout = None)
                if window == None: break
                yield window

        # Create the actual data flows
        for key in self.potentials.keys():

            if self.operation == "train":
                self.node_chains[key] = NodeChainFactory.flow_from_yaml(Flow_Class = NodeChain,
                                                         flow_spec = file(self.potentials[key]["node_chain"]))
                self.node_chains[key][0].set_generator(flow_generator(key))
                flow = open(self.potentials[key]["node_chain"])
            elif self.operation in ("prewindowing", "prewindowing_offline"):
                online_logger.info("loading prewindowing flow..")
                online_logger.info("file: " + str(self.potentials[key]["prewindowing_flow"]))

                self.node_chains[key] = NodeChainFactory.flow_from_yaml(Flow_Class = NodeChain,
                                                             flow_spec = file(self.potentials[key]["prewindowing_flow"]))
                self.node_chains[key][0].set_generator(flow_generator(key))
                flow = open(self.potentials[key]["prewindowing_flow"])
            elif self.operation == "prewindowed_train":
                self.node_chains[key] = NodeChainFactory.flow_from_yaml(Flow_Class = NodeChain, flow_spec = file(self.potentials[key]["postprocess_flow"]))
                replace_start_and_end_markers = False

                final_collection = TimeSeriesDataset()
                final_collection_path = os.path.join(self.prewindowed_data_directory, key, "all_train_data")
                # delete previous training collection
                if os.path.exists(final_collection_path):
                    online_logger.info("deleting old training data collection for " + key)
                    shutil.rmtree(final_collection_path)

                # load all prewindowed collections and
                # append data to the final collection
                prewindowed_sets = \
                    glob.glob(os.path.join(self.prewindowed_data_directory, key, "*"))
                if len(prewindowed_sets) == 0:
                    online_logger.error("Couldn't find data, please do prewindowing first!")
                    raise Exception
                online_logger.info("concatenating prewindowed data from " + str(prewindowed_sets))

                for s,d in enumerate(prewindowed_sets):
                    collection = BaseDataset.load(d)
                    data = collection.get_data(0, 0, "train")
                    for d,(sample,label) in enumerate(data):
                        if replace_start_and_end_markers:
                            # in case we concatenate multiple 'Window' labeled
                            # sets we have to remove every start- and endmarker
                            for k in sample.marker_name.keys():
                                # find '{S,s}  8' or '{S,s}  9'
                                m = re.match("^s\s{0,2}[8,9]{1}$", k, re.IGNORECASE)
                                if m is not None:
                                    online_logger.info(str("remove %s from %d %d" % (m.group(), s, d)))
                                    del(sample.marker_name[m.group()])

                            if s == len(prewindowed_sets)-1 and \
                                d == len(data)-1:
                                # insert endmarker
                                sample.marker_name["S  9"] = [0.0]
                                online_logger.info("added endmarker" + str(s) + " " + str(d))

                            if s == 0 and d == 0:
                                # insert startmarker
                                sample.marker_name["S  8"] = [0.0]
                                online_logger.info("added startmarker" + str(s) + " " + str(d))

                        final_collection.add_sample(sample, label, True)

                # save final collection (just for debugging)
                os.mkdir(final_collection_path)
                final_collection.store(final_collection_path)

                online_logger.info("stored final collection at " + final_collection_path)

                # load final collection again for training
                online_logger.info("loading data from " + final_collection_path)
                self.prewindowed_data[key] =  BaseDataset.load(final_collection_path)
                self.node_chains[key][0].set_input_dataset(self.prewindowed_data[key])

                flow = open(self.potentials[key]["postprocess_flow"])

            # create window_stream for every potential

            if self.operation in ("prewindowing"):
                window_spec_file = os.path.join(spec_base,"node_chains","windower",
                             self.potentials[key]["windower_spec_path_train"])

                self.window_stream[key] = \
                        self.stream_manager.request_window_stream(window_spec_file,
                                                              nullmarker_stride_ms = self.nullmarker_stride_ms)
            elif self.operation in ("prewindowing_offline"):
                pass
            elif self.operation in ("train"):
                pass

            self.node_chain_definitions[key] = yaml.load(flow)
            flow.close()

        # TODO: check if the prewindowing flow is still needed when using the stream mode!
        if self.operation in ("train"):
            online_logger.info( "Removing old flows...")
            try:
                shutil.rmtree(self.flow_storage)
            except:
                online_logger.info("Could not delete flow storage directory")
            os.mkdir(self.flow_storage)
        elif self.operation in ("prewindowing", "prewindowing_offline"):
            # follow this policy:
            # - delete prewindowed data older than 12 hours
            # - always delete trained/stored flows
            now = datetime.datetime.now()
            then = now - datetime.timedelta(hours=12)

            if not os.path.exists(self.prewindowed_data_directory):
                os.mkdir(self.prewindowed_data_directory)
            if not os.path.exists(self.flow_storage):
                os.mkdir(self.flow_storage)

            for key in self.potentials.keys():
                found = self.find_files_older_than(then, \
                        os.path.join(self.prewindowed_data_directory, key))
                if found is not None:
                    for f in found:
                        online_logger.info(str("recursively deleting files in \'%s\'" % f))
                        try:
                            shutil.rmtree(os.path.abspath(f))
                        except Exception as e:
                            # TODO: find a smart solution for this!
                            pass # dir was probably already deleted..

                if os.path.exists(os.path.join(self.prewindowed_data_directory, key, "all_train_data")):
                    shutil.rmtree(os.path.join(self.prewindowed_data_directory, key, "all_train_data"))
                    online_logger.info("deleted concatenated training data for " + key)


        online_logger.info( "Training preparations finished")
        return 0

Example 47

Project: pyspace
Source File: time_series.py
View license
    def get_data(self, run_nr, split_nr, train_test):
        """ Return the train or test data for the given split in the given run.
        
        **Parameters**
          
          :run_nr: The number of the run whose data should be loaded.
          
          :split_nr: The number of the split whose data should be loaded.
          
          :train_test: "train" if the training data should be loaded.
                       "test" if the test data should be loaded.
    
        """
        # Do lazy loading of the time series objects.
        if isinstance(self.data[(run_nr, split_nr, train_test)], basestring):
            self._log("Lazy loading of %s time series windows from input "
                      "collection for run %s, split %s." % (train_test, run_nr, 
                                                            split_nr))
            s_format = self.meta_data["storage_format"]
            if type(s_format) == list:
                f_format = s_format[0]
            else:
                f_format = s_format
            if f_format == "pickle":
                # Load the time series from a pickled file
                f = open(self.data[(run_nr, split_nr, train_test)], 'r')
                try:
                    self.data[(run_nr, split_nr, train_test)] = cPickle.load(f)
                except ImportError:
                    # code for backward compatibility
                    # redirection of old path
                    f.seek(0)
                    self._log("Loading deprecated data. Please transfer it " +
                              "to new format.",level=logging.WARNING)
                    from pySPACE.resources.data_types import time_series
                    sys.modules['abri_dp.types.time_series'] = time_series
                    self.data[(run_nr, split_nr, train_test)] = cPickle.load(f)
                    del sys.modules['abri_dp.types.time_series']
                f.close()
            elif f_format in ["mat", "matlab", "MATLAB"]:
                from scipy.io import loadmat
                from pySPACE.resources.data_types.time_series import TimeSeries
                ts_fname = self.data[(run_nr, split_nr, train_test)]
                dataset = loadmat(ts_fname)
                channel_names = [name.strip() for name in dataset['channel_names']]
                sf = dataset["sampling_frequency"][0][0]
                self.data[(run_nr, split_nr, train_test)] = []
                # assume third axis to be trial axis
                if "channelXtime" in s_format:
                    for i in range(dataset["data"].shape[2]):
                        self.data[(run_nr, split_nr, train_test)].append(\
                            (TimeSeries(dataset["data"][:,:,i].T, channel_names,
                                        sf), dataset["labels"][i].strip()))
                else:
                    for i in range(dataset["data"].shape[2]):
                        self.data[(run_nr, split_nr, train_test)].append(\
                            (TimeSeries(dataset["data"][:,:,i], channel_names,
                                        sf), dataset["labels"][i].strip()))                    
            elif f_format.startswith("bci_comp"):
                from scipy.io import loadmat
                from pySPACE.resources.data_types.time_series import TimeSeries
                if self.comp_number == "2":
                    if self.comp_set == "4":
                        ts_fname = self.data[(run_nr, split_nr, train_test)]
                        d = loadmat(ts_fname)
                        channel_names = [name[0].astype('|S3') for name in \
                                                                   d["clab"][0]]
                        if train_test == "train":
                            self.data[(run_nr, split_nr, train_test)] = []
                            input_d = d["x_train"]
                            input_l = d["y_train"][0]
                            for i in range(input_d.shape[2]):
                                self.data[(run_nr, split_nr, 
                                           train_test)].append(\
                                            (TimeSeries(input_d[:,:,i],
                                                 channel_names, float(self.sf)), 
                                        "Left" if input_l[i] == 0 else "Right"))
                        else:
                            label_fname = glob.glob(os.path.join(
                                          os.path.dirname(ts_fname),"*.txt"))[0]
                            input_d = d["x_test"]
                            input_l = open(label_fname,'r')
                            self.data[(run_nr, split_nr, train_test)] = []
                            for i in range(input_d.shape[2]):
                                label = int(input_l.readline())
                                self.data[(run_nr, split_nr, 
                                           train_test)].append(\
                                            (TimeSeries(input_d[:,:,i],
                                                 channel_names, float(self.sf)), 
                                             "Left" if label == 0 else "Right"))
                elif self.comp_number == "3":
                    if self.comp_set == "2":
                        data = loadmat(self.data[(run_nr, split_nr, train_test)])
                        signal = data['Signal']
                        flashing = data['Flashing']
                        stimulus_code = data['StimulusCode']
                        stimulus_type = data['StimulusType']
                
                        window = 240
                        Fs = 240
                        channels = 64
                        epochs = signal.shape[0]
                        self.data[(run_nr, split_nr, train_test)] = []
                        self.start_offset_ms = 1000.0
                        self.end_offset_ms = 1000.0
                        
                        whole_len = (self.start_offset_ms + self.end_offset_ms)*Fs/1000.0 + window
                        responses = numpy.zeros((12, 15, whole_len, channels))
                        for epoch in range(epochs):
                            rowcolcnt=numpy.ones(12)
                            for n in range(1, signal.shape[1]):
                                if (flashing[epoch,n]==0 and flashing[epoch,n-1]==1):
                                    rowcol=stimulus_code[epoch,n-1]
                                    if n-24-self.start_offset_ms*Fs/1000.0 < 0:
                                        temp = signal[epoch,0:n+window+self.end_offset_ms*Fs/1000.0-24,:]
                                        temp = numpy.vstack((numpy.zeros((whole_len - temp.shape[0], temp.shape[1])), temp))
                                    elif n+window+self.end_offset_ms*Fs/1000.0-24> signal.shape[1]:
                                        temp = signal[epoch,n-24-self.start_offset_ms*Fs/1000.0:signal.shape[1],:]
                                        temp = numpy.vstack((temp, numpy.zeros((whole_len-temp.shape[0], temp.shape[1]))))
                                    else:
                                        temp = signal[epoch, n-24-self.start_offset_ms*Fs/1000.0:n+window+self.end_offset_ms*Fs/1000.0-24, :]
                                    responses[rowcol-1,rowcolcnt[rowcol-1]-1,:,:]=temp
                                    rowcolcnt[rowcol-1]=rowcolcnt[rowcol-1]+1
                
                            avgresp=numpy.mean(responses,1)
                
                            targets = stimulus_code[epoch,:]*stimulus_type[epoch,:]
                            target_rowcol = []
                            for value in targets:
                                if value not in target_rowcol:
                                    target_rowcol.append(value)
                
                            target_rowcol.sort()
                
                            for i in range(avgresp.shape[0]):
                                temp = avgresp[i,:,:]
                                data = TimeSeries(input_array = temp,
                                                  channel_names = range(64), 
                                                  sampling_frequency = window)
                                if i == target_rowcol[1]-1 or i == target_rowcol[2]-1:
                                    self.data[(run_nr, split_nr, train_test)].append((data,"Target"))
                                else:
                                    self.data[(run_nr, split_nr, train_test)].append((data,"Standard"))                 
        if self.stream_mode and not self.data[(run_nr, split_nr, train_test)] == []:
            # Create a connection to the TimeSeriesClient and return an iterator
            # that passes all received data through the windower.
            self.reader = TimeSeriesClient(self.data[(run_nr, split_nr, train_test)], blocksize=100)

            # Creates a windower that splits the training data into windows
            # based in the window definitions provided
            # and assigns correct labels to these windows
            self.reader.set_window_defs(self.window_definition)
            self.reader.connect()
            self.marker_windower = MarkerWindower(
                self.reader, self.window_definition,
                nullmarker_stride_ms=self.nullmarker_stride_ms,
                no_overlap=self.no_overlap,
                data_consistency_check=self.data_consistency_check)
            return self.marker_windower
        else:
            return self.data[(run_nr, split_nr, train_test)]

Example 48

Project: mpop
Source File: msg_seviri_hdf.py
View license
def load(satscene, calibrate=True, area_extent=None, **kwargs):
    """Load MSG SEVIRI data from hdf5 format.
    """

    # Read config file content
    conf = ConfigParser()
    conf.read(os.path.join(CONFIG_PATH, satscene.fullname + ".cfg"))
    values = {"orbit": satscene.orbit,
    "satname": satscene.satname,
    "number": satscene.number,
    "instrument": satscene.instrument_name,
    "satellite": satscene.fullname
    }

    LOG.info("assume seviri-level4")
    print "... assume seviri-level4"

    satscene.add_to_history("hdf5 data read by mpop/msg_seviri_hdf.py")


    if "reader_level" in kwargs.keys():
        reader_level = kwargs["reader_level"]
    else:
        reader_level = "seviri-level4"

    if "RSS" in kwargs.keys():
        if kwargs["RSS"]:
            dt_end =  4
        else:
            dt_end = 12
    else:
        from my_msg_module import check_RSS
        RSS = check_RSS(satscene.sat_nr(), satscene.time_slot)
        if RSS == None:
            print "*** Error in mpop/satin/msg_seviri_hdf.py"
            print "    satellite MSG", satscene.sat_nr() ," is not active yet"
            quit()
        else:
            if RSS:
                dt_end =  4
            else:
                dt_end = 12

    print "... hdf file name is specified by observation end time"
    print "    assume ", dt_end, " min between start and end time of observation"

    # end of scan time 4 min after start 
    end_time = satscene.time_slot + datetime.timedelta(minutes=dt_end)

    filename = os.path.join( end_time.strftime(conf.get(reader_level, "dir", raw=True)),
                             end_time.strftime(conf.get(reader_level, "filename", raw=True)) % values )
    
    print "... search for file: ", filename
    filenames=glob(str(filename))
    if len(filenames) == 0:
        print "*** Error, no file found"
        return # just return without exit the program 
    elif len(filenames) > 1:
        print "*** Warning, more than 1 datafile found: ", filenames 
    filename = filenames[0]
    print("... read data from %s" % str(filename))

    # read data from hdf5 file 
    data_folder='U-MARF/MSG/Level1.5/'

    # Load data from hdf file
    with h5py.File(filename,'r') as hf:

        subset_info=hf.get(data_folder+'METADATA/SUBSET')
        for i in range(subset_info.len()):
            #print subset_info[i]['EntryName'], subset_info[i]['Value']
            if subset_info[i]['EntryName'] == "VIS_IRSouthLineSelectedRectangle":
                VIS_IRSouthLine = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "VIS_IRNorthLineSelectedRectangle":
                VIS_IRNorthLine = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "VIS_IREastColumnSelectedRectangle":
                VIS_IREastColumn = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "VIS_IRWestColumnSelectedRectangle":
                VIS_IRWestColumn = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "HRVLowerNorthLineSelectedRectangle":
                HRVLowerNorthLine = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "HRVLowerSouthLineSelectedRectangle":
                HRVLowerSouthLine = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "HRVLowerEastColumnSelectedRectangle":
                HRVLowerEastColumn = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "HRVLowerWestColumnSelectedRectangle":
                HRVLowerWestColumn = int(subset_info[i]['Value'])
            if subset_info[i]['EntryName'] == "HRVUpperSouthLineSelectedRectangle":
                HRVUpperSouthLine = int(subset_info[i]['Value'])  # 0
            if subset_info[i]['EntryName'] == "HRVUpperNorthLineSelectedRectangle":
                HRVUpperNorthLine = int(subset_info[i]['Value'])  # 0
            if subset_info[i]['EntryName'] == "HRVUpperEastColumnSelectedRectangle":
                HRVUpperEastColumn = int(subset_info[i]['Value']) # 0
            if subset_info[i]['EntryName'] == "HRVUpperWestColumnSelectedRectangle":
                HRVUpperWestColumn = int(subset_info[i]['Value']) # 0

        sat_status=hf.get(data_folder+'METADATA/HEADER/SatelliteStatus/SatelliteStatus_DESCR')
        for i in range(subset_info.len()):
            if sat_status[i]['EntryName']=="SatelliteDefinition-NominalLongitude":
                sat_lon = sat_status[i]['Value']
                break

        #print 'VIS_IRSouthLine', VIS_IRSouthLine
        #print 'VIS_IRNorthLine', VIS_IRNorthLine
        #print 'VIS_IREastColumn', VIS_IREastColumn
        #print 'VIS_IRWestColumn', VIS_IRWestColumn
        #print 'sat_longitude', sat_lon, type(sat_lon), 'GEOS<'+'{:+06.1f}'.format(sat_lon)+'>' 

        if 1 == 0:
            # works only if all pixels are on the disk 
            from msg_pixcoord2area import msg_pixcoord2area
            print "VIS_IRNorthLine, VIS_IRWestColumn, VIS_IRSouthLine, VIS_IREastColumn: ", VIS_IRNorthLine, VIS_IRWestColumn, VIS_IRSouthLine, VIS_IREastColumn
            area_def = msg_pixcoord2area ( VIS_IRNorthLine, VIS_IRWestColumn, VIS_IRSouthLine, VIS_IREastColumn, "vis", sat_lon )
        else:
            # works also for pixels outside of the disk 
            pname = 'GEOS<'+'{:+06.1f}'.format(sat_lon)+'>'  # "GEOS<+009.5>"
            proj = {'proj': 'geos', 'a': '6378169.0', 'b': '6356583.8', 'h': '35785831.0', 'lon_0': str(sat_lon)}
            aex=(-5570248.4773392612, -5567248.074173444, 5567248.074173444, 5570248.4773392612)

            # define full disk projection 
            from pyresample.geometry import AreaDefinition
            full_disk_def = AreaDefinition('full_disk',
                                           'full_disk',
                                           pname,
                                           proj,
                                           3712,
                                           3712,
                                           aex )

            # define name and calculate area for sub-demain 
            area_name= 'MSG_'+'{:04d}'.format(VIS_IRNorthLine)+'_'+'{:04d}'.format(VIS_IRWestColumn)+'_'+'{:04d}'.format(VIS_IRSouthLine)+'_'+'{:04d}'.format(VIS_IREastColumn)
            aex = full_disk_def.get_area_extent_for_subset(3712-VIS_IRSouthLine,3712-VIS_IRWestColumn,3712-VIS_IRNorthLine,3712-VIS_IREastColumn)

            area_def = AreaDefinition(area_name,
                                      area_name,
                                      pname,
                                      proj,
                                      (VIS_IRWestColumn-VIS_IREastColumn)+1,
                                      (VIS_IRNorthLine-VIS_IRSouthLine)+1,
                                      aex )

        #print area_def
        #print "REGION:", area_def.area_id, "{"
        #print "\tNAME:\t", area_def.name
        #print "\tPCS_ID:\t", area_def.proj_id
        #print ("\tPCS_DEF:\tproj="+area_def.proj_dict['proj']+", lon_0=" + area_def.proj_dict['lon_0'] + ", a="+area_def.proj_dict['a']+", b="+area_def.proj_dict['b']+", h="+area_def.proj_dict['h'])
        #print "\tXSIZE:\t", area_def.x_size
        #print "\tYSIZE:\t", area_def.y_size
        #print "\tAREA_EXTENT:\t", area_def.area_extent
        #print "};"

        # copy area to satscene 
        satscene.area = area_def

        # write information used by mipp.xrit.MSG._Calibrator in a fake header file
        hdr = dict()

        # satellite ID number 
        hdr["SatelliteDefinition"] = dict()
        hdr["SatelliteDefinition"]["SatelliteId"] = SatelliteIds[str(satscene.sat_nr())]
        
        # processing 
        hdr["Level 1_5 ImageProduction"] = dict()
        hdr["Level 1_5 ImageProduction"]["PlannedChanProcessing"] = np_array([2,2,2,2,2,2,2,2,2,2,2,2], int)
        
        # calibration factors  
        Level15ImageCalibration = hf.get(data_folder+'METADATA/HEADER/RadiometricProcessing/Level15ImageCalibration_ARRAY')
        hdr["Level1_5ImageCalibration"] = dict()

        for chn_name in channel_numbers.keys():
            chn_nb = channel_numbers[chn_name]-1
            hdr["Level1_5ImageCalibration"][chn_nb] = dict()
            #print chn_name, chn_nb, Level15ImageCalibration[chn_nb]['Cal_Slope'], Level15ImageCalibration[chn_nb]['Cal_Offset']
            hdr["Level1_5ImageCalibration"][chn_nb]['Cal_Slope']  = Level15ImageCalibration[chn_nb]['Cal_Slope']
            hdr["Level1_5ImageCalibration"][chn_nb]['Cal_Offset'] = Level15ImageCalibration[chn_nb]['Cal_Offset']

        # loop over channels to load 
        for chn_name in satscene.channels_to_load:

            dataset_name = data_folder+'DATA/'+dict_channel[chn_name]+'/IMAGE_DATA'
            if dataset_name in hf:
                data_tmp = hf.get(data_folder+'DATA/'+dict_channel[chn_name]+'/IMAGE_DATA')

                LOG.info('hdr["SatelliteDefinition"]["SatelliteId"]: '+str(hdr["SatelliteDefinition"]["SatelliteId"]))
                #LOG.info('hdr["Level 1_5 ImageProduction"]["PlannedChanProcessing"]', hdr["Level 1_5 ImageProduction"]["PlannedChanProcessing"])
                chn_nb = channel_numbers[chn_name]-1
                LOG.info('hdr["Level1_5ImageCalibration"][chn_nb]["Cal_Slope"]:  '+str(hdr["Level1_5ImageCalibration"][chn_nb]["Cal_Slope"]))
                LOG.info('hdr["Level1_5ImageCalibration"][chn_nb]["Cal_Offset"]: '+str(hdr["Level1_5ImageCalibration"][chn_nb]["Cal_Offset"]))

                if calibrate:
                    #Calibrator = _Calibrator(hdr, chn_name)
                    bits_per_pixel = 10   ### !!! I have no idea if this is correct !!!
                    Calibrator = _Calibrator(hdr, chn_name, bits_per_pixel) ## changed call in mipp/xrit/MSG.py
                    data, calibration_unit = Calibrator (data_tmp, calibrate=1)
                else:
                    data = data_tmp
                    calibration_unit = "counts"

                LOG.info(chn_name+ " min/max: "+str(data.min())+","+str(data.max())+" "+calibration_unit )

                satscene[chn_name] = ma.asarray(data)

                satscene[chn_name].info['units'] = calibration_unit
                satscene[chn_name].info['satname'] = satscene.satname
                satscene[chn_name].info['satnumber'] = satscene.number
                satscene[chn_name].info['instrument_name'] = satscene.instrument_name
                satscene[chn_name].info['time'] = satscene.time_slot
                satscene[chn_name].info['is_calibrated'] = True

            else: 
                print "*** Warning, no data for channel "+ chn_name+ " in file "+ filename
                data = np_nan
                calibration_unit = ""
                LOG.info("*** Warning, no data for channel "+ chn_name+" in file "+filename)

Example 49

Project: mpop
Source File: odyssey_radar.py
View license
def load(satscene, *args, **kwargs):
   """Loads the *channels* into the satellite *scene*.
   """
   #
   # Dataset information
   #
   # Read config file content
   conf = ConfigParser()
   conf.read(os.path.join(CONFIG_PATH, satscene.fullname + ".cfg"))

   values = {"orbit": satscene.orbit,
          "satname": satscene.satname,
          "number": satscene.number,
          "instrument": satscene.instrument_name,
          "satellite": satscene.fullname
          }

   # projection info
   projectionName = conf.get("radar-level2", "projection")
   projection = pyresample.utils.load_area(os.path.join(CONFIG_PATH, "areas.def"), projectionName)
   satscene.area = projection
   
   for chn_name in satscene.channels_to_load:
      filename = os.path.join(
         satscene.time_slot.strftime(conf.get("radar-level2", "dir", raw=True)) % values,
         satscene.time_slot.strftime(conf.get(chn_name,  "filename", raw=True)) % values )

      # Load data from the h5 file
      LOG.debug("filename: "+filename)
      filenames=glob.glob(str(filename))

      if len(filenames) == 0:
         LOG.debug("no input file found: "+filename)
         print "no input file found:"+filename
         quit()
      else:
         filename = glob.glob(str(filename))[0]
      
      # open the file
      hfile = h5py.File(filename, 'r')
      odim_object = hfile['what'].attrs['object']
      if odim_object != 'COMP':
         raise NotImplementedError('object: %s not implemented.' % (odim_object))
      else:
         # File structure
         
         #>>> hfile.keys()
         #[u'dataset1', u'dataset2', u'how', u'what', u'where']


         #>>> for f in hfile['what'].attrs.keys():
         #...  print "hfile['what'].attrs['",f,"']=", hfile['what'].attrs[f]
         #
         #hfile['what'].attrs[' object ']= COMP
         #hfile['what'].attrs[' version ']= H5rad 2.0
         #hfile['what'].attrs[' date ']= 20151201
         #hfile['what'].attrs[' time ']= 060000
         #hfile['what'].attrs[' source ']= ORG:247

         #>>> for f in hfile['where'].attrs.keys():
         #...  print "hfile['where'].attrs['",f,"']=", hfile['where'].attrs[f]
         #
         #hfile['where'].attrs[' projdef ']= +proj=laea +lat_0=55.0 +lon_0=10.0 +x_0=1950000.0 +y_0=-2100000.0 +units=m +ellps=WGS84
         #hfile['where'].attrs[' xsize ']= 1900
         #hfile['where'].attrs[' ysize ']= 2200
         #hfile['where'].attrs[' xscale ']= 2000.0
         #hfile['where'].attrs[' yscale ']= 2000.0
         #hfile['where'].attrs[' LL_lon ']= -10.4345768386
         #hfile['where'].attrs[' LL_lat ']= 31.7462153193
         #hfile['where'].attrs[' UL_lon ']= -39.5357864125
         #hfile['where'].attrs[' UL_lat ']= 67.0228327583
         #hfile['where'].attrs[' UR_lon ']= 57.8119647501
         #hfile['where'].attrs[' UR_lat ']= 67.6210371028
         #hfile['where'].attrs[' LR_lon ']= 29.4210386356
         #hfile['where'].attrs[' LR_lat ']= 31.9876502779

         # hfile['how'].attrs['nodes'] 
         # list of radar in composite

         #>>> for f in hfile['dataset1']['what'].attrs.keys():
         #...  print "hfile['dataset1'][what].attrs['",f,"']=", hfile['dataset1']['what'].attrs[f]
         #
         #hfile['dataset1'][what].attrs[' product ']= COMP
         #hfile['dataset1'][what].attrs[' startdate ']= 20151201
         #hfile['dataset1'][what].attrs[' starttime ']= 055000
         #hfile['dataset1'][what].attrs[' enddate ']= 20151201
         #hfile['dataset1'][what].attrs[' endtime ']= 060500
         #hfile['dataset1'][what].attrs[' quantity ']= RATE
         #hfile['dataset1'][what].attrs[' gain ']= 1.0
         #hfile['dataset1'][what].attrs[' offset ']= 0.0
         #hfile['dataset1'][what].attrs[' nodata ']= -9999000.0
         #hfile['dataset1'][what].attrs[' undetect ']= -8888000.0
         #>>> for f in hfile['dataset2']['what'].attrs.keys():
         #...  print "hfile['dataset2'][what].attrs['",f,"']=", hfile['dataset2']['what'].attrs[f]
         #
         #hfile['dataset2'][what].attrs[' product ']= COMP
         #hfile['dataset2'][what].attrs[' startdate ']= 20151201
         #hfile['dataset2'][what].attrs[' starttime ']= 055000
         #hfile['dataset2'][what].attrs[' enddate ']= 20151201
         #hfile['dataset2'][what].attrs[' endtime ']= 060500
         #hfile['dataset2'][what].attrs[' quantity ']= QIND
         #hfile['dataset2'][what].attrs[' gain ']= 1.0
         #hfile['dataset2'][what].attrs[' offset ']= 0.0
         #hfile['dataset2'][what].attrs[' nodata ']= -9999000.0
         #hfile['dataset2'][what].attrs[' undetect ']= -8888000.0

         _xsize = hfile['where'].attrs['xsize']
         _ysize = hfile['where'].attrs['ysize']
         #nbins= _xsize * _ysize

         #projection = hfile['where'].attrs['projdef']
         
         datasets = [k for k in hfile if k.startswith('dataset')]
         datasets.sort()
         nsweeps = len(datasets)
         
         try:
            ds1_what = hfile[datasets[0]]['what'].attrs
         except KeyError:
            # if no how group exists mock it with an empty dictionary
            ds1_what = {}
         
         _type = ''
         if 'product' in ds1_what:
            LOG.debug("product: "+ds1_what['product'])
            if ds1_what['product'] == 'COMP':
               if 'quantity' in ds1_what:
                  _type = ds1_what['quantity']
                  LOG.debug("product_type: "+_type)

                  #for chn_name in satscene.channels_to_load:
                  #   if chn_name == _type:

                  raw_data = hfile[datasets[0]]['data1']['data'][:]
                  raw_data = raw_data.reshape(_ysize,_xsize)
         
                  # flag no data
                  if 'nodata' in ds1_what:
                     nodata = ds1_what['nodata']
                     data = np.ma.masked_equal(raw_data, nodata)
                  else:
                     data = np.ma.masked_array(raw_data)
         
                  mask = np.ma.masked_array( raw_data == nodata )
                  mask = np.ma.masked_equal( mask, False)
            
                  # flag undetect data 
                  if 'undetect' in ds1_what:
                     undetect = ds1_what['undetect']
                     data[data == undetect] = np.ma.masked
                        
                  #from trollimage.image import Image as trollimage
                  #img = trollimage(mask, mode="L", fill_value=[1,1,1]) # [0,0,0] [1,1,1]
                  #from trollimage.colormap import rainbow
                  #img.colorize(rainbow)
                  #img.show()
                  #quit()

                  # gain/offset adjustment
                  if 'offset' in ds1_what:
                     offset = ds1_what['offset']
                  else:
                     offset = 0.0
                     
                  if 'gain' in ds1_what:
                     gain = ds1_what['gain']
                  else:
                     gain = 1.0

                  data *= gain + offset
                  
                  satscene[chn_name] = data
                  satscene[chn_name+'-MASK'] = mask

                  LOG.debug(" *** channel:"+chn_name)
                  
                  if _type == 'DBZH':
                     units = 'dBZ'
                  
                  if _type == 'RATE':
                     units = 'mm/h'

                  if _type == 'ACRR':
                     units = 'mm'
                     
                  satscene[chn_name].info["units"] = units
                  LOG.debug("channel:"+chn_name+" units:"+units)

Example 50

Project: r-bridge-install
Source File: install_package.py
View license
def install_package(overwrite=False, r_library_path=r_lib_path()):
    """Install ArcGIS R bindings onto this machine."""
    if overwrite is True:
        overwrite = True
    else:
        overwrite = False

    (install_dir, arc_version, product) = arcgis_platform()
    arcmap_needs_link = False

    # check that we're in a sane installation environment
    validate_environment(overwrite)

    # detect if we we have a 10.3.1 install that needs linking
    if product == 'Pro' and arcmap_exists("10.3"):
        arcmap_needs_link = True
        msg_base = "Pro side by side with 10.3 detected,"
        if arcmap_path() is not None:
            msg = "{} installing bridge for both environments.".format(msg_base)
            arcpy.AddMessage(msg)
        else:
            msg = "{} but unable to find install path.".format(msg_base) + \
                  "ArcGIS bridge must be manually installed in ArcGIS 10.3."
            arcpy.AddWarning(msg)

    # if we're going to install the bridge in 10.3.1, create the appropriate
    # directory before trying to install.
    if arc_version == '10.3.1' and product == 'ArcMap' or arcmap_needs_link:
        r_integration_dir = os.path.join(arcmap_path(), "Rintegration")
        # TODO escalate privs here? test on non-admin user
        if not os.path.exists(r_integration_dir):
            try:
                write_test = os.path.join(install_dir, 'test.txt')
                with open(write_test, 'w') as f:
                    f.write('test')
                os.remove(write_test)
                os.makedirs(r_integration_dir)
            except IOError:
                arcpy.AddError(
                    "Insufficient privileges to create 10.3.1 bridge directory."
                    " Please start {} as an administrator, by right clicking"
                    " the icon, selecting \"Run as Administrator\", then run this"
                    " script again.".format(product))
                return

    # set an R-compatible temporary folder, if needed.
    orig_tmpdir = os.getenv("TMPDIR")
    if not orig_tmpdir:
        set_env_tmpdir()

    download_url = release_info()[0]
    if download_url is None:
        arcpy.AddWarning(
            "Unable to get current release information."
            " Trying offline installation.")

    local_install = False
    base_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')
    zip_glob = glob.glob(os.path.join(base_path, "arcgisbinding*.zip"))
    # see if we have a local copy of the binding
    if zip_glob and os.path.exists(zip_glob[0]):
        local_install = True
        zip_path = zip_glob[0]
        zip_name = os.path.basename(zip_path)
    elif not download_url and not local_install:
        arcpy.AddError(
            "Unable to access online package, and no "
            "local copy of package found.")
        return
    else:
        local_install = False
        zip_name = os.path.basename(download_url)

    # check for a network-based R installation
    if r_path() and r_path()[0:2] == r'\\':
        arcpy.AddMessage(
            "R installed on a network path, using fallback installation method.")
        r_local_install = False
    else:
        r_local_install = True

    # we have a release, write it to disk for installation
    with mkdtemp() as temp_dir:
        package_path = os.path.join(temp_dir, zip_name)
        if local_install:
            arcpy.AddMessage("Found local copy of binding, installing from zip")
            shutil.copyfile(zip_path, package_path)
        else:
            save_url(download_url, package_path)
        if os.path.exists(package_path):
            # TODO -- need to do UAC escalation here?
            # call the R installation script
            rcmd_return = 0
            if r_local_install:
                rcmd_return = execute_r('Rcmd', 'INSTALL', package_path)
            if not r_local_install or rcmd_return != 0:
                # Can't execute Rcmd in this context, write out a temporary
                # script and run install.packages() from within an R session.
                install_script = os.path.join(temp_dir, 'install.R')
                with open(install_script, 'w') as f:
                    f.write("install.packages(\"{}\", repos=NULL)".format(
                        package_path.replace("\\", "/")))
                rcmd_return = execute_r("Rscript", install_script)
                if rcmd_return != 0:
                    arcpy.AddWarning("Fallback installation method failed.")
        else:
            arcpy.AddError("No package found at {}".format(package_path))
            return

    # return TMPDIR to its original value; only need it for Rcmd INSTALL
    set_env_tmpdir(orig_tmpdir)

    # at 10.4 and Pro <=1.2, if the user has installed a version with a non-
    # numeric patch level (e.g. 3.2.4revised), and the bridge is installed
    # into Program Files, the link will fail. In this case, set the
    # appropriate registry key so that the bridge will still work. Note that
    # this isn't ideal, because it will persist after updates, but it is
    # better than the bridge failing to work at all.
    if (arc_version == '10.4' and product == 'Desktop') or \
            (arc_version in ('1.1', '1.1.1', '1.2')
             and product == 'Pro'):

        if r_version():
            (r_major, r_minor, r_patchlevel) = r_version().split(".")
            # if we have a patchlevel like '4revised' or '3alpha', and
            # the global library path is used, then use the registry key.
            if len(r_patchlevel) > 1 and 'Program Files' in r_library_path:
                # create_registry_entry(product, arc_version)
                msg = ("Currently, the bridge doesn't support patched releases"
                       " (e.g. 3.2.4 Revised) in a global install. Please use"
                       " another version of R.")
                arcpy.AddError(msg)
                return

    # at 10.3.1, we _must_ have the bridge installed at the correct location.
    # create a symlink that connects back to the correct location on disk.
    if arc_version == '10.3.1' and product == 'ArcMap' or arcmap_needs_link:
        link_dir = os.path.join(r_integration_dir, PACKAGE_NAME)

        if os.path.exists(link_dir):
            if junctions_supported(link_dir) or hardlinks_supported(link_dir):
                # os.rmdir uses RemoveDirectoryW, and can delete a junction
                os.rmdir(link_dir)
            else:
                shutil.rmtree(link_dir)

        # set up the link
        r_package_path = r_pkg_path()

        if r_package_path:
            arcpy.AddMessage("R package path: {}.".format(r_package_path))
        else:
            arcpy.AddError("Unable to locate R package library. Link failed.")
            return

        detect_msg = "ArcGIS 10.3.1 detected."
        if junctions_supported(link_dir) or hardlinks_supported(link_dir):
            arcpy.AddMessage("{} Creating link to package.".format(detect_msg))
            kdll.CreateSymbolicLinkW(link_dir, r_package_path, 1)
        else:
            # working on a non-NTFS volume, copy instead
            vol_info = getvolumeinfo(link_dir)
            arcpy.AddMessage("{} Drive type: {}. Copying package files.".format(
                detect_msg, vol_info[0]))
            # NOTE: this will need to be resynced when the package is updated,
            #       if installed from the R side.
            shutil.copytree(r_package_path, link_dir)