copy.deepcopy

Here are the examples of the python api copy.deepcopy taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: AZOrange
Source File: getUnbiasedAccuracy.py
View license
    def getAcc(self, callBack = None, callBackWithFoldModel = None):
        """ For regression problems, it returns the RMSE and the Q2 
            For Classification problems, it returns CA and the ConfMat
            The return is made in a Dict: {"RMSE":0.2,"Q2":0.1,"CA":0.98,"CM":[[TP, FP],[FN,TN]]}
            For the EvalResults not supported for a specific learner/datase, the respective result will be None

            if the learner is a dict {"LearnerName":learner, ...} the results will be a dict with results for all Learners and for a consensus
                made out of those that were stable

            It some error occurred, the respective values in the Dict will be None
        """
        self.__log("Starting Calculating MLStatistics")
        statistics = {}
        if not self.__areInputsOK():
            return None
        # Set the response type
        self.responseType =  self.data.domain.classVar.varType == orange.VarTypes.Discrete and "Classification"  or "Regression"
        self.__log("  "+str(self.responseType))

        #Create the Train and test sets
        if self.usePreDefFolds:
            DataIdxs = self.preDefIndices 
        else:
            DataIdxs = self.sampler(self.data, self.nExtFolds) 
        foldsN = [f for f in dict.fromkeys(DataIdxs) if f != 0] #Folds used only from 1 on ... 0 are for fixed train Bias
        nFolds = len(foldsN)
        #Fix the Indexes based on DataIdxs
        # (0s) represents the train set  ( >= 1s) represents the test set folds
        if self.useVarCtrlCV:
            nShifted = [0] * nFolds
            for idx,isTest in enumerate(self.preDefIndices):  # self.preDefIndices == 0 are to be used in TrainBias
                if not isTest:
                    if DataIdxs[idx]:
                        nShifted[DataIdxs[idx]] += 1
                        DataIdxs[idx] = 0
            for idx,shift in enumerate(nShifted):
                self.__log("In fold "+str(idx)+", "+str(shift)+" examples were shifted to the train set.")

        #Var for saving each Fols result
        optAcc = {}
        results = {}
        exp_pred = {}
        nTrainEx = {}
        nTestEx = {}
        
        #Set a dict of learners
        MLmethods = {}
        if type(self.learner) == dict:
            for ml in self.learner:
                MLmethods[ml] = self.learner[ml]
        else:
            MLmethods[self.learner.name] = self.learner

        models={}
        self.__log("Calculating Statistics for MLmethods:")
        self.__log("  "+str([x for x in MLmethods]))

        #Check data in advance so that, by chance, it will not faill at the last fold!
        for foldN in foldsN:
            trainData = self.data.select(DataIdxs,foldN,negate=1)
            self.__checkTrainData(trainData)

        #Optional!!
        # Order Learners so that PLS is the first
        sortedML = [ml for ml in MLmethods]
        if "PLS" in sortedML:
            sortedML.remove("PLS")
            sortedML.insert(0,"PLS")

        stepsDone = 0
        nTotalSteps = len(sortedML) * self.nExtFolds  
        for ml in sortedML:
          startTime = time.time()
          self.__log("    > "+str(ml)+"...")
          try:
            #Var for saving each Fols result
            results[ml] = []
            exp_pred[ml] = []
            models[ml] = []
            nTrainEx[ml] = []
            nTestEx[ml] = []
            optAcc[ml] = []
            logTxt = ""
            for foldN in foldsN:
                if type(self.learner) == dict:
                    self.paramList = None

                trainData = self.data.select(DataIdxs,foldN,negate=1)
                testData = self.data.select(DataIdxs,foldN)
                smilesAttr = dataUtilities.getSMILESAttr(trainData)
                if smilesAttr:
                    self.__log("Found SMILES attribute:"+smilesAttr)
                    if MLmethods[ml].specialType == 1:
                       trainData = dataUtilities.attributeSelectionData(trainData, [smilesAttr, trainData.domain.classVar.name]) 
                       testData = dataUtilities.attributeSelectionData(testData, [smilesAttr, testData.domain.classVar.name]) 
                       self.__log("Selected attrs: "+str([attr.name for attr in trainData.domain]))
                    else:
                       trainData = dataUtilities.attributeDeselectionData(trainData, [smilesAttr]) 
                       testData = dataUtilities.attributeDeselectionData(testData, [smilesAttr]) 
                       self.__log("Selected attrs: "+str([attr.name for attr in trainData.domain[0:3]] + ["..."] + [attr.name for attr in trainData.domain[len(trainData.domain)-3:]]))

                nTrainEx[ml].append(len(trainData))
                nTestEx[ml].append(len(testData))
                #Test if trainsets inside optimizer will respect dataSize criterias.
                #  if not, don't optimize, but still train the model
                dontOptimize = False
                if self.responseType != "Classification" and (len(trainData)*(1-1.0/self.nInnerFolds) < 20):
                    dontOptimize = True
                else:                      
                    tmpDataIdxs = self.sampler(trainData, self.nInnerFolds)
                    tmpTrainData = trainData.select(tmpDataIdxs,1,negate=1)
                    if not self.__checkTrainData(tmpTrainData, False):
                        dontOptimize = True

                SpecialModel = None
                if dontOptimize:
                    logTxt += "       Fold "+str(foldN)+": Too few compounds to optimize model hyper-parameters\n"
                    self.__log(logTxt)
                    if trainData.domain.classVar.varType == orange.VarTypes.Discrete:
                        res = evalUtilities.crossValidation([MLmethods[ml]], trainData, folds=5, stratified=orange.MakeRandomIndices.StratifiedIfPossible, random_generator = random.randint(0, 100))
                        CA = evalUtilities.CA(res)[0]
                        optAcc[ml].append(CA)
                    else:
                        res = evalUtilities.crossValidation([MLmethods[ml]], trainData, folds=5, stratified=orange.MakeRandomIndices.StratifiedIfPossible, random_generator = random.randint(0, 100))
                        R2 = evalUtilities.R2(res)[0]
                        optAcc[ml].append(R2)
                else:
                    if MLmethods[ml].specialType == 1: 
                            if trainData.domain.classVar.varType == orange.VarTypes.Discrete:
                                    optInfo, SpecialModel = MLmethods[ml].optimizePars(trainData, folds = 5)
                                    optAcc[ml].append(optInfo["Acc"])
                            else:
                                    res = evalUtilities.crossValidation([MLmethods[ml]], trainData, folds=5, stratified=orange.MakeRandomIndices.StratifiedIfPossible, random_generator = random.randint(0, 100))
                                    R2 = evalUtilities.R2(res)[0]
                                    optAcc[ml].append(R2)
                    else:
                            runPath = miscUtilities.createScratchDir(baseDir = AZOC.NFS_SCRATCHDIR, desc = "AccWOptParam", seed = id(trainData))
                            trainData.save(os.path.join(runPath,"trainData.tab"))
                            tunedPars = paramOptUtilities.getOptParam(
                                learner = MLmethods[ml], 
                                trainDataFile = os.path.join(runPath,"trainData.tab"), 
                                paramList = self.paramList, 
                                useGrid = False, 
                                verbose = self.verbose, 
                                queueType = self.queueType, 
                                runPath = runPath, 
                                nExtFolds = None, 
                                nFolds = self.nInnerFolds,
                                logFile = self.logFile,
                                getTunedPars = True,
                                fixedParams = self.fixedParams)
                            if not MLmethods[ml] or not MLmethods[ml].optimized:
                                self.__log("       WARNING: GETACCWOPTPARAM: The learner "+str(ml)+" was not optimized.")
                                self.__log("                It will be ignored")
                                #self.__log("                It will be set to default parameters")
                                self.__log("                    DEBUG can be done in: "+runPath)
                                #Set learner back to default 
                                #MLmethods[ml] = MLmethods[ml].__class__()
                                raise Exception("The learner "+str(ml)+" was not optimized.")
                            else:
                                if trainData.domain.classVar.varType == orange.VarTypes.Discrete:
                                    optAcc[ml].append(tunedPars[0])
                                else:
                                    res = evalUtilities.crossValidation([MLmethods[ml]], trainData, folds=5, stratified=orange.MakeRandomIndices.StratifiedIfPossible, random_generator = random.randint(0, 100))
                                    R2 = evalUtilities.R2(res)[0]
                                    optAcc[ml].append(R2)

                                miscUtilities.removeDir(runPath) 
                #Train the model
                if SpecialModel is not None:
                    model = SpecialModel 
                else:
                    model = MLmethods[ml](trainData)
                models[ml].append(model)
                #Test the model
                if self.responseType == "Classification":
                    results[ml].append((evalUtilities.getClassificationAccuracy(testData, model), evalUtilities.getConfMat(testData, model) ) )
                else:
                    local_exp_pred = []
                    # Predict using bulk-predict
                    predictions = model(testData)
                    # Gather predictions
                    for n,ex in enumerate(testData):
                        local_exp_pred.append((ex.getclass().value, predictions[n].value))
                    results[ml].append((evalUtilities.calcRMSE(local_exp_pred), evalUtilities.calcRsqrt(local_exp_pred) ) )
                    #Save the experimental value and correspondent predicted value
                    exp_pred[ml] += local_exp_pred
                if callBack:
                     stepsDone += 1
                     if not callBack((100*stepsDone)/nTotalSteps): return None
                if callBackWithFoldModel:
                    callBackWithFoldModel(model) 

            res = self.createStatObj(results[ml], exp_pred[ml], nTrainEx[ml], nTestEx[ml],self.responseType, self.nExtFolds, logTxt, labels = hasattr(self.data.domain.classVar,"values") and list(self.data.domain.classVar.values) or None )
            if self.verbose > 0: 
                print "UnbiasedAccuracyGetter!Results  "+ml+":\n"
                pprint(res)
            if not res:
                raise Exception("No results available!")
            res["runningTime"] = time.time() - startTime
            statistics[ml] = copy.deepcopy(res)
            self.__writeResults(statistics)
            self.__log("       OK")
          except:
            self.__log("       Learner "+str(ml)+" failed to create/optimize the model!")
            error = str(sys.exc_info()[0]) +" "+\
                        str(sys.exc_info()[1]) +" "+\
                        str(traceback.extract_tb(sys.exc_info()[2]))
            self.__log(error)
 
            res = self.createStatObj()
            statistics[ml] = copy.deepcopy(res)
            self.__writeResults(statistics)

        if not statistics or len(statistics) < 1:
            self.__log("ERROR: No statistics to return!")
            return None
        elif len(statistics) > 1:
            #We still need to build a consensus model out of the stable models 
            #   ONLY if there are more that one model stable!
            #   When only one or no stable models, build a consensus based on all models
            # ALWAYS exclude specialType models (MLmethods[ml].specialType > 0)
            consensusMLs={}
            for modelName in statistics:
                StabilityValue = statistics[modelName]["StabilityValue"]
                if StabilityValue is not None and statistics[modelName]["stable"]:
                    consensusMLs[modelName] = copy.deepcopy(statistics[modelName])

            self.__log("Found "+str(len(consensusMLs))+" stable MLmethods out of "+str(len(statistics))+" MLmethods.")

            if len(consensusMLs) <= 1:   # we need more models to build a consensus!
                consensusMLs={}
                for modelName in statistics:
                    consensusMLs[modelName] = copy.deepcopy(statistics[modelName])

            # Exclude specialType models 
            excludeThis = []
            for learnerName in consensusMLs:
                if models[learnerName][0].specialType > 0:
                    excludeThis.append(learnerName)
            for learnerName in excludeThis:
                consensusMLs.pop(learnerName)
                self.__log("    > Excluded special model " + learnerName)
            self.__log("    > Stable modules: " + str(consensusMLs.keys()))

            if len(consensusMLs) >= 2:
                #Var for saving each Fols result
                startTime = time.time()
                Cresults = []
                Cexp_pred = []
                CnTrainEx = []
                CnTestEx = []
                self.__log("Calculating the statistics for a Consensus model based on "+str([ml for ml in consensusMLs]))
                for foldN in range(self.nExtFolds):
                    if self.responseType == "Classification":
                        CLASS0 = str(self.data.domain.classVar.values[0])
                        CLASS1 = str(self.data.domain.classVar.values[1])
                        # exprTest0
                        exprTest0 = "(0"
                        for ml in consensusMLs:
                            exprTest0 += "+( "+ml+" == "+CLASS0+" )*"+str(optAcc[ml][foldN])+" "
                        exprTest0 += ")/IF0(sum([False"
                        for ml in consensusMLs:
                            exprTest0 += ", "+ml+" == "+CLASS0+" "
                        exprTest0 += "]),1)"
                        # exprTest1
                        exprTest1 = "(0"
                        for ml in consensusMLs:
                            exprTest1 += "+( "+ml+" == "+CLASS1+" )*"+str(optAcc[ml][foldN])+" "
                        exprTest1 += ")/IF0(sum([False"
                        for ml in consensusMLs:
                            exprTest1 += ", "+ml+" == "+CLASS1+" "
                        exprTest1 += "]),1)"
                        # Expression
                        expression = [exprTest0+" >= "+exprTest1+" -> "+CLASS0," -> "+CLASS1]
                    else:
                        Q2sum = sum([optAcc[ml][foldN] for ml in consensusMLs])
                        expression = "(1 / "+str(Q2sum)+") * (0"
                        for ml in consensusMLs:
                            expression += " + "+str(optAcc[ml][foldN])+" * "+ml+" "
                        expression += ")"

                    testData = self.data.select(DataIdxs,foldN+1)  # fold 0 if for the train Bias!!
                    smilesAttr = dataUtilities.getSMILESAttr(testData)
                    if smilesAttr:
                        self.__log("Found SMILES attribute:"+smilesAttr)
                        testData = dataUtilities.attributeDeselectionData(testData, [smilesAttr])
                        self.__log("Selected attrs: "+str([attr.name for attr in trainData.domain[0:3]] + ["..."] + [attr.name for attr in trainData.domain[len(trainData.domain)-3:]]))

                    CnTestEx.append(len(testData))
                    consensusClassifiers = {}
                    for learnerName in consensusMLs:
                        consensusClassifiers[learnerName] = models[learnerName][foldN]

                    model = AZorngConsensus.ConsensusClassifier(classifiers = consensusClassifiers, expression = expression)     
                    CnTrainEx.append(model.NTrainEx)
                    #Test the model
                    if self.responseType == "Classification":
                        Cresults.append((evalUtilities.getClassificationAccuracy(testData, model), evalUtilities.getConfMat(testData, model) ) )
                    else:
                        local_exp_pred = []
                        # Predict using bulk-predict
                        predictions = model(testData)
                        # Gather predictions
                        for n,ex in enumerate(testData):
                            local_exp_pred.append((ex.getclass().value, predictions[n].value))
                        Cresults.append((evalUtilities.calcRMSE(local_exp_pred), evalUtilities.calcRsqrt(local_exp_pred) ) )
                        #Save the experimental value and correspondent predicted value
                        Cexp_pred += local_exp_pred

                res = self.createStatObj(Cresults, Cexp_pred, CnTrainEx, CnTestEx, self.responseType, self.nExtFolds, labels = hasattr(self.data.domain.classVar,"values") and list(self.data.domain.classVar.values) or None )
                res["runningTime"] = time.time() - startTime
                statistics["Consensus"] = copy.deepcopy(res)
                statistics["Consensus"]["IndividualStatistics"] = copy.deepcopy(consensusMLs)
                self.__writeResults(statistics)
            self.__log("Returned multiple ML methods statistics.")
            return statistics
                 
        #By default return the only existing statistics!
        self.__writeResults(statistics)
        self.__log("Returned only one ML method statistics.")
        return statistics[statistics.keys()[0]]

Example 2

Project: firewalld
Source File: fw.py
View license
    def _start(self, reload=False, complete_reload=False):
        # initialize firewall
        default_zone = config.FALLBACK_ZONE

        # load firewalld config
        log.debug1("Loading firewalld config file '%s'", config.FIREWALLD_CONF)
        try:
            self._firewalld_conf.read()
        except Exception as msg:
            log.warning(msg)
            log.warning("Using fallback firewalld configuration settings.")
        else:
            if self._firewalld_conf.get("DefaultZone"):
                default_zone = self._firewalld_conf.get("DefaultZone")

            if self._firewalld_conf.get("MinimalMark"):
                self._min_mark = int(self._firewalld_conf.get("MinimalMark"))

            if self._firewalld_conf.get("CleanupOnExit"):
                value = self._firewalld_conf.get("CleanupOnExit")
                if value is not None and value.lower() in [ "no", "false" ]:
                    self.cleanup_on_exit = False

            if self._firewalld_conf.get("Lockdown"):
                value = self._firewalld_conf.get("Lockdown")
                if value is not None and value.lower() in [ "yes", "true" ]:
                    log.debug1("Lockdown is enabled")
                    try:
                        self.policies.enable_lockdown()
                    except FirewallError:
                        # already enabled, this is probably reload
                        pass

            if self._firewalld_conf.get("IPv6_rpfilter"):
                value = self._firewalld_conf.get("IPv6_rpfilter")
                if value is not None:
                    if value.lower() in [ "no", "false" ]:
                        self.ipv6_rpfilter_enabled = False
                    if value.lower() in [ "yes", "true" ]:
                        self.ipv6_rpfilter_enabled = True
            if self.ipv6_rpfilter_enabled:
                log.debug1("IPv6 rpfilter is enabled")
            else:
                log.debug1("IPV6 rpfilter is disabled")

            if self._firewalld_conf.get("IndividualCalls"):
                value = self._firewalld_conf.get("IndividualCalls")
                if value is not None and value.lower() in [ "yes", "true" ]:
                    log.debug1("IndividualCalls is enabled")
                    self._individual_calls = True

            if self._firewalld_conf.get("LogDenied"):
                value = self._firewalld_conf.get("LogDenied")
                if value is None or value.lower() == "no":
                    self._log_denied = "off"
                else:
                    self._log_denied = value.lower()
                    log.debug1("LogDenied is set to '%s'", self._log_denied)

            if self._firewalld_conf.get("AutomaticHelpers"):
                value = self._firewalld_conf.get("AutomaticHelpers")
                if value is None:
                    if value.lower() in [ "no", "false" ]:
                        self._automatic_helpers = "no"
                    if value.lower() in [ "yes", "true" ]:
                        self._automatic_helpers = "yes"
                else:
                    self._automatic_helpers = value.lower()
                log.debug1("AutomaticHelpers is set to '%s'",
                           self._automatic_helpers)

        self.config.set_firewalld_conf(copy.deepcopy(self._firewalld_conf))

        self._start_check()

        # load lockdown whitelist
        log.debug1("Loading lockdown whitelist")
        try:
            self.policies.lockdown_whitelist.read()
        except Exception as msg:
            if self.policies.query_lockdown():
                log.error("Failed to load lockdown whitelist '%s': %s",
                          self.policies.lockdown_whitelist.filename, msg)
            else:
                log.debug1("Failed to load lockdown whitelist '%s': %s",
                           self.policies.lockdown_whitelist.filename, msg)

        # copy policies to config interface
        self.config.set_policies(copy.deepcopy(self.policies))

        # load ipset files
        self._loader(config.FIREWALLD_IPSETS, "ipset")
        self._loader(config.ETC_FIREWALLD_IPSETS, "ipset")

        # load icmptype files
        self._loader(config.FIREWALLD_ICMPTYPES, "icmptype")
        self._loader(config.ETC_FIREWALLD_ICMPTYPES, "icmptype")

        if len(self.icmptype.get_icmptypes()) == 0:
            log.error("No icmptypes found.")

        # load helper files
        self._loader(config.FIREWALLD_HELPERS, "helper")
        self._loader(config.ETC_FIREWALLD_HELPERS, "helper")

        # load service files
        self._loader(config.FIREWALLD_SERVICES, "service")
        self._loader(config.ETC_FIREWALLD_SERVICES, "service")

        if len(self.service.get_services()) == 0:
            log.error("No services found.")

        # load zone files
        self._loader(config.FIREWALLD_ZONES, "zone")
        self._loader(config.ETC_FIREWALLD_ZONES, "zone")

        if len(self.zone.get_zones()) == 0:
            log.fatal("No zones found.")
            sys.exit(1)

        # check minimum required zones
        error = False
        for z in [ "block", "drop", "trusted" ]:
            if z not in self.zone.get_zones():
                log.fatal("Zone '%s' is not available.", z)
                error = True
        if error:
            sys.exit(1)

        # check if default_zone is a valid zone
        if default_zone not in self.zone.get_zones():
            if "public" in self.zone.get_zones():
                zone = "public"
            elif "external" in self.zone.get_zones():
                zone = "external"
            else:
                zone = "block" # block is a base zone, therefore it has to exist

            log.error("Default zone '%s' is not valid. Using '%s'.",
                      default_zone, zone)
            default_zone = zone
        else:
            log.debug1("Using default zone '%s'", default_zone)

        # load direct rules
        obj = Direct(config.FIREWALLD_DIRECT)
        if os.path.exists(config.FIREWALLD_DIRECT):
            log.debug1("Loading direct rules file '%s'" % \
                       config.FIREWALLD_DIRECT)
            try:
                obj.read()
            except Exception as msg:
                log.debug1("Failed to load direct rules file '%s': %s",
                           config.FIREWALLD_DIRECT, msg)
        self.direct.set_permanent_config(obj)
        self.config.set_direct(copy.deepcopy(obj))

        # automatic helpers
        if self._automatic_helpers != "system":
            functions.set_nf_conntrack_helper_setting(self._automatic_helpers == "yes")
        self.nf_conntrack_helper_setting = \
            functions.get_nf_conntrack_helper_setting()

        # check if needed tables are there
        self._check_tables()

        if log.getDebugLogLevel() > 0:
            # get time before flushing and applying
            tm1 = time.time()

        # Start transaction
        transaction = FirewallTransaction(self)

        if reload:
            self.set_policy("DROP", use_transaction=transaction)

        # flush rules
        self.flush(use_transaction=transaction)

        # If modules need to be unloaded in complete reload or if there are
        # ipsets to get applied, limit the transaction to set_policy and flush.
        #
        # Future optimization for the ipset case in reload: The transaction
        # only needs to be split here if there are conflicting ipset types in
        # exsting ipsets and the configuration in firewalld.
        if (reload and complete_reload) or \
           (self.ipset_enabled and self.ipset.has_ipsets()):
            transaction.execute(True)
            transaction.clear()

        # complete reload: unload modules also
        if reload and complete_reload:
            log.debug1("Unloading firewall modules")
            self.modules_backend.unload_firewall_modules()

        # apply settings for loaded ipsets while reloading here
        if self.ipset_enabled and self.ipset.has_ipsets():
            log.debug1("Applying ipsets")
            self.ipset.apply_ipsets()

        # Start or continue with transaction

        # apply default rules
        log.debug1("Applying default rule set")
        self.apply_default_rules(use_transaction=transaction)

        # apply settings for loaded zones
        log.debug1("Applying used zones")
        self.zone.apply_zones(use_transaction=transaction)

        self._default_zone = self.check_zone(default_zone)
        self.zone.change_default_zone(None, self._default_zone,
                                      use_transaction=transaction)

        # Execute transaction
        transaction.execute(True)

        # Start new transaction for direct rules
        transaction.clear()

        # apply direct chains, rules and passthrough rules
        if self.direct.has_configuration():
            transaction.enable_generous_mode()
            log.debug1("Applying direct chains rules and passthrough rules")
            self.direct.apply_direct(transaction)

            # Execute transaction
            transaction.execute(True)
            transaction.disable_generous_mode()
            transaction.clear()

        del transaction

        if log.getDebugLogLevel() > 1:
            # get time after flushing and applying
            tm2 = time.time()
            log.debug2("Flushing and applying took %f seconds" % (tm2 - tm1))

        self._state = "RUNNING"

Example 3

Project: Cura
Source File: CuraApplication.py
View license
    def __init__(self):
        Resources.addSearchPath(os.path.join(QtApplication.getInstallPrefix(), "share", "cura", "resources"))
        if not hasattr(sys, "frozen"):
            Resources.addSearchPath(os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "resources"))

        self._open_file_queue = []  # Files to open when plug-ins are loaded.

        # Need to do this before ContainerRegistry tries to load the machines
        SettingDefinition.addSupportedProperty("settable_per_mesh", DefinitionPropertyType.Any, default = True, read_only = True)
        SettingDefinition.addSupportedProperty("settable_per_extruder", DefinitionPropertyType.Any, default = True, read_only = True)
        # this setting can be changed for each group in one-at-a-time mode
        SettingDefinition.addSupportedProperty("settable_per_meshgroup", DefinitionPropertyType.Any, default = True, read_only = True)
        SettingDefinition.addSupportedProperty("settable_globally", DefinitionPropertyType.Any, default = True, read_only = True)

        # From which stack the setting would inherit if not defined per object (handled in the engine)
        # AND for settings which are not settable_per_mesh:
        # which extruder is the only extruder this setting is obtained from
        SettingDefinition.addSupportedProperty("limit_to_extruder", DefinitionPropertyType.Function, default = "-1")

        # For settings which are not settable_per_mesh and not settable_per_extruder:
        # A function which determines the glabel/meshgroup value by looking at the values of the setting in all (used) extruders
        SettingDefinition.addSupportedProperty("resolve", DefinitionPropertyType.Function, default = None)

        SettingDefinition.addSettingType("extruder", None, str, Validator)

        SettingFunction.registerOperator("extruderValues", cura.Settings.ExtruderManager.getExtruderValues)
        SettingFunction.registerOperator("extruderValue", cura.Settings.ExtruderManager.getExtruderValue)
        SettingFunction.registerOperator("resolveOrValue", cura.Settings.ExtruderManager.getResolveOrValue)

        ## Add the 4 types of profiles to storage.
        Resources.addStorageType(self.ResourceTypes.QualityInstanceContainer, "quality")
        Resources.addStorageType(self.ResourceTypes.VariantInstanceContainer, "variants")
        Resources.addStorageType(self.ResourceTypes.MaterialInstanceContainer, "materials")
        Resources.addStorageType(self.ResourceTypes.UserInstanceContainer, "user")
        Resources.addStorageType(self.ResourceTypes.ExtruderStack, "extruders")
        Resources.addStorageType(self.ResourceTypes.MachineStack, "machine_instances")

        ContainerRegistry.getInstance().addResourceType(self.ResourceTypes.QualityInstanceContainer)
        ContainerRegistry.getInstance().addResourceType(self.ResourceTypes.VariantInstanceContainer)
        ContainerRegistry.getInstance().addResourceType(self.ResourceTypes.MaterialInstanceContainer)
        ContainerRegistry.getInstance().addResourceType(self.ResourceTypes.UserInstanceContainer)
        ContainerRegistry.getInstance().addResourceType(self.ResourceTypes.ExtruderStack)
        ContainerRegistry.getInstance().addResourceType(self.ResourceTypes.MachineStack)

        ##  Initialise the version upgrade manager with Cura's storage paths.
        import UM.VersionUpgradeManager #Needs to be here to prevent circular dependencies.
        UM.VersionUpgradeManager.VersionUpgradeManager.getInstance().setCurrentVersions(
            {
                ("quality", UM.Settings.InstanceContainer.Version):    (self.ResourceTypes.QualityInstanceContainer, "application/x-uranium-instancecontainer"),
                ("machine_stack", UM.Settings.ContainerStack.Version): (self.ResourceTypes.MachineStack, "application/x-uranium-containerstack"),
                ("preferences", UM.Preferences.Version):               (Resources.Preferences, "application/x-uranium-preferences"),
                ("user", UM.Settings.InstanceContainer.Version):       (self.ResourceTypes.UserInstanceContainer, "application/x-uranium-instancecontainer")
            }
        )

        self._machine_action_manager = MachineActionManager.MachineActionManager()
        self._machine_manager = None    # This is initialized on demand.
        self._setting_inheritance_manager = None

        self._additional_components = {} # Components to add to certain areas in the interface

        super().__init__(name = "cura", version = CuraVersion, buildtype = CuraBuildType)

        self.setWindowIcon(QIcon(Resources.getPath(Resources.Images, "cura-icon.png")))

        self.setRequiredPlugins([
            "CuraEngineBackend",
            "MeshView",
            "LayerView",
            "STLReader",
            "SelectionTool",
            "CameraTool",
            "GCodeWriter",
            "LocalFileOutputDevice"
        ])
        self._physics = None
        self._volume = None
        self._output_devices = {}
        self._print_information = None
        self._previous_active_tool = None
        self._platform_activity = False
        self._scene_bounding_box = AxisAlignedBox.Null

        self._job_name = None
        self._center_after_select = False
        self._camera_animation = None
        self._cura_actions = None
        self._started = False

        self._message_box_callback = None
        self._message_box_callback_arguments = []

        self._i18n_catalog = i18nCatalog("cura")

        self.getController().getScene().sceneChanged.connect(self.updatePlatformActivity)
        self.getController().toolOperationStopped.connect(self._onToolOperationStopped)

        Resources.addType(self.ResourceTypes.QmlFiles, "qml")
        Resources.addType(self.ResourceTypes.Firmware, "firmware")

        self.showSplashMessage(self._i18n_catalog.i18nc("@info:progress", "Loading machines..."))

        # Add empty variant, material and quality containers.
        # Since they are empty, they should never be serialized and instead just programmatically created.
        # We need them to simplify the switching between materials.
        empty_container = ContainerRegistry.getInstance().getEmptyInstanceContainer()
        empty_variant_container = copy.deepcopy(empty_container)
        empty_variant_container._id = "empty_variant"
        empty_variant_container.addMetaDataEntry("type", "variant")
        ContainerRegistry.getInstance().addContainer(empty_variant_container)
        empty_material_container = copy.deepcopy(empty_container)
        empty_material_container._id = "empty_material"
        empty_material_container.addMetaDataEntry("type", "material")
        ContainerRegistry.getInstance().addContainer(empty_material_container)
        empty_quality_container = copy.deepcopy(empty_container)
        empty_quality_container._id = "empty_quality"
        empty_quality_container.setName("Not supported")
        empty_quality_container.addMetaDataEntry("quality_type", "normal")
        empty_quality_container.addMetaDataEntry("type", "quality")
        ContainerRegistry.getInstance().addContainer(empty_quality_container)
        empty_quality_changes_container = copy.deepcopy(empty_container)
        empty_quality_changes_container._id = "empty_quality_changes"
        empty_quality_changes_container.addMetaDataEntry("type", "quality_changes")
        ContainerRegistry.getInstance().addContainer(empty_quality_changes_container)

        # Set the filename to create if cura is writing in the config dir.
        self._config_lock_filename = os.path.join(Resources.getConfigStoragePath(), CONFIG_LOCK_FILENAME)
        self.waitConfigLockFile()
        ContainerRegistry.getInstance().load()

        Preferences.getInstance().addPreference("cura/active_mode", "simple")
        Preferences.getInstance().addPreference("cura/recent_files", "")
        Preferences.getInstance().addPreference("cura/categories_expanded", "")
        Preferences.getInstance().addPreference("cura/jobname_prefix", True)
        Preferences.getInstance().addPreference("view/center_on_select", True)
        Preferences.getInstance().addPreference("mesh/scale_to_fit", True)
        Preferences.getInstance().addPreference("mesh/scale_tiny_meshes", True)

        for key in [
            "dialog_load_path",  # dialog_save_path is in LocalFileOutputDevicePlugin
            "dialog_profile_path",
            "dialog_material_path"]:

            Preferences.getInstance().addPreference("local_file/%s" % key, os.path.expanduser("~/"))

        Preferences.getInstance().setDefault("local_file/last_used_type", "text/x-gcode")

        Preferences.getInstance().setDefault("general/visible_settings", """
            machine_settings
            resolution
                layer_height
            shell
                wall_thickness
                top_bottom_thickness
            infill
                infill_sparse_density
            material
                material_print_temperature
                material_bed_temperature
                material_diameter
                material_flow
                retraction_enable
            speed
                speed_print
                speed_travel
                acceleration_print
                acceleration_travel
                jerk_print
                jerk_travel
            travel
            cooling
                cool_fan_enabled
            support
                support_enable
                support_extruder_nr
                support_type
                support_interface_density
            platform_adhesion
                adhesion_type
                adhesion_extruder_nr
                brim_width
                raft_airgap
                layer_0_z_overlap
                raft_surface_layers
            dual
                prime_tower_enable
                prime_tower_size
                prime_tower_position_x
                prime_tower_position_y
            meshfix
            blackmagic
                print_sequence
                infill_mesh
            experimental
        """.replace("\n", ";").replace(" ", ""))

        JobQueue.getInstance().jobFinished.connect(self._onJobFinished)

        self.applicationShuttingDown.connect(self.saveSettings)
        self.engineCreatedSignal.connect(self._onEngineCreated)
        self._recent_files = []
        files = Preferences.getInstance().getValue("cura/recent_files").split(";")
        for f in files:
            if not os.path.isfile(f):
                continue

            self._recent_files.append(QUrl.fromLocalFile(f))

Example 4

View license
    def minion(self,  storage=None, *args, **xargs):        
        self.app = self.Verum.app(self.parent.PluginFolder, None)
        # set storage
        if storage is None:
            storage = self.parent.storage
        self.app.set_interface(storage)

        # Check until stopped
        while not self.shutdown:
            # Check to see if it's the same day, if it is, sleep for a while, otherwise run the import
            delta = datetime.utcnow() - self.today
            if delta.days <= 0:
                time.sleep(SLEEP_TIME)
            else:
                logging.info("Starting daily {0} enrichment.".format(NAME))

                # Get the file
                r = requests.get(FEED)

                # split it out
                feed = r.text.split("\n")

                # Create list of IPs for cymru enrichment
                ips = set()

                for row in feed:
                    # Parse date
                    l = row.find("Feed generated at:")
                    if l > -1:
                        dt = row[l+18:].strip()
                        dt = dateutil.parser.parse(dt).strftime("%Y-%m-%dT%H:%M:%SZ")
                        next
                    row = row.split(",")

                    # if it's a record, parse the record
                    if len(row) == 6:
                        try:
                            # split out sub values
                            # row[0] -> domain
                            row[1] = row[1].split("|")  # ip
                            row[2] = row[2].split("|")  # nameserver domain
                            row[3] = row[3].split("|")  # nameserver ip
                            row[4] = row[4][26:-22]  # malware
                            # row[5] -> source

                            # Validate data in row
                            ext = tldextract.extract(row[0])
                            if not ext.domain or not ext.suffix:
                                # domain is not legitimate
                                next
                            l = list()
                            for ip in row[1]:
                                try:
                                    _ = ipaddress.ip_address(unicode(ip))
                                    l.append(ip)
                                except:
                                    pass
                            row[1] = copy.deepcopy(l)
                            l = list()
                            for domain in row[2]:
                                ext = tldextract.extract(domain)
                                if ext.domain and ext.suffix:
                                    l.append(domain)
                            row[2] = copy.deepcopy(l)
                            l = list()
                            for ip in row[3]:
                                try:
                                    _ = ipaddress.ip_address(unicode(ip))
                                    l.append(ip)
                                except:
                                    pass
                            row[3] = copy.deepcopy(l)

                            # add the ips to the set of ips
                            ips = ips.union(set(row[1])).union(set(row[3]))

                            g = nx.MultiDiGraph()

                            # Add indicator to graph
                            ## (Must account for the different types of indicators)
                            target_uri = "class=attribute&key={0}&value={1}".format('domain', row[0]) 
                            g.add_node(target_uri, {
                                'class': 'attribute',
                                'key': 'domain',
                                "value": row[0],
                                "start_time": dt,
                                "uri": target_uri
                            })


                            # Threat node
                            threat_uri = "class=attribute&key={0}&value={1}".format("malware", row[4]) 
                            g.add_node(threat_uri, {
                                'class': 'attribute',
                                'key': "malware",
                                "value": row[4],
                                "start_time": dt,
                                "uri": threat_uri
                            })

                            # Threat Edge
                            edge_attr = {
                                "relationship": "describedBy",
                                "origin": row[5],
                                "start_time": dt
                            }
                            source_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_uri)
                            dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, threat_uri)
                            edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                            rel_chain = "relationship"
                            while rel_chain in edge_attr:
                                edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                rel_chain = edge_attr[rel_chain]
                            if "origin" in edge_attr:
                                edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                            edge_attr["uri"] = edge_uri
                            g.add_edge(target_uri, threat_uri, edge_uri, edge_attr)                        

                            # for each IP associated with the domain, connect it to the target
                            for ip in row[1]:
                                # Create IP node
                                target_ip_uri = "class=attribute&key={0}&value={1}".format("ip", ip) 
                                g.add_node(target_ip_uri, {
                                    'class': 'attribute',
                                    'key': "ip",
                                    "value": ip,
                                    "start_time": dt,
                                    "uri": target_ip_uri
                                })

                                # ip Edge
                                edge_attr = {
                                    "relationship": "describedBy",
                                    "origin": row[5],
                                    "start_time": dt,
                                }
                                source_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_uri)
                                dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_ip_uri)
                                edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                rel_chain = "relationship"
                                while rel_chain in edge_attr:
                                    edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                    rel_chain = edge_attr[rel_chain]
                                if "origin" in edge_attr:
                                    edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                edge_attr["uri"] = edge_uri
                                g.add_edge(target_uri, target_ip_uri, edge_uri, edge_attr)


                            for nameserver in row[2]:
                                # Create nameserver node
                                ns_uri = "class=attribute&key={0}&value={1}".format("domain", nameserver) 
                                g.add_node(ns_uri, {
                                    'class': 'attribute',
                                    'key': "domain",
                                    "value": nameserver,
                                    "start_time": dt,
                                    "uri": ns_uri
                                })

                                # nameserver Edge
                                edge_attr = {
                                    "relationship": "describedBy",
                                    "origin": row[5],
                                    "start_time": dt,
                                    'describedBy': 'nameserver'
                                }
                                source_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_uri)
                                dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_ip_uri)
                                edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                rel_chain = "relationship"
                                while rel_chain in edge_attr:
                                    edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                    rel_chain = edge_attr[rel_chain]
                                if "origin" in edge_attr:
                                    edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                edge_attr["uri"] = edge_uri
                                g.add_edge(target_uri, ns_uri, edge_uri, edge_attr)

                            # if the number of NS IPs is a multiple of the # of NS's, we'll aassume each NS gets some of the ips
                            if len(row[2]) and len(row[3]) % len(row[2]) == 0:
                                for i in range(len(row[2])):
                                    for j in range(len(row[3])/len(row[2])):
                                        # Create NS IP node
                                        ns_ip_uri = "class=attribute&key={0}&value={1}".format("ip", row[3][i*len(row[3])/len(row[2]) + j]) 
                                        g.add_node(ns_ip_uri, {
                                            'class': 'attribute',
                                            'key': "ip",
                                            "value": ip,
                                            "start_time": dt,
                                            "uri": ns_ip_uri
                                        })

                                        # create NS uri
                                        ns_uri = "class=attribute&key={0}&value={1}".format("domain", row[2][i]) 


                                        # link NS to IP
                                        edge_attr = {
                                            "relationship": "describedBy",
                                            "origin": row[5],
                                            "start_time": dt
                                        }
                                        source_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_uri)
                                        dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_ip_uri)
                                        edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                        rel_chain = "relationship"
                                        while rel_chain in edge_attr:
                                            edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                            rel_chain = edge_attr[rel_chain]
                                        if "origin" in edge_attr:
                                            edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                        edge_attr["uri"] = edge_uri
                                        g.add_edge(ns_uri, ns_ip_uri, edge_uri, edge_attr)

                            # otherwise we'll attach each IP to each NS
                            else:
                                for ip in row[3]:
                                    # Create NS IP node
                                    ns_ip_uri = "class=attribute&key={0}&value={1}".format("ip", ip) 
                                    g.add_node(ns_ip_uri, {
                                        'class': 'attribute',
                                        'key': "ip",
                                        "value": ip,
                                        "start_time": dt,
                                        "uri": ns_ip_uri
                                    })
                                    
                                    for ns in row[2]:
                                        # create NS uri
                                        ns_uri = "class=attribute&key={0}&value={1}".format("domain", ns)

                                         # link NS to IP
                                        edge_attr = {
                                            "relationship": "describedBy",
                                            "origin": row[5],
                                            "start_time": dt
                                        }
                                        source_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_uri)
                                        dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_ip_uri)
                                        edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                        rel_chain = "relationship"
                                        while rel_chain in edge_attr:
                                            edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                            rel_chain = edge_attr[rel_chain]
                                        if "origin" in edge_attr:
                                            edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                        edge_attr["uri"] = edge_uri
                                        g.add_edge(ns_uri, ns_ip_uri, edge_uri, edge_attr)

                            # classify malicious and merge with current graph
                            g = self.Verum.merge_graphs(g, self.app.classify.run({'key': 'domain', 'value': row[0], 'classification': 'malice'}))

                            # enrich depending on type
                            for domain in [row[0]] + row[2]:
                                try:
                                    g = self.Verum.merge_graphs(g, self.app.run_enrichments(domain, "domain", names=['TLD Enrichment']))
                                    g = self.Verum.merge_graphs(g, self.app.run_enrichments(domain, "domain", names=['IP Whois Enrichment']))
                                except Exception as e:
                                    logging.info("Enrichment of {0} failed due to {1}.".format(domain, e))
                                    #print "Enrichment of {0} failed due to {1}.".format(domain, e)  # DEBUG
                                    #raise
                                    pass
                            for ip in row[1] + row[3]:
                                try:
                                    g = self.Verum.merge_graphs(g, self.app.run_enrichments(ip, "ip", names=[u'Maxmind ASN Enrichment']))
                                except Exception as e:
                                    logging.info("Enrichment of {0} failed due to {1}.".format(ip, e))
                                    pass

                            try:
                                self.app.store_graph(self.Verum.remove_non_ascii_from_graph(g))
                            except:
                                print g.nodes(data=True)  # DEBUG
                                print g.edges(data=True)  # DEBUG
                                raise

                            # Do cymru enrichment
                            if len(ips) >= 50:
                                # validate IPs
                                ips2 = set()
                                for ip in ips:
                                    try:
                                        _ = ipaddress.ip_address(unicode(ip))
                                        ips2.add(ip)
                                    except:
                                        pass
                                ips = ips2
                                del(ips2)
                                try:
                                    self.app.store_graph(self.app.run_enrichments(ips, 'ip', names=[u'Cymru Enrichment']))
                                    #print "Cymru enrichment complete."
                                except Exception as e:
                                    logging.info("Cymru enrichment of {0} IPs failed due to {1}.".format(len(ips), e))
                                    #print "Cymru enrichment of {0} IPs failed due to {1}.".format(len(ips), e)  # DEBUG
                                    pass
                                ips = set()

                        except Exception as e:
                            print row
                            print e
                            raise

                # Copy today's date to today
                self.today = datetime.utcnow()

                logging.info("Daily {0} enrichment complete.".format(NAME))
                print "Daily {0} enrichment complete.".format(NAME)  # DEBUG

Example 5

Project: bde-tools
Source File: buildconfigfactory.py
View license
def make_build_config(repo_context, build_flags_parser, uplid, ufid,
                      default_rules, debug_keys=[]):
    """Create a build configuration for repository.

    Args:
        repo_context (RepoContext): Repository structure and metadata.
        build_flags_parser (BuildFlagsParser): Parser for build flags.
        uplid (Uplid): Build platform id.
        ufid (Ufid): Build flags id.
        default_rules (list of OptionRule): Option rules that is the base of
            all UORs.
        debug_keys (list of str): Print some debug information for some
            option keys.
    """

    build_config = buildconfig.BuildConfig(repo_context.root_path, uplid,
                                           ufid)

    for unit in repo_context.units.values():
        if unit.type_ == repounits.UnitType.THIRD_PARTY_DIR:
            build_config.third_party_dirs[unit.name] = unit

    uor_dep_graph = repocontextutil.get_uor_digraph(repo_context)
    uor_map = repocontextutil.get_uor_map(repo_context)

    build_config.external_dep = graphutil.find_external_nodes(uor_dep_graph)

    # BDE_CXXINCLUDES is hard coded in bde_build
    initial_options = {
        'BDE_CXXINCLUDES': '$(BDE_CXXINCLUDE)'
    }

    # Waf already knows about the flags necessary for building shared objects,
    # so don't get the necessary flags from opts files.  We don't have to the
    # operation below once we remove the the option rules for 'shr' in the
    # default option files.

    effective_ufid = copy.deepcopy(build_config.ufid)
    effective_ufid.flags.discard('shr')
    def_oe = optionsevaluator.OptionsEvaluator(build_config.uplid,
                                               effective_ufid,
                                               initial_options)

    def_oe.store_option_rules(default_rules, debug_keys)

    def_oe_copy = copy.deepcopy(def_oe)
    def_oe_copy.evaluate()
    build_config.default_flags = get_build_flags_from_opts(
        build_flags_parser, def_oe_copy.results, def_oe_copy.results)

    # These custom environment variables comes from default_internal.opts and
    # needs to be set when building dpkg.
    env_variables = ('SET_TMPDIR', 'XLC_LIBPATH')
    setenv_re = re.compile(r'^([^=]+)=(.*)$')
    for e in env_variables:
        if e in def_oe_copy.results:
            m = setenv_re.match(def_oe_copy.results[e])
            if m:
                build_config.custom_envs[m.group(1)] = m.group(2)

    def set_unit_loc(oe, unit):
        oe.results['%s_LOCN' %
                   unit.name.upper().replace('+', '_')] = unit.path

    def load_uor(uor):
        # Preserve the existing behavior of loading defs, opts and cap files as
        # bde_build:
        #
        # - Exported options of an UOR: read the defs files of its dependencies
        #   follow by itself. The files of the dependencies should be read in
        #   topological order, if the order of certain dependencies are
        #   ambiguous, order them first by dependency levels, and then by their
        #   name.
        #
        # - Internal options of an UOR: read the defs files in the same way,
        #   followed by its own opts file.

        dep_levels = graphutil.levelize(uor_dep_graph,
                                        uor_dep_graph[uor.name])

        oe = copy.deepcopy(def_oe)
        # We load options in levelized order instead of any topological order
        # to preserve the behavior with bde_build (older version of the build
        # tool).  Note that we cannot cache intermediate results because later
        # option rules may change the results from the preivous rule.

        if uor.type_ == repounits.UnitType.GROUP:
            uor_bc = buildconfig.PackageGroupBuildConfig()
        elif uor.type_ in repounits.UnitTypeCategory.PACKAGE_STAND_ALONE_CAT:
            uor_bc = buildconfig.StdalonePackageBuildConfig()
        else:
            assert(False)

        for level in dep_levels:
            for dep_name in sorted(level):
                if dep_name in build_config.external_dep:
                    uor_bc.external_dep.add(dep_name)
                elif dep_name not in build_config.third_party_dirs:
                    dep_uor = uor_map[dep_name]
                    oe.store_option_rules(dep_uor.cap)
                    oe.store_option_rules(dep_uor.defs)

        if (build_config.uplid.os_type == 'windows' and
                build_config.uplid.comp_type == 'cl'):
            # By default, Visual Studio uses a single pdb file for all object
            # files compiled from a particular directory named
            # vc<vs_version>.pdb.  We want to use a separate pdb file for each
            # package group and standard alone package.
            #
            # BDE_CXXFLAGS and BDE_CFLAGS are defined by default.opts, so the
            # code below is a bit hackish.
            pdb_option = ' /Fd%s\\%s.pdb' % (
                os.path.relpath(uor.path, build_config.root_path), uor.name)
            oe.options['BDE_CXXFLAGS'] += pdb_option
            oe.options['BDE_CFLAGS'] += pdb_option

        uor_bc.name = uor.name
        uor_bc.path = uor.path
        uor_bc.doc = uor.doc
        uor_bc.version = uor.version
        uor_bc.dep = uor.dep - build_config.external_dep
        uor_bc.external_dep.union(uor.dep & build_config.external_dep)

        # Store options from dependencies, options for exports, and internal
        # options separately

        dep_oe = copy.deepcopy(oe)
        dep_oe.evaluate()
        oe.store_option_rules(uor.cap)
        oe.store_option_rules(uor.defs)
        set_unit_loc(oe, uor)
        export_oe = copy.deepcopy(oe)
        int_oe = copy.deepcopy(oe)
        export_oe.evaluate()
        if export_oe.results.get('CAPABILITY') == 'NEVER':
            logutil.warn('Skipped non-supported UOR %s' % uor.name)
            return

        int_oe.store_option_rules(uor.opts)

        # Copy unevaluted internal options to be used by packages within
        # package groups.
        int_oe_copy = copy.deepcopy(int_oe)
        if debug_keys:
            logutil.info('--Evaluating %s' % uor.name)
        int_oe.evaluate(debug_keys)

        # Remove export flags of an uor's dependencies from its own export
        # flags.  This implementation is not very optimal, but it's gets the
        # job done.
        dep_flags = get_build_flags_from_opts(build_flags_parser,
                                              dep_oe.results, dep_oe.results)

        if uor.name == 'bsl':
            # Include flags such as "-lpthread" in bsl.pc, because
            # historically, some of the applications built on top of BDE
            # requires this.
            exclude_exportflags = []
            exclude_exportlibs = []
        else:
            exclude_exportflags = dep_flags.export_flags
            exclude_exportlibs = dep_flags.export_libs

        uor_bc.flags = get_build_flags_from_opts(
            build_flags_parser, int_oe.results, export_oe.results,
            exclude_exportflags, exclude_exportlibs)

        if uor.type_ == repounits.UnitType.GROUP:
            load_package_group(uor, uor_bc, int_oe_copy)
        elif uor.type_ in repounits.UnitTypeCategory.PACKAGE_STAND_ALONE_CAT:
            load_sa_package(uor, uor_bc)
        else:
            assert(False)

    def load_sa_package(package, package_bc):
        package_bc.components = package.components
        package_bc.type_ = package.type_
        package_bc.has_dums = package.has_dums
        package_bc.app_main = package.app_main
        build_config.stdalone_packages[package_bc.name] = package_bc

    def load_package_group(group, group_bc, oe):
        skipped_packages = set()
        for package_name in group.mem:
            is_skipped = not load_normal_package(package_name, oe)
            if is_skipped:
                skipped_packages.add(package_name)

        group_bc.mem = group.mem - skipped_packages
        build_config.package_groups[group_bc.name] = group_bc

    def load_normal_package(package_name, oe):
        package = repo_context.units[package_name]
        int_oe = copy.deepcopy(oe)
        int_oe.store_option_rules(package.opts)
        int_oe.store_option_rules(package.cap)
        set_unit_loc(int_oe, package)

        if debug_keys:
            logutil.info('--Evaluating %s' % package_name)
        int_oe.evaluate(debug_keys)

        if int_oe.results.get('CAPABILITY') == 'NEVER':
            logutil.warn('Skipped non-supported package %s' % package_name)
            return False

        if package.type_ == repounits.PackageType.PACKAGE_PLUS:
            package_bc = buildconfig.PlusPackageBuildConfig()
        else:
            package_bc = buildconfig.InnerPackageBuildConfig()

        package_bc.name = package.name
        package_bc.path = package.path
        package_bc.dep = package.dep
        package_bc.type_ = package.type_
        package_bc.flags = get_build_flags_from_opts(build_flags_parser,
                                                     int_oe.results)
        package_bc.has_dums = package.has_dums

        if package.type_ == repounits.PackageType.PACKAGE_PLUS:
            package_bc.headers = package.pt_extras.headers
            package_bc.cpp_sources = package.pt_extras.cpp_sources
            package_bc.cpp_tests = package.pt_extras.cpp_tests
            package_bc.c_tests = package.pt_extras.c_tests
        else:
            package_bc.components = package.components

        build_config.inner_packages[package_name] = package_bc
        return True

    for unit in repo_context.units.values():
        if (unit.type_ in repounits.UnitTypeCategory.UOR_CAT and
                unit.type_ != repounits.UnitType.THIRD_PARTY_DIR):
            load_uor(unit)

    # Sometimes we want to override the default SONAME used for a shared
    # object.  This can be done by setting a environment variable.
    # E.g., we can set 'BDE_BSL_SONAME' to 'robo20150101bsl' to set the
    # SONAME of the shared object built for the package group 'bsl'.

    for name in uor_map:
        soname = os.environ.get('BDE_%s_SONAME' % name.upper())
        if soname:
            build_config.soname_overrides[name] = soname

    return build_config

Example 6

Project: veyepar
Source File: mk_mlt.py
View license
def mk_mlt(template, output, params):

    def set_text(node,prop_name,value=None):
        # print(node,prop_name,value)
        p = node.find("property[@name='{}']".format(prop_name))
        if value is None:
            node.remove(p)
        else:
            if type(value)==int:
                value = "0:{}.0".format(value)
            elif type(value)==float:
                value = "0:{}".format(value)
            p.text = value

    def set_attrib(node, attrib_name, value=None):
        if value is None:
            del node.attrib[attrib_name]
        else:
            if type(value)==int:
                value = "0:{}.0".format(value)
            # print(attrib_name, value)
            node.set(attrib_name, value)

    # parse the template 
    tree=xml.etree.ElementTree.parse(template)

    # grab nodes we are going to store values into
    nodes={}
    for id in [
        'pi_title_img', 'ti_title',
        'pi_foot_img',  'ti_foot',
        # 'spacer',
        'pl_vid0', 'pi_vid0', # Play List and Item
        'tl_vid2', 'ti_vid2', # Time Line and Item
        'audio_fade_in', 'audio_fade_out',
        'pic_in_pic', 'opacity',
        'channelcopy', 
        'mono', 
        'normalize', 
        'title_fade','foot_fade',
        ]:

        node = tree.find(".//*[@id='{}']".format(id))
        # print(id,node)
        nodes[id] = node

    # special case because Shotcut steps on the id='spacer'
    # <playlist id="playlist1">
    # <blank id="spacer" length="00:00:00.267"/>
    nodes['spacer'] = tree.find("./playlist[@id='playlist1']blank")

    # remove all placeholder nodes
    mlt = tree.find('.')

    play_list = tree.find("./playlist[@id='main bin']")
    for pe in play_list.findall("./entry[@sample]"):
        producer = pe.get('producer')
        producer_node = tree.find("./producer[@id='{}']".format(producer))
        # print( producer )
        play_list.remove(pe)
        mlt.remove(producer_node)

    # <playlist id="playlist0">
    # <entry producer="tl_vid1" in="00:00:00.667" out="00:00:03.003" sample="1" /> 
    time_line = tree.find("./playlist[@id='playlist0']")
    for te in time_line.findall("./entry[@sample]"):
        # print("te",te)
        producer = te.get('producer')
        producer_node = tree.find("./producer[@id='{}']".format(producer))
        # print( "producer", producer )
        time_line.remove(te)
        mlt.remove(producer_node)

    nodes['ti_vid2'].remove(nodes['channelcopy'])
    nodes['ti_vid2'].remove(nodes['mono'])
    nodes['ti_vid2'].remove(nodes['normalize'])

    # add each clip to the playlist
    for i,clip in enumerate(params['clips']):

        node_id = "pi_vid{}".format(clip['id'])
        # print("node_id",node_id)

        # setup and add Play List
        pl = copy.deepcopy( nodes['pl_vid0'] )
        pl.set("id", "pl_vid{}".format(clip['id']))
        pl.set("producer", node_id)
        set_attrib(pl, "in", clip['in'])
        set_attrib(pl, "out", clip['out'])
        play_list.insert(i,pl)

        # setup and add Playlist Item
        pi = copy.deepcopy( nodes['pi_vid0'] )
        pi.set("id", node_id)
        set_attrib(pi, "in", clip['in'])
        set_attrib(pi, "out", clip['out'])
        set_text(pi,'length')
        set_text(pi,'resource',clip['filename'])
        mlt.insert(i,pi)

    # add each cut to the timeline
    total_length = 0
    for i,cut in enumerate(params['cuts']):
        # print(i,cut)

        node_id = "ti_vid{}".format(cut['id'])

        tl = copy.deepcopy( nodes['tl_vid2'] )
        tl.set("producer", node_id)
        set_attrib(tl, "in", cut['in'])
        set_attrib(tl, "out", cut['out'])
        time_line.insert(i,tl)

        ti = copy.deepcopy( nodes['ti_vid2'] )
        ti.set("id", node_id)
        set_attrib(ti, "in")
        set_attrib(ti, "out")
        set_text(ti,'length')
        set_text(ti,'resource',cut['filename'])
        set_text(ti,'video_delay',cut['video_delay'])

        # apply the filters to te cuts

        if cut['channelcopy']=='00':
            pass
        elif cut['channelcopy']=='m':
            ti.insert(0,nodes['mono'])
        else:
            channelcopy = copy.deepcopy( nodes['channelcopy'] )
            set_text(channelcopy,'from' , cut['channelcopy'][0])
            set_text(channelcopy,'to' , cut['channelcopy'][1])
            ti.insert(0,channelcopy)

        if cut['normalize']!='0':
            normalize = copy.deepcopy( nodes['normalize'] )
            set_text(normalize,'program' , cut['normalize'])
            ti.insert(0,normalize)

        if nodes['pic_in_pic'] is not None:
            # for Node 15
            ti.insert(0,nodes['pic_in_pic'])
            ti.insert(0,nodes['opacity'])

        if i==0:
            # apply audio fade in/out to first/last cut
            ti.insert(0,nodes['audio_fade_in'])

        mlt.insert(i,ti)

        total_length += cut['length']
        print( total_length )

    # ti is left over from the above loop
    ti.insert(0,nodes['audio_fade_out'])

    # set title screen image
    set_text(nodes['pi_title_img'],'resource',params['title_img'])
    set_text(nodes['ti_title'],'resource',params['title_img'])

    # set footer image
    set_text(nodes['pi_foot_img'],'resource',params['foot_img'])
    set_text(nodes['ti_foot'],'resource',params['foot_img'])

    # set the lenght of the spacer so it puts the footer image to end-5sec
    # Duration: 27mn 53s
    # nodes['ti_foot'].set("in",str(total_length))
    # nodes['spacer'].set("length","00:27:46.00")
    nodes['spacer'].set("length","0:{}.0".format(total_length-8.0))

    # put the 1.5 fadeout at the end
    nodes['audio_fade_out'].set("in","0:{}.0".format(total_length-1.5))

    tree.write(output)
    # import code; code.interact(local=locals())

    return True

Example 7

Project: yank
Source File: yank.py
View license
    def _create_phase(self, thermodynamic_state, alchemical_phase):
        """
        Create a repex object for a specified phase.

        Parameters
        ----------
        thermodynamic_state : ThermodynamicState (System need not be defined)
            Thermodynamic state from which reference temperature and pressure are to be taken.
        alchemical_phase : AlchemicalPhase
           The alchemical phase to be created.

        """
        # We add default repex options only on creation, on resume repex will pick them from the store file
        repex_parameters = {
            'number_of_equilibration_iterations': 0,
            'number_of_iterations': 100,
            'timestep': 2.0 * unit.femtoseconds,
            'collision_rate': 5.0 / unit.picoseconds,
            'minimize': False,
            'show_mixing_statistics': True,  # this causes slowdown with iteration and should not be used for production
            'displacement_sigma': 1.0 * unit.nanometers  # attempt to displace ligand by this stddev will be made each iteration
        }
        repex_parameters.update(self._repex_parameters)

        # Convenience variables
        positions = alchemical_phase.positions
        reference_system = copy.deepcopy(alchemical_phase.reference_system)
        atom_indices = alchemical_phase.atom_indices
        alchemical_states = alchemical_phase.protocol

        # Check the dimensions of positions.
        for index in range(len(positions)):
            n_atoms, _ = (positions[index] / positions[index].unit).shape
            if n_atoms != reference_system.getNumParticles():
                err_msg = "Phase {}: number of atoms in positions {} and and " \
                          "reference system differ ({} and {} respectively)"
                err_msg.format(alchemical_phase.name, index, n_atoms,
                               reference_system.getNumParticles())
                logger.error(err_msg)
                raise RuntimeError(err_msg)

        # Inizialize metadata storage.
        metadata = dict()

        # TODO: Use more general approach to determine whether system is periodic.
        is_periodic = reference_system.usesPeriodicBoundaryConditions()
        is_complex_explicit = len(atom_indices['receptor']) > 0 and is_periodic
        is_complex_implicit = len(atom_indices['receptor']) > 0 and not is_periodic

        # Make sure pressure is None if not periodic.
        if not is_periodic:
            thermodynamic_state.pressure = None
        # If temperature and pressure are specified, make sure MonteCarloBarostat is attached.
        elif thermodynamic_state.temperature and thermodynamic_state.pressure:
            forces = { reference_system.getForce(index).__class__.__name__ : reference_system.getForce(index) for index in range(reference_system.getNumForces()) }

            if 'MonteCarloAnisotropicBarostat' in forces:
                raise Exception('MonteCarloAnisotropicBarostat is unsupported.')

            if 'MonteCarloBarostat' in forces:
                logger.debug('MonteCarloBarostat found: Setting default temperature and pressure.')
                barostat = forces['MonteCarloBarostat']
                # Set temperature and pressure.
                try:
                    barostat.setDefaultTemperature(thermodynamic_state.temperature)
                except AttributeError:  # versions previous to OpenMM7.1
                    barostat.setTemperature(thermodynamic_state.temperature)
                barostat.setDefaultPressure(thermodynamic_state.pressure)
            else:
                # Create barostat and add it to the system if it doesn't have one already.
                logger.debug('MonteCarloBarostat not found: Creating one.')
                barostat = openmm.MonteCarloBarostat(thermodynamic_state.pressure, thermodynamic_state.temperature)
                reference_system.addForce(barostat)

        # Store a serialized copy of the reference system.
        metadata['reference_system'] = openmm.XmlSerializer.serialize(reference_system)
        metadata['topology'] = utils.serialize_topology(alchemical_phase.reference_topology)

        # Create a copy of the system for which the fully-interacting energy is to be computed.
        # For explicit solvent calculations, an enlarged cutoff is used to account for the anisotropic dispersion correction.

        # fully_interacting_system = copy.deepcopy(reference_system)
        reference_system_LJ = copy.deepcopy(reference_system)
        forces_to_remove = []
        for forceIndex in range(reference_system_LJ.getNumForces()):
            force = reference_system_LJ.getForce(forceIndex)
            if isinstance(force, openmm.NonbondedForce):
                for particle in range(force.getNumParticles()):
                    q, sigma, epsilon = force.getParticleParameters(particle)
                    force.setParticleParameters(particle, 0, sigma, epsilon)
                for exception in range(force.getNumExceptions()):
                    particle1, particle2, chargeprod, epsilon, sigma = force.getExceptionParameters(exception)
                    force.setExceptionParameters(exception, particle1, particle2, 0, sigma, epsilon)
            else:
                # Queue force to remove if not a NB fore
                forces_to_remove.append(forceIndex)
        # Remove all but NonbondedForce
        # If done in preveious loop, nuber of forces change so indices change
        for forceIndex in forces_to_remove[::-1]:
            reference_system_LJ.removeForce(forceIndex)

        reference_system_LJ_expanded = copy.deepcopy(reference_system_LJ)
        if is_periodic:
            # Determine minimum box side dimension
            box_vectors = reference_system_LJ_expanded.getDefaultPeriodicBoxVectors()
            min_box_dimension = min([max(vector) for vector in box_vectors])

            # Expand cutoff to minimize artifact and verify that box is big enough.
            # If we use a barostat we leave more room for volume fluctuations or
            # we risk fatal errors. If we don't use a barostat, OpenMM will raise
            # the appropriate exception on context creation.
            max_allowed_cutoff = 16 * unit.angstroms
            max_switching_distance = max_allowed_cutoff - (1 * unit.angstrom)
            # TODO: Make max_allowed_cutoff an option
            if thermodynamic_state.pressure and min_box_dimension < 2.25 * max_allowed_cutoff:
                raise RuntimeError('Barostated box sides must be at least 36 Angstroms '
                                   'to correct for missing dispersion interactions')

            logger.debug('Setting cutoff for fully interacting system to maximum allowed (%s)' % str(max_allowed_cutoff))

            # Expanded cutoff LJ system if needed
            # We don't want to reduce the cutoff if its already large
            for force in reference_system_LJ_expanded.getForces():
                try:
                    if force.getCutoffDistance() < max_allowed_cutoff:
                        force.setCutoffDistance(max_allowed_cutoff)
                        # Set switch distance
                        # We don't need to check if we are using a switch since there is a setting for that.
                        force.setSwitchingDistance(max_switching_distance)
                except:
                    pass
                try:
                    if force.getCutoff() < max_allowed_cutoff:
                        force.setCutoff(max_allowed_cutoff)
                except:
                    pass

        # Construct thermodynamic states
        reference_state = copy.deepcopy(thermodynamic_state)
        reference_state.system = reference_system
        reference_LJ_state = copy.deepcopy(thermodynamic_state)
        reference_LJ_expanded_state = copy.deepcopy(thermodynamic_state)
        reference_LJ_state.system = reference_system_LJ
        reference_LJ_expanded_state.system = reference_system_LJ_expanded

        # Compute standard state corrections for complex phase.
        metadata['standard_state_correction'] = 0.0
        # TODO: Do we need to include a standard state correction for other phases in periodic boxes?
        if is_complex_implicit:
            # Impose restraints for complex system in implicit solvent to keep ligand from drifting too far away from receptor.
            logger.debug("Creating receptor-ligand restraints...")
            reference_positions = positions[0]
            restraints = create_restraints(self._restraint_type,
                alchemical_phase.reference_topology, thermodynamic_state, reference_system, reference_positions, atom_indices['receptor'], atom_indices['ligand'])
            force = restraints.get_restraint_force() # Get Force object incorporating restraints
            reference_system.addForce(force)
            metadata['standard_state_correction'] = restraints.get_standard_state_correction() # standard state correction in kT
        elif is_complex_explicit:
            # For periodic systems, we do not use a restraint, but must add a standard state correction for the box volume.
            # TODO: What if the box volume fluctuates during the simulation?
            box_vectors = reference_system.getDefaultPeriodicBoxVectors()
            box_volume = thermodynamic_state._volume(box_vectors)
            STANDARD_STATE_VOLUME = 1660.53928 * unit.angstrom**3
            metadata['standard_state_correction'] = - np.log(STANDARD_STATE_VOLUME / box_volume)

        # Create alchemically-modified states using alchemical factory.
        logger.debug("Creating alchemically-modified states...")
        try:
            alchemical_indices = atom_indices['ligand_counterions'] + atom_indices['ligand']
        except KeyError:
            alchemical_indices = atom_indices['ligand']
        factory = AbsoluteAlchemicalFactory(reference_system, ligand_atoms=alchemical_indices,
                                            **self._alchemy_parameters)
        alchemical_system = factory.alchemically_modified_system
        thermodynamic_state.system = alchemical_system

        # Check systems for finite energies.
        # TODO: Refactor this into another function.
        finite_energy_check = False
        if finite_energy_check:
            logger.debug("Checking energies are finite for all alchemical systems.")
            integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
            context = openmm.Context(alchemical_system, integrator)
            context.setPositions(positions[0])
            for index, alchemical_state in enumerate(alchemical_states):
                AbsoluteAlchemicalFactory.perturbContext(context, alchemical_state)
                potential = context.getState(getEnergy=True).getPotentialEnergy()
                if np.isnan(potential / unit.kilocalories_per_mole):
                    raise Exception("Energy for system %d is NaN." % index)
            del context, integrator
            logger.debug("All energies are finite.")

        # Randomize ligand position if requested, but only for implicit solvent systems.
        if self._randomize_ligand and is_complex_implicit:
            logger.debug("Randomizing ligand positions and excluding overlapping configurations...")
            randomized_positions = list()
            nstates = len(alchemical_states)
            for state_index in range(nstates):
                positions_index = np.random.randint(0, len(positions))
                current_positions = positions[positions_index]
                new_positions = ModifiedHamiltonianExchange.randomize_ligand_position(current_positions,
                                                                                      atom_indices['receptor'], atom_indices['ligand'],
                                                                                      self._randomize_ligand_sigma_multiplier * restraints.getReceptorRadiusOfGyration(),
                                                                                      self._randomize_ligand_close_cutoff)
                randomized_positions.append(new_positions)
            positions = randomized_positions
        if self._randomize_ligand and is_complex_explicit:
            logger.warning("Ligand randomization requested, but will not be performed for explicit solvent simulations.")

        # Identify whether any atoms will be displaced via MC, unless option is turned off.
        mc_atoms = None
        if self._mc_displacement_sigma:
            mc_atoms = list()
            if 'ligand' in atom_indices:
                mc_atoms = atom_indices['ligand']

        # Set up simulation.
        # TODO: Support MPI initialization?
        logger.debug("Creating replica exchange object...")
        store_filename = os.path.join(self._store_directory, alchemical_phase.name + '.nc')
        self._store_filenames[alchemical_phase.name] = store_filename
        simulation = ModifiedHamiltonianExchange(store_filename, platform=self._platform)
        simulation.create(thermodynamic_state, alchemical_states, positions,
                          displacement_sigma=self._mc_displacement_sigma, mc_atoms=mc_atoms,
                          options=repex_parameters, metadata=metadata,
                          reference_state = reference_state,
                          reference_LJ_state = reference_LJ_state,
                          reference_LJ_expanded_state = reference_LJ_expanded_state)

        # Initialize simulation.
        # TODO: Use the right scheme for initializing the simulation without running.
        #logger.debug("Initializing simulation...")
        #simulation.run(0)

        # Clean up simulation.
        del simulation

        # Add to list of phases that have been set up.
        self._phases.append(alchemical_phase.name)

        return

Example 8

Project: ck-crowdtuning
Source File: module.py
View license
def crowdsource(i):
    """
    Input:  {
              See 'crowdsource program.optimization'

              (compiler_env_uoa)           - fix compiler environment
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    global cfg, work

    import copy

    mcfg=i.get('module_cfg',{})
    if len(mcfg)>0: 
       cfg=mcfg

    mwork=i.get('module_work',{})
    if len(mwork)>0: work=mwork

    # Setting output
    o=i.get('out','')
    oo=''
    if o=='con': oo='con'

    quiet=i.get('quiet','')

    er=i.get('exchange_repo','')
    if er=='': er=ck.cfg['default_exchange_repo_uoa']
    esr=i.get('exchange_subrepo','')
    if esr=='': esr=ck.cfg['default_exchange_subrepo_uoa']

    if i.get('local','')=='yes': 
       er='local'
       esr=''

    la=i.get('local_autotuning','')

    # Get user 
    user=''

    mcfg={}
    ii={'action':'load',
        'module_uoa':'module',
        'data_uoa':cfg['module_deps']['program.optimization']}
    r=ck.access(ii)
    if r['return']==0:
       mcfg=r['dict']

       dcfg={}
       ii={'action':'load',
           'module_uoa':mcfg['module_deps']['cfg'],
           'data_uoa':mcfg['cfg_uoa']}
       r=ck.access(ii)
       if r['return']>0 and r['return']!=16: return r
       if r['return']!=16:
          dcfg=r['dict']

       user=dcfg.get('user_email','')

    ceuoa=i.get('compiler_env_uoa', '')

    if ceuoa!='':
       rx=ck.access({'action':'load',
                     'module_uoa':cfg['module_deps']['env'],
                     'data_uoa':ceuoa})
       if rx['return']>0: return rx
       ceuoa=rx['data_uid']

    # Initialize local environment for program optimization ***********************************************************
    pi=i.get('platform_info',{})
    if len(pi)==0:
       ii=copy.deepcopy(i)
       ii['action']='initialize'
       ii['module_uoa']=cfg['module_deps']['program.optimization']
       ii['exchange_repo']=er
       ii['exchange_subrepo']=esr
       r=ck.access(ii)
       if r['return']>0: return r

       pi=r['platform_info']
       user=r.get('user','')

    hos=pi['host_os_uoa']
    hosd=pi['host_os_dict']

    tos=pi['os_uoa']
    tosd=pi['os_dict']
    tbits=tosd.get('bits','')

    remote=tosd.get('remote','')

    tdid=pi['device_id']

    program_tags=i.get('program_tags','')
    if program_tags=='' and i.get('local_autotuning','')!='yes' and i.get('data_uoa','')=='':
       program_tags=cfg['program_tags']

    # Check that has minimal dependencies for this scenario ***********************************************************
    sdeps=i.get('dependencies',{}) # useful to preset inside crowd-tuning
    if len(sdeps)==0:
       sdeps=copy.deepcopy(cfg['deps'])
    if len(sdeps)>0:
       if o=='con':
          ck.out(line)
          ck.out('Resolving software dependencies required for this scenario ...')
          ck.out('')

       if ceuoa!='':
          x=sdeps.get('compiler',{})
          if len(x)>0:
             if 'cus' in x: del(x['cus'])
             if 'deps' in x: del(x['deps'])
             x['uoa']=ceuoa
             sdeps['compiler']=x

       ii={'action':'resolve',
           'module_uoa':cfg['module_deps']['env'],
           'host_os':hos,
           'target_os':tos,
           'device_id':tdid,
           'deps':sdeps,
           'add_customize':'yes'}
       if quiet=='yes': 
          ii['random']='yes'
       else:
          ii['out']=oo
       rx=ck.access(ii)
       if rx['return']>0: return rx

       sdeps=rx['deps'] # Update deps (add UOA)

    cpu_name=pi.get('features',{}).get('cpu',{}).get('name','')
    compiler_soft_uoa=sdeps.get('compiler',{}).get('dict',{}).get('soft_uoa','')
    compiler_env=sdeps.get('compiler',{}).get('bat','')

    plat_extra={}
    pft=pi.get('features',{})
    for q in pft:
        if q.endswith('_uid'):
           plat_extra[q]=pft[q]
        elif type(pft[q])==dict and pft[q].get('name','')!='':
           plat_extra[q+'_name']=pft[q]['name']

    # Detect real compiler version ***********************************************************
    if o=='con':
       ck.out(line)
       ck.out('Detecting compiler version ...')

    ii={'action':'internal_detect',
        'module_uoa':cfg['module_deps']['soft'],
        'data_uoa':compiler_soft_uoa,
        'host_os':hos,
        'target_os':tos,
        'target_device_id':tdid,
        'env':compiler_env}
    r=ck.access(ii)
    if r['return']>0: return r

    compiler_version=r['version_str']

    compiler=cfg.get('compiler_name','')+' '+compiler_version

    if o=='con':
       ck.out('')
       ck.out('* Compiler: '+compiler)
       ck.out('* CPU:      '+cpu_name)

    # Start preparing input to run program.optimization
    ii=copy.deepcopy(i)

    ii['action']='run'
    ii['module_uoa']=cfg['module_deps']['program.optimization']

    ii['host_os']=hos
    ii['target_os']=tos
    ii['target_device_id']=tdid
    ii['dependencies']=sdeps

    ii['scenario_cfg']=cfg

    ii['platform_info']=pi

    ii['program_tags']=program_tags

    ii['scenario_module_uoa']=work['self_module_uid']

    ii['experiment_meta']={'cpu_name':cpu_name,
                           'compiler':compiler}

    ii['experiment_meta_extra']=plat_extra

    ii['exchange_repo']=er
    ii['exchange_subrepo']=esr

    ii['user']=user

    # Select sub-scenario ********************************************************************
    from random import randint
    ss=1 # num of scenarios

    sx=randint(1,ss)

    rr={'return':0}

    if sx==1 or la=='yes':
       # **************************************************************** explore random program/dataset
       sdesc='explore random program/cmd/data set'
       if o=='con':
          ck.out('')
          ck.out('  ****** Sub-scenario: '+sdesc+' ******')

       ii['subscenario_desc']=sdesc

       rr=ck.access(ii)
       if rr['return']>0: return rr

    rr['platform_info']=pi

    return rr

Example 9

Project: ck-crowdtuning
Source File: module.py
View license
def request(i):
    """
    Input:  {
              (crowd_uid)         - if !='', processing results and possibly chaining experiments

              (email)             - email or person UOA
              (platform_features) - remote device platform features

              (scenario)          - pre-set scenario
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import os
    from random import randint
    import copy
    import shutil
    import zipfile
    import json

    # Setting output
    o=i.get('out','')
    oo=''
    if o=='con': oo='con'

    rr={'return':0}

    email=i.get('email','')

    ruoa=i.get('record_repo_uoa','')
#    if ruoa=='': ruoa='upload'
    # Hack
    ck.cfg["forbid_writing_to_local_repo"]="no"
    ck.cfg["allow_writing_only_to_allowed"]="no"
    ck.cfg["forbid_global_delete"]="no"
#    ck.cfg["allow_run_only_from_allowed_repos"]="yes"

    # Check if processing started experiments
    cuid=i.get('crowd_uid','')
    if cuid!='':
       ###################################################################################################################
       # Load info
       r=ck.access({'action':'load',
                    'module_uoa':work['self_module_uid'],
                    'data_uoa':cuid})
       if r['return']>0: return r
       d=r['dict']

       euoa=d['experiment_uoa']
       ol=d['off_line']

       suid=ol.get('solution_uid','') # should normally be prepared in advance!

       scenario_uoa=ol['scenario_module_uoa']
       condition_objective='#'+ol['meta']['objective']

       xstatus=''

       results=i.get('results',{})

       #Log
       r=ck.access({'action':'log',
                    'module_uoa':cfg['module_deps']['experiment'],
                    'text':'Finishing crowd experiment: '+cuid+' ('+email+')\n'+json.dumps(results,indent=2,sort_keys=True)+'\n'})
       if r['return']>0: return r

       if len(results)>0:

          repeat=results.get('ct_repeat','')
          if repeat=='': repeat=1

          cpu_freq0=results.get('cpu_freq0',{})
          cpu_freq1=results.get('cpu_freq1',{})

          ch0=results.get('characteristics0',{})
          ch1=results.get('characteristics1',{})

          # TBD: improve stat analysis -> use CK module (here quick prototyping)
          fch0min=-1
          fch0max=-1
          for q in ch0:
              v=ch0[q]
              if fch0min==-1 or v<fch0min:
                 fch0min=v
              if fch0max==-1 or v>fch0max:
                 fch0max=v

          var=(fch0max-fch0min)/fch0min

          fch1min=-1
          fch1max=-1
          for q in ch1:
              v=ch1[q]
              if fch1min==-1 or v<fch1min:
                 fch1min=v
              if fch1max==-1 or v>fch1max:
                 fch1max=v

          impr=0.0
          if fch1min!=0: impr=fch0min/fch1min

          ol["meta_extra"]["cpu_cur_freq"]=cpu_freq1

          sol=ol["solutions"][0]
          sol["extra_meta"]["cpu_cur_freq"]=cpu_freq1

          point=sol["points"][0]

          point["characteristics"]["##characteristics#run#execution_time_kernel_0#min"]=fch0min
          point["characteristics"]["##characteristics#run#repeat#min"]=repeat

          point["improvements"]["##characteristics#run#execution_time_kernel_0#min_imp"]=impr

#          Hack: don't write for now, otherwise most of the time ignored ...
          var=-1
          point["misc"]["##characteristics#run#execution_time_kernel_0#range_percent"]=var

          sol["points"][0]=point
          ol["solutions"][0]=sol

          # Get conditions from a scenario
          r=ck.access({'action':'load',
                       'module_uoa':cfg['module_deps']['module'],
                       'data_uoa':scenario_uoa})
          if r['return']>0: return r
          ds=r['dict']

          scon=ds.get('solution_conditions',[])
          if len(scon)>0:
             con=copy.deepcopy(point["characteristics"])
             con.update(point["improvements"])
             con.update(point["misc"])
             # Hack
             con["##characteristics#compile#md5_sum#min_imp"]=0

             ii={'action':'check',
                 'module_uoa':cfg['module_deps']['math.conditions'],
                 'new_points':['0'],
                 'results':[{'point_uid':'0', 'flat':con}],
                 'conditions':scon,
                 'middle_key':condition_objective,
                 'out':oo}
             ry=ck.access(ii)
             if ry['return']>0: return ry 

             xdpoints=ry['points_to_delete']
             if len(xdpoints)>0:
                xstatus='*** Your explored solution is not better than existing ones (conditions are not met) ***\n' 
                if o=='con':
                   ck.out('')
                   ck.out('    WARNING: conditions on characteristics were not met!')
             else:
                # Submitting solution
                ii=copy.deepcopy(ol)
                ii['action']='add_solution'
                ii['module_uoa']=cfg['module_deps']['program.optimization']
                ii['repo_uoa']='upload' # Hack 
                ii['user']=email
                rx=ck.access(ii)
                if rx['return']>0: return rx

                if rx.get('recorded','')=='yes':
                   ri=rx.get('recorded_info',{})
                   xstatus=ri.get('status','')
                   xlog=ri.get('log','')

                   rz=ck.access({'action':'log',
                                 'module_uoa':cfg['module_deps']['experiment'],
                                 'file_name':cfg['log_file_results'],
                                 'text':xlog})
                   if rz['return']>0: return rz

                else:
                   xstatus='*** Your explored solution is not better than existing ones ***\n' 

             r=ck.access({'action':'log',
                          'module_uoa':cfg['module_deps']['experiment'],
                          'text':'Result of crowd experiment (UID='+suid+') : '+cuid+' ('+email+'): '+xstatus+'\n'})
             if r['return']>0: return r


       # Cleaning experiment entry
       r=ck.access({'action':'delete',
                    'module_uoa':cfg['module_deps']['experiment'],
                    'data_uoa':euoa})
       if r['return']>0: return r

       # Cleaning crowd entry
       r=ck.access({'action':'delete',
                    'module_uoa':work['self_module_uid'],
                    'data_uoa':cuid})
       if r['return']>0: return r

       # Finishing
       status='Crowdsourced results from your mobile device were successfully processed by Collective Knowledge Aggregator!\n\n'+xstatus

       if o=='con':
          ck.out('')
          ck.out(status)

       rr['status']=status

    else:
       ###################################################################################################################
       # Initialize platform
       pf=i.get('platform_features',{})

       cpu_abi=pf.get('cpu',{}).get('cpu_abi','')
       os_bits=pf.get('os',{}).get('bits','')

       tos=''
       static=''
       max_size_pack=1200000
       extra_tags=''
       if cpu_abi.startswith('armeabi-'):
          tos='android19-arm'
          extra_tags='arm-specific'
       elif cpu_abi.startswith('arm64'):
          tos='android21-arm64'
#          extra_tags='arm-specific'
       elif cpu_abi=='x86':
          tos='android19-x86'
          static='yes'
          max_size_pack=2200000
#          if os_bits=='64':
#             tos='android21-x86_64'

       if tos=='':
          return {'return':1, 'error':'ABI of your mobile device is not yet supported for crowdtuning ('+cpu_abi+') - please contact author ([email protected]) to check if it\'s in development'}

       tdid=''
       hos=''

       xscenario=i.get('scenario','')

       scenarios=cfg['scenarios']
       ls=len(scenarios)

       # Prepare platform info
       ii={'action':'detect',
           'module_uoa':cfg['module_deps']['platform.os'],
           'target_os':tos,
           'skip_info_collection':'yes',
           'out':oo}
       pi=ck.access(ii)
       if pi['return']>0: return pi
       del(pi['return'])

       # Merge with remote device platform features
       r=ck.merge_dicts({'dict1':pi['features'], 'dict2':pf})
       if r['return']>0: return r

       pf['features']=r['dict1']

       #Log
       r=ck.dumps_json({'dict':pf, 'skip_indent':'yes', 'sort_keys':'yes'})
       if r['return']>0: return r
       x=r['string']

       r=ck.access({'action':'log',
                    'module_uoa':cfg['module_deps']['experiment'],
                    'text':email+'\n'+x+'\n'})
       if r['return']>0: return r

       # Try to generate at least one experimental pack!
       n=0
       nm=32

       success=False
       while n<nm and not success:
          n+=1

          # select scenario randomly
          if xscenario!='': scenario=xscenario
          else:             scenario=scenarios[randint(0,ls-1)]

          pic=copy.deepcopy(pi)

          ii={'action':'crowdsource',
              'module_uoa':scenario,
              'target_os':tos,
              'local':'yes',
              'quiet':'yes',
              'iterations':1,
              'platform_info':pic,
              'once':'yes',
              'skip_collaborative':'yes',
              'parametric_flags':'yes',
#              'static':'yes',
#              'program_uoa':'*susan',
              'any_flag_tags':extra_tags,
#              'cmd_key':'edges',
#              'dataset_uoa':'image-pgm-0001',
              'extra_dataset_tags':['small'],
              'no_run':'yes',
              'keep_experiments':'yes',
              'new':'yes',
              'static':static,
              'skip_pruning':'yes',
              'skip_info_collection':'yes',
              'out':oo}
          rrr=ck.access(ii)
          if rrr['return']>0:
             if o=='con':
                ck.out('')
                ck.out('WARNING: '+rrr['error']+' - can\'t continue this sub-scenario ...')
                ck.out('')
          else:
             # Prepare pack ...
             ri=rrr.get('recorded_info',{})
             ruid=ri.get('recorded_uid','')

             lio=rrr.get('last_iteration_output',{})
             fail=lio.get('fail','')
             if fail=='yes' or 'off_line' not in rrr: # sometimes off_line not in rrr, why I don't know yet
                if o=='con':
                   ck.out('')
                   ck.out('WARNING: Pipeline failed ('+lio.get('fail_reason','')+')')
                   ck.out('')

                # Delete failed experiment
                if ruid!='':
                   ii={'action':'delete',
                       'module_uoa':cfg['module_deps']['experiment'],
                       'data_uoa':ruid}
                   r=ck.access(ii)
                   # ignore output

             else:
                # Prepare pack
                ol=rrr['off_line']
                ed=rrr.get('experiment_desc',{})
                choices=ed.get('choices',{})

                d={'experiment_uoa':ruid,
                   'off_line':ol}

                ii={'action':'add',
                    'module_uoa':work['self_module_uid'],
                    'repo_uoa':ruoa,
                    'dict':d}
                r=ck.access(ii)
                if r['return']>0: return r

                p=r['path']

                cuid=r['data_uid'] # crowd experiment identifier
                rr['crowd_uid']=cuid

                x=lio.get('characteristics',{}).get('compile',{}).get('joined_compiler_flags','')

                dsc='Scenario: '+rrr.get('scenario_desc','')+'\n'
                dsc+='Sub-scenario: '+rrr.get('subscenario_desc','')+'\n'
                dsc+='Benchmark/codelet: '+choices.get('data_uoa','')+'\n'
                dsc+='CMD key: '+choices.get('cmd_key','')+'\n'
                dsc+='Dataset: '+choices.get('dataset_uoa','')+'\n'
                dsc+='Dataset file: '+choices.get('dataset_file','')+'\n'
                dsc+='Optimizations:\n'
                dsc+='* OpenCl tuning: not used\n'
                dsc+='* Compiler description: '+choices.get('compiler_description_uoa','')+'\n'
                dsc+='* Compiler flags: -O3 vs '+x+'\n'

                rr['desc']=dsc

                deps=lio.get('dependencies',{})
                for kdp in deps:
                    dp=deps[kdp]
                    z=dp.get('cus',{})
                    dl=z.get('dynamic_lib','')
                    pl=z.get('path_lib','')

                    if dl!='' and pl!='':
                       pidl=os.path.join(pl, dl)
                       if os.path.isfile(pidl):
                          pidl1=os.path.join(p, dl)
                          try:
                             shutil.copyfile(pidl, pidl1)
                          except Exception as e: 
                             pass

                if o=='con':
                   ck.out('')
                   ck.out('  Crowd UID: '+cuid)

                # Copying binaries and inputs here
                target_exe_0=rrr.get('original_target_exe','')
                target_path_0=rrr.get('original_path_exe','')
                target_exe_1=lio.get('state',{}).get('target_exe','')
                tp1=rrr.get('new_path_exe','')

                tp0=os.path.dirname(target_path_0)
                target_path_1=os.path.join(tp0,tp1)

                if o=='con':
                   ck.out('')
                   ck.out('Copying executables:')
                   ck.out(' * '+target_path_0+'  /  '+target_exe_0)
                   ck.out(' * '+target_path_1+'  /  '+target_exe_1)
                   ck.out('')

                duoa=choices.get('dataset_uoa','')
                dfile=choices.get('dataset_file','')

                # create cmd
                prog_uoa=choices.get('data_uoa','')
                cmd_key=choices.get('cmd_key','')
                r=ck.access({'action':'load',
                             'module_uoa':cfg['module_deps']['program'],
                             'data_uoa':prog_uoa})
                if r['return']>0: return r
                dd=r['dict']
                pp=r['path']

                rcm=dd.get('run_cmds','').get(cmd_key,{}).get('run_time',{}).get('run_cmd_main','')
                rcm=rcm.replace('$#BIN_FILE#$ ','')
                rcm=rcm.replace('$#dataset_path#$','')
                rcm=rcm.replace('$#dataset_filename#$',dfile)
                rcm=rcm.replace('$#src_path#$','')

                rif=dd.get('run_cmds','').get(cmd_key,{}).get('run_time',{}).get('run_input_files',[])

                if o=='con':
                   ck.out('Cmd: '+rcm)

                if target_path_0!='' and target_path_1!='' and target_exe_0!='' and target_exe_1!='' and \
                   not (rcm.find('$#')>=0 or rcm.find('#$')>=0 or rcm.find('<')>=0):

                   te0=os.path.join(target_path_0, target_exe_0)
                   te1=os.path.join(target_path_1, target_exe_1)

                   nte0=os.path.join(p, target_exe_0)
                   nte1=os.path.join(p, target_exe_1)

                   # Copying binary files
                   copied=True
                   try:
                      shutil.copyfile(te0, nte0)
                      shutil.copyfile(te1, nte1)

                      for inp in rif:
                          px1=os.path.join(pp, inp)
                          px2=os.path.join(p, inp)
                          shutil.copyfile(px1, px2)

                   except Exception as e: 
                      copied=False
                      pass

                   if copied:
                      # clean dirs
                      try:
                         shutil.rmtree(target_path_0, ignore_errors=True)
                         shutil.rmtree(target_path_1, ignore_errors=True)
                      except Exception as e: 
                         if o=='con':
                            ck.out('')
                            ck.out('WARNING: can\'t fully erase tmp dir')
                            ck.out('')
                         pass

                      if o=='con':
                         ck.out('Copying datasets ...')

                      # Check dataset files
                      rr['choices']=choices

                      copied=True
                      if duoa!='' and dfile!='':
                         r=ck.access({'action':'load',
                                      'module_uoa':cfg['module_deps']['dataset'],
                                      'data_uoa':duoa})
                         if r['return']>0: return r

                         pd=r['path']

                         td=os.path.join(pd, dfile)
                         ntd=os.path.join(p, dfile)

                         copied=True
                         try:
                            shutil.copyfile(td, ntd)
                         except Exception as e: 
                            copied=False
                            pass

                      if copied:
                         if o=='con':
                            ck.out('Preparing zip ...')

                         # Prepare archive
                         zip_method=zipfile.ZIP_DEFLATED

                         gaf=i.get('all','')

                         fl={}

                         r=ck.list_all_files({'path':p})
                         if r['return']>0: return r

                         flx=r['list']

                         for k in flx:
                             fl[k]=flx[k]

                         pfn=os.path.join(p, fpack)

                         # Write archive
                         copied=True
                         try:
                            f=open(pfn, 'wb')
                            z=zipfile.ZipFile(f, 'w', zip_method)
                            for fn in fl:
                                p1=os.path.join(p, fn)
                                z.write(p1, fn, zip_method)
                            z.close()
                            f.close()

                         except Exception as e:
                            copied=False

                         if copied:
                            if o=='con':
                               ck.out('Preparing cmd ...')

                            size=os.path.getsize(pfn) 

                            r=ck.convert_file_to_upload_string({'filename':pfn})
                            if r['return']>0: return r

                            fx=r['file_content_base64']

                            #MD5
                            import hashlib
                            md5=hashlib.md5(fx.encode()).hexdigest()

                            if o=='con':
                               ck.out('Finalizing ...')

                            calibrate='no'
                            if dd.get('run_vars',{}).get('CT_REPEAT_MAIN','')!='':
                               calibrate='yes'

                            if len(fx)>max_size_pack:
                               if o=='con':
                                  ck.out('')
                                  ck.out('WARNING: pack is too large ('+str(len(fx))+')')
                                  ck.out('')
                            else:
                               # finalize info
                               rr['file_content_base64']=fx
                               rr['size']=size 
                               rr['md5sum']=md5
                               rr['run_cmd_main']=rcm
                               rr['bin_file0']=target_exe_0
                               rr['bin_file1']=target_exe_1
                               rr['calibrate']=calibrate
                               rr['calibrate_max_iters']=10
                               rr['calibrate_time']=10.0
                               rr['repeat']=5
                               rr['ct_repeat']=1

                               success=True

                if not success:
                   if o=='con':
                      ck.out('')
                      ck.out('WARNING: some files are missing - removing crowd entry ('+cuid+') ...')

                   ii={'action':'rm',
                       'module_uoa':work['self_module_uid'],
                       'data_uoa':cuid}
                   r=ck.access(ii)
                   if r['return']>0: return r

       if not success:
          rr={'return':1, 'error':'could not create any valid expeirmental pack for your mobile - possibly internal error! Please, contact authors'}

    return rr

Example 10

Project: ck-tensorflow
Source File: module.py
View license
def crowdsource(i):
    """
    Input:  {
              (local)               - if 'yes', local crowd-benchmarking, instead of public
              (user)                - force different user ID/email for demos

              (choices)             - force different choices to program pipeline

              (repetitions)         - statistical repetitions (default=1), for now statistical analysis is not used (TBD)
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import copy
    import os

    # Setting output
    o=i.get('out','')
    oo=''
    if o=='con': oo='con'

    quiet=i.get('quiet','')

    er=i.get('exchange_repo','')
    if er=='': er=ck.cfg['default_exchange_repo_uoa']
    esr=i.get('exchange_subrepo','')
    if esr=='': esr=ck.cfg['default_exchange_subrepo_uoa']

    if i.get('local','')=='yes': 
       er='local'
       esr=''

    la=i.get('local_autotuning','')

    repetitions=i.get('repetitions','')
    if repetitions=='': repetitions=3
    repetitions=int(repetitions)

    record='no'

    # Check if any input has . and convert to dict
    for k in list(i.keys()):
        if k.find('.')>0:
            v=i[k]

            kk='##'+k.replace('.','#')

            del(i[k])

            r=ck.set_by_flat_key({'dict':i, 'key':kk, 'value':v})
            if r['return']>0: return r

    choices=i.get('choices',{})
    xchoices=copy.deepcopy(choices)

    # Get user 
    user=''

    mcfg={}
    ii={'action':'load',
        'module_uoa':'module',
        'data_uoa':cfg['module_deps']['program.optimization']}
    r=ck.access(ii)
    if r['return']==0:
       mcfg=r['dict']

       dcfg={}
       ii={'action':'load',
           'module_uoa':mcfg['module_deps']['cfg'],
           'data_uoa':mcfg['cfg_uoa']}
       r=ck.access(ii)
       if r['return']>0 and r['return']!=16: return r
       if r['return']!=16:
          dcfg=r['dict']

       user=dcfg.get('user_email','')

    # Initialize local environment for program optimization ***********************************************************
    pi=i.get('platform_info',{})
    if len(pi)==0:
       ii=copy.deepcopy(i)
       ii['action']='initialize'
       ii['module_uoa']=cfg['module_deps']['program.optimization']
       ii['data_uoa']='tensorflow'
       ii['exchange_repo']=er
       ii['exchange_subrepo']=esr
       ii['skip_welcome']='yes'
       ii['skip_log_wait']='yes'
       ii['crowdtuning_type']='tensorflow-crowd-benchmarking'
       r=ck.access(ii)
       if r['return']>0: return r

       pi=r['platform_info']
       user=r.get('user','')

    hos=pi['host_os_uoa']
    hosd=pi['host_os_dict']

    tos=pi['os_uoa']
    tosd=pi['os_dict']
    tbits=tosd.get('bits','')

    remote=tosd.get('remote','')

    tdid=pi['device_id']

    features=pi.get('features',{})

    fplat=features.get('platform',{})
    fos=features.get('os',{})
    fcpu=features.get('cpu',{})
    fgpu=features.get('gpu',{})

    plat_name=fplat.get('name','')
    plat_uid=features.get('platform_uid','')
    os_name=fos.get('name','')
    os_uid=features.get('os_uid','')
    cpu_name=fcpu.get('name','')
    if cpu_name=='': cpu_name='unknown-'+fcpu.get('cpu_abi','')
    cpu_uid=features.get('cpu_uid','')
    gpu_name=fgpu.get('name','')
    gpgpu_name=''
    sn=fos.get('serial_number','')

    # Ask for cmd
    tp=['cpu', 'cuda', 'opencl']

    ck.out(line)
    ck.out('Select TensorFlow library type:')
    ck.out('')
    r=ck.access({'action':'select_list',
                 'module_uoa':cfg['module_deps']['choice'],
                 'choices':tp})
    if r['return']>0: return r
    xtp=r['choice']

    # Get extra platform features if "cuda" or "opencl"
    run_cmd='default'
    tags='lib,tensorflow,tensorflow-'+xtp
    gpgpu_uid=''
    if xtp=='cuda' or xtp=='opencl':
        r=ck.access({'action':'detect',
                     'module_uoa':cfg['module_deps']['platform.gpgpu'],
                     'host_os':hos,
                     'target_os':tos,
                     'device_id':tdid,
                     'type':xtp,
                     'share':'yes',
                     'exchange_repo':er,
                     'exchange_subrepo':esr})
        if r['return']>0: return r
        gfeat=r.get('features',{})
        gpgpus=gfeat.get('gpgpu',[])

        if len(gpgpus)>0:
            gpgpu_name=gpgpus[0].get('gpgpu',{}).get('name','')
            gpgpu_uid=gpgpus[0].get('gpgpu_uoa','')

    # Get deps from TensorFlow program
    r=ck.access({'action':'load',
                 'module_uoa':cfg['module_deps']['program'],
                 'data_uoa':'tensorflow'})
    if r['return']>0: return r

    dd=r['dict']
    deps=dd['compile_deps']
    pp=r['path']

    lib_dep=deps['lib-tensorflow']
    lib_dep['tags']=tags

    # Get explicit choices (batch size, num batches)
    env=i.get('env',{})
    echoices=dd['run_vars']
    for k in echoices:
        if env.get(k,'')!='':
            echoices[k]=env[k]

    # Check environment for selected type
    r=ck.access({'action':'resolve',
                 'module_uoa':cfg['module_deps']['env'],
                 'deps':deps,
                 'host_os':hos,
                 'target_os':tos,
                 'device_id':tdid,
                 'out':o})
    if r['return']>0: return r
    deps=r['deps']

    # Prepare CK pipeline for a given workload
    ii={'action':'pipeline',

        'module_uoa':cfg['module_deps']['program'],
        'data_uoa':'tensorflow',

        'prepare':'yes',

        'env':env,
        'choices':choices,
        'dependencies':deps,
        'cmd_key':run_cmd,
        'no_state_check':'yes',
        'no_compiler_description':'yes',
        'skip_info_collection':'yes',
        'skip_calibration':'yes',
        'cpu_freq':'max',
        'gpu_freq':'max',
        'env_speed':'yes',
        'energy':'no',
        'skip_print_timers':'yes',
        'generate_rnd_tmp_dir':'no',

        'out':oo}

    rr=ck.access(ii)
    if rr['return']>0: return rr

#    ck.save_json_to_file({'json_file':'/tmp/xyz3.json','dict':rr, 'sort_keys':'yes'})
#    exit(1)

    fail=rr.get('fail','')
    if fail=='yes':
        return {'return':10, 'error':'pipeline failed ('+rr.get('fail_reason','')+')'}

    ready=rr.get('ready','')
    if ready!='yes':
        return {'return':11, 'error':'couldn\'t prepare universal CK program workflow'}

    state=rr['state']
    tmp_dir=state['tmp_dir']

    # Clean pipeline
    if 'ready' in rr: del(rr['ready'])
    if 'fail' in rr: del(rr['fail'])
    if 'return' in rr: del(rr['return'])

    # Check if aggregted stats
    aggregated_stats={} # Pre-load statistics ...

    # Prepare high-level experiment meta
    meta={'cpu_name':cpu_name,
          'os_name':os_name,
          'plat_name':plat_name,
          'gpu_name':gpu_name,
          'tensorflow_type':xtp,
          'gpgpu_name':gpgpu_name,
          'cmd_key':run_cmd,
          'echoices':echoices}

    # Process deps
    xdeps={}
    xnn=''
    xblas=''
    for k in deps:
        dp=deps[k]
        xdeps[k]={'name':dp.get('name',''), 
                  'data_name':dp.get('dict',{}).get('data_name',''), 
                  'ver':dp.get('ver','')}

    meta['xdeps']=xdeps
    meta['nn_type']='alexnet'

    mmeta=copy.deepcopy(meta)

    # Extra meta which is not used to search similar case ...
    mmeta['platform_uid']=plat_uid
    mmeta['os_uid']=os_uid
    mmeta['cpu_uid']=cpu_uid
    mmeta['gpgpu_uid']=gpgpu_uid
    mmeta['user']=user

    # Check if already exists
    # tbd

    # Run CK pipeline *****************************************************
    pipeline=copy.deepcopy(rr)
    if len(choices)>0:
        r=ck.merge_dicts({'dict1':pipeline['choices'], 'dict2':xchoices})
        if r['return']>0: return r

    ii={'action':'autotune',
        'module_uoa':cfg['module_deps']['pipeline'],

        'iterations':1,
        'repetitions':repetitions,

        'collect_all':'yes',
        'process_multi_keys':['##characteristics#*'],

        'tmp_dir':tmp_dir,

        'pipeline':pipeline,

        'stat_flat_dict':aggregated_stats,

        "features_keys_to_process":["##choices#*"],

        "record_params": {
          "search_point_by_features":"yes"
        },

        'out':oo}

    rrr=ck.access(ii)
    if rrr['return']>0: return rrr

    ls=rrr.get('last_iteration_output',{})
    state=ls.get('state',{})
    xchoices=copy.deepcopy(ls.get('choices',{}))
    lsa=rrr.get('last_stat_analysis',{})
    lsad=lsa.get('dict_flat',{})

    ddd={'meta':mmeta}

    ddd['choices']=xchoices

    features=ls.get('features',{})

    deps=ls.get('dependencies',{})

    fail=ls.get('fail','')
    fail_reason=ls.get('fail_reason','')

    ch=ls.get('characteristics',{})

    # Save pipeline
    ddd['state']={'fail':fail, 'fail_reason':fail_reason}
    ddd['characteristics']=ch

    ddd['user']=user

    if o=='con':
        ck.out('')
        ck.out('Saving results to the remote public repo ...')
        ck.out('')

        # Find remote entry
        rduid=''

        ii={'action':'search',
            'module_uoa':work['self_module_uid'],
            'repo_uoa':er,
            'remote_repo_uoa':esr,
            'search_dict':{'meta':meta}}
        rx=ck.access(ii)
        if rx['return']>0: return rx

        lst=rx['lst']

        if len(lst)==1:
            rduid=lst[0]['data_uid']
        else:
            rx=ck.gen_uid({})
            if rx['return']>0: return rx
            rduid=rx['data_uid']

        # Update meta
        rx=ck.access({'action':'update',
                      'module_uoa':work['self_module_uid'],
                      'data_uoa':rduid,
                      'repo_uoa':er,
                      'remote_repo_uoa':esr,
                      'dict':ddd,
                      'substitute':'yes',
                      'sort_keys':'yes'})
        if rx['return']>0: return rx

        # Push statistical characteristics
        fstat=os.path.join(pp,tmp_dir,ffstat)

        r=ck.save_json_to_file({'json_file':fstat, 'dict':lsad})
        if r['return']>0: return r

        rx=ck.access({'action':'push',
                      'module_uoa':work['self_module_uid'],
                      'data_uoa':rduid,
                      'repo_uoa':er,
                      'remote_repo_uoa':esr,
                      'filename':fstat,
                      'overwrite':'yes'})
        if rx['return']>0: return rx

        os.remove(fstat)

        # Info
        if o=='con':
            ck.out('Succesfully recorded results in remote repo (Entry UID='+rduid+')')

            # Check host URL prefix and default module/action
            url='http://cknowledge.org/repo/web.php?template=cknowledge&action=index&module_uoa=wfe&native_action=show&native_module_uoa=program.optimization&scenario=155b6fa5a4012a93&highlight_uid='+rduid
            ck.out('')
            ck.out('You can see your results at the following URL:')
            ck.out('')
            ck.out(url)

    return {'return':0}

Example 11

Project: ck-wa
Source File: module.py
View license
def run(i):
    """
    Input:  {
              (data_uoa)            - workload to run (see "ck list wa").

              (target)              - machine UOA (see "ck list machine")

              (record)              - if 'yes', record result in repository in 'experiment' standard
              (skip-record-raw)     - if 'yes', skip record raw results
              (overwrite)           - if 'yes', do not record date and time in result directory, but overwrite wa-results

              (repetitions)         - statistical repetitions (default=1), for now statistical analysis is not used (TBD)

              (config)              - customize config
              (params)              - workload params
              (scenario)            - use pre-defined scenario (see ck list wa-scenario)

              (keep)                - if 'yes', keep tmp file in workload (program) directory

              (cache)               - if 'yes', cache params (to automate runs)
              (cache_repo_uoa)      - repo UOA where to cache params

              (share)               - if 'yes', share benchmarking results with public cknowledge.org/repo server
                                      (our crowd-benchmarking demo)
              (exchange_repo)       - which repo to record/update info (remote-ck by default)
              (exchange_subrepo)    - if remote, remote repo UOA
              (scenario_module_uoa) - UOA of the scenario (to share results)
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import os
    import copy
    import time
    import shutil

    o=i.get('out','')
    oo=''
    if o=='con': oo=o

    cur_dir=os.getcwd()

    # Check if any input has . and convert to dict
    for k in list(i.keys()):
        if k.find('.')>0:
            v=i[k]

            kk='##'+k.replace('.','#')

            del(i[k])

            r=ck.set_by_flat_key({'dict':i, 'key':kk, 'value':v})
            if r['return']>0: return r

    # Check if share
    share=i.get('share','')
    user=i.get('user','')
    smuoa=i.get('scenario_module_uoa','')
    if smuoa=='': smuoa=cfg['module_deps']['experiment.bench.workload.android']

    er=i.get('exchange_repo','')
    if er=='': er=ck.cfg['default_exchange_repo_uoa']

    esr=i.get('exchange_subrepo','')
    if esr=='': esr=ck.cfg['default_exchange_subrepo_uoa']

    # Get device and workload params
    config=i.get('config',{})
    params=i.get('params',{})

    # Check scenarios
    scenario=i.get('scenario','')
    if scenario=='': scenario='-'

    if scenario!='' and scenario!='-':
        r=ck.access({'action':'load',
                     'module_uoa':cfg['module_deps']['wa-scenario'],
                     'data_uoa':scenario})
        if r['return']>0: return r
        d=r['dict']

        r=ck.merge_dicts({'dict1':config, 'dict2':d.get('config',{})})
        if r['return']>0: return r

        r=ck.merge_dicts({'dict1':params, 'dict2':d.get('params',{})})
        if r['return']>0: return r

    # Check workload(s)
    duoa=i.get('data_uoa','')
    if duoa!='':
        duoa='wa-'+duoa

    r=ck.access({'action':'search',
                 'module_uoa':cfg['module_deps']['program'],
                 'add_meta':'yes',
                 'data_uoa':duoa,
                 'tags':'wa'})
    if r['return']>0: return r
    lst=r['lst']

    if len(lst)==0:
       return {'return':1, 'error':'workload is not specified or found'}

    record=i.get('record','')
    skip_record_raw=i.get('skip-record-raw','')
    overwrite=i.get('overwrite','')

    repetitions=i.get('repetitions','')
    if repetitions=='': repetitions=3
    repetitions=int(repetitions)

    cache=i.get('cache','')

    # Get target features
    target=i.get('target','')

    if target=='':
        # Check and possibly select target machines
        r=ck.search({'module_uoa':cfg['module_deps']['machine'], 'data_uoa':target, 'add_meta':'yes'})
        if r['return']>0: return r

        dlst=r['lst']

        # Prune search by only required devices
        rdat=['wa_linux', 'wa_android']

        xlst=[]

        if len(rdat)==0:
            xlst=dlst
        else:
            for q in dlst:
                if q.get('meta',{}).get('access_type','') in rdat:
                    xlst.append(q)

        if len(xlst)==0:
            return {'return':1, 'error':'no suitable target devices found (use "ck add machine" to register new target device)'}
        elif len(xlst)==1:
            target=xlst[0]['data_uoa']
        else:
            # SELECTOR *************************************
            ck.out('')
            ck.out('Please select target device to run your workloads on:')
            ck.out('')
            r=ck.select_uoa({'choices':xlst})
            if r['return']>0: return r
            target=r['choice']

    if target=='':
        return {'return':1, 'error':'--target machine is not specified (see "ck list machine")'}

    ck.out('')
    ck.out('Selected target machine: '+target)
    ck.out('')

    # Load target machine description
    r=ck.access({'action':'load',
                 'module_uoa':cfg['module_deps']['machine'],
                 'data_uoa':target})
    if r['return']>0: return r
    target_uoa=r['data_uoa']
    target_uid=r['data_uid']
    features=r['dict']['features']

    device_id=r['dict'].get('device_id','')

    fplat=features.get('platform',{})
    fos=features.get('os',{})
    fcpu=features.get('cpu',{})
    fgpu=features.get('gpu',{})

    plat_name=fplat.get('name','')
    os_name=fos.get('name','')
    cpu_name=fcpu.get('name','')
    if cpu_name=='': cpu_name='unknown-'+fcpu.get('cpu_abi','')
    gpu_name=fgpu.get('name','')
    sn=fos.get('serial_number','')

    # Iterate over workloads
    rrr={}

    cparams=copy.deepcopy(params)
    for wa in lst:
        # Reset dir
        os.chdir(cur_dir)

        # Reset params
        params=copy.deepcopy(cparams)

        duoa=wa['data_uoa']
        duid=wa['data_uid']
        dw=wa['meta']
        dp=wa['path']

        apk_name=dw.get('apk',{}).get('name','')

        ww=dw['wa_alias']

        # If cache, check if params already exist
        if cache=='yes':
            # Check extra
            cruoa=i.get('cache_repo_uoa','')

            # Attempt to load
            r=ck.access({'action':'load',
                         'module_uoa':cfg['module_deps']['wa-params'],
                         'data_uoa':duoa,
                         'repo_uoa':cruoa})
            if r['return']>0 and r['return']!=16:
                return r

            if r['return']==0:
                cruoa=r['repo_uid']

                rx=ck.merge_dicts({'dict1':params, 'dict2':r['dict'].get('params',{})})
                if rx['return']>0: return rx

        # Check params here (there is another place in pre-processing scripts
        #  to be able to run WA via program pipeline directly) 
        dparams=dw.get('params',{})

        if len(dparams)>0:
            ck.out('Parameters needed for this workload:')
            ck.out('')

        for k in sorted(dparams):
            x=dparams[k]

            ds=x.get('desc','')

            dv=params.get(k,None)
            if dv==None:
                dv=x.get('default',None)

            if dv!=None:
                ck.out(k+': '+str(dv))
            elif x.get('mandatory',False):
                r=ck.inp({'text':k+' ('+ds+'): '})
                if r['return']>0: return r
                dv=r['string'].strip()
                if dv=='':
                    dv=None

            if dv!=None:
                params[k]=dv

        # Cache params if required
        if cache=='yes':
            r=ck.access({'action':'update',
                         'module_uoa':cfg['module_deps']['wa-params'],
                         'data_uoa':duoa,
                         'repo_uoa':cruoa,
                         'dict':{'params':params},
                         'sort_keys':'yes',
                         'substitute':'yes',
                         'ignore_update':'yes'})
            if r['return']>0:
                return r

            if o=='con':
                ck.out('')
                ck.out('Parameters were cached in '+r['path']+' ...')

        # Prepare high-level experiment meta
        meta={'program_uoa':duoa,
              'program_uid':duid,
              'workload_name':ww,
              'cpu_name':cpu_name,
              'os_name':os_name,
              'plat_name':plat_name,
              'gpu_name':gpu_name,
              'scenario':scenario,
              'serial_number':sn}

        mmeta=copy.deepcopy(meta)
        mmeta['local_target_uoa']=target_uoa
        mmeta['local_target_uid']=target_uid

        if o=='con':
            ck.out(line)
            ck.out('Running workload '+ww+' (CK UOA='+duoa+') ...')

            time.sleep(1)

        aggregated_stats={} # Pre-load statistics ...

        result_path=''
        result_path0=''
        if skip_record_raw!='yes':
            if o=='con':
                ck.out('  Preparing wa_result entry to store raw results ...')

            ddd={'meta':mmeta}

            ii={'action':'search',
                'module_uoa':cfg['module_deps']['wa-result'],
                'search_dict':{'meta':meta}}
            rx=ck.access(ii)
            if rx['return']>0: return rx

            lst=rx['lst']

            if len(lst)==0:
                rx=ck.access({'action':'add',
                              'module_uoa':cfg['module_deps']['wa-result'],
                              'dict':ddd,
                              'sort_keys':'yes'})
                if rx['return']>0: return rx
                result_uid=rx['data_uid']
                result_path=rx['path']
            else:
                result_uid=lst[0]['data_uid']
                result_path=lst[0]['path']

                # Load entry
                rx=ck.access({'action':'load',
                              'module_uoa':cfg['module_deps']['wa-result'],
                              'data_uoa':result_uid})
                if rx['return']>0: return rx
                ddd=rx['dict']

            # Possible directory extension (date-time)
            result_path0=result_path
            if overwrite!='yes':
                rx=ck.get_current_date_time({})
                if rx['return']>0: return rx

                aa=rx['array']

                ady=str(aa['date_year'])
                adm=str(aa['date_month'])
                adm=('0'*(2-len(adm)))+adm
                add=str(aa['date_day'])
                add=('0'*(2-len(add)))+add
                ath=str(aa['time_hour'])
                ath=('0'*(2-len(ath)))+ath
                atm=str(aa['time_minute'])
                atm=('0'*(2-len(atm)))+atm
                ats=str(aa['time_second'])
                ats=('0'*(2-len(ats)))+ats

                pe=ady+adm+add+'-'+ath+atm+ats

                result_path=os.path.join(result_path,pe)
                if not os.path.isdir(result_path):
                    os.makedirs(result_path)

            # Record input
            finp=os.path.join(result_path,'ck-input.json')
            r=ck.save_json_to_file({'json_file':finp, 'dict':i})
            if r['return']>0: return r

            ff=os.path.join(result_path,'ck-platform-features.json')
            r=ck.save_json_to_file({'json_file':ff, 'dict':features})
            if r['return']>0: return r

            # Check stats ...
            fstat=os.path.join(result_path0,ffstat)
            if overwrite!='yes':
                # Check if file already exists (no check for parallel runs)
                if os.path.isfile(fstat):
                    r=ck.load_json_file({'json_file':fstat})
                    if r['return']==0:
                        aggregated_stats=r['dict']

        # Prepare CK pipeline for a given workload
        ii={'action':'pipeline',

            'module_uoa':cfg['module_deps']['program'],
            'data_uoa':duid,

            'target':target,
            'device_id':device_id,

            'prepare':'yes',

            'params':{'config':config,
                      'params':params},

            'no_state_check':'yes',
            'no_compiler_description':'yes',
            'skip_info_collection':'yes',
            'skip_calibration':'yes',
            'cpu_freq':'',
            'gpu_freq':'',
            'env_speed':'yes',
            'energy':'no',
            'skip_print_timers':'yes',
            'generate_rnd_tmp_dir':'yes',

            'env':{'CK_WA_RAW_RESULT_PATH':result_path},

            'out':oo}
        rr=ck.access(ii)
        if rr['return']>0: return rr

        fail=rr.get('fail','')
        if fail=='yes':
            return {'return':10, 'error':'pipeline failed ('+rr.get('fail_reason','')+')'}

        ready=rr.get('ready','')
        if ready!='yes':
            return {'return':11, 'error':'couldn\'t prepare universal CK program workflow'}

        state=rr['state']
        tmp_dir=state['tmp_dir']

        # Clean pipeline
        if 'ready' in rr: del(rr['ready'])
        if 'fail' in rr: del(rr['fail'])
        if 'return' in rr: del(rr['return'])

        pipeline=copy.deepcopy(rr)

        # Save pipeline
        if skip_record_raw!='yes':
            fpip=os.path.join(result_path,'ck-pipeline-in.json')
            r=ck.save_json_to_file({'json_file':fpip, 'dict':pipeline})
            if r['return']>0: return r

        # Run CK pipeline *****************************************************
        ii={'action':'autotune',
            'module_uoa':cfg['module_deps']['pipeline'],
            'data_uoa':cfg['module_deps']['program'],

            'device_id':device_id,

            'iterations':1,
            'repetitions':repetitions,

            'collect_all':'yes',
            'process_multi_keys':['##characteristics#*'],

            'tmp_dir':tmp_dir,

            'pipeline':pipeline,

            'stat_flat_dict':aggregated_stats,

            'record':record,

            'meta':meta,

            'tags':'wa',

            "features_keys_to_process":["##choices#*"],

            "record_params": {
              "search_point_by_features":"yes"
            },

            "record_dict":{"subview_uoa":"3d9a4f4b03b1b257"},

            'out':oo}

        rrr=ck.access(ii)
        if rrr['return']>0: return rrr

        ls=rrr.get('last_iteration_output',{})
        state=ls.get('state',{})
        xchoices=copy.deepcopy(ls.get('choices',{}))
        lsa=rrr.get('last_stat_analysis',{})
        lsad=lsa.get('dict_flat',{})

        # Not very clean - trying to remove passes ...
        xparams=xchoices.get('params','').get('params',{})
        to_be_deleted=[]
        for k in xparams:
            if k.find('pass')>=0:
                to_be_deleted.append(k)

        for k in to_be_deleted:
            del(xparams[k])

        ddd['choices']=xchoices

        features=ls.get('features',{})
        apk_ver=''
        if apk_name!='':
            apk_ver=features.get('apk',{}).get(apk_name,{}).get('versionName','')

        deps=ls.get('dependencies',{})
        wa_ver=deps.get('wa',{}).get('cus',{}).get('version','')

        # Update meta
        ddd['meta']['apk_name']=apk_name
        ddd['meta']['apk_version']=apk_ver
        ddd['meta']['wa_version']=wa_ver

        # Clean tmp dir
        tmp_dir=state.get('tmp_dir','')
        if dp!='' and tmp_dir!='' and i.get('keep','')!='yes':
            shutil.rmtree(os.path.join(dp,tmp_dir))

        fail=ls.get('fail','')
        fail_reason=ls.get('fail_reason','')

        ch=ls.get('characteristics',{})

#        tet=ch.get('run',{}).get('total_execution_time',0)

        # Save pipeline
        ddd['state']={'fail':fail, 'fail_reason':fail_reason}
        ddd['characteristics']=ch

        if skip_record_raw!='yes':
            fpip=os.path.join(result_path,'ck-pipeline-out.json')
            r=ck.save_json_to_file({'json_file':fpip, 'dict':rrr})
            if r['return']>0: return r

            # Write stats ...
            r=ck.save_json_to_file({'json_file':fstat, 'dict':lsad})
            if r['return']>0: return r

            # Update meta
            rx=ck.access({'action':'update',
                          'module_uoa':cfg['module_deps']['wa-result'],
                          'data_uoa':result_uid,
                          'dict':ddd,
                          'substitute':'yes',
                          'sort_keys':'yes'})
            if rx['return']>0: return rx

        # Share results if crowd-benchmarking
        if share=='yes':
            ddd['user']=user

            if o=='con':
               ck.out('')
               ck.out('Saving results to the remote public repo ...')
               ck.out('')

            # Find remote entry
            rduid=''

            ii={'action':'search',
                'module_uoa':smuoa,
                'repo_uoa':er,
                'remote_repo_uoa':esr,
                'search_dict':{'meta':meta}}
            rx=ck.access(ii)

            lst=rx['lst']

            if len(lst)==1:
                rduid=lst[0]['data_uid']
            else:
                rx=ck.gen_uid({})
                if rx['return']>0: return rx
                rduid=rx['data_uid']

            # Update meta
            rx=ck.access({'action':'update',
                          'module_uoa':smuoa,
                          'data_uoa':rduid,
                          'repo_uoa':er,
                          'remote_repo_uoa':esr,
                          'dict':ddd,
                          'substitute':'yes',
                          'sort_keys':'yes'})
            if rx['return']>0: return rx

            # Push statistical characteristics
            if os.path.isfile(fstat):
                rx=ck.access({'action':'push',
                              'module_uoa':smuoa,
                              'data_uoa':rduid,
                              'repo_uoa':er,
                              'remote_repo_uoa':esr,
                              'filename':fstat,
                              'overwrite':'yes'})
                if rx['return']>0: return rx

            # Push latest results
            fx=os.path.join(result_path,'wa-output','results.json')
            if os.path.isfile(fx):
                rx=ck.access({'action':'push',
                              'module_uoa':smuoa,
                              'data_uoa':rduid,
                              'repo_uoa':er,
                              'remote_repo_uoa':esr,
                              'filename':fx,
                              'extra_path':'wa-output',
                              'overwrite':'yes'})
                if rx['return']>0: return rx

            # Info
            if o=='con':
                ck.out('Succesfully recorded results in the remote repo (Entry UID='+rduid+')')

    return rrr

Example 12

Project: ck-wa
Source File: module.py
View license
def run(i):
    """
    Input:  {
              (data_uoa)            - workload to run (see "ck list wa").

              (target)              - machine UOA (see "ck list machine")

              (record)              - if 'yes', record result in repository in 'experiment' standard
              (skip-record-raw)     - if 'yes', skip record raw results
              (overwrite)           - if 'yes', do not record date and time in result directory, but overwrite wa-results

              (repetitions)         - statistical repetitions (default=1), for now statistical analysis is not used (TBD)

              (config)              - customize config
              (params)              - workload params
              (scenario)            - use pre-defined scenario (see ck list wa-scenario)

              (keep)                - if 'yes', keep tmp file in workload (program) directory

              (cache)               - if 'yes', cache params (to automate runs)
              (cache_repo_uoa)      - repo UOA where to cache params

              (share)               - if 'yes', share benchmarking results with public cknowledge.org/repo server
                                      (our crowd-benchmarking demo)
              (exchange_repo)       - which repo to record/update info (remote-ck by default)
              (exchange_subrepo)    - if remote, remote repo UOA
              (scenario_module_uoa) - UOA of the scenario (to share results)
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import os
    import copy
    import time
    import shutil

    o=i.get('out','')
    oo=''
    if o=='con': oo=o

    cur_dir=os.getcwd()

    # Check if any input has . and convert to dict
    for k in list(i.keys()):
        if k.find('.')>0:
            v=i[k]

            kk='##'+k.replace('.','#')

            del(i[k])

            r=ck.set_by_flat_key({'dict':i, 'key':kk, 'value':v})
            if r['return']>0: return r

    # Check if share
    share=i.get('share','')
    user=i.get('user','')
    smuoa=i.get('scenario_module_uoa','')
    if smuoa=='': smuoa=cfg['module_deps']['experiment.bench.workload.android']

    er=i.get('exchange_repo','')
    if er=='': er=ck.cfg['default_exchange_repo_uoa']

    esr=i.get('exchange_subrepo','')
    if esr=='': esr=ck.cfg['default_exchange_subrepo_uoa']

    # Get device and workload params
    config=i.get('config',{})
    params=i.get('params',{})

    # Check scenarios
    scenario=i.get('scenario','')
    if scenario=='': scenario='-'

    if scenario!='' and scenario!='-':
        r=ck.access({'action':'load',
                     'module_uoa':cfg['module_deps']['wa-scenario'],
                     'data_uoa':scenario})
        if r['return']>0: return r
        d=r['dict']

        r=ck.merge_dicts({'dict1':config, 'dict2':d.get('config',{})})
        if r['return']>0: return r

        r=ck.merge_dicts({'dict1':params, 'dict2':d.get('params',{})})
        if r['return']>0: return r

    # Check workload(s)
    duoa=i.get('data_uoa','')
    if duoa!='':
        duoa='wa-'+duoa

    r=ck.access({'action':'search',
                 'module_uoa':cfg['module_deps']['program'],
                 'add_meta':'yes',
                 'data_uoa':duoa,
                 'tags':'wa'})
    if r['return']>0: return r
    lst=r['lst']

    if len(lst)==0:
       return {'return':1, 'error':'workload is not specified or found'}

    record=i.get('record','')
    skip_record_raw=i.get('skip-record-raw','')
    overwrite=i.get('overwrite','')

    repetitions=i.get('repetitions','')
    if repetitions=='': repetitions=3
    repetitions=int(repetitions)

    cache=i.get('cache','')

    # Get target features
    target=i.get('target','')

    if target=='':
        # Check and possibly select target machines
        r=ck.search({'module_uoa':cfg['module_deps']['machine'], 'data_uoa':target, 'add_meta':'yes'})
        if r['return']>0: return r

        dlst=r['lst']

        # Prune search by only required devices
        rdat=['wa_linux', 'wa_android']

        xlst=[]

        if len(rdat)==0:
            xlst=dlst
        else:
            for q in dlst:
                if q.get('meta',{}).get('access_type','') in rdat:
                    xlst.append(q)

        if len(xlst)==0:
            return {'return':1, 'error':'no suitable target devices found (use "ck add machine" to register new target device)'}
        elif len(xlst)==1:
            target=xlst[0]['data_uoa']
        else:
            # SELECTOR *************************************
            ck.out('')
            ck.out('Please select target device to run your workloads on:')
            ck.out('')
            r=ck.select_uoa({'choices':xlst})
            if r['return']>0: return r
            target=r['choice']

    if target=='':
        return {'return':1, 'error':'--target machine is not specified (see "ck list machine")'}

    ck.out('')
    ck.out('Selected target machine: '+target)
    ck.out('')

    # Load target machine description
    r=ck.access({'action':'load',
                 'module_uoa':cfg['module_deps']['machine'],
                 'data_uoa':target})
    if r['return']>0: return r
    target_uoa=r['data_uoa']
    target_uid=r['data_uid']
    features=r['dict']['features']

    device_id=r['dict'].get('device_id','')

    fplat=features.get('platform',{})
    fos=features.get('os',{})
    fcpu=features.get('cpu',{})
    fgpu=features.get('gpu',{})

    plat_name=fplat.get('name','')
    os_name=fos.get('name','')
    cpu_name=fcpu.get('name','')
    if cpu_name=='': cpu_name='unknown-'+fcpu.get('cpu_abi','')
    gpu_name=fgpu.get('name','')
    sn=fos.get('serial_number','')

    # Iterate over workloads
    rrr={}

    cparams=copy.deepcopy(params)
    for wa in lst:
        # Reset dir
        os.chdir(cur_dir)

        # Reset params
        params=copy.deepcopy(cparams)

        duoa=wa['data_uoa']
        duid=wa['data_uid']
        dw=wa['meta']
        dp=wa['path']

        apk_name=dw.get('apk',{}).get('name','')

        ww=dw['wa_alias']

        # If cache, check if params already exist
        if cache=='yes':
            # Check extra
            cruoa=i.get('cache_repo_uoa','')

            # Attempt to load
            r=ck.access({'action':'load',
                         'module_uoa':cfg['module_deps']['wa-params'],
                         'data_uoa':duoa,
                         'repo_uoa':cruoa})
            if r['return']>0 and r['return']!=16:
                return r

            if r['return']==0:
                cruoa=r['repo_uid']

                rx=ck.merge_dicts({'dict1':params, 'dict2':r['dict'].get('params',{})})
                if rx['return']>0: return rx

        # Check params here (there is another place in pre-processing scripts
        #  to be able to run WA via program pipeline directly) 
        dparams=dw.get('params',{})

        if len(dparams)>0:
            ck.out('Parameters needed for this workload:')
            ck.out('')

        for k in sorted(dparams):
            x=dparams[k]

            ds=x.get('desc','')

            dv=params.get(k,None)
            if dv==None:
                dv=x.get('default',None)

            if dv!=None:
                ck.out(k+': '+str(dv))
            elif x.get('mandatory',False):
                r=ck.inp({'text':k+' ('+ds+'): '})
                if r['return']>0: return r
                dv=r['string'].strip()
                if dv=='':
                    dv=None

            if dv!=None:
                params[k]=dv

        # Cache params if required
        if cache=='yes':
            r=ck.access({'action':'update',
                         'module_uoa':cfg['module_deps']['wa-params'],
                         'data_uoa':duoa,
                         'repo_uoa':cruoa,
                         'dict':{'params':params},
                         'sort_keys':'yes',
                         'substitute':'yes',
                         'ignore_update':'yes'})
            if r['return']>0:
                return r

            if o=='con':
                ck.out('')
                ck.out('Parameters were cached in '+r['path']+' ...')

        # Prepare high-level experiment meta
        meta={'program_uoa':duoa,
              'program_uid':duid,
              'workload_name':ww,
              'cpu_name':cpu_name,
              'os_name':os_name,
              'plat_name':plat_name,
              'gpu_name':gpu_name,
              'scenario':scenario,
              'serial_number':sn}

        mmeta=copy.deepcopy(meta)
        mmeta['local_target_uoa']=target_uoa
        mmeta['local_target_uid']=target_uid

        if o=='con':
            ck.out(line)
            ck.out('Running workload '+ww+' (CK UOA='+duoa+') ...')

            time.sleep(1)

        aggregated_stats={} # Pre-load statistics ...

        result_path=''
        result_path0=''
        if skip_record_raw!='yes':
            if o=='con':
                ck.out('  Preparing wa_result entry to store raw results ...')

            ddd={'meta':mmeta}

            ii={'action':'search',
                'module_uoa':cfg['module_deps']['wa-result'],
                'search_dict':{'meta':meta}}
            rx=ck.access(ii)
            if rx['return']>0: return rx

            lst=rx['lst']

            if len(lst)==0:
                rx=ck.access({'action':'add',
                              'module_uoa':cfg['module_deps']['wa-result'],
                              'dict':ddd,
                              'sort_keys':'yes'})
                if rx['return']>0: return rx
                result_uid=rx['data_uid']
                result_path=rx['path']
            else:
                result_uid=lst[0]['data_uid']
                result_path=lst[0]['path']

                # Load entry
                rx=ck.access({'action':'load',
                              'module_uoa':cfg['module_deps']['wa-result'],
                              'data_uoa':result_uid})
                if rx['return']>0: return rx
                ddd=rx['dict']

            # Possible directory extension (date-time)
            result_path0=result_path
            if overwrite!='yes':
                rx=ck.get_current_date_time({})
                if rx['return']>0: return rx

                aa=rx['array']

                ady=str(aa['date_year'])
                adm=str(aa['date_month'])
                adm=('0'*(2-len(adm)))+adm
                add=str(aa['date_day'])
                add=('0'*(2-len(add)))+add
                ath=str(aa['time_hour'])
                ath=('0'*(2-len(ath)))+ath
                atm=str(aa['time_minute'])
                atm=('0'*(2-len(atm)))+atm
                ats=str(aa['time_second'])
                ats=('0'*(2-len(ats)))+ats

                pe=ady+adm+add+'-'+ath+atm+ats

                result_path=os.path.join(result_path,pe)
                if not os.path.isdir(result_path):
                    os.makedirs(result_path)

            # Record input
            finp=os.path.join(result_path,'ck-input.json')
            r=ck.save_json_to_file({'json_file':finp, 'dict':i})
            if r['return']>0: return r

            ff=os.path.join(result_path,'ck-platform-features.json')
            r=ck.save_json_to_file({'json_file':ff, 'dict':features})
            if r['return']>0: return r

            # Check stats ...
            fstat=os.path.join(result_path0,ffstat)
            if overwrite!='yes':
                # Check if file already exists (no check for parallel runs)
                if os.path.isfile(fstat):
                    r=ck.load_json_file({'json_file':fstat})
                    if r['return']==0:
                        aggregated_stats=r['dict']

        # Prepare CK pipeline for a given workload
        ii={'action':'pipeline',

            'module_uoa':cfg['module_deps']['program'],
            'data_uoa':duid,

            'target':target,
            'device_id':device_id,

            'prepare':'yes',

            'params':{'config':config,
                      'params':params},

            'no_state_check':'yes',
            'no_compiler_description':'yes',
            'skip_info_collection':'yes',
            'skip_calibration':'yes',
            'cpu_freq':'',
            'gpu_freq':'',
            'env_speed':'yes',
            'energy':'no',
            'skip_print_timers':'yes',
            'generate_rnd_tmp_dir':'yes',

            'env':{'CK_WA_RAW_RESULT_PATH':result_path},

            'out':oo}
        rr=ck.access(ii)
        if rr['return']>0: return rr

        fail=rr.get('fail','')
        if fail=='yes':
            return {'return':10, 'error':'pipeline failed ('+rr.get('fail_reason','')+')'}

        ready=rr.get('ready','')
        if ready!='yes':
            return {'return':11, 'error':'couldn\'t prepare universal CK program workflow'}

        state=rr['state']
        tmp_dir=state['tmp_dir']

        # Clean pipeline
        if 'ready' in rr: del(rr['ready'])
        if 'fail' in rr: del(rr['fail'])
        if 'return' in rr: del(rr['return'])

        pipeline=copy.deepcopy(rr)

        # Save pipeline
        if skip_record_raw!='yes':
            fpip=os.path.join(result_path,'ck-pipeline-in.json')
            r=ck.save_json_to_file({'json_file':fpip, 'dict':pipeline})
            if r['return']>0: return r

        # Run CK pipeline *****************************************************
        ii={'action':'autotune',
            'module_uoa':cfg['module_deps']['pipeline'],
            'data_uoa':cfg['module_deps']['program'],

            'device_id':device_id,

            'iterations':1,
            'repetitions':repetitions,

            'collect_all':'yes',
            'process_multi_keys':['##characteristics#*'],

            'tmp_dir':tmp_dir,

            'pipeline':pipeline,

            'stat_flat_dict':aggregated_stats,

            'record':record,

            'meta':meta,

            'tags':'wa',

            "features_keys_to_process":["##choices#*"],

            "record_params": {
              "search_point_by_features":"yes"
            },

            "record_dict":{"subview_uoa":"3d9a4f4b03b1b257"},

            'out':oo}

        rrr=ck.access(ii)
        if rrr['return']>0: return rrr

        ls=rrr.get('last_iteration_output',{})
        state=ls.get('state',{})
        xchoices=copy.deepcopy(ls.get('choices',{}))
        lsa=rrr.get('last_stat_analysis',{})
        lsad=lsa.get('dict_flat',{})

        # Not very clean - trying to remove passes ...
        xparams=xchoices.get('params','').get('params',{})
        to_be_deleted=[]
        for k in xparams:
            if k.find('pass')>=0:
                to_be_deleted.append(k)

        for k in to_be_deleted:
            del(xparams[k])

        ddd['choices']=xchoices

        features=ls.get('features',{})
        apk_ver=''
        if apk_name!='':
            apk_ver=features.get('apk',{}).get(apk_name,{}).get('versionName','')

        deps=ls.get('dependencies',{})
        wa_ver=deps.get('wa',{}).get('cus',{}).get('version','')

        # Update meta
        ddd['meta']['apk_name']=apk_name
        ddd['meta']['apk_version']=apk_ver
        ddd['meta']['wa_version']=wa_ver

        # Clean tmp dir
        tmp_dir=state.get('tmp_dir','')
        if dp!='' and tmp_dir!='' and i.get('keep','')!='yes':
            shutil.rmtree(os.path.join(dp,tmp_dir))

        fail=ls.get('fail','')
        fail_reason=ls.get('fail_reason','')

        ch=ls.get('characteristics',{})

#        tet=ch.get('run',{}).get('total_execution_time',0)

        # Save pipeline
        ddd['state']={'fail':fail, 'fail_reason':fail_reason}
        ddd['characteristics']=ch

        if skip_record_raw!='yes':
            fpip=os.path.join(result_path,'ck-pipeline-out.json')
            r=ck.save_json_to_file({'json_file':fpip, 'dict':rrr})
            if r['return']>0: return r

            # Write stats ...
            r=ck.save_json_to_file({'json_file':fstat, 'dict':lsad})
            if r['return']>0: return r

            # Update meta
            rx=ck.access({'action':'update',
                          'module_uoa':cfg['module_deps']['wa-result'],
                          'data_uoa':result_uid,
                          'dict':ddd,
                          'substitute':'yes',
                          'sort_keys':'yes'})
            if rx['return']>0: return rx

        # Share results if crowd-benchmarking
        if share=='yes':
            ddd['user']=user

            if o=='con':
               ck.out('')
               ck.out('Saving results to the remote public repo ...')
               ck.out('')

            # Find remote entry
            rduid=''

            ii={'action':'search',
                'module_uoa':smuoa,
                'repo_uoa':er,
                'remote_repo_uoa':esr,
                'search_dict':{'meta':meta}}
            rx=ck.access(ii)

            lst=rx['lst']

            if len(lst)==1:
                rduid=lst[0]['data_uid']
            else:
                rx=ck.gen_uid({})
                if rx['return']>0: return rx
                rduid=rx['data_uid']

            # Update meta
            rx=ck.access({'action':'update',
                          'module_uoa':smuoa,
                          'data_uoa':rduid,
                          'repo_uoa':er,
                          'remote_repo_uoa':esr,
                          'dict':ddd,
                          'substitute':'yes',
                          'sort_keys':'yes'})
            if rx['return']>0: return rx

            # Push statistical characteristics
            if os.path.isfile(fstat):
                rx=ck.access({'action':'push',
                              'module_uoa':smuoa,
                              'data_uoa':rduid,
                              'repo_uoa':er,
                              'remote_repo_uoa':esr,
                              'filename':fstat,
                              'overwrite':'yes'})
                if rx['return']>0: return rx

            # Push latest results
            fx=os.path.join(result_path,'wa-output','results.json')
            if os.path.isfile(fx):
                rx=ck.access({'action':'push',
                              'module_uoa':smuoa,
                              'data_uoa':rduid,
                              'repo_uoa':er,
                              'remote_repo_uoa':esr,
                              'filename':fx,
                              'extra_path':'wa-output',
                              'overwrite':'yes'})
                if rx['return']>0: return rx

            # Info
            if o=='con':
                ck.out('Succesfully recorded results in the remote repo (Entry UID='+rduid+')')

    return rrr

Example 13

Project: ck-wa
Source File: module.py
View license
def run(i):
    """
    Input:  {
              (data_uoa)            - workload to run (see "ck list wa").

              (target)              - machine UOA (see "ck list machine")

              (record)              - if 'yes', record result in repository in 'experiment' standard
              (skip-record-raw)     - if 'yes', skip record raw results
              (overwrite)           - if 'yes', do not record date and time in result directory, but overwrite wa-results

              (repetitions)         - statistical repetitions (default=1), for now statistical analysis is not used (TBD)

              (config)              - customize config
              (params)              - workload params
              (scenario)            - use pre-defined scenario (see ck list wa-scenario)

              (keep)                - if 'yes', keep tmp file in workload (program) directory

              (cache)               - if 'yes', cache params (to automate runs)
              (cache_repo_uoa)      - repo UOA where to cache params

              (share)               - if 'yes', share benchmarking results with public cknowledge.org/repo server
                                      (our crowd-benchmarking demo)
              (exchange_repo)       - which repo to record/update info (remote-ck by default)
              (exchange_subrepo)    - if remote, remote repo UOA
              (scenario_module_uoa) - UOA of the scenario (to share results)
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import os
    import copy
    import time
    import shutil

    o=i.get('out','')
    oo=''
    if o=='con': oo=o

    cur_dir=os.getcwd()

    # Check if any input has . and convert to dict
    for k in list(i.keys()):
        if k.find('.')>0:
            v=i[k]

            kk='##'+k.replace('.','#')

            del(i[k])

            r=ck.set_by_flat_key({'dict':i, 'key':kk, 'value':v})
            if r['return']>0: return r

    # Check if share
    share=i.get('share','')
    user=i.get('user','')
    smuoa=i.get('scenario_module_uoa','')
    if smuoa=='': smuoa=cfg['module_deps']['experiment.bench.workload.android']

    er=i.get('exchange_repo','')
    if er=='': er=ck.cfg['default_exchange_repo_uoa']

    esr=i.get('exchange_subrepo','')
    if esr=='': esr=ck.cfg['default_exchange_subrepo_uoa']

    # Get device and workload params
    config=i.get('config',{})
    params=i.get('params',{})

    # Check scenarios
    scenario=i.get('scenario','')
    if scenario=='': scenario='-'

    if scenario!='' and scenario!='-':
        r=ck.access({'action':'load',
                     'module_uoa':cfg['module_deps']['wa-scenario'],
                     'data_uoa':scenario})
        if r['return']>0: return r
        d=r['dict']

        r=ck.merge_dicts({'dict1':config, 'dict2':d.get('config',{})})
        if r['return']>0: return r

        r=ck.merge_dicts({'dict1':params, 'dict2':d.get('params',{})})
        if r['return']>0: return r

    # Check workload(s)
    duoa=i.get('data_uoa','')
    if duoa!='':
        duoa='wa-'+duoa

    r=ck.access({'action':'search',
                 'module_uoa':cfg['module_deps']['program'],
                 'add_meta':'yes',
                 'data_uoa':duoa,
                 'tags':'wa'})
    if r['return']>0: return r
    lst=r['lst']

    if len(lst)==0:
       return {'return':1, 'error':'workload is not specified or found'}

    record=i.get('record','')
    skip_record_raw=i.get('skip-record-raw','')
    overwrite=i.get('overwrite','')

    repetitions=i.get('repetitions','')
    if repetitions=='': repetitions=3
    repetitions=int(repetitions)

    cache=i.get('cache','')

    # Get target features
    target=i.get('target','')

    if target=='':
        # Check and possibly select target machines
        r=ck.search({'module_uoa':cfg['module_deps']['machine'], 'data_uoa':target, 'add_meta':'yes'})
        if r['return']>0: return r

        dlst=r['lst']

        # Prune search by only required devices
        rdat=['wa_linux', 'wa_android']

        xlst=[]

        if len(rdat)==0:
            xlst=dlst
        else:
            for q in dlst:
                if q.get('meta',{}).get('access_type','') in rdat:
                    xlst.append(q)

        if len(xlst)==0:
            return {'return':1, 'error':'no suitable target devices found (use "ck add machine" to register new target device)'}
        elif len(xlst)==1:
            target=xlst[0]['data_uoa']
        else:
            # SELECTOR *************************************
            ck.out('')
            ck.out('Please select target device to run your workloads on:')
            ck.out('')
            r=ck.select_uoa({'choices':xlst})
            if r['return']>0: return r
            target=r['choice']

    if target=='':
        return {'return':1, 'error':'--target machine is not specified (see "ck list machine")'}

    ck.out('')
    ck.out('Selected target machine: '+target)
    ck.out('')

    # Load target machine description
    r=ck.access({'action':'load',
                 'module_uoa':cfg['module_deps']['machine'],
                 'data_uoa':target})
    if r['return']>0: return r
    target_uoa=r['data_uoa']
    target_uid=r['data_uid']
    features=r['dict']['features']

    device_id=r['dict'].get('device_id','')

    fplat=features.get('platform',{})
    fos=features.get('os',{})
    fcpu=features.get('cpu',{})
    fgpu=features.get('gpu',{})

    plat_name=fplat.get('name','')
    os_name=fos.get('name','')
    cpu_name=fcpu.get('name','')
    if cpu_name=='': cpu_name='unknown-'+fcpu.get('cpu_abi','')
    gpu_name=fgpu.get('name','')
    sn=fos.get('serial_number','')

    # Iterate over workloads
    rrr={}

    cparams=copy.deepcopy(params)
    for wa in lst:
        # Reset dir
        os.chdir(cur_dir)

        # Reset params
        params=copy.deepcopy(cparams)

        duoa=wa['data_uoa']
        duid=wa['data_uid']
        dw=wa['meta']
        dp=wa['path']

        apk_name=dw.get('apk',{}).get('name','')

        ww=dw['wa_alias']

        # If cache, check if params already exist
        if cache=='yes':
            # Check extra
            cruoa=i.get('cache_repo_uoa','')

            # Attempt to load
            r=ck.access({'action':'load',
                         'module_uoa':cfg['module_deps']['wa-params'],
                         'data_uoa':duoa,
                         'repo_uoa':cruoa})
            if r['return']>0 and r['return']!=16:
                return r

            if r['return']==0:
                cruoa=r['repo_uid']

                rx=ck.merge_dicts({'dict1':params, 'dict2':r['dict'].get('params',{})})
                if rx['return']>0: return rx

        # Check params here (there is another place in pre-processing scripts
        #  to be able to run WA via program pipeline directly) 
        dparams=dw.get('params',{})

        if len(dparams)>0:
            ck.out('Parameters needed for this workload:')
            ck.out('')

        for k in sorted(dparams):
            x=dparams[k]

            ds=x.get('desc','')

            dv=params.get(k,None)
            if dv==None:
                dv=x.get('default',None)

            if dv!=None:
                ck.out(k+': '+str(dv))
            elif x.get('mandatory',False):
                r=ck.inp({'text':k+' ('+ds+'): '})
                if r['return']>0: return r
                dv=r['string'].strip()
                if dv=='':
                    dv=None

            if dv!=None:
                params[k]=dv

        # Cache params if required
        if cache=='yes':
            r=ck.access({'action':'update',
                         'module_uoa':cfg['module_deps']['wa-params'],
                         'data_uoa':duoa,
                         'repo_uoa':cruoa,
                         'dict':{'params':params},
                         'sort_keys':'yes',
                         'substitute':'yes',
                         'ignore_update':'yes'})
            if r['return']>0:
                return r

            if o=='con':
                ck.out('')
                ck.out('Parameters were cached in '+r['path']+' ...')

        # Prepare high-level experiment meta
        meta={'program_uoa':duoa,
              'program_uid':duid,
              'workload_name':ww,
              'cpu_name':cpu_name,
              'os_name':os_name,
              'plat_name':plat_name,
              'gpu_name':gpu_name,
              'scenario':scenario,
              'serial_number':sn}

        mmeta=copy.deepcopy(meta)
        mmeta['local_target_uoa']=target_uoa
        mmeta['local_target_uid']=target_uid

        if o=='con':
            ck.out(line)
            ck.out('Running workload '+ww+' (CK UOA='+duoa+') ...')

            time.sleep(1)

        aggregated_stats={} # Pre-load statistics ...

        result_path=''
        result_path0=''
        if skip_record_raw!='yes':
            if o=='con':
                ck.out('  Preparing wa_result entry to store raw results ...')

            ddd={'meta':mmeta}

            ii={'action':'search',
                'module_uoa':cfg['module_deps']['wa-result'],
                'search_dict':{'meta':meta}}
            rx=ck.access(ii)
            if rx['return']>0: return rx

            lst=rx['lst']

            if len(lst)==0:
                rx=ck.access({'action':'add',
                              'module_uoa':cfg['module_deps']['wa-result'],
                              'dict':ddd,
                              'sort_keys':'yes'})
                if rx['return']>0: return rx
                result_uid=rx['data_uid']
                result_path=rx['path']
            else:
                result_uid=lst[0]['data_uid']
                result_path=lst[0]['path']

                # Load entry
                rx=ck.access({'action':'load',
                              'module_uoa':cfg['module_deps']['wa-result'],
                              'data_uoa':result_uid})
                if rx['return']>0: return rx
                ddd=rx['dict']

            # Possible directory extension (date-time)
            result_path0=result_path
            if overwrite!='yes':
                rx=ck.get_current_date_time({})
                if rx['return']>0: return rx

                aa=rx['array']

                ady=str(aa['date_year'])
                adm=str(aa['date_month'])
                adm=('0'*(2-len(adm)))+adm
                add=str(aa['date_day'])
                add=('0'*(2-len(add)))+add
                ath=str(aa['time_hour'])
                ath=('0'*(2-len(ath)))+ath
                atm=str(aa['time_minute'])
                atm=('0'*(2-len(atm)))+atm
                ats=str(aa['time_second'])
                ats=('0'*(2-len(ats)))+ats

                pe=ady+adm+add+'-'+ath+atm+ats

                result_path=os.path.join(result_path,pe)
                if not os.path.isdir(result_path):
                    os.makedirs(result_path)

            # Record input
            finp=os.path.join(result_path,'ck-input.json')
            r=ck.save_json_to_file({'json_file':finp, 'dict':i})
            if r['return']>0: return r

            ff=os.path.join(result_path,'ck-platform-features.json')
            r=ck.save_json_to_file({'json_file':ff, 'dict':features})
            if r['return']>0: return r

            # Check stats ...
            fstat=os.path.join(result_path0,ffstat)
            if overwrite!='yes':
                # Check if file already exists (no check for parallel runs)
                if os.path.isfile(fstat):
                    r=ck.load_json_file({'json_file':fstat})
                    if r['return']==0:
                        aggregated_stats=r['dict']

        # Prepare CK pipeline for a given workload
        ii={'action':'pipeline',

            'module_uoa':cfg['module_deps']['program'],
            'data_uoa':duid,

            'target':target,
            'device_id':device_id,

            'prepare':'yes',

            'params':{'config':config,
                      'params':params},

            'no_state_check':'yes',
            'no_compiler_description':'yes',
            'skip_info_collection':'yes',
            'skip_calibration':'yes',
            'cpu_freq':'',
            'gpu_freq':'',
            'env_speed':'yes',
            'energy':'no',
            'skip_print_timers':'yes',
            'generate_rnd_tmp_dir':'yes',

            'env':{'CK_WA_RAW_RESULT_PATH':result_path},

            'out':oo}
        rr=ck.access(ii)
        if rr['return']>0: return rr

        fail=rr.get('fail','')
        if fail=='yes':
            return {'return':10, 'error':'pipeline failed ('+rr.get('fail_reason','')+')'}

        ready=rr.get('ready','')
        if ready!='yes':
            return {'return':11, 'error':'couldn\'t prepare universal CK program workflow'}

        state=rr['state']
        tmp_dir=state['tmp_dir']

        # Clean pipeline
        if 'ready' in rr: del(rr['ready'])
        if 'fail' in rr: del(rr['fail'])
        if 'return' in rr: del(rr['return'])

        pipeline=copy.deepcopy(rr)

        # Save pipeline
        if skip_record_raw!='yes':
            fpip=os.path.join(result_path,'ck-pipeline-in.json')
            r=ck.save_json_to_file({'json_file':fpip, 'dict':pipeline})
            if r['return']>0: return r

        # Run CK pipeline *****************************************************
        ii={'action':'autotune',
            'module_uoa':cfg['module_deps']['pipeline'],
            'data_uoa':cfg['module_deps']['program'],

            'device_id':device_id,

            'iterations':1,
            'repetitions':repetitions,

            'collect_all':'yes',
            'process_multi_keys':['##characteristics#*'],

            'tmp_dir':tmp_dir,

            'pipeline':pipeline,

            'stat_flat_dict':aggregated_stats,

            'record':record,

            'meta':meta,

            'tags':'wa',

            "features_keys_to_process":["##choices#*"],

            "record_params": {
              "search_point_by_features":"yes"
            },

            "record_dict":{"subview_uoa":"3d9a4f4b03b1b257"},

            'out':oo}

        rrr=ck.access(ii)
        if rrr['return']>0: return rrr

        ls=rrr.get('last_iteration_output',{})
        state=ls.get('state',{})
        xchoices=copy.deepcopy(ls.get('choices',{}))
        lsa=rrr.get('last_stat_analysis',{})
        lsad=lsa.get('dict_flat',{})

        # Not very clean - trying to remove passes ...
        xparams=xchoices.get('params','').get('params',{})
        to_be_deleted=[]
        for k in xparams:
            if k.find('pass')>=0:
                to_be_deleted.append(k)

        for k in to_be_deleted:
            del(xparams[k])

        ddd['choices']=xchoices

        features=ls.get('features',{})
        apk_ver=''
        if apk_name!='':
            apk_ver=features.get('apk',{}).get(apk_name,{}).get('versionName','')

        deps=ls.get('dependencies',{})
        wa_ver=deps.get('wa',{}).get('cus',{}).get('version','')

        # Update meta
        ddd['meta']['apk_name']=apk_name
        ddd['meta']['apk_version']=apk_ver
        ddd['meta']['wa_version']=wa_ver

        # Clean tmp dir
        tmp_dir=state.get('tmp_dir','')
        if dp!='' and tmp_dir!='' and i.get('keep','')!='yes':
            shutil.rmtree(os.path.join(dp,tmp_dir))

        fail=ls.get('fail','')
        fail_reason=ls.get('fail_reason','')

        ch=ls.get('characteristics',{})

#        tet=ch.get('run',{}).get('total_execution_time',0)

        # Save pipeline
        ddd['state']={'fail':fail, 'fail_reason':fail_reason}
        ddd['characteristics']=ch

        if skip_record_raw!='yes':
            fpip=os.path.join(result_path,'ck-pipeline-out.json')
            r=ck.save_json_to_file({'json_file':fpip, 'dict':rrr})
            if r['return']>0: return r

            # Write stats ...
            r=ck.save_json_to_file({'json_file':fstat, 'dict':lsad})
            if r['return']>0: return r

            # Update meta
            rx=ck.access({'action':'update',
                          'module_uoa':cfg['module_deps']['wa-result'],
                          'data_uoa':result_uid,
                          'dict':ddd,
                          'substitute':'yes',
                          'sort_keys':'yes'})
            if rx['return']>0: return rx

        # Share results if crowd-benchmarking
        if share=='yes':
            ddd['user']=user

            if o=='con':
               ck.out('')
               ck.out('Saving results to the remote public repo ...')
               ck.out('')

            # Find remote entry
            rduid=''

            ii={'action':'search',
                'module_uoa':smuoa,
                'repo_uoa':er,
                'remote_repo_uoa':esr,
                'search_dict':{'meta':meta}}
            rx=ck.access(ii)

            lst=rx['lst']

            if len(lst)==1:
                rduid=lst[0]['data_uid']
            else:
                rx=ck.gen_uid({})
                if rx['return']>0: return rx
                rduid=rx['data_uid']

            # Update meta
            rx=ck.access({'action':'update',
                          'module_uoa':smuoa,
                          'data_uoa':rduid,
                          'repo_uoa':er,
                          'remote_repo_uoa':esr,
                          'dict':ddd,
                          'substitute':'yes',
                          'sort_keys':'yes'})
            if rx['return']>0: return rx

            # Push statistical characteristics
            if os.path.isfile(fstat):
                rx=ck.access({'action':'push',
                              'module_uoa':smuoa,
                              'data_uoa':rduid,
                              'repo_uoa':er,
                              'remote_repo_uoa':esr,
                              'filename':fstat,
                              'overwrite':'yes'})
                if rx['return']>0: return rx

            # Push latest results
            fx=os.path.join(result_path,'wa-output','results.json')
            if os.path.isfile(fx):
                rx=ck.access({'action':'push',
                              'module_uoa':smuoa,
                              'data_uoa':rduid,
                              'repo_uoa':er,
                              'remote_repo_uoa':esr,
                              'filename':fx,
                              'extra_path':'wa-output',
                              'overwrite':'yes'})
                if rx['return']>0: return rx

            # Info
            if o=='con':
                ck.out('Succesfully recorded results in the remote repo (Entry UID='+rduid+')')

    return rrr

Example 14

Project: libpgm
Source File: pgmlearner.py
View license
    def discrete_condind(self, data, X, Y, U):
        '''
        Test how independent a variable *X* and a variable *Y* are in a discrete data set given by *data*, where the independence is conditioned on a set of variables given by *U*. This method works by assuming as a null hypothesis that the variables are conditionally independent on *U*, and thus that:

        .. math::

            P(X, Y, U) = P(U) \\cdot P(X|U) \\cdot P(Y|U) 

        It tests the deviance of the data from this null hypothesis, returning the result of a chi-square test and a p-value.

        Arguments:
            1. *data* -- An array of dicts containing samples from the network in {vertex: value} format. Example::

                    [
                        {
                            'Grade': 'B',
                            'SAT': 'lowscore',
                            ...
                        },
                        ...
                    ]
            2. *X* -- A variable whose dependence on Y we are testing given U.
            3. *Y* -- A variable whose dependence on X we are testing given U.
            4. *U* -- A list of variables that are given.

        Returns:
            1. *chi* -- The result of the chi-squared test on the data. This is a
                   measure of the deviance of the actual distribution of X and
                   Y given U from the expected distribution of X and Y given U.
                   Since the null hypothesis is that X and Y are independent 
                   given U, the expected distribution is that :math:`P(X, Y, U) =
                   P(U) P(X | U) P (Y | U)`.
            2. *pval* -- The p-value of the test, meaning the probability of
                    attaining a chi-square result as extreme as or more extreme
                    than the one found, assuming that the null hypothesis is
                    true. (e.g., a p-value of .05 means that if X and Y were 
                    independent given U, the chance of getting a chi-squared
                    result this high or higher are .05)
            3. *U* -- The 'witness' of X and Y's independence. This is the variable
                 that, when it is known, leaves X and Y independent.

        For more information see Koller et al. 790.
        
        '''
        # find possible outcomes and store
        _outcomes = dict()
        for key in data[0].keys():
            _outcomes[key] = [data[0][key]]
        for sample in data:
            for key in _outcomes.keys():
                if _outcomes[key].count(sample[key]) == 0:
                    _outcomes[key].append(sample[key])

        # store number of outcomes for X, Y, and U
        Xnumoutcomes = len(_outcomes[X])
        Ynumoutcomes = len(_outcomes[Y])
        Unumoutcomes = []
        for val in U:
            Unumoutcomes.append(len(_outcomes[val]))

        # calculate P(U) -- the distribution of U
        PU = 1
        
        # define helper function to add a dimension to an array recursively
        def add_dimension_to_array(mdarray, size):
            if isinstance(mdarray, list):
                for h in range(len(mdarray)):
                    mdarray[h] = add_dimension_to_array(mdarray[h], size)
                return mdarray
            else:
                mdarray = [0 for _ in range(size)]
                return mdarray

        # make PU the right size
        for size in Unumoutcomes:
            PU = add_dimension_to_array(PU, size)

        # fill with data
        if (len(U) > 0):
            for sample in data:
                tmp = PU
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1

        # calculate P(X, U) -- the distribution of X and U
        PXandU = [0 for _ in range(Xnumoutcomes)]
        for size in Unumoutcomes:
            PXandU = add_dimension_to_array(PXandU, size)

        for sample in data:
            Xindex = _outcomes[X].index(sample[X])
            if len(U) > 0: 
                tmp = PXandU[Xindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PXandU[Xindex] += 1

        # calculate P(Y, U) -- the distribution of Y and U
        PYandU = [0 for _ in range(Ynumoutcomes)]
        for size in Unumoutcomes:
            PYandU = add_dimension_to_array(PYandU, size)
        for sample in data:
            Yindex = _outcomes[Y].index(sample[Y])
            if len(U) > 0: 
                tmp = PYandU[Yindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PYandU[Yindex] += 1

        # assemble P(U)P(X|U)P(Y|U) -- the expected distribution if X and Y are
        # independent given U.
        expected = [[ 0 for _ in range(Ynumoutcomes)] for __ in range(Xnumoutcomes)] 

        # define helper function to multiply the entries of two matrices
        def multiply_entries(matrixa, matrixb):
            matrix1 = copy.deepcopy(matrixa)
            matrix2 = copy.deepcopy(matrixb)
            if isinstance(matrix1, list):
                for h in range(len(matrix1)):
                    matrix1[h] = multiply_entries(matrix1[h], matrix2[h])
                return matrix1
            else:
                return (matrix1 * matrix2)

        # define helper function to divide the entries of two matrices
        def divide_entries(matrixa, matrixb):
            matrix1 = copy.deepcopy(matrixa)
            matrix2 = copy.deepcopy(matrixb)
            if isinstance(matrix1, list):
                for h in range(len(matrix1)):
                    matrix1[h] = divide_entries(matrix1[h], matrix2[h])
                return matrix1
            else:
                return (matrix1 / float(matrix2))

        # combine known graphs to calculate P(U)P(X|U)P(Y|U)
        for x in range(Xnumoutcomes):
            for y in range(Ynumoutcomes):
                product = multiply_entries(PXandU[x], PYandU[y])
                final = divide_entries(product, PU)
                expected[x][y] = final

        # find P(XYU) -- the actual distribution of X, Y, and U -- in sample
        PXYU = [[ 0 for _ in range(Ynumoutcomes)] for __ in range(Xnumoutcomes)] 
        for size in Unumoutcomes:
            PXYU = add_dimension_to_array(PXYU, size)
        
        for sample in data:
            Xindex = _outcomes[X].index(sample[X])
            Yindex = _outcomes[Y].index(sample[Y])
            if len(U) > 0:
                tmp = PXYU[Xindex][Yindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PXYU[Xindex][Yindex] += 1 

        # use scipy's chisquare to determine the deviance of the evidence
        a = np.array(expected)
        a = a.flatten()
        b = np.array(PXYU)
        b = b.flatten()

        # delete entries with value 0 (they mess up the chisquare function)
        for i in reversed(range(b.size)):
            if (b[i] == 0):
                if i != 0:
                    a.itemset(i-1, a[i-1]+a[i])
                a = np.delete(a, i)
                b = np.delete(b, i)

        # run chi-squared
        chi, pv = chisquare(a, b)

        # return chi-squared result, p-value for that result, and witness
        return chi, pv, U

Example 15

Project: libpgm
Source File: pgmlearner.py
View license
    def discrete_condind(self, data, X, Y, U):
        '''
        Test how independent a variable *X* and a variable *Y* are in a discrete data set given by *data*, where the independence is conditioned on a set of variables given by *U*. This method works by assuming as a null hypothesis that the variables are conditionally independent on *U*, and thus that:

        .. math::

            P(X, Y, U) = P(U) \\cdot P(X|U) \\cdot P(Y|U) 

        It tests the deviance of the data from this null hypothesis, returning the result of a chi-square test and a p-value.

        Arguments:
            1. *data* -- An array of dicts containing samples from the network in {vertex: value} format. Example::

                    [
                        {
                            'Grade': 'B',
                            'SAT': 'lowscore',
                            ...
                        },
                        ...
                    ]
            2. *X* -- A variable whose dependence on Y we are testing given U.
            3. *Y* -- A variable whose dependence on X we are testing given U.
            4. *U* -- A list of variables that are given.

        Returns:
            1. *chi* -- The result of the chi-squared test on the data. This is a
                   measure of the deviance of the actual distribution of X and
                   Y given U from the expected distribution of X and Y given U.
                   Since the null hypothesis is that X and Y are independent 
                   given U, the expected distribution is that :math:`P(X, Y, U) =
                   P(U) P(X | U) P (Y | U)`.
            2. *pval* -- The p-value of the test, meaning the probability of
                    attaining a chi-square result as extreme as or more extreme
                    than the one found, assuming that the null hypothesis is
                    true. (e.g., a p-value of .05 means that if X and Y were 
                    independent given U, the chance of getting a chi-squared
                    result this high or higher are .05)
            3. *U* -- The 'witness' of X and Y's independence. This is the variable
                 that, when it is known, leaves X and Y independent.

        For more information see Koller et al. 790.
        
        '''
        # find possible outcomes and store
        _outcomes = dict()
        for key in data[0].keys():
            _outcomes[key] = [data[0][key]]
        for sample in data:
            for key in _outcomes.keys():
                if _outcomes[key].count(sample[key]) == 0:
                    _outcomes[key].append(sample[key])

        # store number of outcomes for X, Y, and U
        Xnumoutcomes = len(_outcomes[X])
        Ynumoutcomes = len(_outcomes[Y])
        Unumoutcomes = []
        for val in U:
            Unumoutcomes.append(len(_outcomes[val]))

        # calculate P(U) -- the distribution of U
        PU = 1
        
        # define helper function to add a dimension to an array recursively
        def add_dimension_to_array(mdarray, size):
            if isinstance(mdarray, list):
                for h in range(len(mdarray)):
                    mdarray[h] = add_dimension_to_array(mdarray[h], size)
                return mdarray
            else:
                mdarray = [0 for _ in range(size)]
                return mdarray

        # make PU the right size
        for size in Unumoutcomes:
            PU = add_dimension_to_array(PU, size)

        # fill with data
        if (len(U) > 0):
            for sample in data:
                tmp = PU
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1

        # calculate P(X, U) -- the distribution of X and U
        PXandU = [0 for _ in range(Xnumoutcomes)]
        for size in Unumoutcomes:
            PXandU = add_dimension_to_array(PXandU, size)

        for sample in data:
            Xindex = _outcomes[X].index(sample[X])
            if len(U) > 0: 
                tmp = PXandU[Xindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PXandU[Xindex] += 1

        # calculate P(Y, U) -- the distribution of Y and U
        PYandU = [0 for _ in range(Ynumoutcomes)]
        for size in Unumoutcomes:
            PYandU = add_dimension_to_array(PYandU, size)
        for sample in data:
            Yindex = _outcomes[Y].index(sample[Y])
            if len(U) > 0: 
                tmp = PYandU[Yindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PYandU[Yindex] += 1

        # assemble P(U)P(X|U)P(Y|U) -- the expected distribution if X and Y are
        # independent given U.
        expected = [[ 0 for _ in range(Ynumoutcomes)] for __ in range(Xnumoutcomes)] 

        # define helper function to multiply the entries of two matrices
        def multiply_entries(matrixa, matrixb):
            matrix1 = copy.deepcopy(matrixa)
            matrix2 = copy.deepcopy(matrixb)
            if isinstance(matrix1, list):
                for h in range(len(matrix1)):
                    matrix1[h] = multiply_entries(matrix1[h], matrix2[h])
                return matrix1
            else:
                return (matrix1 * matrix2)

        # define helper function to divide the entries of two matrices
        def divide_entries(matrixa, matrixb):
            matrix1 = copy.deepcopy(matrixa)
            matrix2 = copy.deepcopy(matrixb)
            if isinstance(matrix1, list):
                for h in range(len(matrix1)):
                    matrix1[h] = divide_entries(matrix1[h], matrix2[h])
                return matrix1
            else:
                return (matrix1 / float(matrix2))

        # combine known graphs to calculate P(U)P(X|U)P(Y|U)
        for x in range(Xnumoutcomes):
            for y in range(Ynumoutcomes):
                product = multiply_entries(PXandU[x], PYandU[y])
                final = divide_entries(product, PU)
                expected[x][y] = final

        # find P(XYU) -- the actual distribution of X, Y, and U -- in sample
        PXYU = [[ 0 for _ in range(Ynumoutcomes)] for __ in range(Xnumoutcomes)] 
        for size in Unumoutcomes:
            PXYU = add_dimension_to_array(PXYU, size)
        
        for sample in data:
            Xindex = _outcomes[X].index(sample[X])
            Yindex = _outcomes[Y].index(sample[Y])
            if len(U) > 0:
                tmp = PXYU[Xindex][Yindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PXYU[Xindex][Yindex] += 1 

        # use scipy's chisquare to determine the deviance of the evidence
        a = np.array(expected)
        a = a.flatten()
        b = np.array(PXYU)
        b = b.flatten()

        # delete entries with value 0 (they mess up the chisquare function)
        for i in reversed(range(b.size)):
            if (b[i] == 0):
                if i != 0:
                    a.itemset(i-1, a[i-1]+a[i])
                a = np.delete(a, i)
                b = np.delete(b, i)

        # run chi-squared
        chi, pv = chisquare(a, b)

        # return chi-squared result, p-value for that result, and witness
        return chi, pv, U

Example 16

Project: libpgm
Source File: pgmlearner.py
View license
    def discrete_condind(self, data, X, Y, U):
        '''
        Test how independent a variable *X* and a variable *Y* are in a discrete data set given by *data*, where the independence is conditioned on a set of variables given by *U*. This method works by assuming as a null hypothesis that the variables are conditionally independent on *U*, and thus that:

        .. math::

            P(X, Y, U) = P(U) \\cdot P(X|U) \\cdot P(Y|U) 

        It tests the deviance of the data from this null hypothesis, returning the result of a chi-square test and a p-value.

        Arguments:
            1. *data* -- An array of dicts containing samples from the network in {vertex: value} format. Example::

                    [
                        {
                            'Grade': 'B',
                            'SAT': 'lowscore',
                            ...
                        },
                        ...
                    ]
            2. *X* -- A variable whose dependence on Y we are testing given U.
            3. *Y* -- A variable whose dependence on X we are testing given U.
            4. *U* -- A list of variables that are given.

        Returns:
            1. *chi* -- The result of the chi-squared test on the data. This is a
                   measure of the deviance of the actual distribution of X and
                   Y given U from the expected distribution of X and Y given U.
                   Since the null hypothesis is that X and Y are independent 
                   given U, the expected distribution is that :math:`P(X, Y, U) =
                   P(U) P(X | U) P (Y | U)`.
            2. *pval* -- The p-value of the test, meaning the probability of
                    attaining a chi-square result as extreme as or more extreme
                    than the one found, assuming that the null hypothesis is
                    true. (e.g., a p-value of .05 means that if X and Y were 
                    independent given U, the chance of getting a chi-squared
                    result this high or higher are .05)
            3. *U* -- The 'witness' of X and Y's independence. This is the variable
                 that, when it is known, leaves X and Y independent.

        For more information see Koller et al. 790.
        
        '''
        # find possible outcomes and store
        _outcomes = dict()
        for key in data[0].keys():
            _outcomes[key] = [data[0][key]]
        for sample in data:
            for key in _outcomes.keys():
                if _outcomes[key].count(sample[key]) == 0:
                    _outcomes[key].append(sample[key])

        # store number of outcomes for X, Y, and U
        Xnumoutcomes = len(_outcomes[X])
        Ynumoutcomes = len(_outcomes[Y])
        Unumoutcomes = []
        for val in U:
            Unumoutcomes.append(len(_outcomes[val]))

        # calculate P(U) -- the distribution of U
        PU = 1
        
        # define helper function to add a dimension to an array recursively
        def add_dimension_to_array(mdarray, size):
            if isinstance(mdarray, list):
                for h in range(len(mdarray)):
                    mdarray[h] = add_dimension_to_array(mdarray[h], size)
                return mdarray
            else:
                mdarray = [0 for _ in range(size)]
                return mdarray

        # make PU the right size
        for size in Unumoutcomes:
            PU = add_dimension_to_array(PU, size)

        # fill with data
        if (len(U) > 0):
            for sample in data:
                tmp = PU
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1

        # calculate P(X, U) -- the distribution of X and U
        PXandU = [0 for _ in range(Xnumoutcomes)]
        for size in Unumoutcomes:
            PXandU = add_dimension_to_array(PXandU, size)

        for sample in data:
            Xindex = _outcomes[X].index(sample[X])
            if len(U) > 0: 
                tmp = PXandU[Xindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PXandU[Xindex] += 1

        # calculate P(Y, U) -- the distribution of Y and U
        PYandU = [0 for _ in range(Ynumoutcomes)]
        for size in Unumoutcomes:
            PYandU = add_dimension_to_array(PYandU, size)
        for sample in data:
            Yindex = _outcomes[Y].index(sample[Y])
            if len(U) > 0: 
                tmp = PYandU[Yindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PYandU[Yindex] += 1

        # assemble P(U)P(X|U)P(Y|U) -- the expected distribution if X and Y are
        # independent given U.
        expected = [[ 0 for _ in range(Ynumoutcomes)] for __ in range(Xnumoutcomes)] 

        # define helper function to multiply the entries of two matrices
        def multiply_entries(matrixa, matrixb):
            matrix1 = copy.deepcopy(matrixa)
            matrix2 = copy.deepcopy(matrixb)
            if isinstance(matrix1, list):
                for h in range(len(matrix1)):
                    matrix1[h] = multiply_entries(matrix1[h], matrix2[h])
                return matrix1
            else:
                return (matrix1 * matrix2)

        # define helper function to divide the entries of two matrices
        def divide_entries(matrixa, matrixb):
            matrix1 = copy.deepcopy(matrixa)
            matrix2 = copy.deepcopy(matrixb)
            if isinstance(matrix1, list):
                for h in range(len(matrix1)):
                    matrix1[h] = divide_entries(matrix1[h], matrix2[h])
                return matrix1
            else:
                return (matrix1 / float(matrix2))

        # combine known graphs to calculate P(U)P(X|U)P(Y|U)
        for x in range(Xnumoutcomes):
            for y in range(Ynumoutcomes):
                product = multiply_entries(PXandU[x], PYandU[y])
                final = divide_entries(product, PU)
                expected[x][y] = final

        # find P(XYU) -- the actual distribution of X, Y, and U -- in sample
        PXYU = [[ 0 for _ in range(Ynumoutcomes)] for __ in range(Xnumoutcomes)] 
        for size in Unumoutcomes:
            PXYU = add_dimension_to_array(PXYU, size)
        
        for sample in data:
            Xindex = _outcomes[X].index(sample[X])
            Yindex = _outcomes[Y].index(sample[Y])
            if len(U) > 0:
                tmp = PXYU[Xindex][Yindex]
                for x in range(len(U)-1):
                    Uindex = _outcomes[U[x]].index(sample[U[x]])
                    tmp = tmp[Uindex]
                lastindex = _outcomes[U[-1]].index(sample[U[-1]])
                tmp[lastindex] += 1
            else:
                PXYU[Xindex][Yindex] += 1 

        # use scipy's chisquare to determine the deviance of the evidence
        a = np.array(expected)
        a = a.flatten()
        b = np.array(PXYU)
        b = b.flatten()

        # delete entries with value 0 (they mess up the chisquare function)
        for i in reversed(range(b.size)):
            if (b[i] == 0):
                if i != 0:
                    a.itemset(i-1, a[i-1]+a[i])
                a = np.delete(a, i)
                b = np.delete(b, i)

        # run chi-squared
        chi, pv = chisquare(a, b)

        # return chi-squared result, p-value for that result, and witness
        return chi, pv, U

Example 17

Project: letsencrypt-nosudo
Source File: sign_csr.py
View license
def sign_csr(pubkey, csr, email=None, file_based=False):
    """Use the ACME protocol to get an ssl certificate signed by a
    certificate authority.

    :param string pubkey: Path to the user account public key.
    :param string csr: Path to the certificate signing request.
    :param string email: An optional user account contact email
                         (defaults to [email protected]<shortest_domain>)
    :param bool file_based: An optional flag indicating that the
                            hosting should be file-based rather
                            than providing a simple python HTTP
                            server.

    :returns: Signed Certificate (PEM format)
    :rtype: string

    """
    #CA = "https://acme-staging.api.letsencrypt.org"
    CA = "https://acme-v01.api.letsencrypt.org"
    TERMS = "https://letsencrypt.org/documents/LE-SA-v1.1.1-August-1-2016.pdf"
    nonce_req = urllib2.Request("{0}/directory".format(CA))
    nonce_req.get_method = lambda : 'HEAD'

    def _b64(b):
        "Shortcut function to go from bytes to jwt base64 string"
        return base64.urlsafe_b64encode(b).replace("=", "")

    # Step 1: Get account public key
    sys.stderr.write("Reading pubkey file...\n")
    proc = subprocess.Popen(["openssl", "rsa", "-pubin", "-in", pubkey, "-noout", "-text"],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = proc.communicate()
    if proc.returncode != 0:
        raise IOError("Error loading {0}".format(pubkey))
    pub_hex, pub_exp = re.search(
        "Modulus(?: \((?:2048|4096) bit\)|)\:\s+00:([a-f0-9\:\s]+?)Exponent\: ([0-9]+)",
        out, re.MULTILINE|re.DOTALL).groups()
    pub_mod = binascii.unhexlify(re.sub("(\s|:)", "", pub_hex))
    pub_mod64 = _b64(pub_mod)
    pub_exp = int(pub_exp)
    pub_exp = "{0:x}".format(pub_exp)
    pub_exp = "0{0}".format(pub_exp) if len(pub_exp) % 2 else pub_exp
    pub_exp = binascii.unhexlify(pub_exp)
    pub_exp64 = _b64(pub_exp)
    header = {
        "alg": "RS256",
        "jwk": {
            "e": pub_exp64,
            "kty": "RSA",
            "n": pub_mod64,
        },
    }
    accountkey_json = json.dumps(header['jwk'], sort_keys=True, separators=(',', ':'))
    thumbprint = _b64(hashlib.sha256(accountkey_json).digest())
    sys.stderr.write("Found public key!\n")

    # Step 2: Get the domain names to be certified
    sys.stderr.write("Reading csr file...\n")
    proc = subprocess.Popen(["openssl", "req", "-in", csr, "-noout", "-text"],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = proc.communicate()
    if proc.returncode != 0:
        raise IOError("Error loading {0}".format(csr))
    domains = set([])
    common_name = re.search("Subject:.*? CN=([^\s,;/]+)", out)
    if common_name is not None:
        domains.add(common_name.group(1))
    subject_alt_names = re.search("X509v3 Subject Alternative Name: \n +([^\n]+)\n", out, re.MULTILINE|re.DOTALL)
    if subject_alt_names is not None:
        for san in subject_alt_names.group(1).split(", "):
            if san.startswith("DNS:"):
                domains.add(san[4:])
    sys.stderr.write("Found domains {0}\n".format(", ".join(domains)))

    # Step 3: Ask user for contact email
    if not email:
        default_email = "[email protected]{0}".format(min(domains, key=len))
        stdout = sys.stdout
        sys.stdout = sys.stderr
        input_email = raw_input("STEP 1: What is your contact email? ({0}) ".format(default_email))
        email = input_email if input_email else default_email
        sys.stdout = stdout

    # Step 4: Generate the payloads that need to be signed
    # registration
    sys.stderr.write("Building request payloads...\n")
    reg_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
    reg_raw = json.dumps({
        "resource": "new-reg",
        "contact": ["mailto:{0}".format(email)],
        "agreement": TERMS,
    }, sort_keys=True, indent=4)
    reg_b64 = _b64(reg_raw)
    reg_protected = copy.deepcopy(header)
    reg_protected.update({"nonce": reg_nonce})
    reg_protected64 = _b64(json.dumps(reg_protected, sort_keys=True, indent=4))
    reg_file = tempfile.NamedTemporaryFile(dir=".", prefix="register_", suffix=".json")
    reg_file.write("{0}.{1}".format(reg_protected64, reg_b64))
    reg_file.flush()
    reg_file_name = os.path.basename(reg_file.name)
    reg_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="register_", suffix=".sig")
    reg_file_sig_name = os.path.basename(reg_file_sig.name)

    # need signature for each domain identifiers
    ids = []
    for domain in domains:
        sys.stderr.write("Building request for {0}...\n".format(domain))
        id_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
        id_raw = json.dumps({
            "resource": "new-authz",
            "identifier": {
                "type": "dns",
                "value": domain,
            },
        }, sort_keys=True)
        id_b64 = _b64(id_raw)
        id_protected = copy.deepcopy(header)
        id_protected.update({"nonce": id_nonce})
        id_protected64 = _b64(json.dumps(id_protected, sort_keys=True, indent=4))
        id_file = tempfile.NamedTemporaryFile(dir=".", prefix="domain_", suffix=".json")
        id_file.write("{0}.{1}".format(id_protected64, id_b64))
        id_file.flush()
        id_file_name = os.path.basename(id_file.name)
        id_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="domain_", suffix=".sig")
        id_file_sig_name = os.path.basename(id_file_sig.name)
        ids.append({
            "domain": domain,
            "protected64": id_protected64,
            "data64": id_b64,
            "file": id_file,
            "file_name": id_file_name,
            "sig": id_file_sig,
            "sig_name": id_file_sig_name,
        })

    # need signature for the final certificate issuance
    sys.stderr.write("Building request for CSR...\n")
    proc = subprocess.Popen(["openssl", "req", "-in", csr, "-outform", "DER"],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    csr_der, err = proc.communicate()
    csr_der64 = _b64(csr_der)
    csr_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
    csr_raw = json.dumps({
        "resource": "new-cert",
        "csr": csr_der64,
    }, sort_keys=True, indent=4)
    csr_b64 = _b64(csr_raw)
    csr_protected = copy.deepcopy(header)
    csr_protected.update({"nonce": csr_nonce})
    csr_protected64 = _b64(json.dumps(csr_protected, sort_keys=True, indent=4))
    csr_file = tempfile.NamedTemporaryFile(dir=".", prefix="cert_", suffix=".json")
    csr_file.write("{0}.{1}".format(csr_protected64, csr_b64))
    csr_file.flush()
    csr_file_name = os.path.basename(csr_file.name)
    csr_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="cert_", suffix=".sig")
    csr_file_sig_name = os.path.basename(csr_file_sig.name)

    # Step 5: Ask the user to sign the registration and requests
    sys.stderr.write("""\
STEP 2: You need to sign some files (replace 'user.key' with your user private key).

openssl dgst -sha256 -sign user.key -out {0} {1}
{2}
openssl dgst -sha256 -sign user.key -out {3} {4}

""".format(
    reg_file_sig_name, reg_file_name,
    "\n".join("openssl dgst -sha256 -sign user.key -out {0} {1}".format(i['sig_name'], i['file_name']) for i in ids),
    csr_file_sig_name, csr_file_name))

    stdout = sys.stdout
    sys.stdout = sys.stderr
    raw_input("Press Enter when you've run the above commands in a new terminal window...")
    sys.stdout = stdout

    # Step 6: Load the signatures
    reg_file_sig.seek(0)
    reg_sig64 = _b64(reg_file_sig.read())
    for n, i in enumerate(ids):
        i['sig'].seek(0)
        i['sig64'] = _b64(i['sig'].read())

    # Step 7: Register the user
    sys.stderr.write("Registering {0}...\n".format(email))
    reg_data = json.dumps({
        "header": header,
        "protected": reg_protected64,
        "payload": reg_b64,
        "signature": reg_sig64,
    }, sort_keys=True, indent=4)
    reg_url = "{0}/acme/new-reg".format(CA)
    try:
        resp = urllib2.urlopen(reg_url, reg_data)
        result = json.loads(resp.read())
    except urllib2.HTTPError as e:
        err = e.read()
        # skip already registered accounts
        if "Registration key is already in use" in err:
            sys.stderr.write("Already registered. Skipping...\n")
        else:
            sys.stderr.write("Error: reg_data:\n")
            sys.stderr.write("POST {0}\n".format(reg_url))
            sys.stderr.write(reg_data)
            sys.stderr.write("\n")
            sys.stderr.write(err)
            sys.stderr.write("\n")
            raise

    # Step 8: Request challenges for each domain
    responses = []
    tests = []
    for n, i in enumerate(ids):
        sys.stderr.write("Requesting challenges for {0}...\n".format(i['domain']))
        id_data = json.dumps({
            "header": header,
            "protected": i['protected64'],
            "payload": i['data64'],
            "signature": i['sig64'],
        }, sort_keys=True, indent=4)
        id_url = "{0}/acme/new-authz".format(CA)
        try:
            resp = urllib2.urlopen(id_url, id_data)
            result = json.loads(resp.read())
        except urllib2.HTTPError as e:
            sys.stderr.write("Error: id_data:\n")
            sys.stderr.write("POST {0}\n".format(id_url))
            sys.stderr.write(id_data)
            sys.stderr.write("\n")
            sys.stderr.write(e.read())
            sys.stderr.write("\n")
            raise
        challenge = [c for c in result['challenges'] if c['type'] == "http-01"][0]
        keyauthorization = "{0}.{1}".format(challenge['token'], thumbprint)

        # challenge request
        sys.stderr.write("Building challenge responses for {0}...\n".format(i['domain']))
        test_nonce = urllib2.urlopen(nonce_req).headers['Replay-Nonce']
        test_raw = json.dumps({
            "resource": "challenge",
            "keyAuthorization": keyauthorization,
        }, sort_keys=True, indent=4)
        test_b64 = _b64(test_raw)
        test_protected = copy.deepcopy(header)
        test_protected.update({"nonce": test_nonce})
        test_protected64 = _b64(json.dumps(test_protected, sort_keys=True, indent=4))
        test_file = tempfile.NamedTemporaryFile(dir=".", prefix="challenge_", suffix=".json")
        test_file.write("{0}.{1}".format(test_protected64, test_b64))
        test_file.flush()
        test_file_name = os.path.basename(test_file.name)
        test_file_sig = tempfile.NamedTemporaryFile(dir=".", prefix="challenge_", suffix=".sig")
        test_file_sig_name = os.path.basename(test_file_sig.name)
        tests.append({
            "uri": challenge['uri'],
            "protected64": test_protected64,
            "data64": test_b64,
            "file": test_file,
            "file_name": test_file_name,
            "sig": test_file_sig,
            "sig_name": test_file_sig_name,
        })

        # challenge response for server
        responses.append({
            "uri": ".well-known/acme-challenge/{0}".format(challenge['token']),
            "data": keyauthorization,
        })

    # Step 9: Ask the user to sign the challenge responses
    sys.stderr.write("""\
STEP 3: You need to sign some more files (replace 'user.key' with your user private key).

{0}

""".format(
    "\n".join("openssl dgst -sha256 -sign user.key -out {0} {1}".format(
        i['sig_name'], i['file_name']) for i in tests)))

    stdout = sys.stdout
    sys.stdout = sys.stderr
    raw_input("Press Enter when you've run the above commands in a new terminal window...")
    sys.stdout = stdout

    # Step 10: Load the response signatures
    for n, i in enumerate(ids):
        tests[n]['sig'].seek(0)
        tests[n]['sig64'] = _b64(tests[n]['sig'].read())

    # Step 11: Ask the user to host the token on their server
    for n, i in enumerate(ids):
        if file_based:
            sys.stderr.write("""\
STEP {0}: Please update your server to serve the following file at this URL:

--------------
URL: http://{1}/{2}
File contents: \"{3}\"
--------------

Notes:
- Do not include the quotes in the file.
- The file should be one line without any spaces.

""".format(n + 4, i['domain'], responses[n]['uri'], responses[n]['data']))

            stdout = sys.stdout
            sys.stdout = sys.stderr
            raw_input("Press Enter when you've got the file hosted on your server...")
            sys.stdout = stdout
        else:
            sys.stderr.write("""\
STEP {0}: You need to run this command on {1} (don't stop the python command until the next step).

sudo python -c "import BaseHTTPServer; \\
    h = BaseHTTPServer.BaseHTTPRequestHandler; \\
    h.do_GET = lambda r: r.send_response(200) or r.end_headers() or r.wfile.write('{2}'); \\
    s = BaseHTTPServer.HTTPServer(('0.0.0.0', 80), h); \\
    s.serve_forever()"

""".format(n + 4, i['domain'], responses[n]['data']))

            stdout = sys.stdout
            sys.stdout = sys.stderr
            raw_input("Press Enter when you've got the python command running on your server...")
            sys.stdout = stdout

        # Step 12: Let the CA know you're ready for the challenge
        sys.stderr.write("Requesting verification for {0}...\n".format(i['domain']))
        test_data = json.dumps({
            "header": header,
            "protected": tests[n]['protected64'],
            "payload": tests[n]['data64'],
            "signature": tests[n]['sig64'],
        }, sort_keys=True, indent=4)
        test_url = tests[n]['uri']
        try:
            resp = urllib2.urlopen(test_url, test_data)
            test_result = json.loads(resp.read())
        except urllib2.HTTPError as e:
            sys.stderr.write("Error: test_data:\n")
            sys.stderr.write("POST {0}\n".format(test_url))
            sys.stderr.write(test_data)
            sys.stderr.write("\n")
            sys.stderr.write(e.read())
            sys.stderr.write("\n")
            raise

        # Step 13: Wait for CA to mark test as valid
        sys.stderr.write("Waiting for {0} challenge to pass...\n".format(i['domain']))
        while True:
            try:
                resp = urllib2.urlopen(test_url)
                challenge_status = json.loads(resp.read())
            except urllib2.HTTPError as e:
                sys.stderr.write("Error: test_data:\n")
                sys.stderr.write("GET {0}\n".format(test_url))
                sys.stderr.write(test_data)
                sys.stderr.write("\n")
                sys.stderr.write(e.read())
                sys.stderr.write("\n")
                raise
            if challenge_status['status'] == "pending":
                time.sleep(2)
            elif challenge_status['status'] == "valid":
                sys.stderr.write("Passed {0} challenge!\n".format(i['domain']))
                break
            else:
                raise KeyError("'{0}' challenge did not pass: {1}".format(i['domain'], challenge_status))

    # Step 14: Get the certificate signed
    sys.stderr.write("Requesting signature...\n")
    csr_file_sig.seek(0)
    csr_sig64 = _b64(csr_file_sig.read())
    csr_data = json.dumps({
        "header": header,
        "protected": csr_protected64,
        "payload": csr_b64,
        "signature": csr_sig64,
    }, sort_keys=True, indent=4)
    csr_url = "{0}/acme/new-cert".format(CA)
    try:
        resp = urllib2.urlopen(csr_url, csr_data)
        signed_der = resp.read()
    except urllib2.HTTPError as e:
        sys.stderr.write("Error: csr_data:\n")
        sys.stderr.write("POST {0}\n".format(csr_url))
        sys.stderr.write(csr_data)
        sys.stderr.write("\n")
        sys.stderr.write(e.read())
        sys.stderr.write("\n")
        raise

    # Step 15: Convert the signed cert from DER to PEM
    sys.stderr.write("Certificate signed!\n")

    if file_based:
        sys.stderr.write("You can remove the acme-challenge file from your webserver now.\n")
    else:
        sys.stderr.write("You can stop running the python command on your server (Ctrl+C works).\n")

    signed_der64 = base64.b64encode(signed_der)
    signed_pem = """\
-----BEGIN CERTIFICATE-----
{0}
-----END CERTIFICATE-----
""".format("\n".join(textwrap.wrap(signed_der64, 64)))

    return signed_pem

Example 18

Project: ck-caffe
Source File: module.py
View license
def crowdsource(i):
    """
    Input:  {
              (local)               - if 'yes', local crowd-benchmarking, instead of public
              (user)                - force different user ID/email for demos

              (choices)             - force different choices to program pipeline

              (repetitions)         - statistical repetitions (default=1), for now statistical analysis is not used (TBD)
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import copy
    import os

    # Setting output
    o=i.get('out','')
    oo=''
    if o=='con': oo='con'

    quiet=i.get('quiet','')

    er=i.get('exchange_repo','')
    if er=='': er=ck.cfg['default_exchange_repo_uoa']
    esr=i.get('exchange_subrepo','')
    if esr=='': esr=ck.cfg['default_exchange_subrepo_uoa']

    if i.get('local','')=='yes': 
       er='local'
       esr=''

    la=i.get('local_autotuning','')

    repetitions=i.get('repetitions','')
    if repetitions=='': repetitions=3
    repetitions=int(repetitions)

    record='no'

    # Check if any input has . and convert to dict
    for k in list(i.keys()):
        if k.find('.')>0:
            v=i[k]

            kk='##'+k.replace('.','#')

            del(i[k])

            r=ck.set_by_flat_key({'dict':i, 'key':kk, 'value':v})
            if r['return']>0: return r

    choices=i.get('choices',{})
    xchoices=copy.deepcopy(choices)

    # Get user 
    user=''

    mcfg={}
    ii={'action':'load',
        'module_uoa':'module',
        'data_uoa':cfg['module_deps']['program.optimization']}
    r=ck.access(ii)
    if r['return']==0:
       mcfg=r['dict']

       dcfg={}
       ii={'action':'load',
           'module_uoa':mcfg['module_deps']['cfg'],
           'data_uoa':mcfg['cfg_uoa']}
       r=ck.access(ii)
       if r['return']>0 and r['return']!=16: return r
       if r['return']!=16:
          dcfg=r['dict']

       user=dcfg.get('user_email','')

    # Initialize local environment for program optimization ***********************************************************
    pi=i.get('platform_info',{})
    if len(pi)==0:
       ii=copy.deepcopy(i)
       ii['action']='initialize'
       ii['module_uoa']=cfg['module_deps']['program.optimization']
       ii['data_uoa']='caffe'
       ii['exchange_repo']=er
       ii['exchange_subrepo']=esr
       ii['skip_welcome']='yes'
       ii['skip_log_wait']='yes'
       ii['crowdtuning_type']='caffe-crowd-benchmarking'
       r=ck.access(ii)
       if r['return']>0: return r

       pi=r['platform_info']
       user=r.get('user','')

    hos=pi['host_os_uoa']
    hosd=pi['host_os_dict']

    tos=pi['os_uoa']
    tosd=pi['os_dict']
    tbits=tosd.get('bits','')

    remote=tosd.get('remote','')

    tdid=pi['device_id']

    features=pi.get('features',{})

    fplat=features.get('platform',{})
    fos=features.get('os',{})
    fcpu=features.get('cpu',{})
    fgpu=features.get('gpu',{})

    plat_name=fplat.get('name','')
    plat_uid=features.get('platform_uid','')
    os_name=fos.get('name','')
    os_uid=features.get('os_uid','')
    cpu_name=fcpu.get('name','')
    if cpu_name=='': cpu_name='unknown-'+fcpu.get('cpu_abi','')
    cpu_uid=features.get('cpu_uid','')
    gpu_name=fgpu.get('name','')
    gpgpu_name=''
    sn=fos.get('serial_number','')

    # Ask for cmd
    tp=['cpu', 'cuda', 'opencl']

    ck.out(line)
    ck.out('Select Caffe library type:')
    ck.out('')
    r=ck.access({'action':'select_list',
                 'module_uoa':cfg['module_deps']['choice'],
                 'choices':tp})
    if r['return']>0: return r
    xtp=r['choice']

    # Get extra platform features if "cuda" or "opencl"
    run_cmd='time_cpu'
    tags='lib,caffe'
    ntags='vcuda,vopencl'
    gpgpu_uid=''
    if xtp=='cuda' or xtp=='opencl':
        run_cmd='time_gpu'
        r=ck.access({'action':'detect',
                     'module_uoa':cfg['module_deps']['platform.gpgpu'],
                     'host_os':hos,
                     'target_os':tos,
                     'device_id':tdid,
                     'type':xtp,
                     'share':'yes',
                     'exchange_repo':er,
                     'exchange_subrepo':esr})
        if r['return']>0: return r
        gfeat=r.get('features',{})
        gpgpus=gfeat.get('gpgpu',[])

        if len(gpgpus)>0:
            gpgpu_name=gpgpus[0].get('gpgpu',{}).get('name','')
            gpgpu_uid=gpgpus[0].get('gpgpu_uoa','')

        ntags=''
        tags+=',v'+xtp

    # Get deps from caffe program
    r=ck.access({'action':'load',
                 'module_uoa':cfg['module_deps']['program'],
                 'data_uoa':'caffe'})
    if r['return']>0: return r

    deps=r['dict']['compile_deps']
    pp=r['path']

    lib_dep=deps['lib-caffe']
    lib_dep['tags']=tags
    lib_dep['no_tags']=ntags

    # Check environment for selected type
    r=ck.access({'action':'resolve',
                 'module_uoa':cfg['module_deps']['env'],
                 'deps':deps,
                 'host_os':hos,
                 'target_os':tos,
                 'device_id':tdid,
                 'out':o})
    if r['return']>0: return r
    deps=r['deps']

    # Prepare CK pipeline for a given workload
    ii={'action':'pipeline',

        'module_uoa':cfg['module_deps']['program'],
        'data_uoa':'caffe',

        'prepare':'yes',

        'env':i.get('env',{}),
        'choices':choices,
        'dependencies':deps,
        'cmd_key':run_cmd,
        'no_state_check':'yes',
        'no_compiler_description':'yes',
        'skip_info_collection':'yes',
        'skip_calibration':'yes',
        'cpu_freq':'max',
        'gpu_freq':'max',
        'env_speed':'yes',
        'energy':'no',
        'skip_print_timers':'yes',
        'generate_rnd_tmp_dir':'no',

        'out':oo}

    rr=ck.access(ii)
    if rr['return']>0: return rr

    fail=rr.get('fail','')
    if fail=='yes':
        return {'return':10, 'error':'pipeline failed ('+rr.get('fail_reason','')+')'}

    ready=rr.get('ready','')
    if ready!='yes':
        return {'return':11, 'error':'couldn\'t prepare universal CK program workflow'}

    state=rr['state']
    tmp_dir=state['tmp_dir']

    # Clean pipeline
    if 'ready' in rr: del(rr['ready'])
    if 'fail' in rr: del(rr['fail'])
    if 'return' in rr: del(rr['return'])

    # Check if aggregted stats
    aggregated_stats={} # Pre-load statistics ...

    # Prepare high-level experiment meta
    meta={'cpu_name':cpu_name,
          'os_name':os_name,
          'plat_name':plat_name,
          'gpu_name':gpu_name,
          'caffe_type':xtp,
          'gpgpu_name':gpgpu_name,
          'cmd_key':run_cmd}

    # Process deps
    xdeps={}
    xnn=''
    xblas=''
    for k in deps:
        dp=deps[k]
        dname=dp.get('dict',{}).get('data_name','')

        if k=='caffemodel':
            xnn=dname

            j1=xnn.rfind('(')
            if j1>0:
                xnn=xnn[j1+1:-1]

        xdeps[k]={'name':dp.get('name',''), 'data_name':dname, 'ver':dp.get('ver','')}

    meta['xdeps']=xdeps
    meta['nn_type']=xnn

    mmeta=copy.deepcopy(meta)

    # Extra meta which is not used to search similar case ...
    mmeta['platform_uid']=plat_uid
    mmeta['os_uid']=os_uid
    mmeta['cpu_uid']=cpu_uid
    mmeta['gpgpu_uid']=gpgpu_uid
    mmeta['user']=user

    # Check if already exists
    # tbd

    # Run CK pipeline *****************************************************
    pipeline=copy.deepcopy(rr)
    if len(choices)>0:
        r=ck.merge_dicts({'dict1':pipeline['choices'], 'dict2':xchoices})
        if r['return']>0: return r

    ii={'action':'autotune',
        'module_uoa':cfg['module_deps']['pipeline'],

        'iterations':1,
        'repetitions':repetitions,

        'collect_all':'yes',
        'process_multi_keys':['##characteristics#*'],

        'tmp_dir':tmp_dir,

        'pipeline':pipeline,

        'stat_flat_dict':aggregated_stats,

        "features_keys_to_process":["##choices#*"],

        "record_params": {
          "search_point_by_features":"yes"
        },

        'out':oo}

    rrr=ck.access(ii)
    if rrr['return']>0: return rrr

    ls=rrr.get('last_iteration_output',{})
    state=ls.get('state',{})
    xchoices=copy.deepcopy(ls.get('choices',{}))
    lsa=rrr.get('last_stat_analysis',{})
    lsad=lsa.get('dict_flat',{})

    ddd={'meta':mmeta}

    ddd['choices']=xchoices

    features=ls.get('features',{})

    deps=ls.get('dependencies',{})

    fail=ls.get('fail','')
    fail_reason=ls.get('fail_reason','')

    ch=ls.get('characteristics',{})

    # Save pipeline
    ddd['state']={'fail':fail, 'fail_reason':fail_reason}
    ddd['characteristics']=ch

    ddd['user']=user

    if o=='con':
        ck.out('')
        ck.out('Saving results to the remote public repo ...')
        ck.out('')

        # Find remote entry
        rduid=''

        ii={'action':'search',
            'module_uoa':work['self_module_uid'],
            'repo_uoa':er,
            'remote_repo_uoa':esr,
            'search_dict':{'meta':meta}}
        rx=ck.access(ii)
        if rx['return']>0: return rx

        lst=rx['lst']

        if len(lst)==1:
            rduid=lst[0]['data_uid']
        else:
            rx=ck.gen_uid({})
            if rx['return']>0: return rx
            rduid=rx['data_uid']

        # Update meta
        rx=ck.access({'action':'update',
                      'module_uoa':work['self_module_uid'],
                      'data_uoa':rduid,
                      'repo_uoa':er,
                      'remote_repo_uoa':esr,
                      'dict':ddd,
                      'substitute':'yes',
                      'sort_keys':'yes'})
        if rx['return']>0: return rx

        # Push statistical characteristics
        fstat=os.path.join(pp,tmp_dir,ffstat)

        r=ck.save_json_to_file({'json_file':fstat, 'dict':lsad})
        if r['return']>0: return r

        rx=ck.access({'action':'push',
                      'module_uoa':work['self_module_uid'],
                      'data_uoa':rduid,
                      'repo_uoa':er,
                      'remote_repo_uoa':esr,
                      'filename':fstat,
                      'overwrite':'yes'})
        if rx['return']>0: return rx

        os.remove(fstat)

        # Info
        if o=='con':
            ck.out('Succesfully recorded results in remote repo (Entry UID='+rduid+')')

            # Check host URL prefix and default module/action
            url='http://cknowledge.org/repo/web.php?template=cknowledge&action=index&module_uoa=wfe&native_action=show&native_module_uoa=program.optimization&scenario=155b6fa5a4012a93&highlight_uid='+rduid
            ck.out('')
            ck.out('You can see your results at the following URL:')
            ck.out('')
            ck.out(url)

    return {'return':0}

Example 19

Project: ck-caffe
Source File: module.py
View license
def crowdsource(i):
    """
    Input:  {
              (local)               - if 'yes', local crowd-benchmarking, instead of public
              (user)                - force different user ID/email for demos

              (choices)             - force different choices to program pipeline

              (repetitions)         - statistical repetitions (default=1), for now statistical analysis is not used (TBD)
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import copy
    import os

    # Setting output
    o=i.get('out','')
    oo=''
    if o=='con': oo='con'

    quiet=i.get('quiet','')

    er=i.get('exchange_repo','')
    if er=='': er=ck.cfg['default_exchange_repo_uoa']
    esr=i.get('exchange_subrepo','')
    if esr=='': esr=ck.cfg['default_exchange_subrepo_uoa']

    if i.get('local','')=='yes': 
       er='local'
       esr=''

    la=i.get('local_autotuning','')

    repetitions=i.get('repetitions','')
    if repetitions=='': repetitions=3
    repetitions=int(repetitions)

    record='no'

    # Check if any input has . and convert to dict
    for k in list(i.keys()):
        if k.find('.')>0:
            v=i[k]

            kk='##'+k.replace('.','#')

            del(i[k])

            r=ck.set_by_flat_key({'dict':i, 'key':kk, 'value':v})
            if r['return']>0: return r

    choices=i.get('choices',{})
    xchoices=copy.deepcopy(choices)

    # Get user 
    user=''

    mcfg={}
    ii={'action':'load',
        'module_uoa':'module',
        'data_uoa':cfg['module_deps']['program.optimization']}
    r=ck.access(ii)
    if r['return']==0:
       mcfg=r['dict']

       dcfg={}
       ii={'action':'load',
           'module_uoa':mcfg['module_deps']['cfg'],
           'data_uoa':mcfg['cfg_uoa']}
       r=ck.access(ii)
       if r['return']>0 and r['return']!=16: return r
       if r['return']!=16:
          dcfg=r['dict']

       user=dcfg.get('user_email','')

    # Initialize local environment for program optimization ***********************************************************
    pi=i.get('platform_info',{})
    if len(pi)==0:
       ii=copy.deepcopy(i)
       ii['action']='initialize'
       ii['module_uoa']=cfg['module_deps']['program.optimization']
       ii['data_uoa']='caffe'
       ii['exchange_repo']=er
       ii['exchange_subrepo']=esr
       ii['skip_welcome']='yes'
       ii['skip_log_wait']='yes'
       ii['crowdtuning_type']='caffe-crowd-benchmarking'
       r=ck.access(ii)
       if r['return']>0: return r

       pi=r['platform_info']
       user=r.get('user','')

    hos=pi['host_os_uoa']
    hosd=pi['host_os_dict']

    tos=pi['os_uoa']
    tosd=pi['os_dict']
    tbits=tosd.get('bits','')

    remote=tosd.get('remote','')

    tdid=pi['device_id']

    features=pi.get('features',{})

    fplat=features.get('platform',{})
    fos=features.get('os',{})
    fcpu=features.get('cpu',{})
    fgpu=features.get('gpu',{})

    plat_name=fplat.get('name','')
    plat_uid=features.get('platform_uid','')
    os_name=fos.get('name','')
    os_uid=features.get('os_uid','')
    cpu_name=fcpu.get('name','')
    if cpu_name=='': cpu_name='unknown-'+fcpu.get('cpu_abi','')
    cpu_uid=features.get('cpu_uid','')
    gpu_name=fgpu.get('name','')
    gpgpu_name=''
    sn=fos.get('serial_number','')

    # Ask for cmd
    tp=['cpu', 'cuda', 'opencl']

    ck.out(line)
    ck.out('Select Caffe library type:')
    ck.out('')
    r=ck.access({'action':'select_list',
                 'module_uoa':cfg['module_deps']['choice'],
                 'choices':tp})
    if r['return']>0: return r
    xtp=r['choice']

    # Get extra platform features if "cuda" or "opencl"
    run_cmd='time_cpu'
    tags='lib,caffe'
    ntags='vcuda,vopencl'
    gpgpu_uid=''
    if xtp=='cuda' or xtp=='opencl':
        run_cmd='time_gpu'
        r=ck.access({'action':'detect',
                     'module_uoa':cfg['module_deps']['platform.gpgpu'],
                     'host_os':hos,
                     'target_os':tos,
                     'device_id':tdid,
                     'type':xtp,
                     'share':'yes',
                     'exchange_repo':er,
                     'exchange_subrepo':esr})
        if r['return']>0: return r
        gfeat=r.get('features',{})
        gpgpus=gfeat.get('gpgpu',[])

        if len(gpgpus)>0:
            gpgpu_name=gpgpus[0].get('gpgpu',{}).get('name','')
            gpgpu_uid=gpgpus[0].get('gpgpu_uoa','')

        ntags=''
        tags+=',v'+xtp

    # Get deps from caffe program
    r=ck.access({'action':'load',
                 'module_uoa':cfg['module_deps']['program'],
                 'data_uoa':'caffe'})
    if r['return']>0: return r

    deps=r['dict']['compile_deps']
    pp=r['path']

    lib_dep=deps['lib-caffe']
    lib_dep['tags']=tags
    lib_dep['no_tags']=ntags

    # Check environment for selected type
    r=ck.access({'action':'resolve',
                 'module_uoa':cfg['module_deps']['env'],
                 'deps':deps,
                 'host_os':hos,
                 'target_os':tos,
                 'device_id':tdid,
                 'out':o})
    if r['return']>0: return r
    deps=r['deps']

    # Prepare CK pipeline for a given workload
    ii={'action':'pipeline',

        'module_uoa':cfg['module_deps']['program'],
        'data_uoa':'caffe',

        'prepare':'yes',

        'env':i.get('env',{}),
        'choices':choices,
        'dependencies':deps,
        'cmd_key':run_cmd,
        'no_state_check':'yes',
        'no_compiler_description':'yes',
        'skip_info_collection':'yes',
        'skip_calibration':'yes',
        'cpu_freq':'max',
        'gpu_freq':'max',
        'env_speed':'yes',
        'energy':'no',
        'skip_print_timers':'yes',
        'generate_rnd_tmp_dir':'no',

        'out':oo}

    rr=ck.access(ii)
    if rr['return']>0: return rr

    fail=rr.get('fail','')
    if fail=='yes':
        return {'return':10, 'error':'pipeline failed ('+rr.get('fail_reason','')+')'}

    ready=rr.get('ready','')
    if ready!='yes':
        return {'return':11, 'error':'couldn\'t prepare universal CK program workflow'}

    state=rr['state']
    tmp_dir=state['tmp_dir']

    # Clean pipeline
    if 'ready' in rr: del(rr['ready'])
    if 'fail' in rr: del(rr['fail'])
    if 'return' in rr: del(rr['return'])

    # Check if aggregted stats
    aggregated_stats={} # Pre-load statistics ...

    # Prepare high-level experiment meta
    meta={'cpu_name':cpu_name,
          'os_name':os_name,
          'plat_name':plat_name,
          'gpu_name':gpu_name,
          'caffe_type':xtp,
          'gpgpu_name':gpgpu_name,
          'cmd_key':run_cmd}

    # Process deps
    xdeps={}
    xnn=''
    xblas=''
    for k in deps:
        dp=deps[k]
        dname=dp.get('dict',{}).get('data_name','')

        if k=='caffemodel':
            xnn=dname

            j1=xnn.rfind('(')
            if j1>0:
                xnn=xnn[j1+1:-1]

        xdeps[k]={'name':dp.get('name',''), 'data_name':dname, 'ver':dp.get('ver','')}

    meta['xdeps']=xdeps
    meta['nn_type']=xnn

    mmeta=copy.deepcopy(meta)

    # Extra meta which is not used to search similar case ...
    mmeta['platform_uid']=plat_uid
    mmeta['os_uid']=os_uid
    mmeta['cpu_uid']=cpu_uid
    mmeta['gpgpu_uid']=gpgpu_uid
    mmeta['user']=user

    # Check if already exists
    # tbd

    # Run CK pipeline *****************************************************
    pipeline=copy.deepcopy(rr)
    if len(choices)>0:
        r=ck.merge_dicts({'dict1':pipeline['choices'], 'dict2':xchoices})
        if r['return']>0: return r

    ii={'action':'autotune',
        'module_uoa':cfg['module_deps']['pipeline'],

        'iterations':1,
        'repetitions':repetitions,

        'collect_all':'yes',
        'process_multi_keys':['##characteristics#*'],

        'tmp_dir':tmp_dir,

        'pipeline':pipeline,

        'stat_flat_dict':aggregated_stats,

        "features_keys_to_process":["##choices#*"],

        "record_params": {
          "search_point_by_features":"yes"
        },

        'out':oo}

    rrr=ck.access(ii)
    if rrr['return']>0: return rrr

    ls=rrr.get('last_iteration_output',{})
    state=ls.get('state',{})
    xchoices=copy.deepcopy(ls.get('choices',{}))
    lsa=rrr.get('last_stat_analysis',{})
    lsad=lsa.get('dict_flat',{})

    ddd={'meta':mmeta}

    ddd['choices']=xchoices

    features=ls.get('features',{})

    deps=ls.get('dependencies',{})

    fail=ls.get('fail','')
    fail_reason=ls.get('fail_reason','')

    ch=ls.get('characteristics',{})

    # Save pipeline
    ddd['state']={'fail':fail, 'fail_reason':fail_reason}
    ddd['characteristics']=ch

    ddd['user']=user

    if o=='con':
        ck.out('')
        ck.out('Saving results to the remote public repo ...')
        ck.out('')

        # Find remote entry
        rduid=''

        ii={'action':'search',
            'module_uoa':work['self_module_uid'],
            'repo_uoa':er,
            'remote_repo_uoa':esr,
            'search_dict':{'meta':meta}}
        rx=ck.access(ii)
        if rx['return']>0: return rx

        lst=rx['lst']

        if len(lst)==1:
            rduid=lst[0]['data_uid']
        else:
            rx=ck.gen_uid({})
            if rx['return']>0: return rx
            rduid=rx['data_uid']

        # Update meta
        rx=ck.access({'action':'update',
                      'module_uoa':work['self_module_uid'],
                      'data_uoa':rduid,
                      'repo_uoa':er,
                      'remote_repo_uoa':esr,
                      'dict':ddd,
                      'substitute':'yes',
                      'sort_keys':'yes'})
        if rx['return']>0: return rx

        # Push statistical characteristics
        fstat=os.path.join(pp,tmp_dir,ffstat)

        r=ck.save_json_to_file({'json_file':fstat, 'dict':lsad})
        if r['return']>0: return r

        rx=ck.access({'action':'push',
                      'module_uoa':work['self_module_uid'],
                      'data_uoa':rduid,
                      'repo_uoa':er,
                      'remote_repo_uoa':esr,
                      'filename':fstat,
                      'overwrite':'yes'})
        if rx['return']>0: return rx

        os.remove(fstat)

        # Info
        if o=='con':
            ck.out('Succesfully recorded results in remote repo (Entry UID='+rduid+')')

            # Check host URL prefix and default module/action
            url='http://cknowledge.org/repo/web.php?template=cknowledge&action=index&module_uoa=wfe&native_action=show&native_module_uoa=program.optimization&scenario=155b6fa5a4012a93&highlight_uid='+rduid
            ck.out('')
            ck.out('You can see your results at the following URL:')
            ck.out('')
            ck.out(url)

    return {'return':0}

Example 20

View license
def do(i):
    # Detect basic platform info.
    ii={'action':'detect',
        'module_uoa':'platform',
        'out':'out'}
    r=ck.access(ii)
    if r['return']>0: return r

    # Host and target OS params.
    hos=r['host_os_uoa']
    hosd=r['host_os_dict']

    tos=r['os_uoa']
    tosd=r['os_dict']
    tdid=r['device_id']

    # Load Caffe program meta and desc to check deps.
    ii={'action':'load',
        'module_uoa':'program',
        'data_uoa':'caffe'}
    rx=ck.access(ii)
    if rx['return']>0: return rx
    mm=rx['dict']

    # Update deps from GPGPU or ones remembered during autotuning.
    cdeps=mm.get('compile_deps',{})

    # Caffe libs.
    depl=copy.deepcopy(cdeps['lib-caffe'])

    ii={'action':'resolve',
        'module_uoa':'env',
        'host_os':hos,
        'target_os':tos,
        'device_id':tdid,
        'deps':{'lib-caffe':copy.deepcopy(depl)}
    }
    r=ck.access(ii)
    if r['return']>0: return r

    udepl=r['deps']['lib-caffe'].get('choices',[]) # All UOAs of env for Caffe libs.
    if len(udepl)==0:
        return {'return':1, 'error':'no installed Caffe libs'}

    # Caffe models.
    depm=copy.deepcopy(cdeps['caffemodel'])

    ii={'action':'resolve',
        'module_uoa':'env',
        'host_os':hos,
        'target_os':tos,
        'device_id':tdid,
        'deps':{'caffemodel':copy.deepcopy(depm)}
    }
    r=ck.access(ii)
    if r['return']>0: return r

    udepm=r['deps']['caffemodel'].get('choices',[]) # All UOAs of env for Caffe models.
    if len(udepm)==0:
        return {'return':1, 'error':'no installed Caffe models'}

    # Prepare pipeline.
    cdeps['lib-caffe']['uoa']=udepl[0]
    cdeps['caffemodel']['uoa']=udepm[0]

    ii={'action':'pipeline',

        'module_uoa':'program',
        'data_uoa':'caffe',

        'prepare':'yes',

        'dependencies': cdeps,

        'no_state_check':'yes',
        'no_compiler_description':'yes',
        'skip_calibration':'yes',

        'cmd_key':'time_gpu',

        'cpu_freq':'max',
        'gpu_freq':'max',

        'speed':'yes',
        'energy':'no',

        'out':'con',
        'skip_print_timers':'yes'
    }

    r=ck.access(ii)
    if r['return']>0: return r

    fail=r.get('fail','')
    if fail=='yes':
        return {'return':10, 'error':'pipeline failed ('+r.get('fail_reason','')+')'}

    ready=r.get('ready','')
    if ready!='yes':
        return {'return':11, 'error':'pipeline not ready'}

    state=r['state']
    tmp_dir=state['tmp_dir']

    # Remember resolved deps for this benchmarking session.
    xcdeps=r.get('dependencies',{})

    # Clean pipeline.
    if 'ready' in r: del(r['ready'])
    if 'fail' in r: del(r['fail'])
    if 'return' in r: del(r['return'])

    pipeline=copy.deepcopy(r)

    # For each Caffe lib.
    for lib_uoa in udepl:
        # Load Caffe lib.
        ii={'action':'load',
            'module_uoa':'env',
            'data_uoa':lib_uoa}
        r=ck.access(ii)
        if r['return']>0: return r
        # Get the tags from e.g. 'BVLC Caffe framework (libdnn,viennacl)'
        lib_name=r['data_name']
        lib_tags=re.match('BVLC Caffe framework \((?P<tags>.*)\)', lib_name)
        lib_tags=lib_tags.group('tags').replace(' ', '').replace(',', '-')
        # Skip non-GPU libs.
        if r['dict']['customize']['params']['cpu_only']==1:
            cmd_key='time_cpu'
        else:
            cmd_key='time_gpu'

        # For each Caffe model.
        for model_uoa in udepm:
            # Load Caffe model.
            ii={'action':'load',
                'module_uoa':'env',
                'data_uoa':model_uoa}
            r=ck.access(ii)
            if r['return']>0: return r
            # Get the tags from e.g. 'Caffe model (net and weights) (deepscale, squeezenet, 1.1)'
            model_name=r['data_name']
            model_tags = re.match('Caffe model \(net and weights\) \((?P<tags>.*)\)', model_name)
            model_tags = model_tags.group('tags').replace(' ', '').replace(',', '-')

            record_repo='local'
            record_uoa=model_tags+'-'+lib_tags

            # Prepare pipeline.
            ck.out('---------------------------------------------------------------------------------------')
            ck.out('%s - %s' % (lib_name, lib_uoa))
            ck.out('%s - %s' % (model_name, model_uoa))
            ck.out('Experiment - %s:%s' % (record_repo, record_uoa))

            # Prepare autotuning input.
            cpipeline=copy.deepcopy(pipeline)

            # Reset deps and change UOA.
            new_deps={'lib-caffe':copy.deepcopy(depl), 
                      'caffemodel':copy.deepcopy(depm)}

            new_deps['lib-caffe']['uoa']=lib_uoa
            new_deps['caffemodel']['uoa']=model_uoa

            jj={'action':'resolve',
                'module_uoa':'env',
                'host_os':hos,
                'target_os':tos,
                'device_id':tdid,
                'deps':new_deps}
            r=ck.access(jj)
            if r['return']>0: return r

            cpipeline['dependencies'].update(new_deps)

            cpipeline['cmd_key']=cmd_key

            ii={'action':'autotune',

                'module_uoa':'pipeline',
                'data_uoa':'program',

                'choices_order':[
                    [
                        '##env#CK_CAFFE_BATCH_SIZE'
                    ]
                ],
                'choices_selection':[
                    {'type':'loop', 'start':p['start'], 'stop':p['stop'], 'step':p['step'], 'default':p['default']}
                ],

                'features_keys_to_process':['##choices#*'],

                'iterations':-1,
                'repetitions':p['repeat'],

                'record':'yes',
                'record_failed':'yes',
                'record_params':{
                    'search_point_by_features':'yes'
                },
                'record_repo':record_repo,
                'record_uoa':record_uoa,

                'tags':['explore-batch-size-libs-models', cmd_key, model_tags, lib_tags],

                'pipeline':cpipeline,
                'out':'con'}

            r=ck.access(ii)
            if r['return']>0: return r

            fail=r.get('fail','')
            if fail=='yes':
                return {'return':10, 'error':'pipeline failed ('+r.get('fail_reason','')+')'}

    return {'return':0}

Example 21

Project: EQcorrscan
Source File: match_filter.py
View license
def match_filter(template_names, template_list, st, threshold,
                 threshold_type, trig_int, plotvar, plotdir='.', cores=1,
                 debug=0, plot_format='png', output_cat=False,
                 extract_detections=False, arg_check=True):
    """
    Main matched-filter detection function.

    Over-arching code to run the correlations of given templates with a \
    day of seismic data and output the detections based on a given threshold.
    For a functional example see the tutorials.

    :type template_names: list
    :param template_names: List of template names in the same order as \
        template_list
    :type template_list: list
    :param template_list: A list of templates of which each template is a \
        Stream of obspy traces containing seismic data and header information.
    :type st: obspy.core.stream.Stream
    :param st: A Stream object containing all the data available and \
        required for the correlations with templates given.  For efficiency \
        this should contain no excess traces which are not in one or more of \
        the templates.  This will now remove excess traces internally, but \
        will copy the stream and work on the copy, leaving your input stream \
        untouched.
    :type threshold: float
    :param threshold: A threshold value set based on the threshold_type
    :type threshold_type: str
    :param threshold_type: The type of threshold to be used, can be MAD, \
        absolute or av_chan_corr.  See Note on thresholding below.
    :type trig_int: float
    :param trig_int: Minimum gap between detections in seconds.
    :type plotvar: bool
    :param plotvar: Turn plotting on or off
    :type plotdir: str
    :param plotdir: Path to plotting folder, plots will be output here, \
        defaults to run location.
    :type cores: int
    :param cores: Number of cores to use
    :type debug: int
    :param debug: Debug output level, the bigger the number, the more the \
        output.
    :type plot_format: str
    :param plot_format: Specify format of output plots if saved
    :type output_cat: bool
    :param output_cat: Specifies if matched_filter will output an \
        obspy.Catalog class containing events for each detection. Default \
        is False, in which case matched_filter will output a list of \
        detection classes, as normal.
    :type extract_detections: bool
    :param extract_detections: Specifies whether or not to return a list of \
        streams, one stream per detection.
    :type arg_check: bool
    :param arg_check: Check arguments, defaults to True, but if running in \
        bulk, and you are certain of your arguments, then set to False.\n

    .. rubric::
        If neither `output_cat` or `extract_detections` are set to `True`,
        then only the list of :class:`eqcorrscan.core.match_filter.DETECTION`'s
        will be output:
    :return: :class:`eqcorrscan.core.match_filter.DETECTION`'s detections for
        each detection made.
    :rtype: list
    .. rubric::
        If `output_cat` is set to `True`, then the
        :class:`obspy.core.event.Catalog` will also be output:
    :return: Catalog containing events for each detection, see above.
    :rtype: :class:`obspy.core.event.Catalog`
    .. rubric::
        If `extract_detections` is set to `True` then the list of
        :class:`obspy.core.stream.Stream`'s will also be output.
    :return:
        list of :class:`obspy.core.stream.Stream`'s for each detection, see
        above.
    :rtype: list

    .. warning::
        Plotting within the match-filter routine uses the Agg backend
        with interactive plotting turned off.  This is because the function
        is designed to work in bulk.  If you wish to turn interactive
        plotting on you must import matplotlib in your script first, when you
        them import match_filter you will get the warning that this call to
        matplotlib has no effect, which will mean that match_filter has not
        changed the plotting behaviour.

    .. note::
        **Thresholding:**

        **MAD** threshold is calculated as the:

        .. math::

            threshold {\\times} (median(abs(cccsum)))

        where :math:`cccsum` is the cross-correlation sum for a given template.

        **absolute** threshold is a true absolute threshold based on the
        cccsum value.

        **av_chan_corr** is based on the mean values of single-channel
        cross-correlations assuming all data are present as required for the
        template, e.g:

        .. math::

            av\_chan\_corr\_thresh=threshold \\times (cccsum / len(template))

        where :math:`template` is a single template from the input and the
        length is the number of channels within this template.

    .. note::
        The output_cat flag will create an :class:`obspy.core.eventCatalog`
        containing one event for each
        :class:`eqcorrscan.core.match_filter.DETECTION`'s generated by
        match_filter. Each event will contain a number of comments dealing
        with correlation values and channels used for the detection. Each
        channel used for the detection will have a corresponding
        :class:`obspy.core.event.Pick` which will contain time and
        waveform information. **HOWEVER**, the user should note that, at
        present, the pick times do not account for the
        prepick times inherent in each template. For example, if a template
        trace starts 0.1 seconds before the actual arrival of that phase,
        then the pick time generated by match_filter for that phase will be
        0.1 seconds early. We are working on a solution that will involve
        saving templates alongside associated metadata.
    """
    import matplotlib
    matplotlib.use('Agg')
    if arg_check:
        # Check the arguments to be nice - if arguments wrong type the parallel
        # output for the error won't be useful
        if not type(template_names) == list:
            raise MatchFilterError('template_names must be of type: list')
        if not type(template_list) == list:
            raise MatchFilterError('templates must be of type: list')
        if not len(template_list) == len(template_names):
            raise MatchFilterError('Not the same number of templates as names')
        for template in template_list:
            if not type(template) == Stream:
                msg = 'template in template_list must be of type: ' +\
                      'obspy.core.stream.Stream'
                raise MatchFilterError(msg)
        if not type(st) == Stream:
            msg = 'st must be of type: obspy.core.stream.Stream'
            raise MatchFilterError(msg)
        if str(threshold_type) not in [str('MAD'), str('absolute'),
                                       str('av_chan_corr')]:
            msg = 'threshold_type must be one of: MAD, absolute, av_chan_corr'
            raise MatchFilterError(msg)

    # Copy the stream here because we will muck about with it
    stream = st.copy()
    templates = copy.deepcopy(template_list)
    _template_names = copy.deepcopy(template_names)
    # Debug option to confirm that the channel names match those in the
    # templates
    if debug >= 2:
        template_stachan = []
        data_stachan = []
        for template in templates:
            for tr in template:
                if isinstance(tr.data, np.ma.core.MaskedArray):
                    raise MatchFilterError('Template contains masked array,'
                                           ' split first')
                template_stachan.append(tr.stats.station + '.' +
                                        tr.stats.channel)
        for tr in stream:
            data_stachan.append(tr.stats.station + '.' + tr.stats.channel)
        template_stachan = list(set(template_stachan))
        data_stachan = list(set(data_stachan))
        if debug >= 3:
            print('I have template info for these stations:')
            print(template_stachan)
            print('I have daylong data for these stations:')
            print(data_stachan)
    # Perform a check that the continuous data are all the same length
    min_start_time = min([tr.stats.starttime for tr in stream])
    max_end_time = max([tr.stats.endtime for tr in stream])
    longest_trace_length = stream[0].stats.sampling_rate * (max_end_time -
                                                            min_start_time)
    for tr in stream:
        if not tr.stats.npts == longest_trace_length:
            msg = 'Data are not equal length, padding short traces'
            warnings.warn(msg)
            start_pad = np.zeros(int(tr.stats.sampling_rate *
                                     (tr.stats.starttime - min_start_time)))
            end_pad = np.zeros(int(tr.stats.sampling_rate *
                                   (max_end_time - tr.stats.endtime)))
            tr.data = np.concatenate([start_pad, tr.data, end_pad])
    # Perform check that all template lengths are internally consistent
    for i, temp in enumerate(template_list):
        if len(set([tr.stats.npts for tr in temp])) > 1:
            msg = ('Template %s contains traces of differing length, this is '
                   'not currently supported' % _template_names[i])
            raise MatchFilterError(msg)
    outtic = time.clock()
    if debug >= 2:
        print('Ensuring all template channels have matches in long data')
    template_stachan = {}
    # Work out what station-channel pairs are in the templates, including
    # duplicate station-channel pairs.  We will use this information to fill
    # all templates with the same station-channel pairs as required by
    # _template_loop.
    for template in templates:
        stachans_in_template = []
        for tr in template:
            stachans_in_template.append((tr.stats.network, tr.stats.station,
                                         tr.stats.location, tr.stats.channel))
        stachans_in_template = dict(Counter(stachans_in_template))
        for stachan in stachans_in_template.keys():
            if stachan not in template_stachan.keys():
                template_stachan.update({stachan:
                                         stachans_in_template[stachan]})
            elif stachans_in_template[stachan] > template_stachan[stachan]:
                template_stachan.update({stachan:
                                         stachans_in_template[stachan]})
    # Remove un-matched channels from templates.
    _template_stachan = copy.deepcopy(template_stachan)
    for stachan in template_stachan.keys():
        if not stream.select(network=stachan[0], station=stachan[1],
                             location=stachan[2], channel=stachan[3]):
            # Remove stachan from list of dictionary of template_stachans
            _template_stachan.pop(stachan)
            # Remove template traces rather than adding NaN data
            for template in templates:
                if template.select(network=stachan[0], station=stachan[1],
                                   location=stachan[2], channel=stachan[3]):
                    for tr in template.select(network=stachan[0],
                                              station=stachan[1],
                                              location=stachan[2],
                                              channel=stachan[3]):
                        template.remove(tr)
    template_stachan = _template_stachan
    # Remove un-needed channels from continuous data.
    for tr in stream:
        if not (tr.stats.network, tr.stats.station,
                tr.stats.location, tr.stats.channel) in \
                template_stachan.keys():
            stream.remove(tr)
    # Check for duplicate channels
    stachans = [(tr.stats.network, tr.stats.station,
                 tr.stats.location, tr.stats.channel) for tr in stream]
    c_stachans = Counter(stachans)
    for key in c_stachans.keys():
        if c_stachans[key] > 1:
            msg = ('Multiple channels for %s.%s.%s.%s, likely a data issue'
                   % (key[0], key[1], key[2], key[3]))
            raise MatchFilterError(msg)
    # Pad out templates to have all channels
    for template, template_name in zip(templates, _template_names):
        if len(template) == 0:
            msg = ('No channels matching in continuous data for ' +
                   'template' + template_name)
            warnings.warn(msg)
            templates.remove(template)
            _template_names.remove(template_name)
            continue
        for stachan in template_stachan.keys():
            number_of_channels = len(template.select(network=stachan[0],
                                                     station=stachan[1],
                                                     location=stachan[2],
                                                     channel=stachan[3]))
            if number_of_channels < template_stachan[stachan]:
                missed_channels = template_stachan[stachan] -\
                                  number_of_channels
                nulltrace = Trace()
                nulltrace.stats.update(
                    {'network': stachan[0], 'station': stachan[1],
                     'location': stachan[2], 'channel': stachan[3],
                     'sampling_rate': template[0].stats.sampling_rate,
                     'starttime': template[0].stats.starttime})
                nulltrace.data = np.array([np.NaN] * len(template[0].data),
                                          dtype=np.float32)
                for dummy in range(missed_channels):
                    template += nulltrace
        template.sort()
        # Quick check that this has all worked
        if len(template) != max([len(t) for t in templates]):
            raise MatchFilterError('Internal error forcing same template '
                                   'lengths, report this error.')
    if debug >= 2:
        print('Starting the correlation run for this day')
    if debug >= 4:
        for template in templates:
            print(template)
        print(stream)
    [cccsums, no_chans, chans] = _channel_loop(templates=templates,
                                               stream=stream,
                                               cores=cores,
                                               debug=debug)
    if len(cccsums[0]) == 0:
        raise MatchFilterError('Correlation has not run, zero length cccsum')
    outtoc = time.clock()
    print(' '.join(['Looping over templates and streams took:',
                    str(outtoc - outtic), 's']))
    if debug >= 2:
        print(' '.join(['The shape of the returned cccsums is:',
                        str(np.shape(cccsums))]))
        print(' '.join(['This is from', str(len(templates)), 'templates']))
        print(' '.join(['Correlated with', str(len(stream)),
                        'channels of data']))
    detections = []
    if output_cat:
        det_cat = Catalog()
    for i, cccsum in enumerate(cccsums):
        template = templates[i]
        if str(threshold_type) == str('MAD'):
            rawthresh = threshold * np.median(np.abs(cccsum))
        elif str(threshold_type) == str('absolute'):
            rawthresh = threshold
        elif str(threshold_type) == str('av_chan_corr'):
            rawthresh = threshold * no_chans[i]
        # Findpeaks returns a list of tuples in the form [(cccsum, sample)]
        print(' '.join(['Threshold is set at:', str(rawthresh)]))
        print(' '.join(['Max of data is:', str(max(cccsum))]))
        print(' '.join(['Mean of data is:', str(np.mean(cccsum))]))
        if np.abs(np.mean(cccsum)) > 0.05:
            warnings.warn('Mean is not zero!  Check this!')
        # Set up a trace object for the cccsum as this is easier to plot and
        # maintains timing
        if plotvar:
            _match_filter_plot(stream=stream, cccsum=cccsum,
                               template_names=_template_names,
                               rawthresh=rawthresh, plotdir=plotdir,
                               plot_format=plot_format, i=i)
        if debug >= 4:
            print(' '.join(['Saved the cccsum to:', _template_names[i],
                            stream[0].stats.starttime.datetime.
                           strftime('%Y%j')]))
            np.save(_template_names[i] +
                    stream[0].stats.starttime.datetime.strftime('%Y%j'),
                    cccsum)
        tic = time.clock()
        if max(cccsum) > rawthresh:
            peaks = findpeaks.find_peaks2_short(
                arr=cccsum, thresh=rawthresh,
                trig_int=trig_int * stream[0].stats.sampling_rate, debug=debug,
                starttime=stream[0].stats.starttime,
                samp_rate=stream[0].stats.sampling_rate)
        else:
            print('No peaks found above threshold')
            peaks = False
        toc = time.clock()
        if debug >= 1:
            print(' '.join(['Finding peaks took:', str(toc - tic), 's']))
        if peaks:
            for peak in peaks:
                detecttime = stream[0].stats.starttime +\
                    peak[1] / stream[0].stats.sampling_rate
                # Detect time must be valid QuakeML uri within resource_id.
                # This will write a formatted string which is still
                # readable by UTCDateTime
                rid = ResourceIdentifier(id=_template_names[i] + '_' +
                                         str(detecttime.
                                             strftime('%Y%m%dT%H%M%S.%f')),
                                         prefix='smi:local')
                ev = Event(resource_id=rid)
                cr_i = CreationInfo(author='EQcorrscan',
                                    creation_time=UTCDateTime())
                ev.creation_info = cr_i
                # All detection info in Comments for lack of a better idea
                thresh_str = 'threshold=' + str(rawthresh)
                ccc_str = 'detect_val=' + str(peak[0])
                used_chans = 'channels used: ' +\
                             ' '.join([str(pair) for pair in chans[i]])
                ev.comments.append(Comment(text=thresh_str))
                ev.comments.append(Comment(text=ccc_str))
                ev.comments.append(Comment(text=used_chans))
                min_template_tm = min([tr.stats.starttime for tr in template])
                for tr in template:
                    if (tr.stats.station, tr.stats.channel) not in chans[i]:
                        continue
                    else:
                        pick_tm = detecttime + (tr.stats.starttime -
                                                min_template_tm)
                        wv_id = WaveformStreamID(network_code=tr.stats.network,
                                                 station_code=tr.stats.station,
                                                 channel_code=tr.stats.channel)
                        ev.picks.append(Pick(time=pick_tm, waveform_id=wv_id))
                detections.append(DETECTION(_template_names[i],
                                            detecttime,
                                            no_chans[i], peak[0], rawthresh,
                                            'corr', chans[i], event=ev))
                if output_cat:
                    det_cat.append(ev)
        if extract_detections:
            detection_streams = extract_from_stream(stream, detections)
    del stream, templates
    if output_cat and not extract_detections:
        return detections, det_cat
    elif not extract_detections:
        return detections
    elif extract_detections and not output_cat:
        return detections, detection_streams
    else:
        return detections, det_cat, detection_streams

Example 22

Project: markovbot
Source File: markovbot.py
View license
	def _autoreply(self):
		
		"""Continuously monitors Twitter Stream and replies when a tweet
		appears that matches self._targetstring. It will include
		self._tweetprefix and self._tweetsuffix in the tweets, provided they
		are not None.
		"""
		
		# Run indefinitively
		while self._autoreplythreadlives:

			# Wait a bit before rechecking whether autoreplying should be
			# started. It's highly unlikely the bot will miss something if
			# it is a second late, and checking continuously is a waste of
			# resource.
			time.sleep(1)

			# Only start when the bot logs in to twitter, and when a
			# target string is available
			if self._loggedin and self._targetstring != None:
	
				# Acquire the TwitterStream lock
				self._tslock.acquire(True)
	
				# Create a new iterator from the TwitterStream
				iterator = self._ts.statuses.filter(track=self._targetstring)
				
				# Release the TwitterStream lock
				self._tslock.release()
	
				# Only check for tweets when autoreplying
				while self._autoreplying:
					
					# Get a new Tweet (this will block until a new
					# tweet becomes available, but can also raise a
					# StopIteration Exception every now and again.)
					try:
						# Attempt to get the next tweet.
						tweet = iterator.next()
					except StopIteration:
						# Restart the iterator, and skip the rest of
						# the loop.
						iterator = self._ts.statuses.filter(track=self._targetstring)
						continue
					
					# Restart the connection if this is a 'hangup'
					# notification, which will be {'hangup':True}
					if u'hangup' in tweet.keys():
						# Reanimate the Twitter connection.
						self._twitter_reconnect()
						# Skip further processing.
						continue
					
					# Store a copy of the latest incoming tweet, for
					# debugging purposes
					self._lasttweetin = copy.deepcopy(tweet)
					
					# Only proceed if autoreplying is still required (there
					# can be a delay before the iterator produces a new, and
					# by that time autoreplying might already be stopped)
					if not self._autoreplying:
						# Skip one cycle, which will likely also make the
						# the while self._autoreplying loop stop
						continue

					# Report to console
					self._message(u'_autoreply', u"I've found a new tweet!")
					try:
						self._message(u'_autoreply', u'%s (@%s): %s' % \
							(tweet[u'user'][u'name'], \
							tweet[u'user'][u'screen_name'], tweet[u'text']))
					except:
						self._message(u'_autoreply', \
							u'Failed to report on new Tweet :(')
					
					# Don't reply to this bot's own tweets
					if tweet[u'user'][u'id_str'] == self._credentials[u'id_str']:
						# Skip one cycle, which will bring us to the
						# next tweet
						self._message(u'_autoreply', \
							u"This tweet was my own, so I won't reply!")
						continue
					
					# Don't reply to retweets
					if u'retweeted_status' in tweet.keys():
						# Skip one cycle, which will bring us to the
						# next tweet
						self._message(u'_autoreply', \
							u"This was a retweet, so I won't reply!")
						continue

					# Don't reply to tweets that are in the nono-list
					if tweet[u'id_str'] in self._nonotweets:
						# Skip one cycle, which will bring us to the
						# next tweet
						self._message(u'_autoreply', \
							u"This tweet was in the nono-list, so I won't reply!")
						continue

					# Skip tweets that are too deep into a conversation
					if self._maxconvdepth != None:
						# Get the ID of the tweet that the current tweet
						# was a reply to
						orid = tweet[u'in_reply_to_status_id_str']
						# Keep digging through the tweets until the the
						# top-level tweet is found, or until we pass the
						# maximum conversation depth
						counter = 0
						while orid != None and orid not in self._nonotweets:
							# If the current in-reply-to-ID is not None,
							# the current tweet was a reply. Increase
							# the reply counter by one.
							ortweet = self._t.statuses.show(id=orid)
							orid = ortweet[u'in_reply_to_status_id_str']
							counter += 1
							# Stop counting when the current value
							# exceeds the maximum allowed depth
							if counter >= self._maxconvdepth:
								# Add the current tweets ID to the list
								# of tweets that this bot should not
								# reply to. (Keeping track prevents
								# excessive use of the Twitter API by
								# continuously asking for the
								# in-reply-to-ID of tweets)
								self._nonotweets.append(orid)
						# Don't reply if this tweet is a reply in a tweet
						# conversation of more than self._maxconvdepth tweets,
						# or if the tweet's ID is in this bot's list of
						# tweets that it shouldn't reply to
						if counter >= self._maxconvdepth or \
							orid in self._nonotweets:
							self._message(u'_autoreply', \
								u"This tweet is part of a conversation, and I don't reply to conversations with over %d tweets." % (self._maxconvdepth))
							continue
					
					# Detect the language of the tweet, if the
					# language of the reply depends on it.
					if self._autoreply_database == u'auto-language':
						# Get the language of the tweet, or default
						# to English if it isn't available.
						if u'lang' in tweet.keys():
							lang = tweet[u'lang'].lower()
							self._message(u'_autoreply', u"I detected language: '%s'." % (lang))
						else:
							lang = u'en'
							self._message(u'_autoreply', u"I couldn't detect the language, so I defaulted to '%s'." % (lang))
						# Check if the language is available in the
						# existing dicts. Select the associated
						# database, or default to English when the
						# detected language isn't available, or
						# default to u'default' when English is not
						# available.
						if lang in self.data.keys():
							database = lang
							self._message(u'_autoreply', u"I chose database: '%s'." % (database))
						elif u'en' in self.data.keys():
							database = u'en'
							self._message(u'_autoreply', u"There was no database for detected language '%s', so I defaulted to '%s'." % (lang, database))
						else:
							database = u'default'
							self._message(u'_autoreply', u"There was no database for detected language '%s', nor for 'en', so I defaulted to '%s'." % (lang, database))
					# Randomly choose a database if a random database
					# was requested. Never use an empty database,
					# though (the while loop prevents this).
					elif self._autoreply_database == u'random-database':
						database = random.choice(self.data.keys())
						while self.data[database] == {}:
							database = random.choice(self.data.keys())
						self._message(u'_autoreply', \
							u'Randomly chose database: %s' % (database))
					# Randomly choose a database out of a list of
					# potential databases.
					elif type(self._autoreply_database) in [list, tuple]:
						database = random.choice(self._autoreply_database)
						self._message(u'_autoreply', \
							u'Randomly chose database: %s' % (database))
					# Use the preferred database.
					elif type(self._autoreply_database) in [str, unicode]:
						database = copy.deepcopy(self._autoreply_database)
						self._message(u'_autoreply', \
							u'Using database: %s' % (database))
					# If none of the above options apply, default to
					# the default database.
					else:
						database = u'default'
						self._message(u'_autoreply', \
							u'Defaulted to database: %s' % (database))
					
					# If the selected database is not a string, or if
					# it is empty, then fall back on the default
					# database.
					if type(database) not in [str, unicode]:
						self._message(u'_autoreply', \
							u"Selected database '%s' is invalid, defaulting to: %s" % (database, u'default'))
						database = u'default'
					elif database not in self.data.keys():
						self._message(u'_autoreply', \
							u"Selected database '%s' does not exist, defaulting to: %s" % (database, u'default'))
						database = u'default'
					elif self.data[database] == {}:
						self._message(u'_autoreply', \
							u"Selected database '%s' is empty, defaulting to: %s" % (database, u'default'))
						database = u'default'
	
					# Separate the words in the tweet
					tw = tweet[u'text'].split()
					# Clean up the words in the tweet
					for i in range(len(tw)):
						# Remove clutter
						tw[i] = tw[i].replace(u'@',u''). \
							replace(u'#',u'').replace(u'.',u''). \
							replace(u',',u'').replace(u';',u''). \
							replace(u':',u'').replace(u'!',u''). \
							replace(u'?',u'').replace(u"'",u'')

					# Make a list of potential seed words in the tweet
					seedword = []
					if self._keywords != None:
						for kw in self._keywords:
							# Check if the keyword is in the list of
							# words from the tweet
							if kw in tw:
								seedword.append(kw)
					# If there are no potential seeds in the tweet, None
					# will lead to a random word being chosen
					if len(seedword) == 0:
						seedword = None
					# Report back on the chosen keyword
					self._message(u'_autoreply', u"I found seedwords: '%s'." % (seedword))

					# Construct a prefix for this tweet, which should
					# include the handle ('@example') of the sender
					if self._tweetprefix == None:
						prefix = u'@%s' % (tweet[u'user'][u'screen_name'])
					else:
						# Use the specified prefix.
						if type(self._tweetprefix) in [str, unicode]:
							prefix = u'@%s %s' % \
								(tweet[u'user'][u'screen_name'], \
								self._tweetprefix)
						# Randomly choose one of the specified
						# prefixes.
						elif type(self._tweetprefix) in [list, tuple]:
							prefix = u'@%s %s' % \
								(tweet[u'user'][u'screen_name'], \
								random.choice(self._tweetprefix))
						# Fall back on the default option.
						else:
							prefix = u'@%s' % (tweet[u'user'][u'screen_name'])
							self._message(u'_autoreply', \
								u"Could not recognise the type of prefix '%s'; using no prefix." % (self._tweetprefix))

					# Construct a suffix for this tweet. We use the
					# specified prefix, which can also be None. Or
					# we randomly select one from a list of potential
					# suffixes.
					if self._tweetsuffix == None:
						suffix = copy.deepcopy(self._tweetprefix)
					elif type(self._tweetsuffix) in [str, unicode]:
						suffix = copy.deepcopy(self._tweetprefix)
					elif type(self._tweetprefix) in [list, tuple]:
						suffix = random.choice(self._tweetprefix)
					else:
						suffix = None
						self._message(u'_autoreply', \
							u"Could not recognise the type of suffix '%s'; using no suffix." % (self._tweetsuffix))

					# Construct a new tweet
					response = self._construct_tweet(database=database, \
						seedword=None, prefix=prefix, suffix=suffix)

					# Acquire the twitter lock
					self._tlock.acquire(True)
					# Reply to the incoming tweet
					try:
						# Post a new tweet
						resp = self._t.statuses.update(status=response,
							in_reply_to_status_id=tweet[u'id_str'],
							in_reply_to_user_id=tweet[u'user'][u'id_str'],
							in_reply_to_screen_name=tweet[u'user'][u'screen_name']
							)
						# Report to the console
						self._message(u'_autoreply', u'Posted reply: %s' % (response))
						# Store a copy of the latest outgoing tweet, for
						# debugging purposes
						self._lasttweetout = copy.deepcopy(resp)
					except Exception, e:
						self._error(u'_autoreply', u"Failed to post a reply: '%s'" % (e))
					# Release the twitter lock
					self._tlock.release()
					
					# Wait for the minimal tweeting delay.
					time.sleep(60.0*self._mindelay)

Example 23

Project: Nagstamon
Source File: IcingaWeb2.py
View license
    def _get_status(self):
        """
            Get status from Icinga Server - only JSON
        """
        # define CGI URLs for hosts and services
        if self.cgiurl_hosts == self.cgiurl_services == None:
            # services (unknown, warning or critical?)
            self.cgiurl_services = {'hard': self.monitor_cgi_url + '/monitoring/list/services?service_state>0&service_state<=3&service_state_type=1&addColumns=service_last_check&format=json', \
                                    'soft': self.monitor_cgi_url + '/monitoring/list/services?service_state>0&service_state<=3&service_state_type=0&addColumns=service_last_check&format=json'}
            # hosts (up or down or unreachable)
            self.cgiurl_hosts = {'hard': self.monitor_cgi_url + '/monitoring/list/hosts?host_state>0&host_state<=2&host_state_type=1&addColumns=host_last_check&format=json', \
                                 'soft': self.monitor_cgi_url + '/monitoring/list/hosts?host_state>0&host_state<=2&host_state_type=0&addColumns=host_last_check&format=json'}

        # new_hosts dictionary
        self.new_hosts = dict()

        # hosts - mostly the down ones
        # now using JSON output from Icinga
        try:
            for status_type in 'hard', 'soft':   
                # first attempt
                result = self.FetchURL(self.cgiurl_hosts[status_type], giveback='raw')            
                # authentication errors get a status code 200 too back because its
                # HTML works fine :-(
                if result.status_code < 400 and\
                   result.result.startswith('<'):
                    # in case of auth error reset HTTP session and try again
                    self.reset_HTTP()
                    result = self.FetchURL(self.cgiurl_hosts[status_type], giveback='raw') 
                    # if it does not work again tell GUI there is a problem
                    if result.status_code < 400 and\
                       result.result.startswith('<'):
                        self.refresh_authentication = True
                        return Result(result=result.result,
                                      error='Authentication error',
                                      status_code=result.status_code)
                
                # purify JSON result of unnecessary control sequence \n
                jsonraw, error, status_code = copy.deepcopy(result.result.replace('\n', '')),\
                                              copy.deepcopy(result.error),\
                                              result.status_code

                if error != '' or status_code >= 400:
                    return Result(result=jsonraw,
                                  error=error,
                                  status_code=status_code)

                # check if any error occured
                self.check_for_error(jsonraw, error, status_code)

                hosts = json.loads(jsonraw)

                for host in hosts:
                    # make dict of tuples for better reading
                    h = dict(host.items())

                    # host
                    if self.use_display_name_host == False:
                        # according to http://sourceforge.net/p/nagstamon/bugs/83/ it might
                        # better be host_name instead of host_display_name
                        # legacy Icinga adjustments
                        if 'host_name' in h: host_name = h['host_name']
                        elif 'host' in h: host_name = h['host']
                    else:
                        # https://github.com/HenriWahl/Nagstamon/issues/46 on the other hand has
                        # problems with that so here we go with extra display_name option
                        host_name = h['host_display_name']

                    # host objects contain service objects
                    if not host_name in self.new_hosts:
                        self.new_hosts[host_name] = GenericHost()
                        self.new_hosts[host_name].name = host_name
                        self.new_hosts[host_name].server = self.name
                        self.new_hosts[host_name].status = self.STATES_MAPPING['hosts'][int(h['host_state'])]
                        self.new_hosts[host_name].last_check = datetime.datetime.fromtimestamp(int(h['host_last_check']))
                        self.new_hosts[host_name].attempt = h['host_attempt']
                        self.new_hosts[host_name].status_information = BeautifulSoup(h['host_output'].replace('\n', ' ').strip(), 'html.parser').text
                        self.new_hosts[host_name].passiveonly = not(int(h['host_active_checks_enabled']))
                        self.new_hosts[host_name].notifications_disabled = not(int(h['host_notifications_enabled']))
                        self.new_hosts[host_name].flapping = int(h['host_is_flapping'])
                        self.new_hosts[host_name].acknowledged = int(h['host_acknowledged'])
                        self.new_hosts[host_name].scheduled_downtime = int(h['host_in_downtime'])
                        self.new_hosts[host_name].status_type = status_type
                        
                        # extra Icinga properties to solve https://github.com/HenriWahl/Nagstamon/issues/192
                        # acknowledge needs host_description and no display name
                        self.new_hosts[host_name].real_name = h['host_name']
       
                        # extra duration needed for calculation
                        duration = datetime.datetime.now() - datetime.datetime.fromtimestamp(int(h['host_last_state_change']))
                        self.new_hosts[host_name].duration = strfdelta(duration, '{days}d {hours}h {minutes}m {seconds}s')
                        
                    del h, host_name
        except:
            import traceback
            traceback.print_exc(file=sys.stdout)

            # set checking flag back to False
            self.isChecking = False
            result, error = self.Error(sys.exc_info())
            return Result(result=result, error=error)

        # services
        try:
            for status_type in 'hard', 'soft':
                result = self.FetchURL(self.cgiurl_services[status_type], giveback='raw')
                # purify JSON result of unnecessary control sequence \n
                jsonraw, error, status_code = copy.deepcopy(result.result.replace('\n', '')),\
                                              copy.deepcopy(result.error),\
                                              result.status_code

                if error != '' or status_code >= 400:
                    return Result(result=jsonraw,
                                  error=error,
                                  status_code=status_code)
                
                # check if any error occured
                self.check_for_error(jsonraw, error, status_code)

                services = copy.deepcopy(json.loads(jsonraw))

                for service in services:
                    # make dict of tuples for better reading
                    s = dict(service.items())

                    if self.use_display_name_host == False:
                        # according to http://sourceforge.net/p/nagstamon/bugs/83/ it might
                        # better be host_name instead of host_display_name
                        # legacy Icinga adjustments
                        if 'host_name' in s: host_name = s['host_name']
                        elif 'host' in s: host_name = s['host']
                    else:
                        # https://github.com/HenriWahl/Nagstamon/issues/46 on the other hand has
                        # problems with that so here we go with extra display_name option
                        host_name = s['host_display_name']

                    # host objects contain service objects
                    # ##if not self.new_hosts.has_key(host_name):
                    if not host_name in self.new_hosts:
                        self.new_hosts[host_name] = GenericHost()
                        self.new_hosts[host_name].name = host_name
                        self.new_hosts[host_name].status = 'UP'
                        # extra Icinga properties to solve https://github.com/HenriWahl/Nagstamon/issues/192
                        # acknowledge needs host_description and no display name
                        self.new_hosts[host_name].real_name = s['host_name']

                    if self.use_display_name_host == False:
                        # legacy Icinga adjustments
                        if 'service_description' in s: service_name = s['service_description']
                        elif 'description' in s: service_name = s['description']
                        elif 'service' in s: service_name = s['service']
                    else:
                        service_name = s['service_display_name']

                    # if a service does not exist create its object
                    if not service_name in self.new_hosts[host_name].services:
                        self.new_hosts[host_name].services[service_name] = GenericService()
                        self.new_hosts[host_name].services[service_name].host = host_name
                        self.new_hosts[host_name].services[service_name].name = service_name
                        self.new_hosts[host_name].services[service_name].server = self.name
                        self.new_hosts[host_name].services[service_name].status = self.STATES_MAPPING['services'][int(s['service_state'])]
                        self.new_hosts[host_name].services[service_name].last_check = datetime.datetime.fromtimestamp(int(s['service_last_check']))                      
                        self.new_hosts[host_name].services[service_name].attempt = s['service_attempt']
                        self.new_hosts[host_name].services[service_name].status_information = BeautifulSoup(s['service_output'].replace('\n', ' ').strip(), 'html.parser').text
                        self.new_hosts[host_name].services[service_name].passiveonly = not(int(s['service_active_checks_enabled']))
                        self.new_hosts[host_name].services[service_name].notifications_disabled = not(int(s['service_notifications_enabled']))
                        self.new_hosts[host_name].services[service_name].flapping = int(s['service_is_flapping'])
                        self.new_hosts[host_name].services[service_name].acknowledged = int(s['service_acknowledged'])
                        self.new_hosts[host_name].services[service_name].scheduled_downtime = int(s['service_in_downtime'])
                        self.new_hosts[host_name].services[service_name].status_type = status_type
                        
                        # extra Icinga properties to solve https://github.com/HenriWahl/Nagstamon/issues/192
                        # acknowledge needs service_description and no display name
                        self.new_hosts[host_name].services[service_name].real_name = s['service_description']
                        
                        # extra duration needed for calculation
                        duration = datetime.datetime.now() - datetime.datetime.fromtimestamp(int(s['service_last_state_change']))
                        self.new_hosts[host_name].services[service_name].duration = strfdelta(duration, '{days}d {hours}h {minutes}m {seconds}s')                      
                        
                    del s, host_name, service_name
        except:

            import traceback
            traceback.print_exc(file=sys.stdout)

            # set checking flag back to False
            self.isChecking = False
            result, error = self.Error(sys.exc_info())
            return Result(result=result, error=error)

        # some cleanup
        del jsonraw, error, hosts, services

        # dummy return in case all is OK
        return Result()

Example 24

Project: Nagstamon
Source File: Multisite.py
View license
    def _get_status(self):
        """
            Get status from Check_MK Server
        """

        ret = Result()

        # Create URLs for the configured filters
        url_params = ''

        if self.force_authuser:
            url_params += "&force_authuser=1"

        url_params += '&is_host_acknowledged=-1&is_service_acknowledged=-1'
        url_params += '&is_host_notifications_enabled=-1&is_service_notifications_enabled=-1'
        url_params += '&is_host_active_checks_enabled=-1&is_service_active_checks_enabled=-1'
        url_params += '&host_scheduled_downtime_depth=-1&is_in_downtime=-1'

        try:
            response = []
            try:
                response = self._get_url(self.urls['api_hosts'] + url_params)
            except MultisiteError as e:
                if e.terminate:
                    return e.result

            if response == '':
                return Result(result='',
                              error='Login failed',
                              status_code=401)

            for row in response[1:]:
                host= dict(list(zip(copy.deepcopy(response[0]), copy.deepcopy(row))))
                n = {
                    'host':               host['host'],
                    'status':             self.statemap.get(host['host_state'], host['host_state']),
                    'last_check':         host['host_check_age'],
                    'duration':           host['host_state_age'],
                    'status_information': html.unescape(host['host_plugin_output'].replace('\n', ' ')),
                    'attempt':            host['host_attempt'],
                    'site':               host['sitename_plain'],
                    'address':            host['host_address']
                }

                # host objects contain service objects
                if n['host'] not in self.new_hosts:
                    new_host = n['host']
                    self.new_hosts[new_host] = GenericHost()
                    self.new_hosts[new_host].name = n['host']
                    self.new_hosts[new_host].server = self.name
                    self.new_hosts[new_host].status = n['status']
                    self.new_hosts[new_host].last_check = n['last_check']
                    self.new_hosts[new_host].duration = n['duration']
                    self.new_hosts[new_host].attempt = n['attempt']
                    self.new_hosts[new_host].status_information= html.unescape(n['status_information'].replace('\n', ' '))
                    self.new_hosts[new_host].site = n['site']
                    self.new_hosts[new_host].address = n['address']

                    # transisition to Check_MK 1.1.10p2
                    if 'host_in_downtime' in host:
                        if host['host_in_downtime'] == 'yes':
                            self.new_hosts[new_host].scheduled_downtime = True
                    if 'host_acknowledged' in host:
                        if host['host_acknowledged'] == 'yes':
                            self.new_hosts[new_host].acknowledged = True
                    if 'host_notifications_enabled' in host:
                        if host['host_notifications_enabled'] == 'no':
                            self.new_hosts[new_host].notifications_disabled = True

                    # hard/soft state for later filter evaluation
                    real_attempt, max_attempt = self.new_hosts[new_host].attempt.split('/')
                    if real_attempt != max_attempt:
                        self.new_hosts[new_host].status_type = 'soft'
                    else:
                        self.new_hosts[new_host].status_type = 'hard'

            del response

        except:
            import traceback
            traceback.print_exc(file=sys.stdout)

            self.isChecking = False
            result, error = self.Error(sys.exc_info())
            return Result(result=result, error=error)

        # Add filters to the url which should only be applied to the service request
        if conf.filter_services_on_unreachable_hosts == True:
            url_params += '&hst2=0'

        # services
        try:
            response = []
            try:
                response = self._get_url(self.urls['api_services'] + url_params)
            except MultisiteError as e:
                if e.terminate:
                    return e.result
                else:
                    response = copy.deepcopy(e.result.content)
                    ret = copy.deepcopy(e.result)

            for row in response[1:]:
                service = dict(list(zip(copy.deepcopy(response[0]), copy.deepcopy(row))))
                n = {
                    'host':               service['host'],
                    'service':            service['service_description'],
                    'status':             self.statemap.get(service['service_state'], service['service_state']),
                    'last_check':         service['svc_check_age'],
                    'duration':           service['svc_state_age'],
                    'attempt':            service['svc_attempt'],
                    'status_information': html.unescape(service['svc_plugin_output'].replace('\n', ' ')),
                    # Check_MK passive services can be re-scheduled by using the Check_MK service
                    'passiveonly':        service['svc_is_active'] == 'no' and not service['svc_check_command'].startswith('check_mk'),
                    'flapping':           service['svc_flapping'] == 'yes',
                    'site':               service['sitename_plain'],
                    'address':            service['host_address'],
                    'command':            service['svc_check_command'],
                }

                # host objects contain service objects
                if n['host'] not in self.new_hosts:
                    self.new_hosts[n['host']] = GenericHost()
                    self.new_hosts[n['host']].name = n['host']
                    self.new_hosts[n['host']].status = 'UP'
                    self.new_hosts[n['host']].site = n['site']
                    self.new_hosts[n['host']].address = n['address']
                # if a service does not exist create its object
                if n['service'] not in self.new_hosts[n['host']].services:
                    new_service = n['service']
                    self.new_hosts[n['host']].services[new_service] = GenericService()
                    self.new_hosts[n['host']].services[new_service].host = n['host']
                    self.new_hosts[n['host']].services[new_service].server = self.name
                    self.new_hosts[n['host']].services[new_service].name = n['service']
                    self.new_hosts[n['host']].services[new_service].status = n['status']
                    self.new_hosts[n['host']].services[new_service].last_check = n['last_check']
                    self.new_hosts[n['host']].services[new_service].duration = n['duration']
                    self.new_hosts[n['host']].services[new_service].attempt = n['attempt']
                    self.new_hosts[n['host']].services[new_service].status_information = n['status_information'].strip()
                    self.new_hosts[n['host']].services[new_service].passiveonly = n['passiveonly']
                    self.new_hosts[n['host']].services[new_service].flapping = n['flapping']
                    self.new_hosts[n['host']].services[new_service].site = n['site']
                    self.new_hosts[n['host']].services[new_service].address = n['address']
                    self.new_hosts[n['host']].services[new_service].command = n['command']

                    # transistion to Check_MK 1.1.10p2
                    if 'svc_in_downtime' in service:
                        if service['svc_in_downtime'] == 'yes':
                            self.new_hosts[n['host']].services[new_service].scheduled_downtime = True
                    if 'svc_acknowledged' in service:
                        if service['svc_acknowledged'] == 'yes':
                            self.new_hosts[n['host']].services[new_service].acknowledged = True
                    if 'svc_flapping' in service:
                        if service['svc_flapping'] == 'yes':
                            self.new_hosts[n['host']].services[new_service].flapping = True
                    if 'svc_notifications_enabled' in service:
                        if service['svc_notifications_enabled'] == 'no':
                            self.new_hosts[n['host']].services[new_service].notifications_disabled = True

                    # hard/soft state for later filter evaluation
                    real_attempt, max_attempt = self.new_hosts[n['host']].services[new_service].attempt.split('/')
                    if real_attempt != max_attempt:
                        self.new_hosts[n['host']].services[new_service].status_type = 'soft'
                    else:
                        self.new_hosts[n['host']].services[new_service].status_type = 'hard'

            del response

        except:
            import traceback
            traceback.print_exc(file=sys.stdout)

            # set checking flag back to False
            self.isChecking = False
            result, error = self.Error(sys.exc_info())
            return Result(result=copy.deepcopy(result), error=copy.deepcopy(error))

        del url_params

        return ret

Example 25

Project: kay
Source File: media_compiler.py
View license
def compile_js_(tag_name, js_config, force):
  if IS_APPSERVER:
    return

  def needs_update(media_info):
    if js_config['tool'] != 'goog_calcdeps':
      # update if target file does not exist
      target_path = make_output_path_(js_config, js_config['subdir'],
                                      js_config['output_filename'])
      if not os.path.exists(target_path):
        return True

    # update if it lacks required info in _media.yaml
    last_info = media_info.get(js_config['subdir'], tag_name)
    if not last_info:
      return True
    last_config = last_info.get('config')
    if not last_config:
      return True

    # update if any configuration setting is changed
    if not equal_object_(last_config, js_config):
      return True

    if 'related_files' not in last_info:
      return True
    for path, mtime in last_info['related_files']:
      if mtime != os.path.getmtime(path):
        return True
      
  def jsminify(js_path):
    from StringIO import StringIO
    from kay.ext.media_compressor.jsmin import JavascriptMinify
    ifile = open(js_path)
    outs = StringIO()
    JavascriptMinify().minify(ifile, outs)
    ret = outs.getvalue()
    if len(ret) > 0 and ret[0] == '\n':
      ret = ret[1:]
    return ret

  def concat(js_path):
    print_status(" concat %s" % js_path)
    ifile = open(js_path)
    js = ifile.read()
    ifile.close()
    return js

  def goog_calcdeps():
    deps_config = copy.deepcopy(js_config['goog_common'])
    deps_config.update(js_config['goog_calcdeps'])

    if deps_config.get('method') not in \
          ['separate', 'concat', 'concat_refs', 'compile']:
      print_status("COMPILE_MEDIA_JS['goog_calcdeps']['method'] setting is"
                   " invalid; unknown method `%s'" % deps_config.get('method'))
      sys.exit(1)

    output_urls = []
    if deps_config['method'] == 'separate':
      source_files, output_urls = goog_calcdeps_separate(deps_config)
    elif deps_config['method'] == 'concat':
      source_files, output_urls = goog_calcdeps_concat(deps_config)
    elif deps_config['method'] == 'concat_refs':
      source_files, output_urls = goog_calcdeps_concat_refs(deps_config)
    elif deps_config['method'] == 'compile':
      source_files, output_urls = goog_calcdeps_compile(deps_config)
      source_files = [file[0] for file in source_files]

    related_files = union_list(source_files, 
                               [make_input_path_(path)
                                  for path in js_config['source_files']])
    related_file_info = [(path, os.path.getmtime(path))
                           for path in related_files]
    
    # create yaml info
    last_info = {'config': copy.deepcopy(js_config),
                 'related_files': related_file_info,
                 'result_urls': output_urls}
    media_info.set(js_config['subdir'], tag_name, last_info)
    media_info.save()

  def goog_calcdeps_separate(deps_config):
    source_files = goog_calcdeps_list(deps_config)
    (output_urls, extern_urls) = goog_calcdeps_copy_files(deps_config,
                                                          source_files)
    return (source_files, extern_urls + output_urls)

  def goog_calcdeps_concat(deps_config):
    source_files = goog_calcdeps_list(deps_config)
    (output_urls, extern_urls) = goog_calcdeps_concat_files(deps_config,
                                                            source_files)
    return (source_files, extern_urls + output_urls)

  def goog_calcdeps_concat_refs(deps_config):
    source_files = goog_calcdeps_list(deps_config)
    original_files = [make_input_path_(path)
                      for path in js_config['source_files']]
    ref_files = [path for path in source_files if path not in original_files]
    (output_urls, extern_urls) = goog_calcdeps_concat_files(deps_config,
                                                            ref_files)
    original_urls = [path[len(kay.PROJECT_DIR):] for path in original_files]
    return (source_files, extern_urls + output_urls + original_urls)

  def goog_calcdeps_compile(deps_config):
    comp_config = copy.deepcopy(js_config['goog_common'])
    comp_config.update(js_config['goog_compiler'])

    source_files = []
    extern_urls = []

    command = '%s -o compiled -c "%s" ' % (deps_config['path'],
                                                 comp_config['path'])
    for path in deps_config.get('search_paths', []):
      command += '-p %s ' % make_input_path_(path)
    for path in js_config['source_files']:
      path = make_input_path_(path)
      command += '-i %s ' % path
      source_files.append((path, os.path.getmtime(path)))

    if comp_config['level'] == 'minify':
      level = 'WHITESPACE_ONLY'
    elif comp_config['level'] == 'advanced':
      level = 'ADVANCED_OPTIMIZATIONS'
    else:
      level = 'SIMPLE_OPTIMIZATIONS'
    flags = '--compilation_level=%s' % level
#    for path in comp_config.get('externs', []):
#      flags += '--externs=%s ' % make_input_path_(path)
#    if comp_config.get('externs'):
#      flags += ' --externs=%s ' % " ".join(comp_config['externs'])
    command += '-f "%s" ' % flags
    print_status(command)
    command_output = os.popen(command).read()

    output_path = make_output_path_(js_config, js_config['subdir'],
                                    js_config['output_filename'])
    ofile = create_file_(output_path)
    try:
      for path in comp_config.get('externs', []):
        if re.match(r'^https?://', path):
          extern_urls.append(path)
          continue
        path = make_input_path_(path)
        ifile = open(path)
        try:
          ofile.write(ifile.read())
        finally:
          ifile.close()
        source_files.append((path, os.path.getmtime(path)))
      ofile.write(command_output)
    finally:
      ofile.close()
    return (source_files, extern_urls + [output_path[len(kay.PROJECT_DIR):]])

  def goog_calcdeps_list(deps_config):
    source_files = []

    command = '%s -o list ' % deps_config['path']
    for path in deps_config['search_paths']:
      command += '-p %s ' % make_input_path_(path)
    for path in js_config['source_files']:
      command += '-i %s ' % make_input_path_(path)
    print_status(command)
    command_output = os.popen(command).read()
    for path in command_output.split("\n"):
      if path == '': continue
      source_files.append(path)
    return source_files

  def goog_calcdeps_copy_files(deps_config, source_files):
    extern_urls = []
    output_urls = []

    output_dir_base = make_output_path_(js_config, 'separated_js')

    if not os.path.exists(output_dir_base):
      os.makedirs(output_dir_base)
    if not deps_config.get('use_dependency_file', True):
      output_path = os.path.join(output_dir_base, '__goog_nodeps.js')
      ofile = open(output_path, "w")
      output_urls.append(output_path[len(kay.PROJECT_DIR):])
      try:
        ofile.write('CLOSURE_NO_DEPS = true;')
      finally:
        ofile.close()

    output_dirs = {}
    search_paths = [make_input_path_(path)
                    for path in deps_config['search_paths']]
    for path in search_paths:
      output_dirs[path] = os.path.join(output_dir_base,
                                       md5.new(path).hexdigest())

    all_paths = [make_input_path_(path)
                 for path in deps_config.get('externs', [])]
    all_paths.extend(source_files)
    for path in all_paths:
      if re.match(r'^https?://', path):
        extern_urls.append(path)
        continue

      path = make_input_path_(path)
      output_path = os.path.join(output_dir_base, re.sub('^/', '', path))
      for dir in search_paths:
        if path[0:len(dir)] == dir:
          output_path = os.path.join(output_dirs[dir],
                                     re.sub('^/', '', path[len(dir):]))
          break
      output_dir = os.path.dirname(output_path)

      if not os.path.exists(output_dir):
        os.makedirs(output_dir)
      shutil.copy2(path, output_path)
      output_urls.append(output_path[len(kay.PROJECT_DIR):])
    return (output_urls, extern_urls)
    
  def goog_calcdeps_concat_files(deps_config, source_files):
    extern_urls = []

    output_path = make_output_path_(js_config, js_config['subdir'],
                                    js_config['output_filename'])
    ofile = create_file_(output_path)
    try:
      if not deps_config.get('use_dependency_file', True):
        ofile.write('CLOSURE_NO_DEPS = true;')
      all_paths = [make_input_path_(path)
                   for path in deps_config.get('externs', [])]
      all_paths.extend(source_files)
      for path in all_paths:
        if re.match(r'^https?://', path):
          extern_urls.append(path)
          continue
        ifile = open(make_input_path_(path))
        ofile.write(ifile.read())
        ifile.close()
    finally:
      ofile.close()

    return ([output_path[len(kay.PROJECT_DIR):]], extern_urls)

  selected_tool = js_config['tool']

  if selected_tool not in \
        (None, 'jsminify', 'concat', 'goog_calcdeps', 'goog_compiler'):
    print_status("COMPILE_MEDIA_JS['tool'] setting is invalid;"
                 " unknown tool `%s'" % selected_tool)
    sys.exit(1)

  global media_info
  if media_info is None:
    media_info = MediaInfo.load()

  if not force and not needs_update(media_info):
    print_status(' up to date.')
    return

  if selected_tool == 'goog_calcdeps':
    return goog_calcdeps()

  if selected_tool is None:
    last_info = {'config': copy.deepcopy(js_config),
                 'result_urls': ['/'+f for f in js_config['source_files']]}
    media_info.set(js_config['subdir'], tag_name, last_info)
    media_info.save()
    return

  dest_path = make_output_path_(js_config, js_config['subdir'],
                                js_config['output_filename'])
  ofile = create_file_(dest_path)
  try:
    if selected_tool == 'jsminify':
      for path in js_config['source_files']:
        src_path = make_input_path_(path)
        ofile.write(jsminify(src_path))
    elif selected_tool == 'concat':
      for path in js_config['source_files']:
        src_path = make_input_path_(path)
        ofile.write(concat(src_path))
  finally:
    ofile.close()
  
  if selected_tool == 'goog_compiler':
    comp_config = copy.deepcopy(js_config['goog_common'])
    comp_config.update(js_config['goog_compiler'])
    if comp_config['level'] == 'minify':
      level = 'WHITESPACE_ONLY'
    elif comp_config['level'] == 'advanced':
      level = 'ADVANCED_OPTIMIZATIONS'
    else:
      level = 'SIMPLE_OPTIMIZATIONS'
    command_args = '--compilation_level=%s' % level
    for path in js_config['source_files']:
      command_args += ' --js %s' % make_input_path_(path)
    command_args += ' --js_output_file %s' % dest_path
    command = 'java -jar %s %s' % (comp_config['path'], command_args)
    command_output = os.popen(command).read()

  info = copy.deepcopy(js_config)
  info['output_filename'] = make_output_path_(js_config, js_config['subdir'],
                                              js_config['output_filename'],
                                              relative=True)
  info['result_urls'] = ['/'+info['output_filename']]
  media_info.set(js_config['subdir'], tag_name, info)
  media_info.save()

Example 26

Project: picochess
Source File: picochess.py
View license
def main():

    def display_system_info():
        if args.enable_internet:
            place = get_location()
            addr = get_ip()
        else:
            place = '?'
            addr = None
        DisplayMsg.show(Message.SYSTEM_INFO(info={'version': version, 'location': place, 'ip': addr,
                                                  'engine_name': engine_name, 'user_name': user_name
                                                  }))

    def compute_legal_fens(g):
        """
        Compute a list of legal FENs for the given game.
        :param g: The game
        :return: A list of legal FENs
        """
        fens = []
        for move in g.legal_moves:
            g.push(move)
            fens.append(g.board_fen())
            g.pop()
        return fens

    def probe_tablebase(game):
        if not gaviota:
            return None
        score = gaviota.probe_dtm(game)
        if score is not None:
            Observable.fire(Event.NEW_SCORE(score='gaviota', mate=score))
        return score

    def think(game, tc):
        """
        Start a new search on the current game.
        If a move is found in the opening book, fire an event in a few seconds.
        :return:
        """
        start_clock()
        book_move = searchmoves.book(bookreader, game)
        if book_move:
            Observable.fire(Event.BEST_MOVE(result=book_move, inbook=True))
        else:
            probe_tablebase(game)
            while not engine.is_waiting():
                time.sleep(0.1)
                logging.warning('engine is still not waiting')
            engine.position(copy.deepcopy(game))
            uci_dict = tc.uci()
            uci_dict['searchmoves'] = searchmoves.all(game)
            engine.go(uci_dict)

    def analyse(game):
        """
        Start a new ponder search on the current game.
        :return:
        """
        probe_tablebase(game)
        engine.position(copy.deepcopy(game))
        engine.ponder()

    def observe(game):
        """
        Starts a new ponder search on the current game.
        :return:
        """
        start_clock()
        analyse(game)

    def stop_search_and_clock():
        if interaction_mode == Mode.NORMAL:
            stop_clock()
        elif interaction_mode in (Mode.REMOTE, Mode.OBSERVE):
            stop_clock()
            stop_search()
        elif interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ, Mode.PONDER):
            stop_search()

    def stop_search():
        """
        Stop current search.
        :return:
        """
        engine.stop()

    def stop_clock():
        if interaction_mode in (Mode.NORMAL, Mode.OBSERVE, Mode.REMOTE):
            time_control.stop()
            DisplayMsg.show(Message.CLOCK_STOP())
        else:
            logging.warning('wrong mode: {}'.format(interaction_mode))

    def start_clock():
        if interaction_mode in (Mode.NORMAL, Mode.OBSERVE, Mode.REMOTE):
            time_control.start(game.turn)
            DisplayMsg.show(Message.CLOCK_START(turn=game.turn, time_control=time_control))
        else:
            logging.warning('wrong mode: {}'.format(interaction_mode))

    def check_game_state(game, play_mode):
        """
        Check if the game has ended or not ; it also sends Message to Displays if the game has ended.
        :param game:
        :param play_mode:
        :return: True is the game continues, False if it has ended
        """
        result = None
        if game.is_stalemate():
            result = GameResult.STALEMATE
        if game.is_insufficient_material():
            result = GameResult.INSUFFICIENT_MATERIAL
        if game.is_seventyfive_moves():
            result = GameResult.SEVENTYFIVE_MOVES
        if game.is_fivefold_repetition():
            result = GameResult.FIVEFOLD_REPETITION
        if game.is_checkmate():
            result = GameResult.MATE

        if result is None:
            return True
        else:
            DisplayMsg.show(Message.GAME_ENDS(result=result, play_mode=play_mode, game=game.copy()))
            return False

    def user_move(move):
        logging.debug('user move [%s]', move)
        if move not in game.legal_moves:
            logging.warning('Illegal move [%s]', move)
        else:
            handle_move(move=move)

    def process_fen(fen):
        nonlocal last_computer_fen
        nonlocal last_legal_fens
        nonlocal searchmoves
        nonlocal legal_fens

        # Check for same position
        if (fen == game.board_fen() and not last_computer_fen) or fen == last_computer_fen:
            logging.debug('Already in this fen: ' + fen)

        # Check if we have to undo a previous move (sliding)
        elif fen in last_legal_fens:
            if interaction_mode == Mode.NORMAL:
                if (play_mode == PlayMode.USER_WHITE and game.turn == chess.BLACK) or \
                        (play_mode == PlayMode.USER_BLACK and game.turn == chess.WHITE):
                    stop_search()
                    game.pop()
                    logging.debug('User move in computer turn, reverting to: ' + game.board_fen())
                elif last_computer_fen:
                    last_computer_fen = None
                    game.pop()
                    game.pop()
                    logging.debug('User move while computer move is displayed, reverting to: ' + game.board_fen())
                else:
                    logging.error("last_legal_fens not cleared: " + game.board_fen())
            elif interaction_mode == Mode.REMOTE:
                if (play_mode == PlayMode.USER_WHITE and game.turn == chess.BLACK) or \
                        (play_mode == PlayMode.USER_BLACK and game.turn == chess.WHITE):
                    game.pop()
                    logging.debug('User move in remote turn, reverting to: ' + game.board_fen())
                elif last_computer_fen:
                    last_computer_fen = None
                    game.pop()
                    game.pop()
                    logging.debug('User move while remote move is displayed, reverting to: ' + game.board_fen())
                else:
                    logging.error('last_legal_fens not cleared: ' + game.board_fen())
            else:
                game.pop()
                logging.debug('Wrong color move -> sliding, reverting to: ' + game.board_fen())
            legal_moves = list(game.legal_moves)
            user_move(legal_moves[last_legal_fens.index(fen)])
            if interaction_mode == Mode.NORMAL or interaction_mode == Mode.REMOTE:
                legal_fens = []
            else:
                legal_fens = compute_legal_fens(game)

        # legal move
        elif fen in legal_fens:
            time_control.add_inc(game.turn)
            legal_moves = list(game.legal_moves)
            user_move(legal_moves[legal_fens.index(fen)])
            last_legal_fens = legal_fens
            if interaction_mode == Mode.NORMAL or interaction_mode == Mode.REMOTE:
                legal_fens = []
            else:
                legal_fens = compute_legal_fens(game)

        # Player had done the computer or remote move on the board
        elif last_computer_fen and fen == game.board_fen():
            last_computer_fen = None
            if check_game_state(game, play_mode) and interaction_mode in (Mode.NORMAL, Mode.REMOTE):
                # finally reset all alternative moves see: handle_move()
                nonlocal searchmoves
                searchmoves.reset()
                time_control.add_inc(not game.turn)
                if time_control.mode != TimeMode.FIXED:
                    start_clock()
                DisplayMsg.show(Message.COMPUTER_MOVE_DONE_ON_BOARD())
                legal_fens = compute_legal_fens(game)
            else:
                legal_fens = []
            last_legal_fens = []

        # Check if this is a previous legal position and allow user to restart from this position
        else:
            game_history = copy.deepcopy(game)
            if last_computer_fen:
                game_history.pop()
            while game_history.move_stack:
                game_history.pop()
                if game_history.board_fen() == fen:
                    logging.debug("Current game FEN      : {}".format(game.fen()))
                    logging.debug("Undoing game until FEN: {}".format(fen))
                    stop_search_and_clock()
                    while len(game_history.move_stack) < len(game.move_stack):
                        game.pop()
                    last_computer_fen = None
                    last_legal_fens = []
                    if (interaction_mode == Mode.REMOTE or interaction_mode == Mode.NORMAL) and \
                            ((play_mode == PlayMode.USER_WHITE and game_history.turn == chess.BLACK)
                              or (play_mode == PlayMode.USER_BLACK and game_history.turn == chess.WHITE)):
                        legal_fens = []
                        if interaction_mode == Mode.NORMAL:
                            searchmoves.reset()
                            if check_game_state(game, play_mode):
                                think(game, time_control)
                    else:
                        legal_fens = compute_legal_fens(game)

                    if interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ, Mode.PONDER):
                        analyse(game)
                    elif interaction_mode in (Mode.OBSERVE, Mode.REMOTE):
                        observe(game)
                    start_clock()
                    DisplayMsg.show(Message.USER_TAKE_BACK())
                    break

    def set_wait_state(start_search=True):
        if interaction_mode == Mode.NORMAL:
            nonlocal play_mode
            play_mode = PlayMode.USER_WHITE if game.turn == chess.WHITE else PlayMode.USER_BLACK
        if start_search:
            # Go back to analysing or observing
            if interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ, Mode.PONDER):
                analyse(game)
            if interaction_mode in (Mode.OBSERVE, Mode.REMOTE):
                observe(game)

    def handle_move(move, ponder=None, inbook=False):
        nonlocal game
        nonlocal last_computer_fen
        nonlocal searchmoves
        fen = game.fen()
        turn = game.turn

        # clock must be stoped BEFORE the "book_move" event cause SetNRun resets the clock display
        stop_search_and_clock()

        # engine or remote move
        if (interaction_mode == Mode.NORMAL or interaction_mode == Mode.REMOTE) and \
                ((play_mode == PlayMode.USER_WHITE and game.turn == chess.BLACK) or
                     (play_mode == PlayMode.USER_BLACK and game.turn == chess.WHITE)):
            last_computer_fen = game.board_fen()
            game.push(move)
            if inbook:
                DisplayMsg.show(Message.BOOK_MOVE())
            searchmoves.add(move)
            text = Message.COMPUTER_MOVE(move=move, ponder=ponder, fen=fen, turn=turn, game=game.copy(),
                                         time_control=time_control, wait=inbook)
            DisplayMsg.show(text)
        else:
            last_computer_fen = None
            game.push(move)
            if inbook:
                DisplayMsg.show(Message.BOOK_MOVE())
            searchmoves.reset()
            if interaction_mode == Mode.NORMAL:
                if check_game_state(game, play_mode):
                    think(game, time_control)
                text = Message.USER_MOVE(move=move, fen=fen, turn=turn, game=game.copy())
            elif interaction_mode == Mode.REMOTE:
                if check_game_state(game, play_mode):
                    observe(game)
                text = Message.USER_MOVE(move=move, fen=fen, turn=turn, game=game.copy())
            elif interaction_mode == Mode.OBSERVE:
                if check_game_state(game, play_mode):
                    observe(game)
                text = Message.REVIEW_MOVE(move=move, fen=fen, turn=turn, game=game.copy(), mode=interaction_mode)
            else:  # interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ):
                if check_game_state(game, play_mode):
                    analyse(game)
                text = Message.REVIEW_MOVE(move=move, fen=fen, turn=turn, game=game.copy(), mode=interaction_mode)
            DisplayMsg.show(text)

    def transfer_time(time_list):
        def num(ts):
            try:
                return int(ts)
            except ValueError:
                return 1

        if len(time_list) == 1:
            secs = num(time_list[0])
            time_control = TimeControl(TimeMode.FIXED, seconds_per_move=secs)
            text = dgttranslate.text('B00_tc_fixed', '{:2d}'.format(secs))
        elif len(time_list) == 2:
            mins = num(time_list[0])
            finc = num(time_list[1])
            if finc == 0:
                time_control = TimeControl(TimeMode.BLITZ, minutes_per_game=mins)
                text = dgttranslate.text('B00_tc_blitz', '{:2d}'.format(mins))
            else:
                time_control = TimeControl(TimeMode.FISCHER, minutes_per_game=mins, fischer_increment=finc)
                text = dgttranslate.text('B00_tc_fisch', '{:2d} {:2d}'.format(mins, finc))
        else:
            time_control = TimeControl(TimeMode.BLITZ, minutes_per_game=5)
            text = dgttranslate.text('B00_tc_blitz', ' 5')
        return time_control, text

    def get_engine_level_dict(engine_level):
        from engine import get_installed_engines

        installed_engines = get_installed_engines(engine.get_shell(), engine.get_file())
        for index in range(0, len(installed_engines)):
            eng = installed_engines[index]
            if eng['file'] == engine.get_file():
                level_list = sorted(eng['level_dict'])
                try:
                    level_index = level_list.index(engine_level)
                    return eng['level_dict'][level_list[level_index]]
                except ValueError:
                    break
        return {}

    # Enable garbage collection - needed for engine swapping as objects orphaned
    gc.enable()

    # Command line argument parsing
    parser = configargparse.ArgParser(default_config_files=[os.path.join(os.path.dirname(__file__), 'picochess.ini')])
    parser.add_argument('-e', '--engine', type=str, help='UCI engine executable path', default=None)
    parser.add_argument('-el', '--engine-level', type=str, help='UCI engine level', default=None)
    parser.add_argument('-d', '--dgt-port', type=str,
                        help='enable dgt board on the given serial port such as /dev/ttyUSB0')
    parser.add_argument('-b', '--book', type=str, help='full path of book such as books/b-flank.bin',
                        default='h-varied.bin')
    parser.add_argument('-t', '--time', type=str, default='5 0',
                        help="Time settings <FixSec> or <StMin IncSec> like '10'(move) or '5 0'(game) '3 2'(fischer)")
    parser.add_argument('-g', '--enable-gaviota', action='store_true', help='enable gavoita tablebase probing')
    parser.add_argument('-leds', '--enable-revelation-leds', action='store_true', help='enable Revelation leds')
    parser.add_argument('-l', '--log-level', choices=['notset', 'debug', 'info', 'warning', 'error', 'critical'],
                        default='warning', help='logging level')
    parser.add_argument('-lf', '--log-file', type=str, help='log to the given file')
    parser.add_argument('-rs', '--remote-server', type=str, help='remote server running the engine')
    parser.add_argument('-ru', '--remote-user', type=str, help='remote user on server running the engine')
    parser.add_argument('-rp', '--remote-pass', type=str, help='password for the remote user')
    parser.add_argument('-rk', '--remote-key', type=str, help='key file used to connect to the remote server')
    parser.add_argument('-pf', '--pgn-file', type=str, help='pgn file used to store the games', default='games.pgn')
    parser.add_argument('-pu', '--pgn-user', type=str, help='user name for the pgn file', default=None)
    parser.add_argument('-ar', '--auto-reboot', action='store_true', help='reboot system after update')
    parser.add_argument('-web', '--web-server', dest='web_server_port', nargs='?', const=80, type=int, metavar='PORT',
                        help='launch web server')
    parser.add_argument('-m', '--email', type=str, help='email used to send pgn files', default=None)
    parser.add_argument('-ms', '--smtp-server', type=str, help='adress of email server', default=None)
    parser.add_argument('-mu', '--smtp-user', type=str, help='username for email server', default=None)
    parser.add_argument('-mp', '--smtp-pass', type=str, help='password for email server', default=None)
    parser.add_argument('-me', '--smtp-encryption', action='store_true',
                        help='use ssl encryption connection to smtp-Server')
    parser.add_argument('-mf', '--smtp-from', type=str, help='From email', default='[email protected]')
    parser.add_argument('-mk', '--mailgun-key', type=str, help='key used to send emails via Mailgun Webservice',
                        default=None)
    parser.add_argument('-bc', '--beep-config', choices=['none', 'some', 'all'], help='sets standard beep config',
                        default='some')
    parser.add_argument('-beep', '--beep-level', type=int, default=0x03,
                        help='sets (some-)beep level from 0(=no beeps) to 15(=all beeps)')
    parser.add_argument('-uvoice', '--user-voice', type=str, help='voice for user', default=None)
    parser.add_argument('-cvoice', '--computer-voice', type=str, help='voice for computer', default=None)
    parser.add_argument('-inet', '--enable-internet', action='store_true', help='enable internet lookups')
    parser.add_argument('-nook', '--disable-ok-message', action='store_true', help='disable ok confirmation messages')
    parser.add_argument('-v', '--version', action='version', version='%(prog)s version {}'.format(version),
                        help='show current version', default=None)
    parser.add_argument('-pi', '--dgtpi', action='store_true', help='use the dgtpi hardware')
    parser.add_argument('-lang', '--language', choices=['en', 'de', 'nl', 'fr', 'es'], default='en',
                        help='picochess language')
    parser.add_argument('-c', '--console', action='store_true', help='use console interface')

    args = parser.parse_args()
    if args.engine is None:
        el = read_engine_ini()
        args.engine = el[0]['file']  # read the first engine filename and use it as standard

    # Enable logging
    if args.log_file:
        handler = RotatingFileHandler('logs' + os.sep + args.log_file, maxBytes=1024*1024, backupCount=9)
        logging.basicConfig(level=getattr(logging, args.log_level.upper()),
                            format='%(asctime)s.%(msecs)03d %(levelname)5s %(module)10s - %(funcName)s: %(message)s',
                            datefmt="%Y-%m-%d %H:%M:%S", handlers=[handler])
    logging.getLogger('chess.uci').setLevel(logging.INFO)  # don't want to get so many python-chess uci messages

    logging.debug('#'*20 + ' PicoChess v' + version + ' ' + '#'*20)
    # log the startup parameters but hide the password fields
    p = copy.copy(vars(args))
    p['mailgun_key'] = p['remote_key'] = p['remote_pass'] = p['smtp_pass'] = '*****'
    logging.debug('startup parameters: {}'.format(p))

    # Update
    if args.enable_internet:
        update_picochess(args.auto_reboot)

    gaviota = None
    if args.enable_gaviota:
        try:
            gaviota = chess.gaviota.open_tablebases('tablebases/gaviota')
            logging.debug('Tablebases gaviota loaded')
        except OSError:
            logging.error('Tablebases gaviota doesnt exist')
            gaviota = None

    # The class dgtDisplay talks to DgtHw/DgtPi or DgtVr
    dgttranslate = DgtTranslate(args.beep_config, args.beep_level, args.language)
    DgtDisplay(args.disable_ok_message, dgttranslate).start()

    # Launch web server
    if args.web_server_port:
        WebServer(args.web_server_port).start()

    dgtserial = DgtSerial(args.dgt_port, args.enable_revelation_leds, args.dgtpi)

    if args.console:
        # Enable keyboard input and terminal display
        logging.debug('starting picochess in virtual mode')
        KeyboardInput(dgttranslate, args.dgtpi).start()
        TerminalDisplay().start()
        DgtVr(dgtserial, dgttranslate).start()
    else:
        # Connect to DGT board
        logging.debug('starting picochess in board mode')
        if args.dgtpi:
            DgtPi(dgtserial, dgttranslate).start()
        DgtHw(dgtserial, dgttranslate).start()
    # Save to PGN
    emailer = Emailer(
        net=args.enable_internet, email=args.email, mailgun_key=args.mailgun_key,
        smtp_server=args.smtp_server, smtp_user=args.smtp_user,
        smtp_pass=args.smtp_pass, smtp_encryption=args.smtp_encryption, smtp_from=args.smtp_from)

    PgnDisplay(args.pgn_file, emailer).start()
    if args.pgn_user:
        user_name = args.pgn_user
    else:
        if args.email:
            user_name = args.email.split('@')[0]
        else:
            user_name = 'Player'

    # Create PicoTalker for speech output
    if args.user_voice or args.computer_voice:
        from talker.picotalker import PicoTalkerDisplay
        logging.debug("initializing PicoTalker [%s, %s]", str(args.user_voice), str(args.computer_voice))
        PicoTalkerDisplay(args.user_voice, args.computer_voice).start()
    else:
        logging.debug('PicoTalker disabled')

    # Gentlemen, start your engines...
    engine = UciEngine(args.engine, hostname=args.remote_server, username=args.remote_user,
                       key_file=args.remote_key, password=args.remote_pass)
    try:
        engine_name = engine.get().name
    except AttributeError:
        logging.error('no engines started')
        sys.exit(-1)

    # Startup - internal
    game = chess.Board()  # Create the current game
    legal_fens = compute_legal_fens(game)  # Compute the legal FENs
    all_books = get_opening_books()
    try:
        book_index = [book['file'] for book in all_books].index(args.book)
    except ValueError:
        logging.warning("selected book not present, defaulting to %s", all_books[7]['file'])
        book_index = 7
    bookreader = chess.polyglot.open_reader(all_books[book_index]['file'])
    searchmoves = AlternativeMover()
    interaction_mode = Mode.NORMAL
    play_mode = PlayMode.USER_WHITE

    last_computer_fen = None
    last_legal_fens = []
    game_declared = False  # User declared resignation or draw

    engine.startup(get_engine_level_dict(args.engine_level))

    # Startup - external
    time_control, time_text = transfer_time(args.time.split())
    time_text.beep = False
    if args.engine_level:
        level_text = dgttranslate.text('B00_level', args.engine_level)
        level_text.beep = False
    else:
        level_text = None
    DisplayMsg.show(Message.STARTUP_INFO(info={'interaction_mode': interaction_mode, 'play_mode': play_mode,
                                               'books': all_books, 'book_index': book_index, 'level_text': level_text,
                                               'time_control': time_control, 'time_text': time_text}))
    DisplayMsg.show(Message.ENGINE_STARTUP(shell=engine.get_shell(), file=engine.get_file(),
                                           has_levels=engine.has_levels(), has_960=engine.has_chess960()))

    system_info_thread = threading.Timer(0, display_system_info)
    system_info_thread.start()

    # Event loop
    logging.info('evt_queue ready')
    while True:
        try:
            event = evt_queue.get()
        except queue.Empty:
            pass
        else:
            logging.debug('received event from evt_queue: %s', event)
            for case in switch(event):
                if case(EventApi.FEN):
                    process_fen(event.fen)
                    break

                if case(EventApi.KEYBOARD_MOVE):
                    move = event.move
                    logging.debug('keyboard move [%s]', move)
                    if move not in game.legal_moves:
                        logging.warning('illegal move [%s]', move)
                    else:
                        g = copy.deepcopy(game)
                        g.push(move)
                        fen = g.fen().split(' ')[0]
                        if event.flip_board:
                            fen = fen[::-1]
                        DisplayMsg.show(Message.KEYBOARD_MOVE(fen=fen))
                    break

                if case(EventApi.LEVEL):
                    if event.options:
                        engine.startup(event.options, False)
                    DisplayMsg.show(Message.LEVEL(level_text=event.level_text))
                    break

                if case(EventApi.NEW_ENGINE):
                    config = ConfigObj('picochess.ini')
                    config['engine'] = event.eng['file']
                    config.write()
                    old_file = engine.get_file()
                    engine_shutdown = True
                    # Stop the old engine cleanly
                    engine.stop()
                    # Closeout the engine process and threads
                    # The all return non-zero error codes, 0=success
                    if engine.quit():  # Ask nicely
                        if engine.terminate():  # If you won't go nicely.... 
                            if engine.kill():  # Right that does it!
                                logging.error('engine shutdown failure')
                                DisplayMsg.show(Message.ENGINE_FAIL())
                                engine_shutdown = False
                    if engine_shutdown:
                        # Load the new one and send args.
                        # Local engines only
                        engine_fallback = False
                        engine = UciEngine(event.eng['file'])
                        try:
                            engine_name = engine.get().name
                        except AttributeError:
                            # New engine failed to start, restart old engine
                            logging.error("new engine failed to start, reverting to %s", old_file)
                            engine_fallback = True
                            event.options = {}  # Reset options. This will load the last(=strongest?) level
                            engine = UciEngine(old_file)
                            try:
                                engine_name = engine.get().name
                            except AttributeError:
                                # Help - old engine failed to restart. There is no engine
                                logging.error('no engines started')
                                sys.exit(-1)
                        # Schedule cleanup of old objects
                        gc.collect()
                        engine.startup(event.options)
                        # All done - rock'n'roll
                        if not engine_fallback:
                            DisplayMsg.show(Message.ENGINE_READY(eng=event.eng, engine_name=engine_name,
                                                                 eng_text=event.eng_text,
                                                                 has_levels=engine.has_levels(),
                                                                 has_960=engine.has_chess960(), ok_text=event.ok_text))
                        else:
                            DisplayMsg.show(Message.ENGINE_FAIL())
                        set_wait_state(not engine_fallback)
                    break

                if case(EventApi.SETUP_POSITION):
                    logging.debug("setting up custom fen: {}".format(event.fen))
                    uci960 = event.uci960

                    if game.move_stack:
                        if not (game.is_game_over() or game_declared):
                            DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    game = chess.Board(event.fen, uci960)
                    # see new_game
                    stop_search_and_clock()
                    if engine.has_chess960():
                        engine.option('UCI_Chess960', uci960)
                        engine.send()
                    legal_fens = compute_legal_fens(game)
                    last_legal_fens = []
                    last_computer_fen = None
                    time_control.reset()
                    searchmoves.reset()
                    DisplayMsg.show(Message.START_NEW_GAME(time_control=time_control, game=game.copy()))
                    game_declared = False
                    set_wait_state()
                    break

                if case(EventApi.NEW_GAME):
                    logging.debug('starting a new game with code: {}'.format(event.pos960))
                    uci960 = event.pos960 != 518

                    if game.move_stack:
                        if not (game.is_game_over() or game_declared):
                            DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    game = chess.Board()
                    if uci960:
                        game.set_chess960_pos(event.pos960)
                    # see setup_position
                    stop_search_and_clock()
                    if engine.has_chess960():
                        engine.option('UCI_Chess960', uci960)
                        engine.send()
                    legal_fens = compute_legal_fens(game)
                    last_legal_fens = []
                    last_computer_fen = None
                    time_control.reset()
                    searchmoves.reset()
                    DisplayMsg.show(Message.START_NEW_GAME(time_control=time_control, game=game.copy()))
                    game_declared = False
                    set_wait_state()
                    break

                if case(EventApi.PAUSE_RESUME):
                    if engine.is_thinking():
                        stop_clock()
                        engine.stop(show_best=True)
                    else:
                        if time_control.is_ticking():
                            stop_clock()
                        else:
                            start_clock()
                    break

                if case(EventApi.ALTERNATIVE_MOVE):
                    if last_computer_fen:
                        last_computer_fen = None
                        game.pop()
                        DisplayMsg.show(Message.ALTERNATIVE_MOVE())
                        think(game, time_control)
                    break

                if case(EventApi.SWITCH_SIDES):
                    if interaction_mode == Mode.NORMAL:
                        user_to_move = False
                        last_legal_fens = []

                        if engine.is_thinking():
                            stop_clock()
                            engine.stop(show_best=False)
                            user_to_move = True
                        if event.engine_finished:
                            last_computer_fen = None
                            move = game.pop()
                            user_to_move = True
                        else:
                            move = chess.Move.null()
                        if user_to_move:
                            last_legal_fens = []
                            play_mode = PlayMode.USER_WHITE if game.turn == chess.WHITE else PlayMode.USER_BLACK
                        else:
                            play_mode = PlayMode.USER_WHITE if game.turn == chess.BLACK else PlayMode.USER_BLACK

                        if not user_to_move and check_game_state(game, play_mode):
                            time_control.reset_start_time()
                            think(game, time_control)
                            legal_fens = []
                        else:
                            start_clock()
                            legal_fens = compute_legal_fens(game)

                        text = dgttranslate.text(play_mode.value)
                        DisplayMsg.show(Message.PLAY_MODE(play_mode=play_mode, play_mode_text=text))

                        if event.engine_finished:
                            DisplayMsg.show(Message.SWITCH_SIDES(move=move))
                    break

                if case(EventApi.DRAWRESIGN):
                    if not game_declared:  # in case user leaves kings in place while moving other pieces
                        stop_search_and_clock()
                        DisplayMsg.show(Message.GAME_ENDS(result=event.result, play_mode=play_mode, game=game.copy()))
                        game_declared = True
                    break

                if case(EventApi.REMOTE_MOVE):
                    if interaction_mode == Mode.REMOTE:
                        handle_move(move=chess.Move.from_uci(event.move))
                        legal_fens = compute_legal_fens(game)
                    break

                if case(EventApi.BEST_MOVE):
                    handle_move(move=event.result.bestmove, ponder=event.result.ponder, inbook=event.inbook)
                    break

                if case(EventApi.NEW_PV):
                    # illegal moves can occur if a pv from the engine arrives at the same time as a user move.
                    if game.is_legal(event.pv[0]):
                        DisplayMsg.show(Message.NEW_PV(pv=event.pv, mode=interaction_mode, fen=game.fen(), turn=game.turn))
                    else:
                        logging.info('illegal move can not be displayed. move:%s fen=%s', event.pv[0], game.fen())
                    break

                if case(EventApi.NEW_SCORE):
                    DisplayMsg.show(Message.NEW_SCORE(score=event.score, mate=event.mate, mode=interaction_mode, turn=game.turn))
                    break

                if case(EventApi.NEW_DEPTH):
                    DisplayMsg.show(Message.NEW_DEPTH(depth=event.depth))
                    break

                if case(EventApi.SET_INTERACTION_MODE):
                    if interaction_mode in (Mode.NORMAL, Mode.OBSERVE, Mode.REMOTE):
                        stop_clock()
                    interaction_mode = event.mode
                    if engine.is_thinking():
                        stop_search()
                    if engine.is_pondering():
                        stop_search()
                    set_wait_state()
                    DisplayMsg.show(Message.INTERACTION_MODE(mode=event.mode, mode_text=event.mode_text, ok_text=event.ok_text))
                    break

                if case(EventApi.SET_OPENING_BOOK):
                    config = ConfigObj('picochess.ini')
                    config['book'] = event.book['file']
                    config.write()
                    logging.debug("changing opening book [%s]", event.book['file'])
                    bookreader = chess.polyglot.open_reader(event.book['file'])
                    DisplayMsg.show(Message.OPENING_BOOK(book_text=event.book_text, ok_text=event.ok_text))
                    break

                if case(EventApi.SET_TIME_CONTROL):
                    time_control = event.time_control
                    config = ConfigObj('picochess.ini')
                    if time_control.mode == TimeMode.BLITZ:
                        config['time'] = '{:d} 0'.format(time_control.minutes_per_game)
                    elif time_control.mode == TimeMode.FISCHER:
                        config['time'] = '{:d} {:d}'.format(time_control.minutes_per_game, time_control.fischer_increment)
                    elif time_control.mode == TimeMode.FIXED:
                        config['time'] = '{:d}'.format(time_control.seconds_per_move)
                    config.write()
                    DisplayMsg.show(Message.TIME_CONTROL(time_text=event.time_text, ok_text=event.ok_text))
                    break

                if case(EventApi.OUT_OF_TIME):
                    stop_search_and_clock()
                    DisplayMsg.show(Message.GAME_ENDS(result=GameResult.OUT_OF_TIME, play_mode=play_mode, game=game.copy()))
                    break

                if case(EventApi.SHUTDOWN):
                    DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    shutdown(args.dgtpi)
                    break

                if case(EventApi.REBOOT):
                    DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    reboot()
                    break

                if case(EventApi.EMAIL_LOG):
                    if args.log_file:
                        email_logger = Emailer(net=args.enable_internet, email=args.email, mailgun_key=args.mailgun_key,
                                               smtp_server=args.smtp_server, smtp_user=args.smtp_user,
                                               smtp_pass=args.smtp_pass, smtp_encryption=args.smtp_encryption,
                                               smtp_from=args.smtp_from)
                        body = 'You probably want to forward this file to a picochess developer ;-)'
                        email_logger.send('Picochess LOG', body, '/opt/picochess/logs/{}'.format(args.log_file))
                    break

                if case(EventApi.DGT_BUTTON):
                    DisplayMsg.show(Message.DGT_BUTTON(button=event.button))
                    break

                if case(EventApi.DGT_FEN):
                    DisplayMsg.show(Message.DGT_FEN(fen=event.fen))
                    break

                if case():  # Default
                    logging.warning("event not handled : [%s]", event)

            evt_queue.task_done()

Example 27

Project: picochess
Source File: picochess.py
View license
def main():

    def display_system_info():
        if args.enable_internet:
            place = get_location()
            addr = get_ip()
        else:
            place = '?'
            addr = None
        DisplayMsg.show(Message.SYSTEM_INFO(info={'version': version, 'location': place, 'ip': addr,
                                                  'engine_name': engine_name, 'user_name': user_name
                                                  }))

    def compute_legal_fens(g):
        """
        Compute a list of legal FENs for the given game.
        :param g: The game
        :return: A list of legal FENs
        """
        fens = []
        for move in g.legal_moves:
            g.push(move)
            fens.append(g.board_fen())
            g.pop()
        return fens

    def probe_tablebase(game):
        if not gaviota:
            return None
        score = gaviota.probe_dtm(game)
        if score is not None:
            Observable.fire(Event.NEW_SCORE(score='gaviota', mate=score))
        return score

    def think(game, tc):
        """
        Start a new search on the current game.
        If a move is found in the opening book, fire an event in a few seconds.
        :return:
        """
        start_clock()
        book_move = searchmoves.book(bookreader, game)
        if book_move:
            Observable.fire(Event.BEST_MOVE(result=book_move, inbook=True))
        else:
            probe_tablebase(game)
            while not engine.is_waiting():
                time.sleep(0.1)
                logging.warning('engine is still not waiting')
            engine.position(copy.deepcopy(game))
            uci_dict = tc.uci()
            uci_dict['searchmoves'] = searchmoves.all(game)
            engine.go(uci_dict)

    def analyse(game):
        """
        Start a new ponder search on the current game.
        :return:
        """
        probe_tablebase(game)
        engine.position(copy.deepcopy(game))
        engine.ponder()

    def observe(game):
        """
        Starts a new ponder search on the current game.
        :return:
        """
        start_clock()
        analyse(game)

    def stop_search_and_clock():
        if interaction_mode == Mode.NORMAL:
            stop_clock()
        elif interaction_mode in (Mode.REMOTE, Mode.OBSERVE):
            stop_clock()
            stop_search()
        elif interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ, Mode.PONDER):
            stop_search()

    def stop_search():
        """
        Stop current search.
        :return:
        """
        engine.stop()

    def stop_clock():
        if interaction_mode in (Mode.NORMAL, Mode.OBSERVE, Mode.REMOTE):
            time_control.stop()
            DisplayMsg.show(Message.CLOCK_STOP())
        else:
            logging.warning('wrong mode: {}'.format(interaction_mode))

    def start_clock():
        if interaction_mode in (Mode.NORMAL, Mode.OBSERVE, Mode.REMOTE):
            time_control.start(game.turn)
            DisplayMsg.show(Message.CLOCK_START(turn=game.turn, time_control=time_control))
        else:
            logging.warning('wrong mode: {}'.format(interaction_mode))

    def check_game_state(game, play_mode):
        """
        Check if the game has ended or not ; it also sends Message to Displays if the game has ended.
        :param game:
        :param play_mode:
        :return: True is the game continues, False if it has ended
        """
        result = None
        if game.is_stalemate():
            result = GameResult.STALEMATE
        if game.is_insufficient_material():
            result = GameResult.INSUFFICIENT_MATERIAL
        if game.is_seventyfive_moves():
            result = GameResult.SEVENTYFIVE_MOVES
        if game.is_fivefold_repetition():
            result = GameResult.FIVEFOLD_REPETITION
        if game.is_checkmate():
            result = GameResult.MATE

        if result is None:
            return True
        else:
            DisplayMsg.show(Message.GAME_ENDS(result=result, play_mode=play_mode, game=game.copy()))
            return False

    def user_move(move):
        logging.debug('user move [%s]', move)
        if move not in game.legal_moves:
            logging.warning('Illegal move [%s]', move)
        else:
            handle_move(move=move)

    def process_fen(fen):
        nonlocal last_computer_fen
        nonlocal last_legal_fens
        nonlocal searchmoves
        nonlocal legal_fens

        # Check for same position
        if (fen == game.board_fen() and not last_computer_fen) or fen == last_computer_fen:
            logging.debug('Already in this fen: ' + fen)

        # Check if we have to undo a previous move (sliding)
        elif fen in last_legal_fens:
            if interaction_mode == Mode.NORMAL:
                if (play_mode == PlayMode.USER_WHITE and game.turn == chess.BLACK) or \
                        (play_mode == PlayMode.USER_BLACK and game.turn == chess.WHITE):
                    stop_search()
                    game.pop()
                    logging.debug('User move in computer turn, reverting to: ' + game.board_fen())
                elif last_computer_fen:
                    last_computer_fen = None
                    game.pop()
                    game.pop()
                    logging.debug('User move while computer move is displayed, reverting to: ' + game.board_fen())
                else:
                    logging.error("last_legal_fens not cleared: " + game.board_fen())
            elif interaction_mode == Mode.REMOTE:
                if (play_mode == PlayMode.USER_WHITE and game.turn == chess.BLACK) or \
                        (play_mode == PlayMode.USER_BLACK and game.turn == chess.WHITE):
                    game.pop()
                    logging.debug('User move in remote turn, reverting to: ' + game.board_fen())
                elif last_computer_fen:
                    last_computer_fen = None
                    game.pop()
                    game.pop()
                    logging.debug('User move while remote move is displayed, reverting to: ' + game.board_fen())
                else:
                    logging.error('last_legal_fens not cleared: ' + game.board_fen())
            else:
                game.pop()
                logging.debug('Wrong color move -> sliding, reverting to: ' + game.board_fen())
            legal_moves = list(game.legal_moves)
            user_move(legal_moves[last_legal_fens.index(fen)])
            if interaction_mode == Mode.NORMAL or interaction_mode == Mode.REMOTE:
                legal_fens = []
            else:
                legal_fens = compute_legal_fens(game)

        # legal move
        elif fen in legal_fens:
            time_control.add_inc(game.turn)
            legal_moves = list(game.legal_moves)
            user_move(legal_moves[legal_fens.index(fen)])
            last_legal_fens = legal_fens
            if interaction_mode == Mode.NORMAL or interaction_mode == Mode.REMOTE:
                legal_fens = []
            else:
                legal_fens = compute_legal_fens(game)

        # Player had done the computer or remote move on the board
        elif last_computer_fen and fen == game.board_fen():
            last_computer_fen = None
            if check_game_state(game, play_mode) and interaction_mode in (Mode.NORMAL, Mode.REMOTE):
                # finally reset all alternative moves see: handle_move()
                nonlocal searchmoves
                searchmoves.reset()
                time_control.add_inc(not game.turn)
                if time_control.mode != TimeMode.FIXED:
                    start_clock()
                DisplayMsg.show(Message.COMPUTER_MOVE_DONE_ON_BOARD())
                legal_fens = compute_legal_fens(game)
            else:
                legal_fens = []
            last_legal_fens = []

        # Check if this is a previous legal position and allow user to restart from this position
        else:
            game_history = copy.deepcopy(game)
            if last_computer_fen:
                game_history.pop()
            while game_history.move_stack:
                game_history.pop()
                if game_history.board_fen() == fen:
                    logging.debug("Current game FEN      : {}".format(game.fen()))
                    logging.debug("Undoing game until FEN: {}".format(fen))
                    stop_search_and_clock()
                    while len(game_history.move_stack) < len(game.move_stack):
                        game.pop()
                    last_computer_fen = None
                    last_legal_fens = []
                    if (interaction_mode == Mode.REMOTE or interaction_mode == Mode.NORMAL) and \
                            ((play_mode == PlayMode.USER_WHITE and game_history.turn == chess.BLACK)
                              or (play_mode == PlayMode.USER_BLACK and game_history.turn == chess.WHITE)):
                        legal_fens = []
                        if interaction_mode == Mode.NORMAL:
                            searchmoves.reset()
                            if check_game_state(game, play_mode):
                                think(game, time_control)
                    else:
                        legal_fens = compute_legal_fens(game)

                    if interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ, Mode.PONDER):
                        analyse(game)
                    elif interaction_mode in (Mode.OBSERVE, Mode.REMOTE):
                        observe(game)
                    start_clock()
                    DisplayMsg.show(Message.USER_TAKE_BACK())
                    break

    def set_wait_state(start_search=True):
        if interaction_mode == Mode.NORMAL:
            nonlocal play_mode
            play_mode = PlayMode.USER_WHITE if game.turn == chess.WHITE else PlayMode.USER_BLACK
        if start_search:
            # Go back to analysing or observing
            if interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ, Mode.PONDER):
                analyse(game)
            if interaction_mode in (Mode.OBSERVE, Mode.REMOTE):
                observe(game)

    def handle_move(move, ponder=None, inbook=False):
        nonlocal game
        nonlocal last_computer_fen
        nonlocal searchmoves
        fen = game.fen()
        turn = game.turn

        # clock must be stoped BEFORE the "book_move" event cause SetNRun resets the clock display
        stop_search_and_clock()

        # engine or remote move
        if (interaction_mode == Mode.NORMAL or interaction_mode == Mode.REMOTE) and \
                ((play_mode == PlayMode.USER_WHITE and game.turn == chess.BLACK) or
                     (play_mode == PlayMode.USER_BLACK and game.turn == chess.WHITE)):
            last_computer_fen = game.board_fen()
            game.push(move)
            if inbook:
                DisplayMsg.show(Message.BOOK_MOVE())
            searchmoves.add(move)
            text = Message.COMPUTER_MOVE(move=move, ponder=ponder, fen=fen, turn=turn, game=game.copy(),
                                         time_control=time_control, wait=inbook)
            DisplayMsg.show(text)
        else:
            last_computer_fen = None
            game.push(move)
            if inbook:
                DisplayMsg.show(Message.BOOK_MOVE())
            searchmoves.reset()
            if interaction_mode == Mode.NORMAL:
                if check_game_state(game, play_mode):
                    think(game, time_control)
                text = Message.USER_MOVE(move=move, fen=fen, turn=turn, game=game.copy())
            elif interaction_mode == Mode.REMOTE:
                if check_game_state(game, play_mode):
                    observe(game)
                text = Message.USER_MOVE(move=move, fen=fen, turn=turn, game=game.copy())
            elif interaction_mode == Mode.OBSERVE:
                if check_game_state(game, play_mode):
                    observe(game)
                text = Message.REVIEW_MOVE(move=move, fen=fen, turn=turn, game=game.copy(), mode=interaction_mode)
            else:  # interaction_mode in (Mode.ANALYSIS, Mode.KIBITZ):
                if check_game_state(game, play_mode):
                    analyse(game)
                text = Message.REVIEW_MOVE(move=move, fen=fen, turn=turn, game=game.copy(), mode=interaction_mode)
            DisplayMsg.show(text)

    def transfer_time(time_list):
        def num(ts):
            try:
                return int(ts)
            except ValueError:
                return 1

        if len(time_list) == 1:
            secs = num(time_list[0])
            time_control = TimeControl(TimeMode.FIXED, seconds_per_move=secs)
            text = dgttranslate.text('B00_tc_fixed', '{:2d}'.format(secs))
        elif len(time_list) == 2:
            mins = num(time_list[0])
            finc = num(time_list[1])
            if finc == 0:
                time_control = TimeControl(TimeMode.BLITZ, minutes_per_game=mins)
                text = dgttranslate.text('B00_tc_blitz', '{:2d}'.format(mins))
            else:
                time_control = TimeControl(TimeMode.FISCHER, minutes_per_game=mins, fischer_increment=finc)
                text = dgttranslate.text('B00_tc_fisch', '{:2d} {:2d}'.format(mins, finc))
        else:
            time_control = TimeControl(TimeMode.BLITZ, minutes_per_game=5)
            text = dgttranslate.text('B00_tc_blitz', ' 5')
        return time_control, text

    def get_engine_level_dict(engine_level):
        from engine import get_installed_engines

        installed_engines = get_installed_engines(engine.get_shell(), engine.get_file())
        for index in range(0, len(installed_engines)):
            eng = installed_engines[index]
            if eng['file'] == engine.get_file():
                level_list = sorted(eng['level_dict'])
                try:
                    level_index = level_list.index(engine_level)
                    return eng['level_dict'][level_list[level_index]]
                except ValueError:
                    break
        return {}

    # Enable garbage collection - needed for engine swapping as objects orphaned
    gc.enable()

    # Command line argument parsing
    parser = configargparse.ArgParser(default_config_files=[os.path.join(os.path.dirname(__file__), 'picochess.ini')])
    parser.add_argument('-e', '--engine', type=str, help='UCI engine executable path', default=None)
    parser.add_argument('-el', '--engine-level', type=str, help='UCI engine level', default=None)
    parser.add_argument('-d', '--dgt-port', type=str,
                        help='enable dgt board on the given serial port such as /dev/ttyUSB0')
    parser.add_argument('-b', '--book', type=str, help='full path of book such as books/b-flank.bin',
                        default='h-varied.bin')
    parser.add_argument('-t', '--time', type=str, default='5 0',
                        help="Time settings <FixSec> or <StMin IncSec> like '10'(move) or '5 0'(game) '3 2'(fischer)")
    parser.add_argument('-g', '--enable-gaviota', action='store_true', help='enable gavoita tablebase probing')
    parser.add_argument('-leds', '--enable-revelation-leds', action='store_true', help='enable Revelation leds')
    parser.add_argument('-l', '--log-level', choices=['notset', 'debug', 'info', 'warning', 'error', 'critical'],
                        default='warning', help='logging level')
    parser.add_argument('-lf', '--log-file', type=str, help='log to the given file')
    parser.add_argument('-rs', '--remote-server', type=str, help='remote server running the engine')
    parser.add_argument('-ru', '--remote-user', type=str, help='remote user on server running the engine')
    parser.add_argument('-rp', '--remote-pass', type=str, help='password for the remote user')
    parser.add_argument('-rk', '--remote-key', type=str, help='key file used to connect to the remote server')
    parser.add_argument('-pf', '--pgn-file', type=str, help='pgn file used to store the games', default='games.pgn')
    parser.add_argument('-pu', '--pgn-user', type=str, help='user name for the pgn file', default=None)
    parser.add_argument('-ar', '--auto-reboot', action='store_true', help='reboot system after update')
    parser.add_argument('-web', '--web-server', dest='web_server_port', nargs='?', const=80, type=int, metavar='PORT',
                        help='launch web server')
    parser.add_argument('-m', '--email', type=str, help='email used to send pgn files', default=None)
    parser.add_argument('-ms', '--smtp-server', type=str, help='adress of email server', default=None)
    parser.add_argument('-mu', '--smtp-user', type=str, help='username for email server', default=None)
    parser.add_argument('-mp', '--smtp-pass', type=str, help='password for email server', default=None)
    parser.add_argument('-me', '--smtp-encryption', action='store_true',
                        help='use ssl encryption connection to smtp-Server')
    parser.add_argument('-mf', '--smtp-from', type=str, help='From email', default='[email protected]')
    parser.add_argument('-mk', '--mailgun-key', type=str, help='key used to send emails via Mailgun Webservice',
                        default=None)
    parser.add_argument('-bc', '--beep-config', choices=['none', 'some', 'all'], help='sets standard beep config',
                        default='some')
    parser.add_argument('-beep', '--beep-level', type=int, default=0x03,
                        help='sets (some-)beep level from 0(=no beeps) to 15(=all beeps)')
    parser.add_argument('-uvoice', '--user-voice', type=str, help='voice for user', default=None)
    parser.add_argument('-cvoice', '--computer-voice', type=str, help='voice for computer', default=None)
    parser.add_argument('-inet', '--enable-internet', action='store_true', help='enable internet lookups')
    parser.add_argument('-nook', '--disable-ok-message', action='store_true', help='disable ok confirmation messages')
    parser.add_argument('-v', '--version', action='version', version='%(prog)s version {}'.format(version),
                        help='show current version', default=None)
    parser.add_argument('-pi', '--dgtpi', action='store_true', help='use the dgtpi hardware')
    parser.add_argument('-lang', '--language', choices=['en', 'de', 'nl', 'fr', 'es'], default='en',
                        help='picochess language')
    parser.add_argument('-c', '--console', action='store_true', help='use console interface')

    args = parser.parse_args()
    if args.engine is None:
        el = read_engine_ini()
        args.engine = el[0]['file']  # read the first engine filename and use it as standard

    # Enable logging
    if args.log_file:
        handler = RotatingFileHandler('logs' + os.sep + args.log_file, maxBytes=1024*1024, backupCount=9)
        logging.basicConfig(level=getattr(logging, args.log_level.upper()),
                            format='%(asctime)s.%(msecs)03d %(levelname)5s %(module)10s - %(funcName)s: %(message)s',
                            datefmt="%Y-%m-%d %H:%M:%S", handlers=[handler])
    logging.getLogger('chess.uci').setLevel(logging.INFO)  # don't want to get so many python-chess uci messages

    logging.debug('#'*20 + ' PicoChess v' + version + ' ' + '#'*20)
    # log the startup parameters but hide the password fields
    p = copy.copy(vars(args))
    p['mailgun_key'] = p['remote_key'] = p['remote_pass'] = p['smtp_pass'] = '*****'
    logging.debug('startup parameters: {}'.format(p))

    # Update
    if args.enable_internet:
        update_picochess(args.auto_reboot)

    gaviota = None
    if args.enable_gaviota:
        try:
            gaviota = chess.gaviota.open_tablebases('tablebases/gaviota')
            logging.debug('Tablebases gaviota loaded')
        except OSError:
            logging.error('Tablebases gaviota doesnt exist')
            gaviota = None

    # The class dgtDisplay talks to DgtHw/DgtPi or DgtVr
    dgttranslate = DgtTranslate(args.beep_config, args.beep_level, args.language)
    DgtDisplay(args.disable_ok_message, dgttranslate).start()

    # Launch web server
    if args.web_server_port:
        WebServer(args.web_server_port).start()

    dgtserial = DgtSerial(args.dgt_port, args.enable_revelation_leds, args.dgtpi)

    if args.console:
        # Enable keyboard input and terminal display
        logging.debug('starting picochess in virtual mode')
        KeyboardInput(dgttranslate, args.dgtpi).start()
        TerminalDisplay().start()
        DgtVr(dgtserial, dgttranslate).start()
    else:
        # Connect to DGT board
        logging.debug('starting picochess in board mode')
        if args.dgtpi:
            DgtPi(dgtserial, dgttranslate).start()
        DgtHw(dgtserial, dgttranslate).start()
    # Save to PGN
    emailer = Emailer(
        net=args.enable_internet, email=args.email, mailgun_key=args.mailgun_key,
        smtp_server=args.smtp_server, smtp_user=args.smtp_user,
        smtp_pass=args.smtp_pass, smtp_encryption=args.smtp_encryption, smtp_from=args.smtp_from)

    PgnDisplay(args.pgn_file, emailer).start()
    if args.pgn_user:
        user_name = args.pgn_user
    else:
        if args.email:
            user_name = args.email.split('@')[0]
        else:
            user_name = 'Player'

    # Create PicoTalker for speech output
    if args.user_voice or args.computer_voice:
        from talker.picotalker import PicoTalkerDisplay
        logging.debug("initializing PicoTalker [%s, %s]", str(args.user_voice), str(args.computer_voice))
        PicoTalkerDisplay(args.user_voice, args.computer_voice).start()
    else:
        logging.debug('PicoTalker disabled')

    # Gentlemen, start your engines...
    engine = UciEngine(args.engine, hostname=args.remote_server, username=args.remote_user,
                       key_file=args.remote_key, password=args.remote_pass)
    try:
        engine_name = engine.get().name
    except AttributeError:
        logging.error('no engines started')
        sys.exit(-1)

    # Startup - internal
    game = chess.Board()  # Create the current game
    legal_fens = compute_legal_fens(game)  # Compute the legal FENs
    all_books = get_opening_books()
    try:
        book_index = [book['file'] for book in all_books].index(args.book)
    except ValueError:
        logging.warning("selected book not present, defaulting to %s", all_books[7]['file'])
        book_index = 7
    bookreader = chess.polyglot.open_reader(all_books[book_index]['file'])
    searchmoves = AlternativeMover()
    interaction_mode = Mode.NORMAL
    play_mode = PlayMode.USER_WHITE

    last_computer_fen = None
    last_legal_fens = []
    game_declared = False  # User declared resignation or draw

    engine.startup(get_engine_level_dict(args.engine_level))

    # Startup - external
    time_control, time_text = transfer_time(args.time.split())
    time_text.beep = False
    if args.engine_level:
        level_text = dgttranslate.text('B00_level', args.engine_level)
        level_text.beep = False
    else:
        level_text = None
    DisplayMsg.show(Message.STARTUP_INFO(info={'interaction_mode': interaction_mode, 'play_mode': play_mode,
                                               'books': all_books, 'book_index': book_index, 'level_text': level_text,
                                               'time_control': time_control, 'time_text': time_text}))
    DisplayMsg.show(Message.ENGINE_STARTUP(shell=engine.get_shell(), file=engine.get_file(),
                                           has_levels=engine.has_levels(), has_960=engine.has_chess960()))

    system_info_thread = threading.Timer(0, display_system_info)
    system_info_thread.start()

    # Event loop
    logging.info('evt_queue ready')
    while True:
        try:
            event = evt_queue.get()
        except queue.Empty:
            pass
        else:
            logging.debug('received event from evt_queue: %s', event)
            for case in switch(event):
                if case(EventApi.FEN):
                    process_fen(event.fen)
                    break

                if case(EventApi.KEYBOARD_MOVE):
                    move = event.move
                    logging.debug('keyboard move [%s]', move)
                    if move not in game.legal_moves:
                        logging.warning('illegal move [%s]', move)
                    else:
                        g = copy.deepcopy(game)
                        g.push(move)
                        fen = g.fen().split(' ')[0]
                        if event.flip_board:
                            fen = fen[::-1]
                        DisplayMsg.show(Message.KEYBOARD_MOVE(fen=fen))
                    break

                if case(EventApi.LEVEL):
                    if event.options:
                        engine.startup(event.options, False)
                    DisplayMsg.show(Message.LEVEL(level_text=event.level_text))
                    break

                if case(EventApi.NEW_ENGINE):
                    config = ConfigObj('picochess.ini')
                    config['engine'] = event.eng['file']
                    config.write()
                    old_file = engine.get_file()
                    engine_shutdown = True
                    # Stop the old engine cleanly
                    engine.stop()
                    # Closeout the engine process and threads
                    # The all return non-zero error codes, 0=success
                    if engine.quit():  # Ask nicely
                        if engine.terminate():  # If you won't go nicely.... 
                            if engine.kill():  # Right that does it!
                                logging.error('engine shutdown failure')
                                DisplayMsg.show(Message.ENGINE_FAIL())
                                engine_shutdown = False
                    if engine_shutdown:
                        # Load the new one and send args.
                        # Local engines only
                        engine_fallback = False
                        engine = UciEngine(event.eng['file'])
                        try:
                            engine_name = engine.get().name
                        except AttributeError:
                            # New engine failed to start, restart old engine
                            logging.error("new engine failed to start, reverting to %s", old_file)
                            engine_fallback = True
                            event.options = {}  # Reset options. This will load the last(=strongest?) level
                            engine = UciEngine(old_file)
                            try:
                                engine_name = engine.get().name
                            except AttributeError:
                                # Help - old engine failed to restart. There is no engine
                                logging.error('no engines started')
                                sys.exit(-1)
                        # Schedule cleanup of old objects
                        gc.collect()
                        engine.startup(event.options)
                        # All done - rock'n'roll
                        if not engine_fallback:
                            DisplayMsg.show(Message.ENGINE_READY(eng=event.eng, engine_name=engine_name,
                                                                 eng_text=event.eng_text,
                                                                 has_levels=engine.has_levels(),
                                                                 has_960=engine.has_chess960(), ok_text=event.ok_text))
                        else:
                            DisplayMsg.show(Message.ENGINE_FAIL())
                        set_wait_state(not engine_fallback)
                    break

                if case(EventApi.SETUP_POSITION):
                    logging.debug("setting up custom fen: {}".format(event.fen))
                    uci960 = event.uci960

                    if game.move_stack:
                        if not (game.is_game_over() or game_declared):
                            DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    game = chess.Board(event.fen, uci960)
                    # see new_game
                    stop_search_and_clock()
                    if engine.has_chess960():
                        engine.option('UCI_Chess960', uci960)
                        engine.send()
                    legal_fens = compute_legal_fens(game)
                    last_legal_fens = []
                    last_computer_fen = None
                    time_control.reset()
                    searchmoves.reset()
                    DisplayMsg.show(Message.START_NEW_GAME(time_control=time_control, game=game.copy()))
                    game_declared = False
                    set_wait_state()
                    break

                if case(EventApi.NEW_GAME):
                    logging.debug('starting a new game with code: {}'.format(event.pos960))
                    uci960 = event.pos960 != 518

                    if game.move_stack:
                        if not (game.is_game_over() or game_declared):
                            DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    game = chess.Board()
                    if uci960:
                        game.set_chess960_pos(event.pos960)
                    # see setup_position
                    stop_search_and_clock()
                    if engine.has_chess960():
                        engine.option('UCI_Chess960', uci960)
                        engine.send()
                    legal_fens = compute_legal_fens(game)
                    last_legal_fens = []
                    last_computer_fen = None
                    time_control.reset()
                    searchmoves.reset()
                    DisplayMsg.show(Message.START_NEW_GAME(time_control=time_control, game=game.copy()))
                    game_declared = False
                    set_wait_state()
                    break

                if case(EventApi.PAUSE_RESUME):
                    if engine.is_thinking():
                        stop_clock()
                        engine.stop(show_best=True)
                    else:
                        if time_control.is_ticking():
                            stop_clock()
                        else:
                            start_clock()
                    break

                if case(EventApi.ALTERNATIVE_MOVE):
                    if last_computer_fen:
                        last_computer_fen = None
                        game.pop()
                        DisplayMsg.show(Message.ALTERNATIVE_MOVE())
                        think(game, time_control)
                    break

                if case(EventApi.SWITCH_SIDES):
                    if interaction_mode == Mode.NORMAL:
                        user_to_move = False
                        last_legal_fens = []

                        if engine.is_thinking():
                            stop_clock()
                            engine.stop(show_best=False)
                            user_to_move = True
                        if event.engine_finished:
                            last_computer_fen = None
                            move = game.pop()
                            user_to_move = True
                        else:
                            move = chess.Move.null()
                        if user_to_move:
                            last_legal_fens = []
                            play_mode = PlayMode.USER_WHITE if game.turn == chess.WHITE else PlayMode.USER_BLACK
                        else:
                            play_mode = PlayMode.USER_WHITE if game.turn == chess.BLACK else PlayMode.USER_BLACK

                        if not user_to_move and check_game_state(game, play_mode):
                            time_control.reset_start_time()
                            think(game, time_control)
                            legal_fens = []
                        else:
                            start_clock()
                            legal_fens = compute_legal_fens(game)

                        text = dgttranslate.text(play_mode.value)
                        DisplayMsg.show(Message.PLAY_MODE(play_mode=play_mode, play_mode_text=text))

                        if event.engine_finished:
                            DisplayMsg.show(Message.SWITCH_SIDES(move=move))
                    break

                if case(EventApi.DRAWRESIGN):
                    if not game_declared:  # in case user leaves kings in place while moving other pieces
                        stop_search_and_clock()
                        DisplayMsg.show(Message.GAME_ENDS(result=event.result, play_mode=play_mode, game=game.copy()))
                        game_declared = True
                    break

                if case(EventApi.REMOTE_MOVE):
                    if interaction_mode == Mode.REMOTE:
                        handle_move(move=chess.Move.from_uci(event.move))
                        legal_fens = compute_legal_fens(game)
                    break

                if case(EventApi.BEST_MOVE):
                    handle_move(move=event.result.bestmove, ponder=event.result.ponder, inbook=event.inbook)
                    break

                if case(EventApi.NEW_PV):
                    # illegal moves can occur if a pv from the engine arrives at the same time as a user move.
                    if game.is_legal(event.pv[0]):
                        DisplayMsg.show(Message.NEW_PV(pv=event.pv, mode=interaction_mode, fen=game.fen(), turn=game.turn))
                    else:
                        logging.info('illegal move can not be displayed. move:%s fen=%s', event.pv[0], game.fen())
                    break

                if case(EventApi.NEW_SCORE):
                    DisplayMsg.show(Message.NEW_SCORE(score=event.score, mate=event.mate, mode=interaction_mode, turn=game.turn))
                    break

                if case(EventApi.NEW_DEPTH):
                    DisplayMsg.show(Message.NEW_DEPTH(depth=event.depth))
                    break

                if case(EventApi.SET_INTERACTION_MODE):
                    if interaction_mode in (Mode.NORMAL, Mode.OBSERVE, Mode.REMOTE):
                        stop_clock()
                    interaction_mode = event.mode
                    if engine.is_thinking():
                        stop_search()
                    if engine.is_pondering():
                        stop_search()
                    set_wait_state()
                    DisplayMsg.show(Message.INTERACTION_MODE(mode=event.mode, mode_text=event.mode_text, ok_text=event.ok_text))
                    break

                if case(EventApi.SET_OPENING_BOOK):
                    config = ConfigObj('picochess.ini')
                    config['book'] = event.book['file']
                    config.write()
                    logging.debug("changing opening book [%s]", event.book['file'])
                    bookreader = chess.polyglot.open_reader(event.book['file'])
                    DisplayMsg.show(Message.OPENING_BOOK(book_text=event.book_text, ok_text=event.ok_text))
                    break

                if case(EventApi.SET_TIME_CONTROL):
                    time_control = event.time_control
                    config = ConfigObj('picochess.ini')
                    if time_control.mode == TimeMode.BLITZ:
                        config['time'] = '{:d} 0'.format(time_control.minutes_per_game)
                    elif time_control.mode == TimeMode.FISCHER:
                        config['time'] = '{:d} {:d}'.format(time_control.minutes_per_game, time_control.fischer_increment)
                    elif time_control.mode == TimeMode.FIXED:
                        config['time'] = '{:d}'.format(time_control.seconds_per_move)
                    config.write()
                    DisplayMsg.show(Message.TIME_CONTROL(time_text=event.time_text, ok_text=event.ok_text))
                    break

                if case(EventApi.OUT_OF_TIME):
                    stop_search_and_clock()
                    DisplayMsg.show(Message.GAME_ENDS(result=GameResult.OUT_OF_TIME, play_mode=play_mode, game=game.copy()))
                    break

                if case(EventApi.SHUTDOWN):
                    DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    shutdown(args.dgtpi)
                    break

                if case(EventApi.REBOOT):
                    DisplayMsg.show(Message.GAME_ENDS(result=GameResult.ABORT, play_mode=play_mode, game=game.copy()))
                    reboot()
                    break

                if case(EventApi.EMAIL_LOG):
                    if args.log_file:
                        email_logger = Emailer(net=args.enable_internet, email=args.email, mailgun_key=args.mailgun_key,
                                               smtp_server=args.smtp_server, smtp_user=args.smtp_user,
                                               smtp_pass=args.smtp_pass, smtp_encryption=args.smtp_encryption,
                                               smtp_from=args.smtp_from)
                        body = 'You probably want to forward this file to a picochess developer ;-)'
                        email_logger.send('Picochess LOG', body, '/opt/picochess/logs/{}'.format(args.log_file))
                    break

                if case(EventApi.DGT_BUTTON):
                    DisplayMsg.show(Message.DGT_BUTTON(button=event.button))
                    break

                if case(EventApi.DGT_FEN):
                    DisplayMsg.show(Message.DGT_FEN(fen=event.fen))
                    break

                if case():  # Default
                    logging.warning("event not handled : [%s]", event)

            evt_queue.task_done()

Example 28

View license
def save(document, db, tool):
    """Save the project details in the Lair database.

    :param document: A complete representation of the project model
    :param db: A connection to the target Lair database
    :raise: MissingRequiredSchemaField, ProjectDoesNotExistError
    """

    has_errors = False
    valid_statuses = [lair_models.STATUS_GREY, lair_models.STATUS_BLUE,
                      lair_models.STATUS_GREEN, lair_models.STATUS_ORANGE,
                      lair_models.STATUS_RED]

    # Validate compatible versions
    version = db.versions.find_one()
    if version['version'] != VERSION:
        raise IncompatibleVersionError(VERSION, version['version'])

    # Validate the schema - will raise an error if invalid
    validate(document)

    print "[+] Processing project {0}".format(document['project_id'])

    temp_drone_log = list()

    q = {'_id': document['project_id']}

    # Ensure the project exists in the database
    if db.projects.find(q).count() != 1:
        raise ProjectDoesNotExistError(document['project_id'])

    project = db.projects.find_one(q)

    # Add the command
    project['commands'].extend(document['commands'])

    # Add project notes
    project['notes'].extend(document['notes'])

    # Add the owner if it isn't already set
    if 'owner' not in project or not project['owner']:
        project['owner'] = document['owner']

    # Add the industry if not already set
    if 'industry' not in project or not project['industry']:
        project['industry'] = document.get('industry', 'N/A')

    # Add the creation date if not already set
    if 'creation_date' not in project or not project['creation_date']:
        project['creation_date'] = document['creation_date']

    # Add the description if not already set
    if 'description' not in project or not project['description']:
        project['description'] = document.get('description', '')

    if 'hosts' not in project:
        project['hosts'] = list()

    if 'vulnerabilities' not in project:
        project['vulnerabilities'] = list()

    if len(project['vulnerabilities']) == 0 and len(project['hosts']) == 0:
        now = datetime.utcnow().isoformat()
        temp_drone_log.append("{0} - Initial project load".format(now))

    # Create indexes
    db.hosts.ensure_index([
        ('project_id', ASCENDING),
        ('string_addr', ASCENDING)
    ])
    db.ports.ensure_index([
        ('project_id', ASCENDING),
        ('host_id', ASCENDING),
        ('port', ASCENDING),
        ('protocol', ASCENDING)
    ])
    db.vulnerabilities.ensure_index([
        ('project_id', ASCENDING),
        ('plugin_ids', ASCENDING)
    ])

    # For each host in the parsed scan, check to see if it already
    # exists in the database.
    for file_host in document['hosts']:

        is_known_host = True
        host = db.hosts.find_one({'project_id': project['_id'], 'string_addr': file_host['string_addr']})
        if not host:
            is_known_host = False
            host = copy.deepcopy(lair_models.host_model)

        pre_md5 = hashlib.md5()
        pre_md5.update(str(host))

        host['project_id'] = project['_id']
        host['alive'] = file_host['alive']
        host['string_addr' ] = file_host['string_addr']
        host['long_addr'] = file_host['long_addr']
        host['is_profiled'] = file_host.get('is_profiled', False)
        host['is_enumerated'] = file_host.get('is_enumerated', False)

        # Include any host notes
        if file_host['notes']:
            host['notes'].extend(file_host['notes'])

        # Add hostnames
        if len(file_host['hostnames']) > 0:
            # Only update if new host names were identified
            if not set(file_host['hostnames']).issubset(host['hostnames']):
                host['hostnames'].extend(file_host['hostnames'])
                host['hostnames'] = list(set(host['hostnames']))
                host['last_modified_by'] = tool

        # Update MAC address if it's not set already
        if not host['mac_addr']:
            host['mac_addr'] = file_host['mac_addr']

        # Add the operating system
        if file_host['os']:
            os_list = []
            # The following ensures that no duplicate entries are
            # added to the database.
            for file_os in file_host['os']:
                dupe_found = False
                for db_os in host['os']:
                    if db_os['tool'] == file_os['tool'] and \
                       db_os['fingerprint'] == \
                       file_os['fingerprint']:
                        dupe_found = True

                if not dupe_found:
                    os_list.append(file_os)
                    host['last_modified_by'] = tool

            host['os'].extend(os_list)

        post_md5 = hashlib.md5()
        post_md5.update(str(host))

        # Only save if changes were detected
        if pre_md5 != post_md5:
            host['last_modified_by'] = tool
            if not is_known_host:
                id = str(ObjectId())
                host['_id'] = id
                s = file_host.get('status', lair_models.STATUS_GREY)
                host['status'] = s if s in valid_statuses else lair_models.STATUS_GREY

            db.hosts.save(host)

        if not is_known_host:
            now = datetime.utcnow().isoformat()
            temp_drone_log.append("{0} - New host found: {1}".format(
                now,
                file_host['string_addr'])
            )

        # Process each web directory for the host, checking against existing dirs
        if 'web_directories' in file_host:
            # Check if web directories are supported by the remote Lair server.
            if 'web_directories' in db.collection_names():
                for file_directory in file_host['web_directories']:
                    q = {
                        'project_id': project['_id'],
                        'host_id': host['_id'],
                        'path_clean': file_directory['path_clean'],
                        'port': file_directory['port'],
                        'response_code': file_directory['response_code'],
                    }
                    directory = db.web_directories.find_one(q)

                    is_known_directory = False
                    if directory:
                        is_known_directory = True
                    else:
                        directory = copy.deepcopy(lair_models.web_directory_model)

                    pre_md5 = hashlib.md5()
                    pre_md5.update(str(directory))

                    directory['project_id'] = project['_id']
                    directory['host_id'] = host['_id']
                    directory['path'] = file_directory['path']
                    directory['path_clean'] = file_directory['path_clean']
                    directory['port'] = file_directory['port']
                    directory['response_code'] = file_directory['response_code']

                    post_md5 = hashlib.md5()
                    post_md5.update(str(directory))

                    if pre_md5 != post_md5:
                        directory['last_modified_by'] = tool
                        if not is_known_directory:
                            id = str(ObjectId())
                            directory['_id'] = id
                        db.web_directories.save(directory)
            else:
                has_errors = True
                print "[!] Your version of Lair does not support the addition of web directories."
                print "[!] Please check the Lair project on GitHub for more information (https://github.com/lair-framework/lair)."

        # Process each port for the host, checking against known ports
        for file_port in file_host['ports']:

            q = {
                'project_id': project['_id'],
                'host_id': host['_id'],
                'port': file_port['port'],
                'protocol': file_port['protocol']
            }
            port = db.ports.find_one(q)

            is_known_port = False
            if port:
                is_known_port = True
            else:
                port = copy.deepcopy(lair_models.port_model)

            pre_md5 = hashlib.md5()
            pre_md5.update(str(port))

            port['host_id'] = host['_id']
            port['project_id'] = project['_id']
            port['protocol'] = file_port['protocol']
            port['port'] = file_port['port']

            # TODO: Determine how to handle a closed port
            port['alive'] = file_port['alive']

            # Update product if it is unknown
            if port['product'] == lair_models.PRODUCT_UNKNOWN:
                port['product'] = file_port['product']

            # Set the service if it is not set
            if not port['service'] or port['service'] == 'unknown':
                port['service'] = file_port['service']

            # Include any script output for the port
            if file_port['notes']:
                port['notes'].extend(file_port['notes'])

            # Include any credentials
            if file_port['credentials']:
                port['credentials'].extend(file_port['credentials'])

            if not is_known_port:
                id = str(ObjectId())
                port['_id'] = id
                s = file_port.get('status', lair_models.STATUS_GREY)
                port['status'] = s if s in valid_statuses else lair_models.STATUS_GREY
                now = datetime.utcnow().isoformat()
                temp_drone_log.append("{0} - New port found: {1}/{2} ({3})".format(
                    now,
                    str(file_port['port']),
                    file_port['protocol'],
                    file_port['service'])
                )

            post_md5 = hashlib.md5()
            post_md5.update(str(port))

            if pre_md5 != post_md5:
                port['last_modified_by'] = tool
                db.ports.save(port)

    # For each vulnerability in the parsed scan, check to see if it already
    # exists in the database.
    for file_vuln in document.get('vulnerabilities', []):

        is_known_vuln = False

        # Attempt a lookup by plugin_id...
        q = {
            'project_id': project['_id'],
            'plugin_ids': {'$all': file_vuln['plugin_ids']}
        }
        db_vuln = db.vulnerabilities.find_one(q)

        if db_vuln:
            is_known_vuln = True

        # No vuln found by plugin_id, treat as new
        if not is_known_vuln:
            db_vuln = copy.deepcopy(file_vuln)
            id = str(ObjectId())
            s = file_vuln.get('status', lair_models.STATUS_GREY)
            db_vuln['status'] = s if s in valid_statuses else lair_models.STATUS_GREY
            db_vuln['_id'] = id
            db_vuln['project_id'] = project['_id']
            db_vuln['last_modified_by'] = tool
            now = datetime.utcnow().isoformat()
            temp_drone_log.append("{0} - New vulnerability found: {1}".format(
                now,
                file_vuln['title'].encode("utf-8"))
            )
            db.vulnerabilities.save(db_vuln)

        if is_known_vuln:
            pre_md5 = hashlib.md5()
            pre_md5.update(str(db_vuln))

            db_vuln['cves'].extend(file_vuln['cves'])
            db_vuln['cves'] = list(set(db_vuln['cves']))
            db_vuln['identified_by'].extend(file_vuln['identified_by'])

            # Only set 'flag' if it's true for parsed vuln
            db_vuln['flag'] = file_vuln['flag'] \
                if file_vuln.get('flag', False) else db_vuln.get('flag', False)

            # Include any script output for the port
            if file_vuln['notes']:
                db_vuln['notes'].extend(file_vuln['notes'])

            for file_host in file_vuln['hosts']:
                if file_host not in db_vuln['hosts']:
                    db_vuln['hosts'].append(file_host)
                    now = datetime.utcnow().isoformat()
                    temp_drone_log.append("{0} - {1}:{2}/{3} - New vulnerability found: {4}".format(
                        now,
                        file_host['string_addr'],
                        str(file_host['port']),
                        file_host['protocol'],
                        file_vuln['title'])
                    )

            post_md5 = hashlib.md5()
            post_md5.update(str(db_vuln))

            # Vulnerability was known, but change was detected
            if pre_md5 != post_md5:
                db_vuln['last_modified_by'] = tool
                db.vulnerabilities.save(db_vuln)

    # Ensure history log does not exceed DRONE_LOG_HISTORY limit
    project['drone_log'].extend(temp_drone_log)
    length = len(project['drone_log'])
    del project['drone_log'][0:(length - DRONE_LOG_HISTORY)]

    db.projects.save(project)

    if not has_errors:
        print "[+] Processing completed: {0} host(s) processed.".format(
            str(len(document['hosts'])))
    else:
        print "[!] Could not process this drone's data. See above for any error messages."

Example 29

View license
def save(document, db, tool):
    """Save the project details in the Lair database.

    :param document: A complete representation of the project model
    :param db: A connection to the target Lair database
    :raise: MissingRequiredSchemaField, ProjectDoesNotExistError
    """

    has_errors = False
    valid_statuses = [lair_models.STATUS_GREY, lair_models.STATUS_BLUE,
                      lair_models.STATUS_GREEN, lair_models.STATUS_ORANGE,
                      lair_models.STATUS_RED]

    # Validate compatible versions
    version = db.versions.find_one()
    if version['version'] != VERSION:
        raise IncompatibleVersionError(VERSION, version['version'])

    # Validate the schema - will raise an error if invalid
    validate(document)

    print "[+] Processing project {0}".format(document['project_id'])

    temp_drone_log = list()

    q = {'_id': document['project_id']}

    # Ensure the project exists in the database
    if db.projects.find(q).count() != 1:
        raise ProjectDoesNotExistError(document['project_id'])

    project = db.projects.find_one(q)

    # Add the command
    project['commands'].extend(document['commands'])

    # Add project notes
    project['notes'].extend(document['notes'])

    # Add the owner if it isn't already set
    if 'owner' not in project or not project['owner']:
        project['owner'] = document['owner']

    # Add the industry if not already set
    if 'industry' not in project or not project['industry']:
        project['industry'] = document.get('industry', 'N/A')

    # Add the creation date if not already set
    if 'creation_date' not in project or not project['creation_date']:
        project['creation_date'] = document['creation_date']

    # Add the description if not already set
    if 'description' not in project or not project['description']:
        project['description'] = document.get('description', '')

    if 'hosts' not in project:
        project['hosts'] = list()

    if 'vulnerabilities' not in project:
        project['vulnerabilities'] = list()

    if len(project['vulnerabilities']) == 0 and len(project['hosts']) == 0:
        now = datetime.utcnow().isoformat()
        temp_drone_log.append("{0} - Initial project load".format(now))

    # Create indexes
    db.hosts.ensure_index([
        ('project_id', ASCENDING),
        ('string_addr', ASCENDING)
    ])
    db.ports.ensure_index([
        ('project_id', ASCENDING),
        ('host_id', ASCENDING),
        ('port', ASCENDING),
        ('protocol', ASCENDING)
    ])
    db.vulnerabilities.ensure_index([
        ('project_id', ASCENDING),
        ('plugin_ids', ASCENDING)
    ])

    # For each host in the parsed scan, check to see if it already
    # exists in the database.
    for file_host in document['hosts']:

        is_known_host = True
        host = db.hosts.find_one({'project_id': project['_id'], 'string_addr': file_host['string_addr']})
        if not host:
            is_known_host = False
            host = copy.deepcopy(lair_models.host_model)

        pre_md5 = hashlib.md5()
        pre_md5.update(str(host))

        host['project_id'] = project['_id']
        host['alive'] = file_host['alive']
        host['string_addr' ] = file_host['string_addr']
        host['long_addr'] = file_host['long_addr']
        host['is_profiled'] = file_host.get('is_profiled', False)
        host['is_enumerated'] = file_host.get('is_enumerated', False)

        # Include any host notes
        if file_host['notes']:
            host['notes'].extend(file_host['notes'])

        # Add hostnames
        if len(file_host['hostnames']) > 0:
            # Only update if new host names were identified
            if not set(file_host['hostnames']).issubset(host['hostnames']):
                host['hostnames'].extend(file_host['hostnames'])
                host['hostnames'] = list(set(host['hostnames']))
                host['last_modified_by'] = tool

        # Update MAC address if it's not set already
        if not host['mac_addr']:
            host['mac_addr'] = file_host['mac_addr']

        # Add the operating system
        if file_host['os']:
            os_list = []
            # The following ensures that no duplicate entries are
            # added to the database.
            for file_os in file_host['os']:
                dupe_found = False
                for db_os in host['os']:
                    if db_os['tool'] == file_os['tool'] and \
                       db_os['fingerprint'] == \
                       file_os['fingerprint']:
                        dupe_found = True

                if not dupe_found:
                    os_list.append(file_os)
                    host['last_modified_by'] = tool

            host['os'].extend(os_list)

        post_md5 = hashlib.md5()
        post_md5.update(str(host))

        # Only save if changes were detected
        if pre_md5 != post_md5:
            host['last_modified_by'] = tool
            if not is_known_host:
                id = str(ObjectId())
                host['_id'] = id
                s = file_host.get('status', lair_models.STATUS_GREY)
                host['status'] = s if s in valid_statuses else lair_models.STATUS_GREY

            db.hosts.save(host)

        if not is_known_host:
            now = datetime.utcnow().isoformat()
            temp_drone_log.append("{0} - New host found: {1}".format(
                now,
                file_host['string_addr'])
            )

        # Process each web directory for the host, checking against existing dirs
        if 'web_directories' in file_host:
            # Check if web directories are supported by the remote Lair server.
            if 'web_directories' in db.collection_names():
                for file_directory in file_host['web_directories']:
                    q = {
                        'project_id': project['_id'],
                        'host_id': host['_id'],
                        'path_clean': file_directory['path_clean'],
                        'port': file_directory['port'],
                        'response_code': file_directory['response_code'],
                    }
                    directory = db.web_directories.find_one(q)

                    is_known_directory = False
                    if directory:
                        is_known_directory = True
                    else:
                        directory = copy.deepcopy(lair_models.web_directory_model)

                    pre_md5 = hashlib.md5()
                    pre_md5.update(str(directory))

                    directory['project_id'] = project['_id']
                    directory['host_id'] = host['_id']
                    directory['path'] = file_directory['path']
                    directory['path_clean'] = file_directory['path_clean']
                    directory['port'] = file_directory['port']
                    directory['response_code'] = file_directory['response_code']

                    post_md5 = hashlib.md5()
                    post_md5.update(str(directory))

                    if pre_md5 != post_md5:
                        directory['last_modified_by'] = tool
                        if not is_known_directory:
                            id = str(ObjectId())
                            directory['_id'] = id
                        db.web_directories.save(directory)
            else:
                has_errors = True
                print "[!] Your version of Lair does not support the addition of web directories."
                print "[!] Please check the Lair project on GitHub for more information (https://github.com/lair-framework/lair)."

        # Process each port for the host, checking against known ports
        for file_port in file_host['ports']:

            q = {
                'project_id': project['_id'],
                'host_id': host['_id'],
                'port': file_port['port'],
                'protocol': file_port['protocol']
            }
            port = db.ports.find_one(q)

            is_known_port = False
            if port:
                is_known_port = True
            else:
                port = copy.deepcopy(lair_models.port_model)

            pre_md5 = hashlib.md5()
            pre_md5.update(str(port))

            port['host_id'] = host['_id']
            port['project_id'] = project['_id']
            port['protocol'] = file_port['protocol']
            port['port'] = file_port['port']

            # TODO: Determine how to handle a closed port
            port['alive'] = file_port['alive']

            # Update product if it is unknown
            if port['product'] == lair_models.PRODUCT_UNKNOWN:
                port['product'] = file_port['product']

            # Set the service if it is not set
            if not port['service'] or port['service'] == 'unknown':
                port['service'] = file_port['service']

            # Include any script output for the port
            if file_port['notes']:
                port['notes'].extend(file_port['notes'])

            # Include any credentials
            if file_port['credentials']:
                port['credentials'].extend(file_port['credentials'])

            if not is_known_port:
                id = str(ObjectId())
                port['_id'] = id
                s = file_port.get('status', lair_models.STATUS_GREY)
                port['status'] = s if s in valid_statuses else lair_models.STATUS_GREY
                now = datetime.utcnow().isoformat()
                temp_drone_log.append("{0} - New port found: {1}/{2} ({3})".format(
                    now,
                    str(file_port['port']),
                    file_port['protocol'],
                    file_port['service'])
                )

            post_md5 = hashlib.md5()
            post_md5.update(str(port))

            if pre_md5 != post_md5:
                port['last_modified_by'] = tool
                db.ports.save(port)

    # For each vulnerability in the parsed scan, check to see if it already
    # exists in the database.
    for file_vuln in document.get('vulnerabilities', []):

        is_known_vuln = False

        # Attempt a lookup by plugin_id...
        q = {
            'project_id': project['_id'],
            'plugin_ids': {'$all': file_vuln['plugin_ids']}
        }
        db_vuln = db.vulnerabilities.find_one(q)

        if db_vuln:
            is_known_vuln = True

        # No vuln found by plugin_id, treat as new
        if not is_known_vuln:
            db_vuln = copy.deepcopy(file_vuln)
            id = str(ObjectId())
            s = file_vuln.get('status', lair_models.STATUS_GREY)
            db_vuln['status'] = s if s in valid_statuses else lair_models.STATUS_GREY
            db_vuln['_id'] = id
            db_vuln['project_id'] = project['_id']
            db_vuln['last_modified_by'] = tool
            now = datetime.utcnow().isoformat()
            temp_drone_log.append("{0} - New vulnerability found: {1}".format(
                now,
                file_vuln['title'].encode("utf-8"))
            )
            db.vulnerabilities.save(db_vuln)

        if is_known_vuln:
            pre_md5 = hashlib.md5()
            pre_md5.update(str(db_vuln))

            db_vuln['cves'].extend(file_vuln['cves'])
            db_vuln['cves'] = list(set(db_vuln['cves']))
            db_vuln['identified_by'].extend(file_vuln['identified_by'])

            # Only set 'flag' if it's true for parsed vuln
            db_vuln['flag'] = file_vuln['flag'] \
                if file_vuln.get('flag', False) else db_vuln.get('flag', False)

            # Include any script output for the port
            if file_vuln['notes']:
                db_vuln['notes'].extend(file_vuln['notes'])

            for file_host in file_vuln['hosts']:
                if file_host not in db_vuln['hosts']:
                    db_vuln['hosts'].append(file_host)
                    now = datetime.utcnow().isoformat()
                    temp_drone_log.append("{0} - {1}:{2}/{3} - New vulnerability found: {4}".format(
                        now,
                        file_host['string_addr'],
                        str(file_host['port']),
                        file_host['protocol'],
                        file_vuln['title'])
                    )

            post_md5 = hashlib.md5()
            post_md5.update(str(db_vuln))

            # Vulnerability was known, but change was detected
            if pre_md5 != post_md5:
                db_vuln['last_modified_by'] = tool
                db.vulnerabilities.save(db_vuln)

    # Ensure history log does not exceed DRONE_LOG_HISTORY limit
    project['drone_log'].extend(temp_drone_log)
    length = len(project['drone_log'])
    del project['drone_log'][0:(length - DRONE_LOG_HISTORY)]

    db.projects.save(project)

    if not has_errors:
        print "[+] Processing completed: {0} host(s) processed.".format(
            str(len(document['hosts'])))
    else:
        print "[!] Could not process this drone's data. See above for any error messages."

Example 30

View license
def parse(project, nessus_file, include_informational=False, min_note_sev=2):
    """Parses a Nessus XMLv2 file and updates the Hive database

    :param project: The project id
    :param nessus_file: The Nessus xml file to be parsed
    :param include_informational: Whether to include info findings in data. Default False
    :min_note_sev: The minimum severity of notes that will be saved. Default 2
    """

    cve_pattern = re.compile(r'(CVE-|CAN-)')
    false_udp_pattern = re.compile(r'.*\?$')

    tree = et.parse(nessus_file)
    root = tree.getroot()
    note_id = 1

    # Create the project dictionary which acts as foundation of document
    project_dict = dict(models.project_model)
    project_dict['commands'] = list()
    project_dict['vulnerabilities'] = list()
    project_dict['project_id'] = project

    # Used to maintain a running list of host:port vulnerabilities by plugin
    vuln_host_map = dict()

    for host in root.iter('ReportHost'):

        temp_ip = host.attrib['name']

        host_dict = dict(models.host_model)
        host_dict['os'] = list()
        host_dict['ports'] = list()
        host_dict['hostnames'] = list()

        # Tags contain host-specific information
        for tag in host.iter('tag'):

            # Operating system tag
            if tag.attrib['name'] == 'operating-system':
                os_dict = dict(models.os_model)
                os_dict['tool'] = TOOL
                os_dict['weight'] = OS_WEIGHT
                os_dict['fingerprint'] = tag.text
                host_dict['os'].append(os_dict)

            # IP address tag
            if tag.attrib['name'] == 'host-ip':
                host_dict['string_addr'] = tag.text
                host_dict['long_addr'] = helper.ip2long(tag.text)

            # MAC address tag
            if tag.attrib['name'] == 'mac-address':
                host_dict['mac_addr'] = tag.text

            # Hostname tag
            if tag.attrib['name'] == 'host-fqdn':
                host_dict['hostnames'].append(tag.text)

            # NetBIOS name tag
            if tag.attrib['name'] == 'netbios-name':
                host_dict['hostnames'].append(tag.text)

        # Track the unique port/protocol combos for a host so we don't
        # add duplicate entries
        ports_processed = dict()

        # Process each 'ReportItem'
        for item in host.findall('ReportItem'):
            plugin_id = item.attrib['pluginID']
            plugin_family = item.attrib['pluginFamily']
            severity = int(item.attrib['severity'])
            title = item.attrib['pluginName']

            port = int(item.attrib['port'])
            protocol = item.attrib['protocol']
            service = item.attrib['svc_name']
            evidence = item.find('plugin_output')

            # Ignore false positive UDP services
            if protocol == "udp" and false_udp_pattern.match(service):
                continue

            # Create a port model and temporarily store it in the dict
            # for tracking purposes. The ports_processed dict is used
            # later to add ports to the host so that no duplicates are
            # present. This is necessary due to the format of the Nessus
            # XML files.
            if '{0}:{1}'.format(port, protocol) not in ports_processed:
                port_dict = copy.deepcopy(models.port_model)
                port_dict['port'] = port
                port_dict['protocol'] = protocol
                port_dict['service'] = service
                ports_processed['{0}:{1}'.format(port, protocol)] = port_dict

            # Set the evidence as a port note if it exists
            if evidence is not None and \
                    severity >= min_note_sev and \
                    plugin_family != 'Port scanners' and \
                    plugin_family != 'Service detection':
                note_dict = copy.deepcopy(models.note_model)
                note_dict['title'] = "{0} (ID{1})".format(title, str(note_id))
                e = evidence.text.strip()
                for line in e.split("\n"):
                    line = line.strip()
                    if line:
                        note_dict['content'] += "    " + line + "\n"
                note_dict['last_modified_by'] = TOOL
                ports_processed['{0}:{1}'.format(port, protocol)]['notes'].append(note_dict)
                note_id += 1

            # This plugin is general scan info...use it for 'command' element
            if plugin_id == '19506':

                command = item.find('plugin_output')

                command_dict = dict(models.command_model)
                command_dict['tool'] = TOOL

                if command is not None:
                    command_dict['command'] = command.text

                if not project_dict['commands']:
                    project_dict['commands'].append(command_dict)

                continue

            # Check if this vulnerability has been seen in this file for
            # another host. If not, create a new vulnerability_model and
            # maintain a mapping between plugin-id and vulnerability as
            # well as a mapping between plugin-id and host. These mappings
            # are later used to completed the Hive schema such that host
            # IP and port information are embedded within each vulnerability
            # while ensuring no duplicate data exists.
            if plugin_id not in vuln_host_map:

                v = copy.deepcopy(models.vulnerability_model)
                v['cves'] = list()
                v['plugin_ids'] = list()
                v['identified_by'] = list()
                v['hosts'] = list()

                # Set the title
                v['title'] = title

                # Set the description
                description = item.find('description')
                if description is not None:
                    v['description'] = description.text

                # Set the solution
                solution = item.find('solution')
                if solution is not None:
                    v['solution'] = solution.text

                # Set the evidence
                if evidence is not None:
                    v['evidence'] = evidence.text

                # Set the vulnerability flag if exploit exists
                exploit = item.find('exploit_available')
                if exploit is not None:
                    v['flag'] = exploit.text == 'true'

                    # Grab Metasploit details
                    exploit_detail = item.find('exploit_framework_metasploit')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Metasploit Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('metasploit_name')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab Canvas details
                    exploit_detail = item.find('exploit_framework_canvas')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Canvas Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('canvas_package')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab Core Impact details
                    exploit_detail = item.find('exploit_framework_core')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Core Impact Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('core_name')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab ExploitHub SKUs
                    exploit_detail = item.find('exploit_framework_exploithub')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Exploit Hub Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('exploithub_sku')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab any and all ExploitDB IDs
                    details = item.iter('edb-id')
                    if details is not None:
                        for module in details:
                            note_dict = copy.deepcopy(models.note_model)
                            note_dict['title'] = 'Exploit-DB Exploit ' \
                                                 '({0})'.format(module.text)
                            note_dict['content'] = module.text
                            note_dict['last_modified_by'] = TOOL
                            v['notes'].append(note_dict)

                # Set the CVSS score
                cvss = item.find('cvss_base_score')
                if cvss is not None:
                    v['cvss'] = float(cvss.text)
                else:
                    risk_factor = item.find('risk_factor')
                    if risk_factor is not None:
                        rf = risk_factor.text
                        if rf == "Low":
                            v['cvss'] = 3.0
                        elif rf == "Medium":
                            v['cvss'] = 5.0
                        elif rf == "High":
                            v['cvss'] = 7.5
                        elif rf == "Critical":
                            v['cvss'] = 10.0

                # Set the CVE(s)
                for cve in item.findall('cve'):
                    c = cve_pattern.sub('', cve.text)
                    v['cves'].append(c)

                # Set the plugin information
                plugin_dict = dict(models.plugin_id_model)
                plugin_dict['tool'] = TOOL
                plugin_dict['id'] = plugin_id
                v['plugin_ids'].append(plugin_dict)

                # Set the identified by information
                identified_dict = dict(models.identified_by_model)
                identified_dict['tool'] = TOOL
                identified_dict['id'] = plugin_id
                v['identified_by'].append(identified_dict)

                # By default, don't include informational findings unless
                # explicitly told to do so.
                if v['cvss'] == 0 and not include_informational:
                    continue

                vuln_host_map[plugin_id] = dict()
                vuln_host_map[plugin_id]['hosts'] = set()
                vuln_host_map[plugin_id]['vuln'] = v

            if plugin_id in vuln_host_map:
                vuln_host_map[plugin_id]['hosts'].add(
                    "{0}:{1}:{2}".format(
                        host_dict['string_addr'],
                        str(port),
                        protocol
                    )
                )

        # In the event no IP was found, use the 'name' attribute of
        # the 'ReportHost' element
        if not host_dict['string_addr']:
            host_dict['string_addr'] = temp_ip
            host_dict['long_addr'] = helper.ip2long(temp_ip)

        # Add all encountered ports to the host
        host_dict['ports'].extend(ports_processed.values())

        project_dict['hosts'].append(host_dict)

    # This code block uses the plugin/host/vuln mapping to associate
    # all vulnerable hosts to their vulnerability data within the
    # context of the expected Hive schema structure.
    for plugin_id, data in vuln_host_map.items():

        # Build list of host and ports affected by vulnerability and
        # assign that list to the vulnerability model
        for key in data['hosts']:
            (string_addr, port, protocol) = key.split(':')

            host_key_dict = dict(models.host_key_model)
            host_key_dict['string_addr'] = string_addr
            host_key_dict['port'] = int(port)
            host_key_dict['protocol'] = protocol
            data['vuln']['hosts'].append(host_key_dict)

        project_dict['vulnerabilities'].append(data['vuln'])

    if not project_dict['commands']:
        # Adds a dummy 'command' in the event the the Nessus plugin used
        # to populate the data was not run. The Lair API expects it to
        # contain a value.
        command = copy.deepcopy(models.command_model)
        command['tool'] = TOOL
        command['command'] = "Nessus scan - command unknown"
        project_dict['commands'].append(command)

    return project_dict

Example 31

View license
def parse(project, nexpose_file, include_informational=False):
    """Parses a Nexpose XMLv2 file and updates the Lair database

    :param project: The project id
    :param nexpose_file: The Nexpose xml file to be parsed
    :include_informational: Whether to include info findings in data. Default False
    """

    cve_pattern = re.compile(r'(CVE-|CAN-)')
    html_tag_pattern = re.compile(r'<.*?>')
    white_space_pattern = re.compile(r'\s+', re.MULTILINE)

    # Used to create unique notes in DB
    note_id = 1

    tree = et.parse(nexpose_file)
    root = tree.getroot()
    if root is None or \
            root.tag != "NexposeReport" or \
            root.attrib['version'] != "2.0":
        raise IncompatibleDataVersionError("Nexpose XML 2.0")

    # Create the project dictionary which acts as foundation of document
    project_dict = dict(models.project_model)
    project_dict['commands'] = list()
    project_dict['vulnerabilities'] = list()
    project_dict['project_id'] = project
    project_dict['commands'].append({'tool': TOOL, 'command': 'scan'})

    # Used to maintain a running list of host:port vulnerabilities by plugin
    vuln_host_map = dict()

    for vuln in root.iter('vulnerability'):
        v = copy.deepcopy(models.vulnerability_model)
        v['cves'] = list()
        v['plugin_ids'] = list()
        v['identified_by'] = list()
        v['hosts'] = list()

        v['cvss'] = float(vuln.attrib['cvssScore'])
        v['title'] = vuln.attrib['title']
        plugin_id = vuln.attrib['id'].lower()

        # Set plugin id
        plugin_dict = dict(models.plugin_id_model)
        plugin_dict['tool'] = TOOL
        plugin_dict['id'] = plugin_id
        v['plugin_ids'].append(plugin_dict)

        # Set identified by information
        identified_dict = dict(models.identified_by_model)
        identified_dict['tool'] = TOOL
        identified_dict['id'] = plugin_id
        v['identified_by'].append(identified_dict)

        # Search for exploits
        for exploit in vuln.iter('exploit'):
            v['flag'] = True
            note_dict = copy.deepcopy(models.note_model)
            note_dict['title'] = "{0} ({1})".format(
                exploit.attrib['type'],
                exploit.attrib['id']
            )
            note_dict['content'] = "{0}\n{1}".format(
                exploit.attrib['title'].encode('ascii', 'replace'),
                exploit.attrib['link'].encode('ascii', 'replace')
            )
            note_dict['last_modified_by'] = TOOL
            v['notes'].append(note_dict)

        # Search for CVE references
        for reference in vuln.iter('reference'):
            if reference.attrib['source'] == 'CVE':
                cve = cve_pattern.sub('', reference.text)
                v['cves'].append(cve)

        # Search for solution
        solution = vuln.find('solution')
        if solution is not None:
            for text in solution.itertext():
                s = text.encode('ascii', 'replace').strip()
                v['solution'] += white_space_pattern.sub(" ", s)

        # Search for description
        description = vuln.find('description')
        if description is not None:
            for text in description.itertext():
                s = text.encode('ascii', 'replace').strip()
                v['description'] += white_space_pattern.sub(" ", s)

        # Build mapping of plugin-id to host to vuln dictionary
        vuln_host_map[plugin_id] = dict()
        vuln_host_map[plugin_id]['vuln'] = v
        vuln_host_map[plugin_id]['hosts'] = set()

    for node in root.iter('node'):

        host_dict = dict(models.host_model)
        host_dict['os'] = list()
        host_dict['ports'] = list()
        host_dict['hostnames'] = list()

        # Set host status
        if node.attrib['status'] != 'alive':
            host_dict['alive'] = False

        # Set IP address
        host_dict['string_addr'] = node.attrib['address']
        host_dict['long_addr'] = helper.ip2long(node.attrib['address'])

        # Set the OS fingerprint
        certainty = 0
        for os in node.iter('os'):
            if float(os.attrib['certainty']) > certainty:
                certainty = float(os.attrib['certainty'])
                os_dict = dict(models.os_model)
                os_dict['tool'] = TOOL
                os_dict['weight'] = OS_WEIGHT

                fingerprint = ''
                if 'vendor' in os.attrib:
                    fingerprint += os.attrib['vendor'] + " "

                # Make an extra check to limit duplication of data in the
                # event that the product name was already in the vendor name
                if 'product' in os.attrib and \
                        os.attrib['product'] not in fingerprint:
                    fingerprint += os.attrib['product'] + " "

                fingerprint = fingerprint.strip()
                os_dict['fingerprint'] = fingerprint

                host_dict['os'] = list()
                host_dict['os'].append(os_dict)

        # Test for general, non-port related vulnerabilities
        # Add them as tcp, port 0
        tests = node.find('tests')
        if tests is not None:
            port_dict = dict(models.port_model)
            port_dict['service'] = "general"

            for test in tests.findall('test'):
                # vulnerable-since attribute is used to flag
                # confirmed vulns
                if 'vulnerable-since' in test.attrib:
                    plugin_id = test.attrib['id'].lower()

                    # This is used to track evidence for the host/port
                    # and plugin
                    h = "{0}:{1}:{2}".format(
                        host_dict['string_addr'],
                        "0",
                        models.PROTOCOL_TCP
                    )
                    vuln_host_map[plugin_id]['hosts'].add(h)

            host_dict['ports'].append(port_dict)

        # Use the endpoint elements to populate port data
        for endpoint in node.iter('endpoint'):
            port_dict = copy.deepcopy(models.port_model)
            port_dict['port'] = int(endpoint.attrib['port'])
            port_dict['protocol'] = endpoint.attrib['protocol']
            if endpoint.attrib['status'] != 'open':
                port_dict['alive'] = False

            # Use the service elements to identify service
            for service in endpoint.iter('service'):

                # Ignore unknown services
                if 'unknown' not in service.attrib['name'].lower():
                    if not port_dict['service']:
                        port_dict['service'] = service.attrib['name'].lower()

                # Use the test elements to identify vulnerabilities for
                # the host
                for test in service.iter('test'):
                    # vulnerable-since attribute is used to flag
                    # confirmed vulns
                    if 'vulnerable-since' in test.attrib:
                        plugin_id = test.attrib['id'].lower()

                        # Add service notes for evidence
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = "{0} (ID{1})".format(plugin_id,
                                                              str(note_id))
                        for evidence in test.iter():
                            if evidence.text:
                                for line in evidence.text.split("\n"):
                                    line = line.strip()
                                    if line:
                                        note_dict['content'] += "    " + \
                                                                line + "\n"
                            elif evidence.tag == "URLLink":
                                note_dict['content'] += "    "
                                note_dict['content'] += evidence.attrib[
                                                            'LinkURL'
                                                        ] + "\n"

                        note_dict['last_modified_by'] = TOOL
                        port_dict['notes'].append(note_dict)
                        note_id += 1

                        # This is used to track evidence for the host/port
                        # and plugin
                        h = "{0}:{1}:{2}".format(
                            host_dict['string_addr'],
                            str(port_dict['port']),
                            port_dict['protocol']
                        )
                        vuln_host_map[plugin_id]['hosts'].add(h)

            # Use the fingerprint elements to identify product
            certainty = 0
            for fingerprint in endpoint.iter('fingerprint'):
                if float(fingerprint.attrib['certainty']) > certainty:
                    certainty = float(fingerprint.attrib['certainty'])
                    prod = ''
                    if 'vendor' in fingerprint.attrib:
                        prod += fingerprint.attrib['vendor'] + " "

                    if 'product' in fingerprint.attrib:
                        prod += fingerprint.attrib['product'] + " "

                    if 'version' in fingerprint.attrib:
                        prod += fingerprint.attrib['version'] + " "

                    prod = prod.strip()
                    port_dict['product'] = prod

            host_dict['ports'].append(port_dict)

        project_dict['hosts'].append(host_dict)

    # This code block uses the plugin/host/vuln mapping to associate
    # all vulnerable hosts to their vulnerability data within the
    # context of the expected Lair schema structure.
    for plugin_id, data in vuln_host_map.items():

        # Build list of host and ports affected by vulnerability and
        # assign that list to the vulnerability model
        for key in data['hosts']:
            (string_addr, port, protocol) = key.split(':')

            host_key_dict = dict(models.host_key_model)
            host_key_dict['string_addr'] = string_addr
            host_key_dict['port'] = int(port)
            host_key_dict['protocol'] = protocol
            data['vuln']['hosts'].append(host_key_dict)

        # By default, don't include informational findings unless
        # explicitly told to do so.
        if data['vuln']['cvss'] == 0 and not include_informational:
            continue

        project_dict['vulnerabilities'].append(data['vuln'])

    return project_dict

Example 32

View license
def parse(project, nessus_file, include_informational=False, min_note_sev=2):
    """Parses a Nessus XMLv2 file and updates the Hive database

    :param project: The project id
    :param nessus_file: The Nessus xml file to be parsed
    :param include_informational: Whether to include info findings in data. Default False
    :min_note_sev: The minimum severity of notes that will be saved. Default 2
    """

    cve_pattern = re.compile(r'(CVE-|CAN-)')
    false_udp_pattern = re.compile(r'.*\?$')

    tree = et.parse(nessus_file)
    root = tree.getroot()
    note_id = 1

    # Create the project dictionary which acts as foundation of document
    project_dict = dict(models.project_model)
    project_dict['commands'] = list()
    project_dict['vulnerabilities'] = list()
    project_dict['project_id'] = project

    # Used to maintain a running list of host:port vulnerabilities by plugin
    vuln_host_map = dict()

    for host in root.iter('ReportHost'):

        temp_ip = host.attrib['name']

        host_dict = dict(models.host_model)
        host_dict['os'] = list()
        host_dict['ports'] = list()
        host_dict['hostnames'] = list()

        # Tags contain host-specific information
        for tag in host.iter('tag'):

            # Operating system tag
            if tag.attrib['name'] == 'operating-system':
                os_dict = dict(models.os_model)
                os_dict['tool'] = TOOL
                os_dict['weight'] = OS_WEIGHT
                os_dict['fingerprint'] = tag.text
                host_dict['os'].append(os_dict)

            # IP address tag
            if tag.attrib['name'] == 'host-ip':
                host_dict['string_addr'] = tag.text
                host_dict['long_addr'] = helper.ip2long(tag.text)

            # MAC address tag
            if tag.attrib['name'] == 'mac-address':
                host_dict['mac_addr'] = tag.text

            # Hostname tag
            if tag.attrib['name'] == 'host-fqdn':
                host_dict['hostnames'].append(tag.text)

            # NetBIOS name tag
            if tag.attrib['name'] == 'netbios-name':
                host_dict['hostnames'].append(tag.text)

        # Track the unique port/protocol combos for a host so we don't
        # add duplicate entries
        ports_processed = dict()

        # Process each 'ReportItem'
        for item in host.findall('ReportItem'):
            plugin_id = item.attrib['pluginID']
            plugin_family = item.attrib['pluginFamily']
            severity = int(item.attrib['severity'])
            title = item.attrib['pluginName']

            port = int(item.attrib['port'])
            protocol = item.attrib['protocol']
            service = item.attrib['svc_name']
            evidence = item.find('plugin_output')

            # Ignore false positive UDP services
            if protocol == "udp" and false_udp_pattern.match(service):
                continue

            # Create a port model and temporarily store it in the dict
            # for tracking purposes. The ports_processed dict is used
            # later to add ports to the host so that no duplicates are
            # present. This is necessary due to the format of the Nessus
            # XML files.
            if '{0}:{1}'.format(port, protocol) not in ports_processed:
                port_dict = copy.deepcopy(models.port_model)
                port_dict['port'] = port
                port_dict['protocol'] = protocol
                port_dict['service'] = service
                ports_processed['{0}:{1}'.format(port, protocol)] = port_dict

            # Set the evidence as a port note if it exists
            if evidence is not None and \
                    severity >= min_note_sev and \
                    plugin_family != 'Port scanners' and \
                    plugin_family != 'Service detection':
                note_dict = copy.deepcopy(models.note_model)
                note_dict['title'] = "{0} (ID{1})".format(title, str(note_id))
                e = evidence.text.strip()
                for line in e.split("\n"):
                    line = line.strip()
                    if line:
                        note_dict['content'] += "    " + line + "\n"
                note_dict['last_modified_by'] = TOOL
                ports_processed['{0}:{1}'.format(port, protocol)]['notes'].append(note_dict)
                note_id += 1

            # This plugin is general scan info...use it for 'command' element
            if plugin_id == '19506':

                command = item.find('plugin_output')

                command_dict = dict(models.command_model)
                command_dict['tool'] = TOOL

                if command is not None:
                    command_dict['command'] = command.text

                if not project_dict['commands']:
                    project_dict['commands'].append(command_dict)

                continue

            # Check if this vulnerability has been seen in this file for
            # another host. If not, create a new vulnerability_model and
            # maintain a mapping between plugin-id and vulnerability as
            # well as a mapping between plugin-id and host. These mappings
            # are later used to completed the Hive schema such that host
            # IP and port information are embedded within each vulnerability
            # while ensuring no duplicate data exists.
            if plugin_id not in vuln_host_map:

                v = copy.deepcopy(models.vulnerability_model)
                v['cves'] = list()
                v['plugin_ids'] = list()
                v['identified_by'] = list()
                v['hosts'] = list()

                # Set the title
                v['title'] = title

                # Set the description
                description = item.find('description')
                if description is not None:
                    v['description'] = description.text

                # Set the solution
                solution = item.find('solution')
                if solution is not None:
                    v['solution'] = solution.text

                # Set the evidence
                if evidence is not None:
                    v['evidence'] = evidence.text

                # Set the vulnerability flag if exploit exists
                exploit = item.find('exploit_available')
                if exploit is not None:
                    v['flag'] = exploit.text == 'true'

                    # Grab Metasploit details
                    exploit_detail = item.find('exploit_framework_metasploit')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Metasploit Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('metasploit_name')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab Canvas details
                    exploit_detail = item.find('exploit_framework_canvas')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Canvas Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('canvas_package')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab Core Impact details
                    exploit_detail = item.find('exploit_framework_core')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Core Impact Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('core_name')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab ExploitHub SKUs
                    exploit_detail = item.find('exploit_framework_exploithub')
                    if exploit_detail is not None and \
                            exploit_detail.text == 'true':
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = 'Exploit Hub Exploit'
                        note_dict['content'] = 'Exploit exists. Details unknown.'
                        module = item.find('exploithub_sku')
                        if module is not None:
                            note_dict['content'] = module.text
                        note_dict['last_modified_by'] = TOOL
                        v['notes'].append(note_dict)

                    # Grab any and all ExploitDB IDs
                    details = item.iter('edb-id')
                    if details is not None:
                        for module in details:
                            note_dict = copy.deepcopy(models.note_model)
                            note_dict['title'] = 'Exploit-DB Exploit ' \
                                                 '({0})'.format(module.text)
                            note_dict['content'] = module.text
                            note_dict['last_modified_by'] = TOOL
                            v['notes'].append(note_dict)

                # Set the CVSS score
                cvss = item.find('cvss_base_score')
                if cvss is not None:
                    v['cvss'] = float(cvss.text)
                else:
                    risk_factor = item.find('risk_factor')
                    if risk_factor is not None:
                        rf = risk_factor.text
                        if rf == "Low":
                            v['cvss'] = 3.0
                        elif rf == "Medium":
                            v['cvss'] = 5.0
                        elif rf == "High":
                            v['cvss'] = 7.5
                        elif rf == "Critical":
                            v['cvss'] = 10.0

                # Set the CVE(s)
                for cve in item.findall('cve'):
                    c = cve_pattern.sub('', cve.text)
                    v['cves'].append(c)

                # Set the plugin information
                plugin_dict = dict(models.plugin_id_model)
                plugin_dict['tool'] = TOOL
                plugin_dict['id'] = plugin_id
                v['plugin_ids'].append(plugin_dict)

                # Set the identified by information
                identified_dict = dict(models.identified_by_model)
                identified_dict['tool'] = TOOL
                identified_dict['id'] = plugin_id
                v['identified_by'].append(identified_dict)

                # By default, don't include informational findings unless
                # explicitly told to do so.
                if v['cvss'] == 0 and not include_informational:
                    continue

                vuln_host_map[plugin_id] = dict()
                vuln_host_map[plugin_id]['hosts'] = set()
                vuln_host_map[plugin_id]['vuln'] = v

            if plugin_id in vuln_host_map:
                vuln_host_map[plugin_id]['hosts'].add(
                    "{0}:{1}:{2}".format(
                        host_dict['string_addr'],
                        str(port),
                        protocol
                    )
                )

        # In the event no IP was found, use the 'name' attribute of
        # the 'ReportHost' element
        if not host_dict['string_addr']:
            host_dict['string_addr'] = temp_ip
            host_dict['long_addr'] = helper.ip2long(temp_ip)

        # Add all encountered ports to the host
        host_dict['ports'].extend(ports_processed.values())

        project_dict['hosts'].append(host_dict)

    # This code block uses the plugin/host/vuln mapping to associate
    # all vulnerable hosts to their vulnerability data within the
    # context of the expected Hive schema structure.
    for plugin_id, data in vuln_host_map.items():

        # Build list of host and ports affected by vulnerability and
        # assign that list to the vulnerability model
        for key in data['hosts']:
            (string_addr, port, protocol) = key.split(':')

            host_key_dict = dict(models.host_key_model)
            host_key_dict['string_addr'] = string_addr
            host_key_dict['port'] = int(port)
            host_key_dict['protocol'] = protocol
            data['vuln']['hosts'].append(host_key_dict)

        project_dict['vulnerabilities'].append(data['vuln'])

    if not project_dict['commands']:
        # Adds a dummy 'command' in the event the the Nessus plugin used
        # to populate the data was not run. The Lair API expects it to
        # contain a value.
        command = copy.deepcopy(models.command_model)
        command['tool'] = TOOL
        command['command'] = "Nessus scan - command unknown"
        project_dict['commands'].append(command)

    return project_dict

Example 33

View license
def parse(project, nexpose_file, include_informational=False):
    """Parses a Nexpose XMLv2 file and updates the Lair database

    :param project: The project id
    :param nexpose_file: The Nexpose xml file to be parsed
    :include_informational: Whether to include info findings in data. Default False
    """

    cve_pattern = re.compile(r'(CVE-|CAN-)')
    html_tag_pattern = re.compile(r'<.*?>')
    white_space_pattern = re.compile(r'\s+', re.MULTILINE)

    # Used to create unique notes in DB
    note_id = 1

    tree = et.parse(nexpose_file)
    root = tree.getroot()
    if root is None or \
            root.tag != "NexposeReport" or \
            root.attrib['version'] != "2.0":
        raise IncompatibleDataVersionError("Nexpose XML 2.0")

    # Create the project dictionary which acts as foundation of document
    project_dict = dict(models.project_model)
    project_dict['commands'] = list()
    project_dict['vulnerabilities'] = list()
    project_dict['project_id'] = project
    project_dict['commands'].append({'tool': TOOL, 'command': 'scan'})

    # Used to maintain a running list of host:port vulnerabilities by plugin
    vuln_host_map = dict()

    for vuln in root.iter('vulnerability'):
        v = copy.deepcopy(models.vulnerability_model)
        v['cves'] = list()
        v['plugin_ids'] = list()
        v['identified_by'] = list()
        v['hosts'] = list()

        v['cvss'] = float(vuln.attrib['cvssScore'])
        v['title'] = vuln.attrib['title']
        plugin_id = vuln.attrib['id'].lower()

        # Set plugin id
        plugin_dict = dict(models.plugin_id_model)
        plugin_dict['tool'] = TOOL
        plugin_dict['id'] = plugin_id
        v['plugin_ids'].append(plugin_dict)

        # Set identified by information
        identified_dict = dict(models.identified_by_model)
        identified_dict['tool'] = TOOL
        identified_dict['id'] = plugin_id
        v['identified_by'].append(identified_dict)

        # Search for exploits
        for exploit in vuln.iter('exploit'):
            v['flag'] = True
            note_dict = copy.deepcopy(models.note_model)
            note_dict['title'] = "{0} ({1})".format(
                exploit.attrib['type'],
                exploit.attrib['id']
            )
            note_dict['content'] = "{0}\n{1}".format(
                exploit.attrib['title'].encode('ascii', 'replace'),
                exploit.attrib['link'].encode('ascii', 'replace')
            )
            note_dict['last_modified_by'] = TOOL
            v['notes'].append(note_dict)

        # Search for CVE references
        for reference in vuln.iter('reference'):
            if reference.attrib['source'] == 'CVE':
                cve = cve_pattern.sub('', reference.text)
                v['cves'].append(cve)

        # Search for solution
        solution = vuln.find('solution')
        if solution is not None:
            for text in solution.itertext():
                s = text.encode('ascii', 'replace').strip()
                v['solution'] += white_space_pattern.sub(" ", s)

        # Search for description
        description = vuln.find('description')
        if description is not None:
            for text in description.itertext():
                s = text.encode('ascii', 'replace').strip()
                v['description'] += white_space_pattern.sub(" ", s)

        # Build mapping of plugin-id to host to vuln dictionary
        vuln_host_map[plugin_id] = dict()
        vuln_host_map[plugin_id]['vuln'] = v
        vuln_host_map[plugin_id]['hosts'] = set()

    for node in root.iter('node'):

        host_dict = dict(models.host_model)
        host_dict['os'] = list()
        host_dict['ports'] = list()
        host_dict['hostnames'] = list()

        # Set host status
        if node.attrib['status'] != 'alive':
            host_dict['alive'] = False

        # Set IP address
        host_dict['string_addr'] = node.attrib['address']
        host_dict['long_addr'] = helper.ip2long(node.attrib['address'])

        # Set the OS fingerprint
        certainty = 0
        for os in node.iter('os'):
            if float(os.attrib['certainty']) > certainty:
                certainty = float(os.attrib['certainty'])
                os_dict = dict(models.os_model)
                os_dict['tool'] = TOOL
                os_dict['weight'] = OS_WEIGHT

                fingerprint = ''
                if 'vendor' in os.attrib:
                    fingerprint += os.attrib['vendor'] + " "

                # Make an extra check to limit duplication of data in the
                # event that the product name was already in the vendor name
                if 'product' in os.attrib and \
                        os.attrib['product'] not in fingerprint:
                    fingerprint += os.attrib['product'] + " "

                fingerprint = fingerprint.strip()
                os_dict['fingerprint'] = fingerprint

                host_dict['os'] = list()
                host_dict['os'].append(os_dict)

        # Test for general, non-port related vulnerabilities
        # Add them as tcp, port 0
        tests = node.find('tests')
        if tests is not None:
            port_dict = dict(models.port_model)
            port_dict['service'] = "general"

            for test in tests.findall('test'):
                # vulnerable-since attribute is used to flag
                # confirmed vulns
                if 'vulnerable-since' in test.attrib:
                    plugin_id = test.attrib['id'].lower()

                    # This is used to track evidence for the host/port
                    # and plugin
                    h = "{0}:{1}:{2}".format(
                        host_dict['string_addr'],
                        "0",
                        models.PROTOCOL_TCP
                    )
                    vuln_host_map[plugin_id]['hosts'].add(h)

            host_dict['ports'].append(port_dict)

        # Use the endpoint elements to populate port data
        for endpoint in node.iter('endpoint'):
            port_dict = copy.deepcopy(models.port_model)
            port_dict['port'] = int(endpoint.attrib['port'])
            port_dict['protocol'] = endpoint.attrib['protocol']
            if endpoint.attrib['status'] != 'open':
                port_dict['alive'] = False

            # Use the service elements to identify service
            for service in endpoint.iter('service'):

                # Ignore unknown services
                if 'unknown' not in service.attrib['name'].lower():
                    if not port_dict['service']:
                        port_dict['service'] = service.attrib['name'].lower()

                # Use the test elements to identify vulnerabilities for
                # the host
                for test in service.iter('test'):
                    # vulnerable-since attribute is used to flag
                    # confirmed vulns
                    if 'vulnerable-since' in test.attrib:
                        plugin_id = test.attrib['id'].lower()

                        # Add service notes for evidence
                        note_dict = copy.deepcopy(models.note_model)
                        note_dict['title'] = "{0} (ID{1})".format(plugin_id,
                                                              str(note_id))
                        for evidence in test.iter():
                            if evidence.text:
                                for line in evidence.text.split("\n"):
                                    line = line.strip()
                                    if line:
                                        note_dict['content'] += "    " + \
                                                                line + "\n"
                            elif evidence.tag == "URLLink":
                                note_dict['content'] += "    "
                                note_dict['content'] += evidence.attrib[
                                                            'LinkURL'
                                                        ] + "\n"

                        note_dict['last_modified_by'] = TOOL
                        port_dict['notes'].append(note_dict)
                        note_id += 1

                        # This is used to track evidence for the host/port
                        # and plugin
                        h = "{0}:{1}:{2}".format(
                            host_dict['string_addr'],
                            str(port_dict['port']),
                            port_dict['protocol']
                        )
                        vuln_host_map[plugin_id]['hosts'].add(h)

            # Use the fingerprint elements to identify product
            certainty = 0
            for fingerprint in endpoint.iter('fingerprint'):
                if float(fingerprint.attrib['certainty']) > certainty:
                    certainty = float(fingerprint.attrib['certainty'])
                    prod = ''
                    if 'vendor' in fingerprint.attrib:
                        prod += fingerprint.attrib['vendor'] + " "

                    if 'product' in fingerprint.attrib:
                        prod += fingerprint.attrib['product'] + " "

                    if 'version' in fingerprint.attrib:
                        prod += fingerprint.attrib['version'] + " "

                    prod = prod.strip()
                    port_dict['product'] = prod

            host_dict['ports'].append(port_dict)

        project_dict['hosts'].append(host_dict)

    # This code block uses the plugin/host/vuln mapping to associate
    # all vulnerable hosts to their vulnerability data within the
    # context of the expected Lair schema structure.
    for plugin_id, data in vuln_host_map.items():

        # Build list of host and ports affected by vulnerability and
        # assign that list to the vulnerability model
        for key in data['hosts']:
            (string_addr, port, protocol) = key.split(':')

            host_key_dict = dict(models.host_key_model)
            host_key_dict['string_addr'] = string_addr
            host_key_dict['port'] = int(port)
            host_key_dict['protocol'] = protocol
            data['vuln']['hosts'].append(host_key_dict)

        # By default, don't include informational findings unless
        # explicitly told to do so.
        if data['vuln']['cvss'] == 0 and not include_informational:
            continue

        project_dict['vulnerabilities'].append(data['vuln'])

    return project_dict

Example 34

Project: zmirror
Source File: test_regex.py
View license
    def test__regex_adv_url_rewriter__and__regex_url_reassemble(self):
        test_cases = (
            dict(
                raw='background: url(../images/boardsearch/mso-hd.gif);',
                main='background: url({path_up}/images/boardsearch/mso-hd.gif);',
                ext='background: url(/extdomains/{ext_domain}{path_up}/images/boardsearch/mso-hd.gif);',
            ),
            dict(
                raw='background: url(http://www.google.com/images/boardsearch/mso-hd.gif););',
                main='background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif););',
                ext='background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif););'
            ),
            dict(
                raw=": url('http://www.google.com/images/boardsearch/mso-hd.gif');",
                main=": url('{our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif');",
                ext=": url('{our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif');",
            ),
            dict(
                raw='background: url("//www.google.com/images/boardsearch/mso-hd.gif");',
                main='background: url("//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif");',
                ext='background: url("//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif");',
            ),
            dict(
                raw=r"""background: url ( "//www.google.com/images/boardsearch/mso-hd.gif" );""",
                main=r"""background: url ( "//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif" );""",
                ext=r"""background: url ( "//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif" );""",
            ),
            dict(
                raw=r""" src="https://ssl.gstatic.com/233.jpg" """,
                main=r""" src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
                ext=r""" src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
            ),
            dict(
                raw=r""" src="/233.jpg" """,
                main=r""" src="/233.jpg" """,
                ext=r""" src="/extdomains/{ext_domain}/233.jpg" """,
            ),
            dict(
                raw=r"""href="http://ssl.gstatic.com/233.jpg" """,
                main=r"""href="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
                ext=r"""href="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
            ),
            dict(
                raw=r"""background: url("//ssl.gstatic.com/images/boardsearch/mso-hd.gif"); """,
                main=r"""background: url("//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif"); """,
                ext=r"""background: url("//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif"); """,
            ),
            dict(
                raw=r"""background: url ( "//ssl.gstatic.com/images/boardsearch/mso-hd.gif" ); """,
                main=r"""background: url ( "//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif" ); """,
                ext=r"""background: url ( "//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif" ); """,
            ),
            dict(
                raw=r"""src="http://www.google.com/233.jpg" """,
                main=r"""src="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
                ext=r"""src="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
            ),
            dict(
                raw=r"""href="http://www.google.com/233.jpg" """,
                main=r"""href="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
                ext=r"""href="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
            ),
            dict(
                raw=r"""href="https://www.foo.com/233.jpg" """,
                main=r"""href="https://www.foo.com/233.jpg" """,
                ext=r"""href="https://www.foo.com/233.jpg" """,
            ),
            dict(
                raw=r"""xhref="http://www.google.com/233.jpg" """,
                main=r"""xhref="http://www.google.com/233.jpg" """,
                ext=r"""xhref="http://www.google.com/233.jpg" """,
            ),
            dict(
                raw=r"""s.href="http://www.google.com/path/233.jpg" """,
                main=r"""s.href="{our_scheme}{our_domain}/extdomains/www.google.com/path/233.jpg" """,
                ext=r"""s.href="{our_scheme}{our_domain}/extdomains/www.google.com/path/233.jpg" """,
            ),
            dict(
                raw=r"""background: url(../images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                main=r"""background: url({path_up}/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                ext=r"""background: url(/extdomains/{ext_domain}{path_up}/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
            ),
            dict(
                raw=r"""background: url(http://www.google.com/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                main=r"""background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                ext=r"""background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
            ),
            dict(
                raw=r"""src="http://ssl.gstatic.com/233.jpg?a=x&bb=fr%34fd" """,
                main=r"""src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg?a=x&bb=fr%34fd" """,
                ext=r"""src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg?a=x&bb=fr%34fd" """,
            ),
            dict(
                raw=r"""href="index.php/img/233.jx" """,
                main=r"""href="{path}/index.php/img/233.jx" """,
                ext=r"""href="/extdomains/{ext_domain}{path}/index.php/img/233.jx" """,
            ),
            dict(
                raw=r"""href="/img/233.jss" """,
                main=r"""href="/img/233.jss" """,
                ext=r"""href="/extdomains/{ext_domain}/img/233.jss" """,
            ),
            dict(
                raw=r"""href="img/233.jpg" """,
                main=r"""href="{path}/img/233.jpg" """,
                ext=r"""href="/extdomains/{ext_domain}{path}/img/233.jpg" """,
            ),
            dict(
                raw=r"""nd-image:url(/static/images/project-logos/zhwiki.png)}@media""",
                main=r"""nd-image:url(/static/images/project-logos/zhwiki.png)}@media""",
                ext=r"""nd-image:url(/extdomains/{ext_domain}/static/images/project-logos/zhwiki.png)}@media""",
            ),
            dict(
                raw=r"""nd-image:url(static/images/project-logos/zhwiki.png)}@media""",
                main=r"""nd-image:url({path}/static/images/project-logos/zhwiki.png)}@media""",
                ext=r"""nd-image:url(/extdomains/{ext_domain}{path}/static/images/project-logos/zhwiki.png)}@media""",
            ),
            dict(
                raw=r"""@import "/wikipedia/zh/w/index.php?title=MediaWiki:Gadget-fontsize.css&action=raw&ctype=text/css";""",
                main=r"""@import "/wikipedia/zh/w/index.php?title=MediaWiki:Gadget-fontsize.css&action=raw&ctype=text/css";""",
                ext=r"""@import "/wikipedia/zh/w/index.php?title=MediaWiki:Gadget-fontsize.css&action=raw&ctype=text/css";""",
            ),
            dict(
                raw=r"""(window['gbar']=window['gbar']||{})._CONFIG=[[[0,"www.gstatic.com","og.og2.en_US.8UP-Hyjzcx8.O","com","zh-CN","1",0,[3,2,".40.64.","","1300102,3700275,3700388","1461637855","0"],"40400","LJ8qV4WxEI_QjwOio6SoDw",0,0,"og.og2.w5jrmmcgm1gp.L.F4.O","AA2YrTt48BbbcLnincZsbUECyYqIio-xhw","AA2YrTu9IQdyFrx2T9b82QPSt9PVPEWOIw","",2,0,200,"USA"],null,0,["m;/_/scs/abc-static/_/js/k=gapi.gapi.en.CqFrPIKIxF4.O/m=__features__/rt=j/d=1/rs=AHpOoo_SqGYjlKSpzsbc2UGyTC5n3Z0ZtQ","https://apis.google.com","","","","",null,1,"es_plusone_gc_20160421.0_p0","zh-CN"],["1","gci_91f30755d6a6b787dcc2a4062e6e9824.js","googleapis.client:plusone:gapi.iframes","","zh-CN"],null,null,null,[0.009999999776482582,"com","1",[null,"","w",null,1,5184000,1,0,""],null,[["","","",0,0,-1]],[null,0,0],0,null,null,["5061451","google\\.(com|ru|ca|by|kz|com\\.mx|com\\.tr)$",1]],null,[0,0,0,null,"","","",""],[1,0.001000000047497451,1],[1,0.1000000014901161,2,1],[0,"",null,"",0,"加载您的 Marketplace 应用时出错。","您没有任何 Marketplace 应用。",0,[1,"https://www.google.com/webhp?tab=ww","搜索","","0 -276px",null,0],null,null,1,0],[1],[0,1,["lg"],1,["lat"]],[["","","","","","","","","","","","","","","","","","","","def","","","","","",""],[""]],null,null,null,[30,127,1,0,60],null,null,null,null,null,[1,1]]];(window['gbar']=window['gbar']||{})._LDD=["in","fot"];this.gbar_=this.gbar_||{};(function(_){var window=this;""",
                main=r"""(window['gbar']=window['gbar']||{})._CONFIG=[[[0,"www.gstatic.com","og.og2.en_US.8UP-Hyjzcx8.O","com","zh-CN","1",0,[3,2,".40.64.","","1300102,3700275,3700388","1461637855","0"],"40400","LJ8qV4WxEI_QjwOio6SoDw",0,0,"og.og2.w5jrmmcgm1gp.L.F4.O","AA2YrTt48BbbcLnincZsbUECyYqIio-xhw","AA2YrTu9IQdyFrx2T9b82QPSt9PVPEWOIw","",2,0,200,"USA"],null,0,["m;/_/scs/abc-static/_/js/k=gapi.gapi.en.CqFrPIKIxF4.O/m=__features__/rt=j/d=1/rs=AHpOoo_SqGYjlKSpzsbc2UGyTC5n3Z0ZtQ","https://apis.google.com","","","","",null,1,"es_plusone_gc_20160421.0_p0","zh-CN"],["1","gci_91f30755d6a6b787dcc2a4062e6e9824.js","googleapis.client:plusone:gapi.iframes","","zh-CN"],null,null,null,[0.009999999776482582,"com","1",[null,"","w",null,1,5184000,1,0,""],null,[["","","",0,0,-1]],[null,0,0],0,null,null,["5061451","google\\.(com|ru|ca|by|kz|com\\.mx|com\\.tr)$",1]],null,[0,0,0,null,"","","",""],[1,0.001000000047497451,1],[1,0.1000000014901161,2,1],[0,"",null,"",0,"加载您的 Marketplace 应用时出错。","您没有任何 Marketplace 应用。",0,[1,"https://www.google.com/webhp?tab=ww","搜索","","0 -276px",null,0],null,null,1,0],[1],[0,1,["lg"],1,["lat"]],[["","","","","","","","","","","","","","","","","","","","def","","","","","",""],[""]],null,null,null,[30,127,1,0,60],null,null,null,null,null,[1,1]]];(window['gbar']=window['gbar']||{})._LDD=["in","fot"];this.gbar_=this.gbar_||{};(function(_){var window=this;""",
                ext=r"""(window['gbar']=window['gbar']||{})._CONFIG=[[[0,"www.gstatic.com","og.og2.en_US.8UP-Hyjzcx8.O","com","zh-CN","1",0,[3,2,".40.64.","","1300102,3700275,3700388","1461637855","0"],"40400","LJ8qV4WxEI_QjwOio6SoDw",0,0,"og.og2.w5jrmmcgm1gp.L.F4.O","AA2YrTt48BbbcLnincZsbUECyYqIio-xhw","AA2YrTu9IQdyFrx2T9b82QPSt9PVPEWOIw","",2,0,200,"USA"],null,0,["m;/_/scs/abc-static/_/js/k=gapi.gapi.en.CqFrPIKIxF4.O/m=__features__/rt=j/d=1/rs=AHpOoo_SqGYjlKSpzsbc2UGyTC5n3Z0ZtQ","https://apis.google.com","","","","",null,1,"es_plusone_gc_20160421.0_p0","zh-CN"],["1","gci_91f30755d6a6b787dcc2a4062e6e9824.js","googleapis.client:plusone:gapi.iframes","","zh-CN"],null,null,null,[0.009999999776482582,"com","1",[null,"","w",null,1,5184000,1,0,""],null,[["","","",0,0,-1]],[null,0,0],0,null,null,["5061451","google\\.(com|ru|ca|by|kz|com\\.mx|com\\.tr)$",1]],null,[0,0,0,null,"","","",""],[1,0.001000000047497451,1],[1,0.1000000014901161,2,1],[0,"",null,"",0,"加载您的 Marketplace 应用时出错。","您没有任何 Marketplace 应用。",0,[1,"https://www.google.com/webhp?tab=ww","搜索","","0 -276px",null,0],null,null,1,0],[1],[0,1,["lg"],1,["lat"]],[["","","","","","","","","","","","","","","","","","","","def","","","","","",""],[""]],null,null,null,[30,127,1,0,60],null,null,null,null,null,[1,1]]];(window['gbar']=window['gbar']||{})._LDD=["in","fot"];this.gbar_=this.gbar_||{};(function(_){var window=this;""",
            ),
            dict(
                raw=r""" src="" """,
                main=r""" src="" """,
                ext=r""" src="" """,
            ),
            dict(
                raw=r""" this.src=c; """,
                main=r""" this.src=c; """,
                ext=r""" this.src=c; """,
            ),
            dict(
                raw=r""" href="http://www.google.com/" """,
                main=r""" href="{our_scheme}{our_domain}/extdomains/www.google.com/" """,
                ext=r""" href="{our_scheme}{our_domain}/extdomains/www.google.com/" """,
            ),
            dict(
                raw=r"""_.Gd=function(a){if(_.na(a)||!a||a.Gb)return!1;var c=a.src;if(_.nd(c))return c.uc(a);var d=a.type,e=a.b;c.removeEventListener?c.removeEventListener(d,e,a.fc):c.detachEvent&&c.detachEvent(Cd(d),e);xd--;(d=_.Ad(c))?(td(d,a),0==d.o&&(d.src=null,c[vd]=null)):qd(a);return!0};Cd=function(a){return a in wd?wd[a]:wd[a]="on"+a};Id=function(a,c,d,e){var f=!0;if(a=_.Ad(a))if(c=a.b[c.toString()])for(c=c.concat(),a=0;a<c.length;a++){var g=c[a];g&&g.fc==d&&!g.Gb&&(g=Hd(g,e),f=f&&!1!==g)}return f};""",
                main=r"""_.Gd=function(a){if(_.na(a)||!a||a.Gb)return!1;var c=a.src;if(_.nd(c))return c.uc(a);var d=a.type,e=a.b;c.removeEventListener?c.removeEventListener(d,e,a.fc):c.detachEvent&&c.detachEvent(Cd(d),e);xd--;(d=_.Ad(c))?(td(d,a),0==d.o&&(d.src=null,c[vd]=null)):qd(a);return!0};Cd=function(a){return a in wd?wd[a]:wd[a]="on"+a};Id=function(a,c,d,e){var f=!0;if(a=_.Ad(a))if(c=a.b[c.toString()])for(c=c.concat(),a=0;a<c.length;a++){var g=c[a];g&&g.fc==d&&!g.Gb&&(g=Hd(g,e),f=f&&!1!==g)}return f};""",
                ext=r"""_.Gd=function(a){if(_.na(a)||!a||a.Gb)return!1;var c=a.src;if(_.nd(c))return c.uc(a);var d=a.type,e=a.b;c.removeEventListener?c.removeEventListener(d,e,a.fc):c.detachEvent&&c.detachEvent(Cd(d),e);xd--;(d=_.Ad(c))?(td(d,a),0==d.o&&(d.src=null,c[vd]=null)):qd(a);return!0};Cd=function(a){return a in wd?wd[a]:wd[a]="on"+a};Id=function(a,c,d,e){var f=!0;if(a=_.Ad(a))if(c=a.b[c.toString()])for(c=c.concat(),a=0;a<c.length;a++){var g=c[a];g&&g.fc==d&&!g.Gb&&(g=Hd(g,e),f=f&&!1!==g)}return f};""",
            ),
            dict(
                raw=r"""<script>(function(){window.google={kEI:'wZ4qV6KnMtjwjwOztI2ABQ',kEXPI:'10201868',authuser:0,j:{en:1,bv:24,u:'e4f4906d',qbp:0},kscs:'e4f4906d_24'};google.kHL='zh-CN';})();(function(){google.lc=[];google.li=0;google.getEI=function(a){for(var b;a&&(!a.getAttribute||!(b=a.getAttribute("eid")));)a=a.parentNode;return b||google.kEI};google.getLEI=function(a){for(var b=null;a&&(!a.getAttribute||!(b=a.getAttribute("leid")));)a=a.parentNode;return b};google.https=function(){return"https:"==window.location.protocol};google.ml=function(){return null};google.wl=function(a,b){try{google.ml(Error(a),!1,b)}catch(c){}};google.time=function(){return(new Date).getTime()};google.log=function(a,b,c,e,g){a=google.logUrl(a,b,c,e,g);if(""!=a){b=new Image;var d=google.lc,f=google.li;d[f]=b;b.onerror=b.onload=b.onabort=function(){delete d[f]};window.google&&window.google.vel&&window.google.vel.lu&&window.google.vel.lu(a);b.src=a;google.li=f+1}};google.logUrl=function(a,b,c,e,g){var d="",f=google.ls||"";if(!c&&-1==b.search("&ei=")){var h=google.getEI(e),d="&ei="+h;-1==b.search("&lei=")&&((e=google.getLEI(e))?d+="&lei="+e:h!=google.kEI&&(d+="&lei="+google.kEI))}a=c||"/"+(g||"gen_204")+"?atyp=i&ct="+a+"&cad="+b+d+f+"&zx="+google.time();/^http:/i.test(a)&&google.https()&&(google.ml(Error("a"),!1,{src:a,glmm:1}),a="");return a};google.y={};google.x=function(a,b){google.y[a.id]=[a,b];return!1};google.load=function(a,b,c){google.x({id:a+k++},function(){google.load(a,b,c)})};var k=0;})();""",
                main=r"""<script>(function(){window.google={kEI:'wZ4qV6KnMtjwjwOztI2ABQ',kEXPI:'10201868',authuser:0,j:{en:1,bv:24,u:'e4f4906d',qbp:0},kscs:'e4f4906d_24'};google.kHL='zh-CN';})();(function(){google.lc=[];google.li=0;google.getEI=function(a){for(var b;a&&(!a.getAttribute||!(b=a.getAttribute("eid")));)a=a.parentNode;return b||google.kEI};google.getLEI=function(a){for(var b=null;a&&(!a.getAttribute||!(b=a.getAttribute("leid")));)a=a.parentNode;return b};google.https=function(){return"https:"==window.location.protocol};google.ml=function(){return null};google.wl=function(a,b){try{google.ml(Error(a),!1,b)}catch(c){}};google.time=function(){return(new Date).getTime()};google.log=function(a,b,c,e,g){a=google.logUrl(a,b,c,e,g);if(""!=a){b=new Image;var d=google.lc,f=google.li;d[f]=b;b.onerror=b.onload=b.onabort=function(){delete d[f]};window.google&&window.google.vel&&window.google.vel.lu&&window.google.vel.lu(a);b.src=a;google.li=f+1}};google.logUrl=function(a,b,c,e,g){var d="",f=google.ls||"";if(!c&&-1==b.search("&ei=")){var h=google.getEI(e),d="&ei="+h;-1==b.search("&lei=")&&((e=google.getLEI(e))?d+="&lei="+e:h!=google.kEI&&(d+="&lei="+google.kEI))}a=c||"/"+(g||"gen_204")+"?atyp=i&ct="+a+"&cad="+b+d+f+"&zx="+google.time();/^http:/i.test(a)&&google.https()&&(google.ml(Error("a"),!1,{src:a,glmm:1}),a="");return a};google.y={};google.x=function(a,b){google.y[a.id]=[a,b];return!1};google.load=function(a,b,c){google.x({id:a+k++},function(){google.load(a,b,c)})};var k=0;})();""",
                ext=r"""<script>(function(){window.google={kEI:'wZ4qV6KnMtjwjwOztI2ABQ',kEXPI:'10201868',authuser:0,j:{en:1,bv:24,u:'e4f4906d',qbp:0},kscs:'e4f4906d_24'};google.kHL='zh-CN';})();(function(){google.lc=[];google.li=0;google.getEI=function(a){for(var b;a&&(!a.getAttribute||!(b=a.getAttribute("eid")));)a=a.parentNode;return b||google.kEI};google.getLEI=function(a){for(var b=null;a&&(!a.getAttribute||!(b=a.getAttribute("leid")));)a=a.parentNode;return b};google.https=function(){return"https:"==window.location.protocol};google.ml=function(){return null};google.wl=function(a,b){try{google.ml(Error(a),!1,b)}catch(c){}};google.time=function(){return(new Date).getTime()};google.log=function(a,b,c,e,g){a=google.logUrl(a,b,c,e,g);if(""!=a){b=new Image;var d=google.lc,f=google.li;d[f]=b;b.onerror=b.onload=b.onabort=function(){delete d[f]};window.google&&window.google.vel&&window.google.vel.lu&&window.google.vel.lu(a);b.src=a;google.li=f+1}};google.logUrl=function(a,b,c,e,g){var d="",f=google.ls||"";if(!c&&-1==b.search("&ei=")){var h=google.getEI(e),d="&ei="+h;-1==b.search("&lei=")&&((e=google.getLEI(e))?d+="&lei="+e:h!=google.kEI&&(d+="&lei="+google.kEI))}a=c||"/"+(g||"gen_204")+"?atyp=i&ct="+a+"&cad="+b+d+f+"&zx="+google.time();/^http:/i.test(a)&&google.https()&&(google.ml(Error("a"),!1,{src:a,glmm:1}),a="");return a};google.y={};google.x=function(a,b){google.y[a.id]=[a,b];return!1};google.load=function(a,b,c){google.x({id:a+k++},function(){google.load(a,b,c)})};var k=0;})();""",
            ),
            dict(
                raw=r"""background-image: url("../skin/default/tabs_m_tile.gif");""",
                main=r"""background-image: url("{path_up}/skin/default/tabs_m_tile.gif");""",
                ext=r"""background-image: url("/extdomains/{ext_domain}{path_up}/skin/default/tabs_m_tile.gif");""",
            ),
            dict(
                raw=r"""background-image: url("xx/skin/default/tabs_m_tile.gif");""",
                main=r"""background-image: url("{path}/xx/skin/default/tabs_m_tile.gif");""",
                ext=r"""background-image: url("/extdomains/{ext_domain}{path}/xx/skin/default/tabs_m_tile.gif");""",
            ),
            dict(
                raw=r"""background-image: url('xx/skin/default/tabs_m_tile.gif");""",
                main=r"""background-image: url('xx/skin/default/tabs_m_tile.gif");""",
                ext=r"""background-image: url('xx/skin/default/tabs_m_tile.gif");""",
            ),
            dict(
                raw=r"""} else 2 == e ? this.Ea ? this.Ea.style.display = "" : (e = QS_XA("sbsb_j " + this.$.ef), f = QS_WA("a"), f.id = "sbsb_f", f.href = "http://www.google.com/support/websearch/bin/answer.py?hl=" + this.$.Xe + "&answer=106230", f.innerHTML = this.$.$k, e.appendChild(f), e.onmousedown = QS_c(this.Ia, this), this.Ea = e, this.ma.appendChild(this.Ea)) : 3 == e ? (e = this.cf.pop(), e || (e = QS_WA("li"), e.VLa = !0, f = QS_WA("div", "sbsb_e"), e.appendChild(f)), this.qa.appendChild(e)) : QS_rhb(this, e) &&""",
                main=r"""} else 2 == e ? this.Ea ? this.Ea.style.display = "" : (e = QS_XA("sbsb_j " + this.$.ef), f = QS_WA("a"), f.id = "sbsb_f", f.href = "{our_scheme}{our_domain}/extdomains/www.google.com/support/websearch/bin/answer.py?hl=" + this.$.Xe + "&answer=106230", f.innerHTML = this.$.$k, e.appendChild(f), e.onmousedown = QS_c(this.Ia, this), this.Ea = e, this.ma.appendChild(this.Ea)) : 3 == e ? (e = this.cf.pop(), e || (e = QS_WA("li"), e.VLa = !0, f = QS_WA("div", "sbsb_e"), e.appendChild(f)), this.qa.appendChild(e)) : QS_rhb(this, e) &&""",
                ext=r"""} else 2 == e ? this.Ea ? this.Ea.style.display = "" : (e = QS_XA("sbsb_j " + this.$.ef), f = QS_WA("a"), f.id = "sbsb_f", f.href = "{our_scheme}{our_domain}/extdomains/www.google.com/support/websearch/bin/answer.py?hl=" + this.$.Xe + "&answer=106230", f.innerHTML = this.$.$k, e.appendChild(f), e.onmousedown = QS_c(this.Ia, this), this.Ea = e, this.ma.appendChild(this.Ea)) : 3 == e ? (e = this.cf.pop(), e || (e = QS_WA("li"), e.VLa = !0, f = QS_WA("div", "sbsb_e"), e.appendChild(f)), this.qa.appendChild(e)) : QS_rhb(this, e) &&""",
            ),
            dict(
                raw=r"""m.background = "url(" + f + ") no-repeat " + b.Ea""",
                main=r"""m.background = "url(" + f + ") no-repeat " + b.Ea""",
                ext=r"""m.background = "url(" + f + ") no-repeat " + b.Ea""",
            ),
            dict(
                raw=r"""m.background="url("+f+") no-repeat " + b.Ea""",
                main=r"""m.background="url("+f+") no-repeat " + b.Ea""",
                ext=r"""m.background="url("+f+") no-repeat " + b.Ea""",
            ),
            dict(
                raw=r""" "assetsBasePath" : "https:\/\/encrypted-tbn0.gstatic.com\/a\/1462524371\/", """,
                main=r""" "assetsBasePath" : "{our_scheme_esc}{our_domain}\/extdomains\/encrypted-tbn0.gstatic.com\/a\/1462524371\/", """,
                ext=r""" "assetsBasePath" : "{our_scheme_esc}{our_domain}\/extdomains\/encrypted-tbn0.gstatic.com\/a\/1462524371\/", """,
            ),
            dict(
                raw=r""" " fullName" : "\/i\/start\/Aploium", """,
                main=r""" " fullName" : "\/i\/start\/Aploium", """,
                ext=r""" " fullName" : "\/i\/start\/Aploium", """,
            ),
            dict(
                raw=r"""!0,g=g.replace(/location\.href/gi,QS_qga(l))),e.push(g);if(0<e.length){f=e.join(";");f=f.replace(/,"is":_loc/g,"");f=f.replace(/,"ss":_ss/g,"");f=f.replace(/,"fp":fp/g,"");f=f.replace(/,"r":dr/g,"");try{var t=QS_Mla(f)}catch(w){f=w.EC,e={},f&&(e.EC=f.substr(0,200)),QS_Lla(k,c,"P",e)}try{ba=b.ha,QS_hka(t,ba)}catch(w){QS_Lla(k,c,"X")}}if(d)c=a.lastIndexOf("\x3c/script>"),b.$=0>c?a:a.substr(c+9);else if('"NCSR"'==a)return QS_Lla(k,c,"C"),!1;return!0};""",
                main=r"""!0,g=g.replace(/location\.href/gi,QS_qga(l))),e.push(g);if(0<e.length){f=e.join(";");f=f.replace(/,"is":_loc/g,"");f=f.replace(/,"ss":_ss/g,"");f=f.replace(/,"fp":fp/g,"");f=f.replace(/,"r":dr/g,"");try{var t=QS_Mla(f)}catch(w){f=w.EC,e={},f&&(e.EC=f.substr(0,200)),QS_Lla(k,c,"P",e)}try{ba=b.ha,QS_hka(t,ba)}catch(w){QS_Lla(k,c,"X")}}if(d)c=a.lastIndexOf("\x3c/script>"),b.$=0>c?a:a.substr(c+9);else if('"NCSR"'==a)return QS_Lla(k,c,"C"),!1;return!0};""",
                ext=r"""!0,g=g.replace(/location\.href/gi,QS_qga(l))),e.push(g);if(0<e.length){f=e.join(";");f=f.replace(/,"is":_loc/g,"");f=f.replace(/,"ss":_ss/g,"");f=f.replace(/,"fp":fp/g,"");f=f.replace(/,"r":dr/g,"");try{var t=QS_Mla(f)}catch(w){f=w.EC,e={},f&&(e.EC=f.substr(0,200)),QS_Lla(k,c,"P",e)}try{ba=b.ha,QS_hka(t,ba)}catch(w){QS_Lla(k,c,"X")}}if(d)c=a.lastIndexOf("\x3c/script>"),b.$=0>c?a:a.substr(c+9);else if('"NCSR"'==a)return QS_Lla(k,c,"C"),!1;return!0};""",
            ),
            dict(
                raw=r"""action="/aa/bbb/ccc/ddd" """,
                main=r"""action="/aa/bbb/ccc/ddd" """,
                ext=r"""action="/extdomains/{ext_domain}/aa/bbb/ccc/ddd" """,
            ),
            dict(
                raw=r"""action="/aa" """,
                main=r"""action="/aa" """,
                ext=r"""action="/extdomains/{ext_domain}/aa" """,
            ),
            dict(
                raw=r"""action="/" """,
                main=r"""action="/" """,
                ext=r"""action="/extdomains/{ext_domain}/" """,
            ),
            dict(
                raw=r"""href='{{url}}' """,
                main=r"""href='{{url}}' """,
                ext=r"""href='{{url}}' """,
            ),
            #     dict(
            #         raw=r"""function ctu(oi,ct){var link = document && document.referrer;var esc_link = "";var e = window && window.encodeURIComponent ?encodeURIComponent :escape;if (link){esc_link = e(link);}
            # new Image().src = "/url?sa=T&url=" + esc_link + "&oi=" + e(oi)+ "&ct=" + e(ct);return false;}
            # </script></head><body><div class="_lFe"><div class="_kFe"><font style="font-size:larger"></div></div><div class="_jFe">&nb href="{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91">{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91</a><br>&nbsphref="#" onclick="return go_back();" onmousedown="ctu('unauthorizedredirect','originlink');><br></div></body></html> """,
            #         main=r"""function ctu(oi,ct){var link = document && document.referrer;var esc_link = "";var e = window && window.encodeURIComponent ?encodeURIComponent :escape;if (link){esc_link = e(link);}
            # new Image().src = "/url?sa=T&url=" + esc_link + "&oi=" + e(oi)+ "&ct=" + e(ct);return false;}
            # </script></head><body><div class="_lFe"><div class="_kFe"><font style="font-size:larger"></div></div><div class="_jFe">&nb href="{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91">{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91</a><br>&nbsphref="#" onclick="return go_back();" onmousedown="ctu('unauthorizedredirect','originlink');><br></div></body></html> """,
            #         ext=r"""function ctu(oi,ct){var link = document && document.referrer;var esc_link = "";var e = window && window.encodeURIComponent ?encodeURIComponent :escape;if (link){esc_link = e(link);}
            # new Image().src = "/url?sa=T&url=" + esc_link + "&oi=" + e(oi)+ "&ct=" + e(ct);return false;}
            # </script></head><body><div class="_lFe"><div class="_kFe"><font style="font-size:larger"></div></div><div class="_jFe">&nb href="{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91">{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91</a><br>&nbsphref="#" onclick="return go_back();" onmousedown="ctu('unauthorizedredirect','originlink');><br></div></body></html> """,
            #     ),
            dict(
                raw=r"""<a href="https://t.co/hWOMicwES0" rel="nofollow" dir="ltr" data-expanded-url="http://onforb.es/1NqvWJT" class="twitter-timeline-link" target="_blank" title="http://onforb.es/1NqvWJT"><span class="tco-ellipsis"></span><span class="invisible">http://</span><span class="js-display-url">onforb.es/1NqvWJT</span><span class="invisible"></span><span class="tco-ellipsis"><span class="invisible">&nbsp;</span></span></a>""",
                main=r"""<a href="https://t.co/hWOMicwES0" rel="nofollow" dir="ltr" data-expanded-url="http://onforb.es/1NqvWJT" class="twitter-timeline-link" target="_blank" title="http://onforb.es/1NqvWJT"><span class="tco-ellipsis"></span><span class="invisible">http://</span><span class="js-display-url">onforb.es/1NqvWJT</span><span class="invisible"></span><span class="tco-ellipsis"><span class="invisible">&nbsp;</span></span></a>""",
                ext=r"""<a href="https://t.co/hWOMicwES0" rel="nofollow" dir="ltr" data-expanded-url="http://onforb.es/1NqvWJT" class="twitter-timeline-link" target="_blank" title="http://onforb.es/1NqvWJT"><span class="tco-ellipsis"></span><span class="invisible">http://</span><span class="js-display-url">onforb.es/1NqvWJT</span><span class="invisible"></span><span class="tco-ellipsis"><span class="invisible">&nbsp;</span></span></a>""",
            ),
            dict(
                raw=r"""<a href="#" onClick="window.clipboardData.setData('text', directlink.href); return false;" title="Copy direct-link" class="bglink">[複製]</a>
                        <a href="http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE" class="bglink">http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE</a>
                        <span id="waitoutput">.</span>
                        <BR><BR>
                        <div style="margin:5px;">
                        <a href="http://www.boosme.info" target="_blank"><img src="ad.gif" border="0" width="468" height="60"></a>&nbsp;&nbsp;&nbsp;&nbsp;
                        <a href="http://www.xpj9199.com/Register/?a=64" target="_blank"><img src="http://dioguitar23.co/images/2015-1206-468X60.gif" border="0" width="468" height="60"></a>
                        </div>
                        <BR><BR>""",
                main=r"""<a href="#" onClick="window.clipboardData.setData('text', directlink.href); return false;" title="Copy direct-link" class="bglink">[複製]</a>
                        <a href="http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE" class="bglink">http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE</a>
                        <span id="waitoutput">.</span>
                        <BR><BR>
                        <div style="margin:5px;">
                        <a href="http://www.boosme.info" target="_blank"><img src="{path}/ad.gif" border="0" width="468" height="60"></a>&nbsp;&nbsp;&nbsp;&nbsp;
                        <a href="http://www.xpj9199.com/Register/?a=64" target="_blank"><img src="http://dioguitar23.co/images/2015-1206-468X60.gif" border="0" width="468" height="60"></a>
                        </div>
                        <BR><BR>""",
                ext=r"""<a href="#" onClick="window.clipboardData.setData('text', directlink.href); return false;" title="Copy direct-link" class="bglink">[複製]</a>
                        <a href="http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE" class="bglink">http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE</a>
                        <span id="waitoutput">.</span>
                        <BR><BR>
                        <div style="margin:5px;">
                        <a href="http://www.boosme.info" target="_blank"><img src="/extdomains/{ext_domain}{path}/ad.gif" border="0" width="468" height="60"></a>&nbsp;&nbsp;&nbsp;&nbsp;
                        <a href="http://www.xpj9199.com/Register/?a=64" target="_blank"><img src="http://dioguitar23.co/images/2015-1206-468X60.gif" border="0" width="468" height="60"></a>
                        </div>
                        <BR><BR>""",
            ),
            dict(
                raw=r"""it(); return true;" action="/bankToAcc.action?__continue=997ec1b2e3453a4ec2c69da040dddf6e" method="post">""",
                main=r"""it(); return true;" action="/bankToAcc.action?__continue=997ec1b2e3453a4ec2c69da040dddf6e" method="post">""",
                ext=r"""it(); return true;" action="/extdomains/{ext_domain}/bankToAcc.action?__continue=997ec1b2e3453a4ec2c69da040dddf6e" method="post">""",

            ),
            dict(
                raw=r"""allback'; };window['__google_recaptcha_client'] = true;var po = document.createElement('script'); po.type = 'text/javascript'; po.async = true;po.src = 'https://www.gstatic.com/recaptcha/api2/r20160913151359/recaptcha__zh_cn.js'; var elem = document.querySelector('script[nonce]');var non""",
                main=r"""allback'; };window['__google_recaptcha_client'] = true;var po = document.createElement('script'); po.type = 'text/javascript'; po.async = true;po.src = '{our_scheme}{our_domain}/extdomains/www.gstatic.com/recaptcha/api2/r20160913151359/recaptcha__zh_cn.js'; var elem = document.querySelector('script[nonce]');var non""",
                ext=r"""allback'; };window['__google_recaptcha_client'] = true;var po = document.createElement('script'); po.type = 'text/javascript'; po.async = true;po.src = '{our_scheme}{our_domain}/extdomains/www.gstatic.com/recaptcha/api2/r20160913151359/recaptcha__zh_cn.js'; var elem = document.querySelector('script[nonce]');var non""",
                mime="text/javascript",
            )
        )

        from more_configs import config_google_and_zhwikipedia
        google_config = dict(
            [(k, v)
             for k, v in config_google_and_zhwikipedia.__dict__.items()
             if not k[0].startswith("_") and not k[0].endswith("__")]
        )
        google_config["my_host_name"] = self.C.my_host_name
        google_config["my_host_scheme"] = self.C.my_host_scheme
        google_config["is_use_proxy"] = os.environ.get("ZMIRROR_UNITTEST_INSIDE_GFW") == "True"
        _google_config = copy.deepcopy(google_config)
        # google_config["verbose_level"] = 5

        for path in ("/", "/aaa", "/aaa/", "/aaa/bbb", "/aaa/bbb/", "/aaa/bb/cc", "/aaa/bb/cc/", "/aaa/b/c/dd"):
            # 测试主站
            google_config = copy.deepcopy(_google_config)
            self.reload_zmirror(configs_dict=google_config)
            self.rv = self.client.get(
                self.url(path),
                environ_base=env(),
                headers=headers(),
            )  # type: Response
            for test_case in test_cases:
                self.zmirror.parse.mime = test_case.get("mime", "text/html")
                raw = self._url_format(test_case["raw"])
                main = self._url_format(test_case["main"])
                # ext = url_format(test_case["ext"])

                self.assertEqual(
                    main, self.zmirror.regex_adv_url_rewriter.sub(self.zmirror.regex_url_reassemble, raw),
                    msg=self.dump(msg="raw: {}\npath:{}".format(raw, path))
                )

            # 测试外部站
            google_config = copy.deepcopy(_google_config)
            self.reload_zmirror(configs_dict=google_config)
            self.rv = self.client.get(
                self.url("/extdomains/{domain}{path}".format(domain=self.zmirror.external_domains[0], path=path)),
                environ_base=env(),
                headers=headers(),
            )  # type: Response
            for test_case in test_cases:
                self.zmirror.parse.mime = test_case.get("mime", "text/html")
                raw = self._url_format(test_case["raw"])
                ext = self._url_format(test_case["ext"])

                self.assertEqual(
                    ext, self.zmirror.regex_adv_url_rewriter.sub(self.zmirror.regex_url_reassemble, raw),
                    msg=self.dump(msg="raw: {}\npath:{}".format(raw, path))
                )

Example 35

Project: zmirror
Source File: test_regex.py
View license
    def test__regex_adv_url_rewriter__and__regex_url_reassemble(self):
        test_cases = (
            dict(
                raw='background: url(../images/boardsearch/mso-hd.gif);',
                main='background: url({path_up}/images/boardsearch/mso-hd.gif);',
                ext='background: url(/extdomains/{ext_domain}{path_up}/images/boardsearch/mso-hd.gif);',
            ),
            dict(
                raw='background: url(http://www.google.com/images/boardsearch/mso-hd.gif););',
                main='background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif););',
                ext='background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif););'
            ),
            dict(
                raw=": url('http://www.google.com/images/boardsearch/mso-hd.gif');",
                main=": url('{our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif');",
                ext=": url('{our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif');",
            ),
            dict(
                raw='background: url("//www.google.com/images/boardsearch/mso-hd.gif");',
                main='background: url("//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif");',
                ext='background: url("//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif");',
            ),
            dict(
                raw=r"""background: url ( "//www.google.com/images/boardsearch/mso-hd.gif" );""",
                main=r"""background: url ( "//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif" );""",
                ext=r"""background: url ( "//{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif" );""",
            ),
            dict(
                raw=r""" src="https://ssl.gstatic.com/233.jpg" """,
                main=r""" src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
                ext=r""" src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
            ),
            dict(
                raw=r""" src="/233.jpg" """,
                main=r""" src="/233.jpg" """,
                ext=r""" src="/extdomains/{ext_domain}/233.jpg" """,
            ),
            dict(
                raw=r"""href="http://ssl.gstatic.com/233.jpg" """,
                main=r"""href="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
                ext=r"""href="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg" """,
            ),
            dict(
                raw=r"""background: url("//ssl.gstatic.com/images/boardsearch/mso-hd.gif"); """,
                main=r"""background: url("//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif"); """,
                ext=r"""background: url("//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif"); """,
            ),
            dict(
                raw=r"""background: url ( "//ssl.gstatic.com/images/boardsearch/mso-hd.gif" ); """,
                main=r"""background: url ( "//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif" ); """,
                ext=r"""background: url ( "//{our_domain}/extdomains/ssl.gstatic.com/images/boardsearch/mso-hd.gif" ); """,
            ),
            dict(
                raw=r"""src="http://www.google.com/233.jpg" """,
                main=r"""src="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
                ext=r"""src="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
            ),
            dict(
                raw=r"""href="http://www.google.com/233.jpg" """,
                main=r"""href="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
                ext=r"""href="{our_scheme}{our_domain}/extdomains/www.google.com/233.jpg" """,
            ),
            dict(
                raw=r"""href="https://www.foo.com/233.jpg" """,
                main=r"""href="https://www.foo.com/233.jpg" """,
                ext=r"""href="https://www.foo.com/233.jpg" """,
            ),
            dict(
                raw=r"""xhref="http://www.google.com/233.jpg" """,
                main=r"""xhref="http://www.google.com/233.jpg" """,
                ext=r"""xhref="http://www.google.com/233.jpg" """,
            ),
            dict(
                raw=r"""s.href="http://www.google.com/path/233.jpg" """,
                main=r"""s.href="{our_scheme}{our_domain}/extdomains/www.google.com/path/233.jpg" """,
                ext=r"""s.href="{our_scheme}{our_domain}/extdomains/www.google.com/path/233.jpg" """,
            ),
            dict(
                raw=r"""background: url(../images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                main=r"""background: url({path_up}/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                ext=r"""background: url(/extdomains/{ext_domain}{path_up}/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
            ),
            dict(
                raw=r"""background: url(http://www.google.com/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                main=r"""background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
                ext=r"""background: url({our_scheme}{our_domain}/extdomains/www.google.com/images/boardsearch/mso-hd.gif?a=x&bb=fr%34fd);""",
            ),
            dict(
                raw=r"""src="http://ssl.gstatic.com/233.jpg?a=x&bb=fr%34fd" """,
                main=r"""src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg?a=x&bb=fr%34fd" """,
                ext=r"""src="{our_scheme}{our_domain}/extdomains/ssl.gstatic.com/233.jpg?a=x&bb=fr%34fd" """,
            ),
            dict(
                raw=r"""href="index.php/img/233.jx" """,
                main=r"""href="{path}/index.php/img/233.jx" """,
                ext=r"""href="/extdomains/{ext_domain}{path}/index.php/img/233.jx" """,
            ),
            dict(
                raw=r"""href="/img/233.jss" """,
                main=r"""href="/img/233.jss" """,
                ext=r"""href="/extdomains/{ext_domain}/img/233.jss" """,
            ),
            dict(
                raw=r"""href="img/233.jpg" """,
                main=r"""href="{path}/img/233.jpg" """,
                ext=r"""href="/extdomains/{ext_domain}{path}/img/233.jpg" """,
            ),
            dict(
                raw=r"""nd-image:url(/static/images/project-logos/zhwiki.png)}@media""",
                main=r"""nd-image:url(/static/images/project-logos/zhwiki.png)}@media""",
                ext=r"""nd-image:url(/extdomains/{ext_domain}/static/images/project-logos/zhwiki.png)}@media""",
            ),
            dict(
                raw=r"""nd-image:url(static/images/project-logos/zhwiki.png)}@media""",
                main=r"""nd-image:url({path}/static/images/project-logos/zhwiki.png)}@media""",
                ext=r"""nd-image:url(/extdomains/{ext_domain}{path}/static/images/project-logos/zhwiki.png)}@media""",
            ),
            dict(
                raw=r"""@import "/wikipedia/zh/w/index.php?title=MediaWiki:Gadget-fontsize.css&action=raw&ctype=text/css";""",
                main=r"""@import "/wikipedia/zh/w/index.php?title=MediaWiki:Gadget-fontsize.css&action=raw&ctype=text/css";""",
                ext=r"""@import "/wikipedia/zh/w/index.php?title=MediaWiki:Gadget-fontsize.css&action=raw&ctype=text/css";""",
            ),
            dict(
                raw=r"""(window['gbar']=window['gbar']||{})._CONFIG=[[[0,"www.gstatic.com","og.og2.en_US.8UP-Hyjzcx8.O","com","zh-CN","1",0,[3,2,".40.64.","","1300102,3700275,3700388","1461637855","0"],"40400","LJ8qV4WxEI_QjwOio6SoDw",0,0,"og.og2.w5jrmmcgm1gp.L.F4.O","AA2YrTt48BbbcLnincZsbUECyYqIio-xhw","AA2YrTu9IQdyFrx2T9b82QPSt9PVPEWOIw","",2,0,200,"USA"],null,0,["m;/_/scs/abc-static/_/js/k=gapi.gapi.en.CqFrPIKIxF4.O/m=__features__/rt=j/d=1/rs=AHpOoo_SqGYjlKSpzsbc2UGyTC5n3Z0ZtQ","https://apis.google.com","","","","",null,1,"es_plusone_gc_20160421.0_p0","zh-CN"],["1","gci_91f30755d6a6b787dcc2a4062e6e9824.js","googleapis.client:plusone:gapi.iframes","","zh-CN"],null,null,null,[0.009999999776482582,"com","1",[null,"","w",null,1,5184000,1,0,""],null,[["","","",0,0,-1]],[null,0,0],0,null,null,["5061451","google\\.(com|ru|ca|by|kz|com\\.mx|com\\.tr)$",1]],null,[0,0,0,null,"","","",""],[1,0.001000000047497451,1],[1,0.1000000014901161,2,1],[0,"",null,"",0,"加载您的 Marketplace 应用时出错。","您没有任何 Marketplace 应用。",0,[1,"https://www.google.com/webhp?tab=ww","搜索","","0 -276px",null,0],null,null,1,0],[1],[0,1,["lg"],1,["lat"]],[["","","","","","","","","","","","","","","","","","","","def","","","","","",""],[""]],null,null,null,[30,127,1,0,60],null,null,null,null,null,[1,1]]];(window['gbar']=window['gbar']||{})._LDD=["in","fot"];this.gbar_=this.gbar_||{};(function(_){var window=this;""",
                main=r"""(window['gbar']=window['gbar']||{})._CONFIG=[[[0,"www.gstatic.com","og.og2.en_US.8UP-Hyjzcx8.O","com","zh-CN","1",0,[3,2,".40.64.","","1300102,3700275,3700388","1461637855","0"],"40400","LJ8qV4WxEI_QjwOio6SoDw",0,0,"og.og2.w5jrmmcgm1gp.L.F4.O","AA2YrTt48BbbcLnincZsbUECyYqIio-xhw","AA2YrTu9IQdyFrx2T9b82QPSt9PVPEWOIw","",2,0,200,"USA"],null,0,["m;/_/scs/abc-static/_/js/k=gapi.gapi.en.CqFrPIKIxF4.O/m=__features__/rt=j/d=1/rs=AHpOoo_SqGYjlKSpzsbc2UGyTC5n3Z0ZtQ","https://apis.google.com","","","","",null,1,"es_plusone_gc_20160421.0_p0","zh-CN"],["1","gci_91f30755d6a6b787dcc2a4062e6e9824.js","googleapis.client:plusone:gapi.iframes","","zh-CN"],null,null,null,[0.009999999776482582,"com","1",[null,"","w",null,1,5184000,1,0,""],null,[["","","",0,0,-1]],[null,0,0],0,null,null,["5061451","google\\.(com|ru|ca|by|kz|com\\.mx|com\\.tr)$",1]],null,[0,0,0,null,"","","",""],[1,0.001000000047497451,1],[1,0.1000000014901161,2,1],[0,"",null,"",0,"加载您的 Marketplace 应用时出错。","您没有任何 Marketplace 应用。",0,[1,"https://www.google.com/webhp?tab=ww","搜索","","0 -276px",null,0],null,null,1,0],[1],[0,1,["lg"],1,["lat"]],[["","","","","","","","","","","","","","","","","","","","def","","","","","",""],[""]],null,null,null,[30,127,1,0,60],null,null,null,null,null,[1,1]]];(window['gbar']=window['gbar']||{})._LDD=["in","fot"];this.gbar_=this.gbar_||{};(function(_){var window=this;""",
                ext=r"""(window['gbar']=window['gbar']||{})._CONFIG=[[[0,"www.gstatic.com","og.og2.en_US.8UP-Hyjzcx8.O","com","zh-CN","1",0,[3,2,".40.64.","","1300102,3700275,3700388","1461637855","0"],"40400","LJ8qV4WxEI_QjwOio6SoDw",0,0,"og.og2.w5jrmmcgm1gp.L.F4.O","AA2YrTt48BbbcLnincZsbUECyYqIio-xhw","AA2YrTu9IQdyFrx2T9b82QPSt9PVPEWOIw","",2,0,200,"USA"],null,0,["m;/_/scs/abc-static/_/js/k=gapi.gapi.en.CqFrPIKIxF4.O/m=__features__/rt=j/d=1/rs=AHpOoo_SqGYjlKSpzsbc2UGyTC5n3Z0ZtQ","https://apis.google.com","","","","",null,1,"es_plusone_gc_20160421.0_p0","zh-CN"],["1","gci_91f30755d6a6b787dcc2a4062e6e9824.js","googleapis.client:plusone:gapi.iframes","","zh-CN"],null,null,null,[0.009999999776482582,"com","1",[null,"","w",null,1,5184000,1,0,""],null,[["","","",0,0,-1]],[null,0,0],0,null,null,["5061451","google\\.(com|ru|ca|by|kz|com\\.mx|com\\.tr)$",1]],null,[0,0,0,null,"","","",""],[1,0.001000000047497451,1],[1,0.1000000014901161,2,1],[0,"",null,"",0,"加载您的 Marketplace 应用时出错。","您没有任何 Marketplace 应用。",0,[1,"https://www.google.com/webhp?tab=ww","搜索","","0 -276px",null,0],null,null,1,0],[1],[0,1,["lg"],1,["lat"]],[["","","","","","","","","","","","","","","","","","","","def","","","","","",""],[""]],null,null,null,[30,127,1,0,60],null,null,null,null,null,[1,1]]];(window['gbar']=window['gbar']||{})._LDD=["in","fot"];this.gbar_=this.gbar_||{};(function(_){var window=this;""",
            ),
            dict(
                raw=r""" src="" """,
                main=r""" src="" """,
                ext=r""" src="" """,
            ),
            dict(
                raw=r""" this.src=c; """,
                main=r""" this.src=c; """,
                ext=r""" this.src=c; """,
            ),
            dict(
                raw=r""" href="http://www.google.com/" """,
                main=r""" href="{our_scheme}{our_domain}/extdomains/www.google.com/" """,
                ext=r""" href="{our_scheme}{our_domain}/extdomains/www.google.com/" """,
            ),
            dict(
                raw=r"""_.Gd=function(a){if(_.na(a)||!a||a.Gb)return!1;var c=a.src;if(_.nd(c))return c.uc(a);var d=a.type,e=a.b;c.removeEventListener?c.removeEventListener(d,e,a.fc):c.detachEvent&&c.detachEvent(Cd(d),e);xd--;(d=_.Ad(c))?(td(d,a),0==d.o&&(d.src=null,c[vd]=null)):qd(a);return!0};Cd=function(a){return a in wd?wd[a]:wd[a]="on"+a};Id=function(a,c,d,e){var f=!0;if(a=_.Ad(a))if(c=a.b[c.toString()])for(c=c.concat(),a=0;a<c.length;a++){var g=c[a];g&&g.fc==d&&!g.Gb&&(g=Hd(g,e),f=f&&!1!==g)}return f};""",
                main=r"""_.Gd=function(a){if(_.na(a)||!a||a.Gb)return!1;var c=a.src;if(_.nd(c))return c.uc(a);var d=a.type,e=a.b;c.removeEventListener?c.removeEventListener(d,e,a.fc):c.detachEvent&&c.detachEvent(Cd(d),e);xd--;(d=_.Ad(c))?(td(d,a),0==d.o&&(d.src=null,c[vd]=null)):qd(a);return!0};Cd=function(a){return a in wd?wd[a]:wd[a]="on"+a};Id=function(a,c,d,e){var f=!0;if(a=_.Ad(a))if(c=a.b[c.toString()])for(c=c.concat(),a=0;a<c.length;a++){var g=c[a];g&&g.fc==d&&!g.Gb&&(g=Hd(g,e),f=f&&!1!==g)}return f};""",
                ext=r"""_.Gd=function(a){if(_.na(a)||!a||a.Gb)return!1;var c=a.src;if(_.nd(c))return c.uc(a);var d=a.type,e=a.b;c.removeEventListener?c.removeEventListener(d,e,a.fc):c.detachEvent&&c.detachEvent(Cd(d),e);xd--;(d=_.Ad(c))?(td(d,a),0==d.o&&(d.src=null,c[vd]=null)):qd(a);return!0};Cd=function(a){return a in wd?wd[a]:wd[a]="on"+a};Id=function(a,c,d,e){var f=!0;if(a=_.Ad(a))if(c=a.b[c.toString()])for(c=c.concat(),a=0;a<c.length;a++){var g=c[a];g&&g.fc==d&&!g.Gb&&(g=Hd(g,e),f=f&&!1!==g)}return f};""",
            ),
            dict(
                raw=r"""<script>(function(){window.google={kEI:'wZ4qV6KnMtjwjwOztI2ABQ',kEXPI:'10201868',authuser:0,j:{en:1,bv:24,u:'e4f4906d',qbp:0},kscs:'e4f4906d_24'};google.kHL='zh-CN';})();(function(){google.lc=[];google.li=0;google.getEI=function(a){for(var b;a&&(!a.getAttribute||!(b=a.getAttribute("eid")));)a=a.parentNode;return b||google.kEI};google.getLEI=function(a){for(var b=null;a&&(!a.getAttribute||!(b=a.getAttribute("leid")));)a=a.parentNode;return b};google.https=function(){return"https:"==window.location.protocol};google.ml=function(){return null};google.wl=function(a,b){try{google.ml(Error(a),!1,b)}catch(c){}};google.time=function(){return(new Date).getTime()};google.log=function(a,b,c,e,g){a=google.logUrl(a,b,c,e,g);if(""!=a){b=new Image;var d=google.lc,f=google.li;d[f]=b;b.onerror=b.onload=b.onabort=function(){delete d[f]};window.google&&window.google.vel&&window.google.vel.lu&&window.google.vel.lu(a);b.src=a;google.li=f+1}};google.logUrl=function(a,b,c,e,g){var d="",f=google.ls||"";if(!c&&-1==b.search("&ei=")){var h=google.getEI(e),d="&ei="+h;-1==b.search("&lei=")&&((e=google.getLEI(e))?d+="&lei="+e:h!=google.kEI&&(d+="&lei="+google.kEI))}a=c||"/"+(g||"gen_204")+"?atyp=i&ct="+a+"&cad="+b+d+f+"&zx="+google.time();/^http:/i.test(a)&&google.https()&&(google.ml(Error("a"),!1,{src:a,glmm:1}),a="");return a};google.y={};google.x=function(a,b){google.y[a.id]=[a,b];return!1};google.load=function(a,b,c){google.x({id:a+k++},function(){google.load(a,b,c)})};var k=0;})();""",
                main=r"""<script>(function(){window.google={kEI:'wZ4qV6KnMtjwjwOztI2ABQ',kEXPI:'10201868',authuser:0,j:{en:1,bv:24,u:'e4f4906d',qbp:0},kscs:'e4f4906d_24'};google.kHL='zh-CN';})();(function(){google.lc=[];google.li=0;google.getEI=function(a){for(var b;a&&(!a.getAttribute||!(b=a.getAttribute("eid")));)a=a.parentNode;return b||google.kEI};google.getLEI=function(a){for(var b=null;a&&(!a.getAttribute||!(b=a.getAttribute("leid")));)a=a.parentNode;return b};google.https=function(){return"https:"==window.location.protocol};google.ml=function(){return null};google.wl=function(a,b){try{google.ml(Error(a),!1,b)}catch(c){}};google.time=function(){return(new Date).getTime()};google.log=function(a,b,c,e,g){a=google.logUrl(a,b,c,e,g);if(""!=a){b=new Image;var d=google.lc,f=google.li;d[f]=b;b.onerror=b.onload=b.onabort=function(){delete d[f]};window.google&&window.google.vel&&window.google.vel.lu&&window.google.vel.lu(a);b.src=a;google.li=f+1}};google.logUrl=function(a,b,c,e,g){var d="",f=google.ls||"";if(!c&&-1==b.search("&ei=")){var h=google.getEI(e),d="&ei="+h;-1==b.search("&lei=")&&((e=google.getLEI(e))?d+="&lei="+e:h!=google.kEI&&(d+="&lei="+google.kEI))}a=c||"/"+(g||"gen_204")+"?atyp=i&ct="+a+"&cad="+b+d+f+"&zx="+google.time();/^http:/i.test(a)&&google.https()&&(google.ml(Error("a"),!1,{src:a,glmm:1}),a="");return a};google.y={};google.x=function(a,b){google.y[a.id]=[a,b];return!1};google.load=function(a,b,c){google.x({id:a+k++},function(){google.load(a,b,c)})};var k=0;})();""",
                ext=r"""<script>(function(){window.google={kEI:'wZ4qV6KnMtjwjwOztI2ABQ',kEXPI:'10201868',authuser:0,j:{en:1,bv:24,u:'e4f4906d',qbp:0},kscs:'e4f4906d_24'};google.kHL='zh-CN';})();(function(){google.lc=[];google.li=0;google.getEI=function(a){for(var b;a&&(!a.getAttribute||!(b=a.getAttribute("eid")));)a=a.parentNode;return b||google.kEI};google.getLEI=function(a){for(var b=null;a&&(!a.getAttribute||!(b=a.getAttribute("leid")));)a=a.parentNode;return b};google.https=function(){return"https:"==window.location.protocol};google.ml=function(){return null};google.wl=function(a,b){try{google.ml(Error(a),!1,b)}catch(c){}};google.time=function(){return(new Date).getTime()};google.log=function(a,b,c,e,g){a=google.logUrl(a,b,c,e,g);if(""!=a){b=new Image;var d=google.lc,f=google.li;d[f]=b;b.onerror=b.onload=b.onabort=function(){delete d[f]};window.google&&window.google.vel&&window.google.vel.lu&&window.google.vel.lu(a);b.src=a;google.li=f+1}};google.logUrl=function(a,b,c,e,g){var d="",f=google.ls||"";if(!c&&-1==b.search("&ei=")){var h=google.getEI(e),d="&ei="+h;-1==b.search("&lei=")&&((e=google.getLEI(e))?d+="&lei="+e:h!=google.kEI&&(d+="&lei="+google.kEI))}a=c||"/"+(g||"gen_204")+"?atyp=i&ct="+a+"&cad="+b+d+f+"&zx="+google.time();/^http:/i.test(a)&&google.https()&&(google.ml(Error("a"),!1,{src:a,glmm:1}),a="");return a};google.y={};google.x=function(a,b){google.y[a.id]=[a,b];return!1};google.load=function(a,b,c){google.x({id:a+k++},function(){google.load(a,b,c)})};var k=0;})();""",
            ),
            dict(
                raw=r"""background-image: url("../skin/default/tabs_m_tile.gif");""",
                main=r"""background-image: url("{path_up}/skin/default/tabs_m_tile.gif");""",
                ext=r"""background-image: url("/extdomains/{ext_domain}{path_up}/skin/default/tabs_m_tile.gif");""",
            ),
            dict(
                raw=r"""background-image: url("xx/skin/default/tabs_m_tile.gif");""",
                main=r"""background-image: url("{path}/xx/skin/default/tabs_m_tile.gif");""",
                ext=r"""background-image: url("/extdomains/{ext_domain}{path}/xx/skin/default/tabs_m_tile.gif");""",
            ),
            dict(
                raw=r"""background-image: url('xx/skin/default/tabs_m_tile.gif");""",
                main=r"""background-image: url('xx/skin/default/tabs_m_tile.gif");""",
                ext=r"""background-image: url('xx/skin/default/tabs_m_tile.gif");""",
            ),
            dict(
                raw=r"""} else 2 == e ? this.Ea ? this.Ea.style.display = "" : (e = QS_XA("sbsb_j " + this.$.ef), f = QS_WA("a"), f.id = "sbsb_f", f.href = "http://www.google.com/support/websearch/bin/answer.py?hl=" + this.$.Xe + "&answer=106230", f.innerHTML = this.$.$k, e.appendChild(f), e.onmousedown = QS_c(this.Ia, this), this.Ea = e, this.ma.appendChild(this.Ea)) : 3 == e ? (e = this.cf.pop(), e || (e = QS_WA("li"), e.VLa = !0, f = QS_WA("div", "sbsb_e"), e.appendChild(f)), this.qa.appendChild(e)) : QS_rhb(this, e) &&""",
                main=r"""} else 2 == e ? this.Ea ? this.Ea.style.display = "" : (e = QS_XA("sbsb_j " + this.$.ef), f = QS_WA("a"), f.id = "sbsb_f", f.href = "{our_scheme}{our_domain}/extdomains/www.google.com/support/websearch/bin/answer.py?hl=" + this.$.Xe + "&answer=106230", f.innerHTML = this.$.$k, e.appendChild(f), e.onmousedown = QS_c(this.Ia, this), this.Ea = e, this.ma.appendChild(this.Ea)) : 3 == e ? (e = this.cf.pop(), e || (e = QS_WA("li"), e.VLa = !0, f = QS_WA("div", "sbsb_e"), e.appendChild(f)), this.qa.appendChild(e)) : QS_rhb(this, e) &&""",
                ext=r"""} else 2 == e ? this.Ea ? this.Ea.style.display = "" : (e = QS_XA("sbsb_j " + this.$.ef), f = QS_WA("a"), f.id = "sbsb_f", f.href = "{our_scheme}{our_domain}/extdomains/www.google.com/support/websearch/bin/answer.py?hl=" + this.$.Xe + "&answer=106230", f.innerHTML = this.$.$k, e.appendChild(f), e.onmousedown = QS_c(this.Ia, this), this.Ea = e, this.ma.appendChild(this.Ea)) : 3 == e ? (e = this.cf.pop(), e || (e = QS_WA("li"), e.VLa = !0, f = QS_WA("div", "sbsb_e"), e.appendChild(f)), this.qa.appendChild(e)) : QS_rhb(this, e) &&""",
            ),
            dict(
                raw=r"""m.background = "url(" + f + ") no-repeat " + b.Ea""",
                main=r"""m.background = "url(" + f + ") no-repeat " + b.Ea""",
                ext=r"""m.background = "url(" + f + ") no-repeat " + b.Ea""",
            ),
            dict(
                raw=r"""m.background="url("+f+") no-repeat " + b.Ea""",
                main=r"""m.background="url("+f+") no-repeat " + b.Ea""",
                ext=r"""m.background="url("+f+") no-repeat " + b.Ea""",
            ),
            dict(
                raw=r""" "assetsBasePath" : "https:\/\/encrypted-tbn0.gstatic.com\/a\/1462524371\/", """,
                main=r""" "assetsBasePath" : "{our_scheme_esc}{our_domain}\/extdomains\/encrypted-tbn0.gstatic.com\/a\/1462524371\/", """,
                ext=r""" "assetsBasePath" : "{our_scheme_esc}{our_domain}\/extdomains\/encrypted-tbn0.gstatic.com\/a\/1462524371\/", """,
            ),
            dict(
                raw=r""" " fullName" : "\/i\/start\/Aploium", """,
                main=r""" " fullName" : "\/i\/start\/Aploium", """,
                ext=r""" " fullName" : "\/i\/start\/Aploium", """,
            ),
            dict(
                raw=r"""!0,g=g.replace(/location\.href/gi,QS_qga(l))),e.push(g);if(0<e.length){f=e.join(";");f=f.replace(/,"is":_loc/g,"");f=f.replace(/,"ss":_ss/g,"");f=f.replace(/,"fp":fp/g,"");f=f.replace(/,"r":dr/g,"");try{var t=QS_Mla(f)}catch(w){f=w.EC,e={},f&&(e.EC=f.substr(0,200)),QS_Lla(k,c,"P",e)}try{ba=b.ha,QS_hka(t,ba)}catch(w){QS_Lla(k,c,"X")}}if(d)c=a.lastIndexOf("\x3c/script>"),b.$=0>c?a:a.substr(c+9);else if('"NCSR"'==a)return QS_Lla(k,c,"C"),!1;return!0};""",
                main=r"""!0,g=g.replace(/location\.href/gi,QS_qga(l))),e.push(g);if(0<e.length){f=e.join(";");f=f.replace(/,"is":_loc/g,"");f=f.replace(/,"ss":_ss/g,"");f=f.replace(/,"fp":fp/g,"");f=f.replace(/,"r":dr/g,"");try{var t=QS_Mla(f)}catch(w){f=w.EC,e={},f&&(e.EC=f.substr(0,200)),QS_Lla(k,c,"P",e)}try{ba=b.ha,QS_hka(t,ba)}catch(w){QS_Lla(k,c,"X")}}if(d)c=a.lastIndexOf("\x3c/script>"),b.$=0>c?a:a.substr(c+9);else if('"NCSR"'==a)return QS_Lla(k,c,"C"),!1;return!0};""",
                ext=r"""!0,g=g.replace(/location\.href/gi,QS_qga(l))),e.push(g);if(0<e.length){f=e.join(";");f=f.replace(/,"is":_loc/g,"");f=f.replace(/,"ss":_ss/g,"");f=f.replace(/,"fp":fp/g,"");f=f.replace(/,"r":dr/g,"");try{var t=QS_Mla(f)}catch(w){f=w.EC,e={},f&&(e.EC=f.substr(0,200)),QS_Lla(k,c,"P",e)}try{ba=b.ha,QS_hka(t,ba)}catch(w){QS_Lla(k,c,"X")}}if(d)c=a.lastIndexOf("\x3c/script>"),b.$=0>c?a:a.substr(c+9);else if('"NCSR"'==a)return QS_Lla(k,c,"C"),!1;return!0};""",
            ),
            dict(
                raw=r"""action="/aa/bbb/ccc/ddd" """,
                main=r"""action="/aa/bbb/ccc/ddd" """,
                ext=r"""action="/extdomains/{ext_domain}/aa/bbb/ccc/ddd" """,
            ),
            dict(
                raw=r"""action="/aa" """,
                main=r"""action="/aa" """,
                ext=r"""action="/extdomains/{ext_domain}/aa" """,
            ),
            dict(
                raw=r"""action="/" """,
                main=r"""action="/" """,
                ext=r"""action="/extdomains/{ext_domain}/" """,
            ),
            dict(
                raw=r"""href='{{url}}' """,
                main=r"""href='{{url}}' """,
                ext=r"""href='{{url}}' """,
            ),
            #     dict(
            #         raw=r"""function ctu(oi,ct){var link = document && document.referrer;var esc_link = "";var e = window && window.encodeURIComponent ?encodeURIComponent :escape;if (link){esc_link = e(link);}
            # new Image().src = "/url?sa=T&url=" + esc_link + "&oi=" + e(oi)+ "&ct=" + e(ct);return false;}
            # </script></head><body><div class="_lFe"><div class="_kFe"><font style="font-size:larger"></div></div><div class="_jFe">&nb href="{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91">{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91</a><br>&nbsphref="#" onclick="return go_back();" onmousedown="ctu('unauthorizedredirect','originlink');><br></div></body></html> """,
            #         main=r"""function ctu(oi,ct){var link = document && document.referrer;var esc_link = "";var e = window && window.encodeURIComponent ?encodeURIComponent :escape;if (link){esc_link = e(link);}
            # new Image().src = "/url?sa=T&url=" + esc_link + "&oi=" + e(oi)+ "&ct=" + e(ct);return false;}
            # </script></head><body><div class="_lFe"><div class="_kFe"><font style="font-size:larger"></div></div><div class="_jFe">&nb href="{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91">{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91</a><br>&nbsphref="#" onclick="return go_back();" onmousedown="ctu('unauthorizedredirect','originlink');><br></div></body></html> """,
            #         ext=r"""function ctu(oi,ct){var link = document && document.referrer;var esc_link = "";var e = window && window.encodeURIComponent ?encodeURIComponent :escape;if (link){esc_link = e(link);}
            # new Image().src = "/url?sa=T&url=" + esc_link + "&oi=" + e(oi)+ "&ct=" + e(ct);return false;}
            # </script></head><body><div class="_lFe"><div class="_kFe"><font style="font-size:larger"></div></div><div class="_jFe">&nb href="{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91">{our_scheme}{our_domain}/extdomains/zh.wikipedia.org/zh-cn/%E7%BB%B4%E5%9F%BA%E7%99%BE%E7%A7%91</a><br>&nbsphref="#" onclick="return go_back();" onmousedown="ctu('unauthorizedredirect','originlink');><br></div></body></html> """,
            #     ),
            dict(
                raw=r"""<a href="https://t.co/hWOMicwES0" rel="nofollow" dir="ltr" data-expanded-url="http://onforb.es/1NqvWJT" class="twitter-timeline-link" target="_blank" title="http://onforb.es/1NqvWJT"><span class="tco-ellipsis"></span><span class="invisible">http://</span><span class="js-display-url">onforb.es/1NqvWJT</span><span class="invisible"></span><span class="tco-ellipsis"><span class="invisible">&nbsp;</span></span></a>""",
                main=r"""<a href="https://t.co/hWOMicwES0" rel="nofollow" dir="ltr" data-expanded-url="http://onforb.es/1NqvWJT" class="twitter-timeline-link" target="_blank" title="http://onforb.es/1NqvWJT"><span class="tco-ellipsis"></span><span class="invisible">http://</span><span class="js-display-url">onforb.es/1NqvWJT</span><span class="invisible"></span><span class="tco-ellipsis"><span class="invisible">&nbsp;</span></span></a>""",
                ext=r"""<a href="https://t.co/hWOMicwES0" rel="nofollow" dir="ltr" data-expanded-url="http://onforb.es/1NqvWJT" class="twitter-timeline-link" target="_blank" title="http://onforb.es/1NqvWJT"><span class="tco-ellipsis"></span><span class="invisible">http://</span><span class="js-display-url">onforb.es/1NqvWJT</span><span class="invisible"></span><span class="tco-ellipsis"><span class="invisible">&nbsp;</span></span></a>""",
            ),
            dict(
                raw=r"""<a href="#" onClick="window.clipboardData.setData('text', directlink.href); return false;" title="Copy direct-link" class="bglink">[複製]</a>
                        <a href="http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE" class="bglink">http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE</a>
                        <span id="waitoutput">.</span>
                        <BR><BR>
                        <div style="margin:5px;">
                        <a href="http://www.boosme.info" target="_blank"><img src="ad.gif" border="0" width="468" height="60"></a>&nbsp;&nbsp;&nbsp;&nbsp;
                        <a href="http://www.xpj9199.com/Register/?a=64" target="_blank"><img src="http://dioguitar23.co/images/2015-1206-468X60.gif" border="0" width="468" height="60"></a>
                        </div>
                        <BR><BR>""",
                main=r"""<a href="#" onClick="window.clipboardData.setData('text', directlink.href); return false;" title="Copy direct-link" class="bglink">[複製]</a>
                        <a href="http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE" class="bglink">http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE</a>
                        <span id="waitoutput">.</span>
                        <BR><BR>
                        <div style="margin:5px;">
                        <a href="http://www.boosme.info" target="_blank"><img src="{path}/ad.gif" border="0" width="468" height="60"></a>&nbsp;&nbsp;&nbsp;&nbsp;
                        <a href="http://www.xpj9199.com/Register/?a=64" target="_blank"><img src="http://dioguitar23.co/images/2015-1206-468X60.gif" border="0" width="468" height="60"></a>
                        </div>
                        <BR><BR>""",
                ext=r"""<a href="#" onClick="window.clipboardData.setData('text', directlink.href); return false;" title="Copy direct-link" class="bglink">[複製]</a>
                        <a href="http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE" class="bglink">http://www.bfooru.info/jdc.php?ref=8aYRLJzCCE</a>
                        <span id="waitoutput">.</span>
                        <BR><BR>
                        <div style="margin:5px;">
                        <a href="http://www.boosme.info" target="_blank"><img src="/extdomains/{ext_domain}{path}/ad.gif" border="0" width="468" height="60"></a>&nbsp;&nbsp;&nbsp;&nbsp;
                        <a href="http://www.xpj9199.com/Register/?a=64" target="_blank"><img src="http://dioguitar23.co/images/2015-1206-468X60.gif" border="0" width="468" height="60"></a>
                        </div>
                        <BR><BR>""",
            ),
            dict(
                raw=r"""it(); return true;" action="/bankToAcc.action?__continue=997ec1b2e3453a4ec2c69da040dddf6e" method="post">""",
                main=r"""it(); return true;" action="/bankToAcc.action?__continue=997ec1b2e3453a4ec2c69da040dddf6e" method="post">""",
                ext=r"""it(); return true;" action="/extdomains/{ext_domain}/bankToAcc.action?__continue=997ec1b2e3453a4ec2c69da040dddf6e" method="post">""",

            ),
            dict(
                raw=r"""allback'; };window['__google_recaptcha_client'] = true;var po = document.createElement('script'); po.type = 'text/javascript'; po.async = true;po.src = 'https://www.gstatic.com/recaptcha/api2/r20160913151359/recaptcha__zh_cn.js'; var elem = document.querySelector('script[nonce]');var non""",
                main=r"""allback'; };window['__google_recaptcha_client'] = true;var po = document.createElement('script'); po.type = 'text/javascript'; po.async = true;po.src = '{our_scheme}{our_domain}/extdomains/www.gstatic.com/recaptcha/api2/r20160913151359/recaptcha__zh_cn.js'; var elem = document.querySelector('script[nonce]');var non""",
                ext=r"""allback'; };window['__google_recaptcha_client'] = true;var po = document.createElement('script'); po.type = 'text/javascript'; po.async = true;po.src = '{our_scheme}{our_domain}/extdomains/www.gstatic.com/recaptcha/api2/r20160913151359/recaptcha__zh_cn.js'; var elem = document.querySelector('script[nonce]');var non""",
                mime="text/javascript",
            )
        )

        from more_configs import config_google_and_zhwikipedia
        google_config = dict(
            [(k, v)
             for k, v in config_google_and_zhwikipedia.__dict__.items()
             if not k[0].startswith("_") and not k[0].endswith("__")]
        )
        google_config["my_host_name"] = self.C.my_host_name
        google_config["my_host_scheme"] = self.C.my_host_scheme
        google_config["is_use_proxy"] = os.environ.get("ZMIRROR_UNITTEST_INSIDE_GFW") == "True"
        _google_config = copy.deepcopy(google_config)
        # google_config["verbose_level"] = 5

        for path in ("/", "/aaa", "/aaa/", "/aaa/bbb", "/aaa/bbb/", "/aaa/bb/cc", "/aaa/bb/cc/", "/aaa/b/c/dd"):
            # 测试主站
            google_config = copy.deepcopy(_google_config)
            self.reload_zmirror(configs_dict=google_config)
            self.rv = self.client.get(
                self.url(path),
                environ_base=env(),
                headers=headers(),
            )  # type: Response
            for test_case in test_cases:
                self.zmirror.parse.mime = test_case.get("mime", "text/html")
                raw = self._url_format(test_case["raw"])
                main = self._url_format(test_case["main"])
                # ext = url_format(test_case["ext"])

                self.assertEqual(
                    main, self.zmirror.regex_adv_url_rewriter.sub(self.zmirror.regex_url_reassemble, raw),
                    msg=self.dump(msg="raw: {}\npath:{}".format(raw, path))
                )

            # 测试外部站
            google_config = copy.deepcopy(_google_config)
            self.reload_zmirror(configs_dict=google_config)
            self.rv = self.client.get(
                self.url("/extdomains/{domain}{path}".format(domain=self.zmirror.external_domains[0], path=path)),
                environ_base=env(),
                headers=headers(),
            )  # type: Response
            for test_case in test_cases:
                self.zmirror.parse.mime = test_case.get("mime", "text/html")
                raw = self._url_format(test_case["raw"])
                ext = self._url_format(test_case["ext"])

                self.assertEqual(
                    ext, self.zmirror.regex_adv_url_rewriter.sub(self.zmirror.regex_url_reassemble, raw),
                    msg=self.dump(msg="raw: {}\npath:{}".format(raw, path))
                )

Example 36

Project: mne-python
Source File: test_time_gen.py
View license
@slow_test
@requires_sklearn_0_15
def test_generalization_across_time():
    """Test time generalization decoding
    """
    from sklearn.svm import SVC
    from sklearn.base import is_classifier
    # KernelRidge is used for testing 1) regression analyses 2) n-dimensional
    # predictions.
    from sklearn.kernel_ridge import KernelRidge
    from sklearn.preprocessing import LabelEncoder
    from sklearn.metrics import roc_auc_score, mean_squared_error

    epochs = make_epochs()
    y_4classes = np.hstack((epochs.events[:7, 2], epochs.events[7:, 2] + 1))
    if check_version('sklearn', '0.18'):
        from sklearn.model_selection import (KFold, StratifiedKFold,
                                             ShuffleSplit, LeaveOneGroupOut)
        cv = LeaveOneGroupOut()
        cv_shuffle = ShuffleSplit()
        # XXX we cannot pass any other parameters than X and y to cv.split
        # so we have to build it before hand
        cv_lolo = [(train, test) for train, test in cv.split(
                   y_4classes, y_4classes, y_4classes)]

        # With sklearn >= 0.17, `clf` can be identified as a regressor, and
        # the scoring metrics can therefore be automatically assigned.
        scorer_regress = None
    else:
        from sklearn.cross_validation import (KFold, StratifiedKFold,
                                              ShuffleSplit, LeaveOneLabelOut)
        cv_shuffle = ShuffleSplit(len(epochs))
        cv_lolo = LeaveOneLabelOut(y_4classes)

        # With sklearn < 0.17, `clf` cannot be identified as a regressor, and
        # therefore the scoring metrics cannot be automatically assigned.
        scorer_regress = mean_squared_error
    # Test default running
    gat = GeneralizationAcrossTime(picks='foo')
    assert_equal("<GAT | no fit, no prediction, no score>", "%s" % gat)
    assert_raises(ValueError, gat.fit, epochs)
    with warnings.catch_warnings(record=True):
        # check classic fit + check manual picks
        gat.picks = [0]
        gat.fit(epochs)
        # check optional y as array
        gat.picks = None
        gat.fit(epochs, y=epochs.events[:, 2])
        # check optional y as list
        gat.fit(epochs, y=epochs.events[:, 2].tolist())
    assert_equal(len(gat.picks_), len(gat.ch_names), 1)
    assert_equal("<GAT | fitted, start : -0.200 (s), stop : 0.499 (s), no "
                 "prediction, no score>", '%s' % gat)
    assert_equal(gat.ch_names, epochs.ch_names)
    # test different predict function:
    gat = GeneralizationAcrossTime(predict_method='decision_function')
    gat.fit(epochs)
    # With classifier, the default cv is StratifiedKFold
    assert_true(gat.cv_.__class__ == StratifiedKFold)
    gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 15, 14, 1))
    gat.predict_method = 'predict_proba'
    gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 15, 14, 2))
    gat.predict_method = 'foo'
    assert_raises(NotImplementedError, gat.predict, epochs)
    gat.predict_method = 'predict'
    gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 15, 14, 1))
    assert_equal("<GAT | fitted, start : -0.200 (s), stop : 0.499 (s), "
                 "predicted 14 epochs, no score>",
                 "%s" % gat)
    gat.score(epochs)
    assert_true(gat.scorer_.__name__ == 'accuracy_score')
    # check clf / predict_method combinations for which the scoring metrics
    # cannot be inferred.
    gat.scorer = None
    gat.predict_method = 'decision_function'
    assert_raises(ValueError, gat.score, epochs)
    # Check specifying y manually
    gat.predict_method = 'predict'
    gat.score(epochs, y=epochs.events[:, 2])
    gat.score(epochs, y=epochs.events[:, 2].tolist())
    assert_equal("<GAT | fitted, start : -0.200 (s), stop : 0.499 (s), "
                 "predicted 14 epochs,\n scored "
                 "(accuracy_score)>", "%s" % gat)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs, y=epochs.events[:, 2])

    old_mode = gat.predict_mode
    gat.predict_mode = 'super-foo-mode'
    assert_raises(ValueError, gat.predict, epochs)
    gat.predict_mode = old_mode

    gat.score(epochs, y=epochs.events[:, 2])
    assert_true("accuracy_score" in '%s' % gat.scorer_)
    epochs2 = epochs.copy()

    # check _DecodingTime class
    assert_equal("<DecodingTime | start: -0.200 (s), stop: 0.499 (s), step: "
                 "0.050 (s), length: 0.050 (s), n_time_windows: 15>",
                 "%s" % gat.train_times_)
    assert_equal("<DecodingTime | start: -0.200 (s), stop: 0.499 (s), step: "
                 "0.050 (s), length: 0.050 (s), n_time_windows: 15 x 15>",
                 "%s" % gat.test_times_)

    # the y-check
    gat.predict_mode = 'mean-prediction'
    epochs2.events[:, 2] += 10
    gat_ = copy.deepcopy(gat)
    with use_log_level('error'):
        assert_raises(ValueError, gat_.score, epochs2)
    gat.predict_mode = 'cross-validation'

    # Test basics
    # --- number of trials
    assert_true(gat.y_train_.shape[0] ==
                gat.y_true_.shape[0] ==
                len(gat.y_pred_[0][0]) == 14)
    # ---  number of folds
    assert_true(np.shape(gat.estimators_)[1] == gat.cv)
    # ---  length training size
    assert_true(len(gat.train_times_['slices']) == 15 ==
                np.shape(gat.estimators_)[0])
    # ---  length testing sizes
    assert_true(len(gat.test_times_['slices']) == 15 ==
                np.shape(gat.scores_)[0])
    assert_true(len(gat.test_times_['slices'][0]) == 15 ==
                np.shape(gat.scores_)[1])

    # Test score_mode
    gat.score_mode = 'foo'
    assert_raises(ValueError, gat.score, epochs)
    gat.score_mode = 'fold-wise'
    scores = gat.score(epochs)
    assert_array_equal(np.shape(scores), [15, 15, 5])
    gat.score_mode = 'mean-sample-wise'
    scores = gat.score(epochs)
    assert_array_equal(np.shape(scores), [15, 15])
    gat.score_mode = 'mean-fold-wise'
    scores = gat.score(epochs)
    assert_array_equal(np.shape(scores), [15, 15])
    gat.predict_mode = 'mean-prediction'
    with warnings.catch_warnings(record=True) as w:
        gat.score(epochs)
        assert_true(any("score_mode changed from " in str(ww.message)
                        for ww in w))

    # Test longer time window
    gat = GeneralizationAcrossTime(train_times={'length': .100})
    with warnings.catch_warnings(record=True):
        gat2 = gat.fit(epochs)
    assert_true(gat is gat2)  # return self
    assert_true(hasattr(gat2, 'cv_'))
    assert_true(gat2.cv_ != gat.cv)
    with warnings.catch_warnings(record=True):  # not vectorizing
        scores = gat.score(epochs)
    assert_true(isinstance(scores, np.ndarray))  # type check
    assert_equal(len(scores[0]), len(scores))  # shape check
    assert_equal(len(gat.test_times_['slices'][0][0]), 2)
    # Decim training steps
    gat = GeneralizationAcrossTime(train_times={'step': .100})
    with warnings.catch_warnings(record=True):
        gat.fit(epochs)
    gat.score(epochs)
    assert_true(len(gat.scores_) == len(gat.estimators_) == 8)  # training time
    assert_equal(len(gat.scores_[0]), 15)  # testing time

    # Test start stop training & test cv without n_fold params
    y_4classes = np.hstack((epochs.events[:7, 2], epochs.events[7:, 2] + 1))
    train_times = dict(start=0.090, stop=0.250)
    gat = GeneralizationAcrossTime(cv=cv_lolo, train_times=train_times)
    # predict without fit
    assert_raises(RuntimeError, gat.predict, epochs)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs, y=y_4classes)
    gat.score(epochs)
    assert_equal(len(gat.scores_), 4)
    assert_equal(gat.train_times_['times'][0], epochs.times[6])
    assert_equal(gat.train_times_['times'][-1], epochs.times[9])

    # Test score without passing epochs & Test diagonal decoding
    gat = GeneralizationAcrossTime(test_times='diagonal')
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.fit(epochs)
    assert_raises(RuntimeError, gat.score)
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.predict(epochs)
    scores = gat.score()
    assert_true(scores is gat.scores_)
    assert_equal(np.shape(gat.scores_), (15, 1))
    assert_array_equal([tim for ttime in gat.test_times_['times']
                        for tim in ttime], gat.train_times_['times'])
    # Test generalization across conditions
    gat = GeneralizationAcrossTime(predict_mode='mean-prediction', cv=2)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs[0:6])
    with warnings.catch_warnings(record=True):
        # There are some empty test folds because of n_trials
        gat.predict(epochs[7:])
        gat.score(epochs[7:])

    # Test training time parameters
    gat_ = copy.deepcopy(gat)
    # --- start stop outside time range
    gat_.train_times = dict(start=-999.)
    with use_log_level('error'):
        assert_raises(ValueError, gat_.fit, epochs)
    gat_.train_times = dict(start=999.)
    assert_raises(ValueError, gat_.fit, epochs)
    # --- impossible slices
    gat_.train_times = dict(step=.000001)
    assert_raises(ValueError, gat_.fit, epochs)
    gat_.train_times = dict(length=.000001)
    assert_raises(ValueError, gat_.fit, epochs)
    gat_.train_times = dict(length=999.)
    assert_raises(ValueError, gat_.fit, epochs)

    # Test testing time parameters
    # --- outside time range
    gat.test_times = dict(start=-999.)
    with warnings.catch_warnings(record=True):  # no epochs in fold
        assert_raises(ValueError, gat.predict, epochs)
    gat.test_times = dict(start=999.)
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat.predict, epochs)
    # --- impossible slices
    gat.test_times = dict(step=.000001)
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat.predict, epochs)
    gat_ = copy.deepcopy(gat)
    gat_.train_times_['length'] = .000001
    gat_.test_times = dict(length=.000001)
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat_.predict, epochs)
    # --- test time region of interest
    gat.test_times = dict(step=.150)
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 5, 14, 1))
    # --- silly value
    gat.test_times = 'foo'
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat.predict, epochs)
    assert_raises(RuntimeError, gat.score)
    # --- unmatched length between training and testing time
    gat.test_times = dict(length=.150)
    assert_raises(ValueError, gat.predict, epochs)
    # --- irregular length training and testing times
    # 2 estimators, the first one is trained on two successive time samples
    # whereas the second one is trained on a single time sample.
    train_times = dict(slices=[[0, 1], [1]])
    # The first estimator is tested once, the second estimator is tested on
    # two successive time samples.
    test_times = dict(slices=[[[0, 1]], [[0], [1]]])
    gat = GeneralizationAcrossTime(train_times=train_times,
                                   test_times=test_times)
    gat.fit(epochs)
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.score(epochs)
    assert_array_equal(np.shape(gat.y_pred_[0]), [1, len(epochs), 1])
    assert_array_equal(np.shape(gat.y_pred_[1]), [2, len(epochs), 1])
    # check cannot Automatically infer testing times for adhoc training times
    gat.test_times = None
    assert_raises(ValueError, gat.predict, epochs)

    svc = SVC(C=1, kernel='linear', probability=True)
    gat = GeneralizationAcrossTime(clf=svc, predict_mode='mean-prediction')
    with warnings.catch_warnings(record=True):
        gat.fit(epochs)

    # sklearn needs it: c.f.
    # https://github.com/scikit-learn/scikit-learn/issues/2723
    # and http://bit.ly/1u7t8UT
    with use_log_level('error'):
        assert_raises(ValueError, gat.score, epochs2)
        gat.score(epochs)
    assert_true(0.0 <= np.min(scores) <= 1.0)
    assert_true(0.0 <= np.max(scores) <= 1.0)

    # Test that gets error if train on one dataset, test on another, and don't
    # specify appropriate cv:
    gat = GeneralizationAcrossTime(cv=cv_shuffle)
    gat.fit(epochs)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs)

    gat.predict(epochs)
    assert_raises(ValueError, gat.predict, epochs[:10])

    # Make CV with some empty train and test folds:
    # --- empty test fold(s) should warn when gat.predict()
    gat._cv_splits[0] = [gat._cv_splits[0][0], np.empty(0)]
    with warnings.catch_warnings(record=True) as w:
        gat.predict(epochs)
        assert_true(len(w) > 0)
        assert_true(any('do not have any test epochs' in str(ww.message)
                        for ww in w))
    # --- empty train fold(s) should raise when gat.fit()
    gat = GeneralizationAcrossTime(cv=[([0], [1]), ([], [0])])
    assert_raises(ValueError, gat.fit, epochs[:2])

    # Check that still works with classifier that output y_pred with
    # shape = (n_trials, 1) instead of (n_trials,)
    if check_version('sklearn', '0.17'):  # no is_regressor before v0.17
        gat = GeneralizationAcrossTime(clf=KernelRidge(), cv=2)
        epochs.crop(None, epochs.times[2])
        gat.fit(epochs)
        # With regression the default cv is KFold and not StratifiedKFold
        assert_true(gat.cv_.__class__ == KFold)
        gat.score(epochs)
        # with regression the default scoring metrics is mean squared error
        assert_true(gat.scorer_.__name__ == 'mean_squared_error')

    # Test combinations of complex scenarios
    # 2 or more distinct classes
    n_classes = [2, 4]  # 4 tested
    # nicely ordered labels or not
    le = LabelEncoder()
    y = le.fit_transform(epochs.events[:, 2])
    y[len(y) // 2:] += 2
    ys = (y, y + 1000)
    # Univariate and multivariate prediction
    svc = SVC(C=1, kernel='linear', probability=True)
    reg = KernelRidge()

    def scorer_proba(y_true, y_pred):
        return roc_auc_score(y_true, y_pred[:, 0])

    # We re testing 3 scenario: default, classifier + predict_proba, regressor
    scorers = [None, scorer_proba, scorer_regress]
    predict_methods = [None, 'predict_proba', None]
    clfs = [svc, svc, reg]
    # Test all combinations
    for clf, predict_method, scorer in zip(clfs, predict_methods, scorers):
        for y in ys:
            for n_class in n_classes:
                for predict_mode in ['cross-validation', 'mean-prediction']:
                    # Cannot use AUC for n_class > 2
                    if (predict_method == 'predict_proba' and n_class != 2):
                        continue

                    y_ = y % n_class

                    with warnings.catch_warnings(record=True):
                        gat = GeneralizationAcrossTime(
                            cv=2, clf=clf, scorer=scorer,
                            predict_mode=predict_mode)
                        gat.fit(epochs, y=y_)
                        gat.score(epochs, y=y_)

                    # Check that scorer is correctly defined manually and
                    # automatically.
                    scorer_name = gat.scorer_.__name__
                    if scorer is None:
                        if is_classifier(clf):
                            assert_equal(scorer_name, 'accuracy_score')
                        else:
                            assert_equal(scorer_name, 'mean_squared_error')
                    else:
                        assert_equal(scorer_name, scorer.__name__)

Example 37

Project: mne-python
Source File: test_time_gen.py
View license
@slow_test
@requires_sklearn_0_15
def test_generalization_across_time():
    """Test time generalization decoding
    """
    from sklearn.svm import SVC
    from sklearn.base import is_classifier
    # KernelRidge is used for testing 1) regression analyses 2) n-dimensional
    # predictions.
    from sklearn.kernel_ridge import KernelRidge
    from sklearn.preprocessing import LabelEncoder
    from sklearn.metrics import roc_auc_score, mean_squared_error

    epochs = make_epochs()
    y_4classes = np.hstack((epochs.events[:7, 2], epochs.events[7:, 2] + 1))
    if check_version('sklearn', '0.18'):
        from sklearn.model_selection import (KFold, StratifiedKFold,
                                             ShuffleSplit, LeaveOneGroupOut)
        cv = LeaveOneGroupOut()
        cv_shuffle = ShuffleSplit()
        # XXX we cannot pass any other parameters than X and y to cv.split
        # so we have to build it before hand
        cv_lolo = [(train, test) for train, test in cv.split(
                   y_4classes, y_4classes, y_4classes)]

        # With sklearn >= 0.17, `clf` can be identified as a regressor, and
        # the scoring metrics can therefore be automatically assigned.
        scorer_regress = None
    else:
        from sklearn.cross_validation import (KFold, StratifiedKFold,
                                              ShuffleSplit, LeaveOneLabelOut)
        cv_shuffle = ShuffleSplit(len(epochs))
        cv_lolo = LeaveOneLabelOut(y_4classes)

        # With sklearn < 0.17, `clf` cannot be identified as a regressor, and
        # therefore the scoring metrics cannot be automatically assigned.
        scorer_regress = mean_squared_error
    # Test default running
    gat = GeneralizationAcrossTime(picks='foo')
    assert_equal("<GAT | no fit, no prediction, no score>", "%s" % gat)
    assert_raises(ValueError, gat.fit, epochs)
    with warnings.catch_warnings(record=True):
        # check classic fit + check manual picks
        gat.picks = [0]
        gat.fit(epochs)
        # check optional y as array
        gat.picks = None
        gat.fit(epochs, y=epochs.events[:, 2])
        # check optional y as list
        gat.fit(epochs, y=epochs.events[:, 2].tolist())
    assert_equal(len(gat.picks_), len(gat.ch_names), 1)
    assert_equal("<GAT | fitted, start : -0.200 (s), stop : 0.499 (s), no "
                 "prediction, no score>", '%s' % gat)
    assert_equal(gat.ch_names, epochs.ch_names)
    # test different predict function:
    gat = GeneralizationAcrossTime(predict_method='decision_function')
    gat.fit(epochs)
    # With classifier, the default cv is StratifiedKFold
    assert_true(gat.cv_.__class__ == StratifiedKFold)
    gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 15, 14, 1))
    gat.predict_method = 'predict_proba'
    gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 15, 14, 2))
    gat.predict_method = 'foo'
    assert_raises(NotImplementedError, gat.predict, epochs)
    gat.predict_method = 'predict'
    gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 15, 14, 1))
    assert_equal("<GAT | fitted, start : -0.200 (s), stop : 0.499 (s), "
                 "predicted 14 epochs, no score>",
                 "%s" % gat)
    gat.score(epochs)
    assert_true(gat.scorer_.__name__ == 'accuracy_score')
    # check clf / predict_method combinations for which the scoring metrics
    # cannot be inferred.
    gat.scorer = None
    gat.predict_method = 'decision_function'
    assert_raises(ValueError, gat.score, epochs)
    # Check specifying y manually
    gat.predict_method = 'predict'
    gat.score(epochs, y=epochs.events[:, 2])
    gat.score(epochs, y=epochs.events[:, 2].tolist())
    assert_equal("<GAT | fitted, start : -0.200 (s), stop : 0.499 (s), "
                 "predicted 14 epochs,\n scored "
                 "(accuracy_score)>", "%s" % gat)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs, y=epochs.events[:, 2])

    old_mode = gat.predict_mode
    gat.predict_mode = 'super-foo-mode'
    assert_raises(ValueError, gat.predict, epochs)
    gat.predict_mode = old_mode

    gat.score(epochs, y=epochs.events[:, 2])
    assert_true("accuracy_score" in '%s' % gat.scorer_)
    epochs2 = epochs.copy()

    # check _DecodingTime class
    assert_equal("<DecodingTime | start: -0.200 (s), stop: 0.499 (s), step: "
                 "0.050 (s), length: 0.050 (s), n_time_windows: 15>",
                 "%s" % gat.train_times_)
    assert_equal("<DecodingTime | start: -0.200 (s), stop: 0.499 (s), step: "
                 "0.050 (s), length: 0.050 (s), n_time_windows: 15 x 15>",
                 "%s" % gat.test_times_)

    # the y-check
    gat.predict_mode = 'mean-prediction'
    epochs2.events[:, 2] += 10
    gat_ = copy.deepcopy(gat)
    with use_log_level('error'):
        assert_raises(ValueError, gat_.score, epochs2)
    gat.predict_mode = 'cross-validation'

    # Test basics
    # --- number of trials
    assert_true(gat.y_train_.shape[0] ==
                gat.y_true_.shape[0] ==
                len(gat.y_pred_[0][0]) == 14)
    # ---  number of folds
    assert_true(np.shape(gat.estimators_)[1] == gat.cv)
    # ---  length training size
    assert_true(len(gat.train_times_['slices']) == 15 ==
                np.shape(gat.estimators_)[0])
    # ---  length testing sizes
    assert_true(len(gat.test_times_['slices']) == 15 ==
                np.shape(gat.scores_)[0])
    assert_true(len(gat.test_times_['slices'][0]) == 15 ==
                np.shape(gat.scores_)[1])

    # Test score_mode
    gat.score_mode = 'foo'
    assert_raises(ValueError, gat.score, epochs)
    gat.score_mode = 'fold-wise'
    scores = gat.score(epochs)
    assert_array_equal(np.shape(scores), [15, 15, 5])
    gat.score_mode = 'mean-sample-wise'
    scores = gat.score(epochs)
    assert_array_equal(np.shape(scores), [15, 15])
    gat.score_mode = 'mean-fold-wise'
    scores = gat.score(epochs)
    assert_array_equal(np.shape(scores), [15, 15])
    gat.predict_mode = 'mean-prediction'
    with warnings.catch_warnings(record=True) as w:
        gat.score(epochs)
        assert_true(any("score_mode changed from " in str(ww.message)
                        for ww in w))

    # Test longer time window
    gat = GeneralizationAcrossTime(train_times={'length': .100})
    with warnings.catch_warnings(record=True):
        gat2 = gat.fit(epochs)
    assert_true(gat is gat2)  # return self
    assert_true(hasattr(gat2, 'cv_'))
    assert_true(gat2.cv_ != gat.cv)
    with warnings.catch_warnings(record=True):  # not vectorizing
        scores = gat.score(epochs)
    assert_true(isinstance(scores, np.ndarray))  # type check
    assert_equal(len(scores[0]), len(scores))  # shape check
    assert_equal(len(gat.test_times_['slices'][0][0]), 2)
    # Decim training steps
    gat = GeneralizationAcrossTime(train_times={'step': .100})
    with warnings.catch_warnings(record=True):
        gat.fit(epochs)
    gat.score(epochs)
    assert_true(len(gat.scores_) == len(gat.estimators_) == 8)  # training time
    assert_equal(len(gat.scores_[0]), 15)  # testing time

    # Test start stop training & test cv without n_fold params
    y_4classes = np.hstack((epochs.events[:7, 2], epochs.events[7:, 2] + 1))
    train_times = dict(start=0.090, stop=0.250)
    gat = GeneralizationAcrossTime(cv=cv_lolo, train_times=train_times)
    # predict without fit
    assert_raises(RuntimeError, gat.predict, epochs)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs, y=y_4classes)
    gat.score(epochs)
    assert_equal(len(gat.scores_), 4)
    assert_equal(gat.train_times_['times'][0], epochs.times[6])
    assert_equal(gat.train_times_['times'][-1], epochs.times[9])

    # Test score without passing epochs & Test diagonal decoding
    gat = GeneralizationAcrossTime(test_times='diagonal')
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.fit(epochs)
    assert_raises(RuntimeError, gat.score)
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.predict(epochs)
    scores = gat.score()
    assert_true(scores is gat.scores_)
    assert_equal(np.shape(gat.scores_), (15, 1))
    assert_array_equal([tim for ttime in gat.test_times_['times']
                        for tim in ttime], gat.train_times_['times'])
    # Test generalization across conditions
    gat = GeneralizationAcrossTime(predict_mode='mean-prediction', cv=2)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs[0:6])
    with warnings.catch_warnings(record=True):
        # There are some empty test folds because of n_trials
        gat.predict(epochs[7:])
        gat.score(epochs[7:])

    # Test training time parameters
    gat_ = copy.deepcopy(gat)
    # --- start stop outside time range
    gat_.train_times = dict(start=-999.)
    with use_log_level('error'):
        assert_raises(ValueError, gat_.fit, epochs)
    gat_.train_times = dict(start=999.)
    assert_raises(ValueError, gat_.fit, epochs)
    # --- impossible slices
    gat_.train_times = dict(step=.000001)
    assert_raises(ValueError, gat_.fit, epochs)
    gat_.train_times = dict(length=.000001)
    assert_raises(ValueError, gat_.fit, epochs)
    gat_.train_times = dict(length=999.)
    assert_raises(ValueError, gat_.fit, epochs)

    # Test testing time parameters
    # --- outside time range
    gat.test_times = dict(start=-999.)
    with warnings.catch_warnings(record=True):  # no epochs in fold
        assert_raises(ValueError, gat.predict, epochs)
    gat.test_times = dict(start=999.)
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat.predict, epochs)
    # --- impossible slices
    gat.test_times = dict(step=.000001)
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat.predict, epochs)
    gat_ = copy.deepcopy(gat)
    gat_.train_times_['length'] = .000001
    gat_.test_times = dict(length=.000001)
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat_.predict, epochs)
    # --- test time region of interest
    gat.test_times = dict(step=.150)
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.predict(epochs)
    assert_array_equal(np.shape(gat.y_pred_), (15, 5, 14, 1))
    # --- silly value
    gat.test_times = 'foo'
    with warnings.catch_warnings(record=True):  # no test epochs
        assert_raises(ValueError, gat.predict, epochs)
    assert_raises(RuntimeError, gat.score)
    # --- unmatched length between training and testing time
    gat.test_times = dict(length=.150)
    assert_raises(ValueError, gat.predict, epochs)
    # --- irregular length training and testing times
    # 2 estimators, the first one is trained on two successive time samples
    # whereas the second one is trained on a single time sample.
    train_times = dict(slices=[[0, 1], [1]])
    # The first estimator is tested once, the second estimator is tested on
    # two successive time samples.
    test_times = dict(slices=[[[0, 1]], [[0], [1]]])
    gat = GeneralizationAcrossTime(train_times=train_times,
                                   test_times=test_times)
    gat.fit(epochs)
    with warnings.catch_warnings(record=True):  # not vectorizing
        gat.score(epochs)
    assert_array_equal(np.shape(gat.y_pred_[0]), [1, len(epochs), 1])
    assert_array_equal(np.shape(gat.y_pred_[1]), [2, len(epochs), 1])
    # check cannot Automatically infer testing times for adhoc training times
    gat.test_times = None
    assert_raises(ValueError, gat.predict, epochs)

    svc = SVC(C=1, kernel='linear', probability=True)
    gat = GeneralizationAcrossTime(clf=svc, predict_mode='mean-prediction')
    with warnings.catch_warnings(record=True):
        gat.fit(epochs)

    # sklearn needs it: c.f.
    # https://github.com/scikit-learn/scikit-learn/issues/2723
    # and http://bit.ly/1u7t8UT
    with use_log_level('error'):
        assert_raises(ValueError, gat.score, epochs2)
        gat.score(epochs)
    assert_true(0.0 <= np.min(scores) <= 1.0)
    assert_true(0.0 <= np.max(scores) <= 1.0)

    # Test that gets error if train on one dataset, test on another, and don't
    # specify appropriate cv:
    gat = GeneralizationAcrossTime(cv=cv_shuffle)
    gat.fit(epochs)
    with warnings.catch_warnings(record=True):
        gat.fit(epochs)

    gat.predict(epochs)
    assert_raises(ValueError, gat.predict, epochs[:10])

    # Make CV with some empty train and test folds:
    # --- empty test fold(s) should warn when gat.predict()
    gat._cv_splits[0] = [gat._cv_splits[0][0], np.empty(0)]
    with warnings.catch_warnings(record=True) as w:
        gat.predict(epochs)
        assert_true(len(w) > 0)
        assert_true(any('do not have any test epochs' in str(ww.message)
                        for ww in w))
    # --- empty train fold(s) should raise when gat.fit()
    gat = GeneralizationAcrossTime(cv=[([0], [1]), ([], [0])])
    assert_raises(ValueError, gat.fit, epochs[:2])

    # Check that still works with classifier that output y_pred with
    # shape = (n_trials, 1) instead of (n_trials,)
    if check_version('sklearn', '0.17'):  # no is_regressor before v0.17
        gat = GeneralizationAcrossTime(clf=KernelRidge(), cv=2)
        epochs.crop(None, epochs.times[2])
        gat.fit(epochs)
        # With regression the default cv is KFold and not StratifiedKFold
        assert_true(gat.cv_.__class__ == KFold)
        gat.score(epochs)
        # with regression the default scoring metrics is mean squared error
        assert_true(gat.scorer_.__name__ == 'mean_squared_error')

    # Test combinations of complex scenarios
    # 2 or more distinct classes
    n_classes = [2, 4]  # 4 tested
    # nicely ordered labels or not
    le = LabelEncoder()
    y = le.fit_transform(epochs.events[:, 2])
    y[len(y) // 2:] += 2
    ys = (y, y + 1000)
    # Univariate and multivariate prediction
    svc = SVC(C=1, kernel='linear', probability=True)
    reg = KernelRidge()

    def scorer_proba(y_true, y_pred):
        return roc_auc_score(y_true, y_pred[:, 0])

    # We re testing 3 scenario: default, classifier + predict_proba, regressor
    scorers = [None, scorer_proba, scorer_regress]
    predict_methods = [None, 'predict_proba', None]
    clfs = [svc, svc, reg]
    # Test all combinations
    for clf, predict_method, scorer in zip(clfs, predict_methods, scorers):
        for y in ys:
            for n_class in n_classes:
                for predict_mode in ['cross-validation', 'mean-prediction']:
                    # Cannot use AUC for n_class > 2
                    if (predict_method == 'predict_proba' and n_class != 2):
                        continue

                    y_ = y % n_class

                    with warnings.catch_warnings(record=True):
                        gat = GeneralizationAcrossTime(
                            cv=2, clf=clf, scorer=scorer,
                            predict_mode=predict_mode)
                        gat.fit(epochs, y=y_)
                        gat.score(epochs, y=y_)

                    # Check that scorer is correctly defined manually and
                    # automatically.
                    scorer_name = gat.scorer_.__name__
                    if scorer is None:
                        if is_classifier(clf):
                            assert_equal(scorer_name, 'accuracy_score')
                        else:
                            assert_equal(scorer_name, 'mean_squared_error')
                    else:
                        assert_equal(scorer_name, scorer.__name__)

Example 38

Project: socorro
Source File: test_legacy_processor.py
View license
    def test_create_basic_processed_crash_normal(self):
        config = setup_config_with_mocks()
        config.collect_addon = False
        config.collect_crash_process = False
        mocked_transform_rules_str = \
            'socorro.processor.legacy_processor.TransformRuleSystem'
        with mock.patch(mocked_transform_rules_str) as m_transform_class:
            m_transform = mock.Mock()
            m_transform_class.return_value = m_transform
            m_transform.attach_mock(mock.Mock(), 'apply_all_rules')
            utc_now_str = 'socorro.processor.legacy_processor.utc_now'
            with mock.patch(utc_now_str) as m_utc_now:
                m_utc_now.return_value = datetime(2012, 5, 4, 15, 11,
                                                  tzinfo=UTC)

                started_timestamp = datetime(2012, 5, 4, 15, 10, tzinfo=UTC)

                raw_crash = canonical_standard_raw_crash
                leg_proc = LegacyCrashProcessor(config, config.mock_quit_fn)
                processor_notes = []

                # test 01
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                assert 'exploitability' in processed_crash
                eq_(
                  processed_crash,
                  dict(cannonical_basic_processed_crash)
                )

                # test 02
                processor_notes = []
                raw_crash_missing_product = copy.deepcopy(raw_crash)
                del raw_crash_missing_product['ProductName']
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_missing_product,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_missing_product = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_missing_product.product = None
                eq_(
                  processed_crash,
                  processed_crash_missing_product
                )
                ok_('WARNING: raw_crash missing ProductName' in
                                processor_notes)
                eq_(len(processor_notes), 1)

                # test 03
                processor_notes = []
                raw_crash_missing_version = copy.deepcopy(raw_crash)
                del raw_crash_missing_version['Version']
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_missing_version,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_missing_version = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_missing_version.version = None
                eq_(
                  processed_crash,
                  processed_crash_missing_version
                )
                ok_('WARNING: raw_crash missing Version' in
                                processor_notes)
                eq_(len(processor_notes), 1)

                # test 04
                processor_notes = []
                raw_crash_with_hangid = copy.deepcopy(raw_crash)
                raw_crash_with_hangid.HangID = \
                    '30cb3212-b61d-4d1f-85ae-3bc4bcaa0504'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_hangid,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_hangid = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_hangid.hangid = \
                    raw_crash_with_hangid.HangID
                processed_crash_with_hangid.hang_type = -1
                eq_(
                  processed_crash,
                  processed_crash_with_hangid
                )
                eq_(len(processor_notes), 0)

                # test 05
                processor_notes = []
                raw_crash_with_pluginhang = copy.deepcopy(raw_crash)
                raw_crash_with_pluginhang.PluginHang = '1'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_pluginhang,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_pluginhang = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_pluginhang.hangid = \
                    'fake-3bc4bcaa-b61d-4d1f-85ae-30cb32120504'
                processed_crash_with_pluginhang.hang_type = -1
                eq_(
                  processed_crash,
                  processed_crash_with_pluginhang
                )
                eq_(len(processor_notes), 0)

                # test 06
                processor_notes = []
                raw_crash_with_hang_only = copy.deepcopy(raw_crash)
                raw_crash_with_hang_only.Hang = 16
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_hang_only,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_hang_only = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_hang_only.hang_type = 1
                eq_(
                  processed_crash,
                  processed_crash_with_hang_only
                )
                eq_(len(processor_notes), 0)
                leg_proc._statistics.assert_has_calls(
                    [
                        mock.call.incr('restarts'),
                    ],
                    any_order=True
                )

                # test 07
                processor_notes = []
                raw_crash_with_hang_only = copy.deepcopy(raw_crash)
                raw_crash_with_hang_only.Hang = 'bad value'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_hang_only,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_hang_only = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_hang_only.hang_type = 0
                eq_(
                  processed_crash,
                  processed_crash_with_hang_only
                )
                eq_(len(processor_notes), 0)
                leg_proc._statistics.assert_has_calls(
                    [
                        mock.call.incr('restarts'),
                    ],
                    any_order=True
                )

                # test 08
                processor_notes = []
                bad_raw_crash = copy.deepcopy(raw_crash)
                bad_raw_crash['SecondsSinceLastCrash'] = 'badness'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  bad_raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                eq_(processed_crash.last_crash, None)
                ok_(
                    'non-integer value of "SecondsSinceLastCrash"' in
                    processor_notes
                )

                # test 09
                processor_notes = []
                bad_raw_crash = copy.deepcopy(raw_crash)
                bad_raw_crash['CrashTime'] = 'badness'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  bad_raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                eq_(processed_crash.crash_time, 0)
                ok_(
                    'non-integer value of "CrashTime"' in processor_notes
                )

                # test 10
                processor_notes = []
                bad_raw_crash = copy.deepcopy(raw_crash)
                bad_raw_crash['StartupTime'] = 'badness'
                bad_raw_crash['InstallTime'] = 'more badness'
                bad_raw_crash['CrashTime'] = 'even more badness'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  bad_raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                eq_(processed_crash.install_age, 0)
                ok_(
                    'non-integer value of "StartupTime"' in processor_notes
                )
                ok_(
                    'non-integer value of "InstallTime"' in processor_notes
                )
                ok_(
                    'non-integer value of "CrashTime"' in processor_notes
                )

Example 39

Project: socorro
Source File: test_legacy_processor.py
View license
    def test_create_basic_processed_crash_normal(self):
        config = setup_config_with_mocks()
        config.collect_addon = False
        config.collect_crash_process = False
        mocked_transform_rules_str = \
            'socorro.processor.legacy_processor.TransformRuleSystem'
        with mock.patch(mocked_transform_rules_str) as m_transform_class:
            m_transform = mock.Mock()
            m_transform_class.return_value = m_transform
            m_transform.attach_mock(mock.Mock(), 'apply_all_rules')
            utc_now_str = 'socorro.processor.legacy_processor.utc_now'
            with mock.patch(utc_now_str) as m_utc_now:
                m_utc_now.return_value = datetime(2012, 5, 4, 15, 11,
                                                  tzinfo=UTC)

                started_timestamp = datetime(2012, 5, 4, 15, 10, tzinfo=UTC)

                raw_crash = canonical_standard_raw_crash
                leg_proc = LegacyCrashProcessor(config, config.mock_quit_fn)
                processor_notes = []

                # test 01
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                assert 'exploitability' in processed_crash
                eq_(
                  processed_crash,
                  dict(cannonical_basic_processed_crash)
                )

                # test 02
                processor_notes = []
                raw_crash_missing_product = copy.deepcopy(raw_crash)
                del raw_crash_missing_product['ProductName']
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_missing_product,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_missing_product = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_missing_product.product = None
                eq_(
                  processed_crash,
                  processed_crash_missing_product
                )
                ok_('WARNING: raw_crash missing ProductName' in
                                processor_notes)
                eq_(len(processor_notes), 1)

                # test 03
                processor_notes = []
                raw_crash_missing_version = copy.deepcopy(raw_crash)
                del raw_crash_missing_version['Version']
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_missing_version,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_missing_version = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_missing_version.version = None
                eq_(
                  processed_crash,
                  processed_crash_missing_version
                )
                ok_('WARNING: raw_crash missing Version' in
                                processor_notes)
                eq_(len(processor_notes), 1)

                # test 04
                processor_notes = []
                raw_crash_with_hangid = copy.deepcopy(raw_crash)
                raw_crash_with_hangid.HangID = \
                    '30cb3212-b61d-4d1f-85ae-3bc4bcaa0504'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_hangid,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_hangid = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_hangid.hangid = \
                    raw_crash_with_hangid.HangID
                processed_crash_with_hangid.hang_type = -1
                eq_(
                  processed_crash,
                  processed_crash_with_hangid
                )
                eq_(len(processor_notes), 0)

                # test 05
                processor_notes = []
                raw_crash_with_pluginhang = copy.deepcopy(raw_crash)
                raw_crash_with_pluginhang.PluginHang = '1'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_pluginhang,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_pluginhang = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_pluginhang.hangid = \
                    'fake-3bc4bcaa-b61d-4d1f-85ae-30cb32120504'
                processed_crash_with_pluginhang.hang_type = -1
                eq_(
                  processed_crash,
                  processed_crash_with_pluginhang
                )
                eq_(len(processor_notes), 0)

                # test 06
                processor_notes = []
                raw_crash_with_hang_only = copy.deepcopy(raw_crash)
                raw_crash_with_hang_only.Hang = 16
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_hang_only,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_hang_only = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_hang_only.hang_type = 1
                eq_(
                  processed_crash,
                  processed_crash_with_hang_only
                )
                eq_(len(processor_notes), 0)
                leg_proc._statistics.assert_has_calls(
                    [
                        mock.call.incr('restarts'),
                    ],
                    any_order=True
                )

                # test 07
                processor_notes = []
                raw_crash_with_hang_only = copy.deepcopy(raw_crash)
                raw_crash_with_hang_only.Hang = 'bad value'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  raw_crash_with_hang_only,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                processed_crash_with_hang_only = \
                    copy.copy(cannonical_basic_processed_crash)
                processed_crash_with_hang_only.hang_type = 0
                eq_(
                  processed_crash,
                  processed_crash_with_hang_only
                )
                eq_(len(processor_notes), 0)
                leg_proc._statistics.assert_has_calls(
                    [
                        mock.call.incr('restarts'),
                    ],
                    any_order=True
                )

                # test 08
                processor_notes = []
                bad_raw_crash = copy.deepcopy(raw_crash)
                bad_raw_crash['SecondsSinceLastCrash'] = 'badness'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  bad_raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                eq_(processed_crash.last_crash, None)
                ok_(
                    'non-integer value of "SecondsSinceLastCrash"' in
                    processor_notes
                )

                # test 09
                processor_notes = []
                bad_raw_crash = copy.deepcopy(raw_crash)
                bad_raw_crash['CrashTime'] = 'badness'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  bad_raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                eq_(processed_crash.crash_time, 0)
                ok_(
                    'non-integer value of "CrashTime"' in processor_notes
                )

                # test 10
                processor_notes = []
                bad_raw_crash = copy.deepcopy(raw_crash)
                bad_raw_crash['StartupTime'] = 'badness'
                bad_raw_crash['InstallTime'] = 'more badness'
                bad_raw_crash['CrashTime'] = 'even more badness'
                processed_crash = leg_proc._create_basic_processed_crash(
                  '3bc4bcaa-b61d-4d1f-85ae-30cb32120504',
                  bad_raw_crash,
                  datetimeFromISOdateString(raw_crash.submitted_timestamp),
                  started_timestamp,
                  processor_notes,
                )
                eq_(processed_crash.install_age, 0)
                ok_(
                    'non-integer value of "StartupTime"' in processor_notes
                )
                ok_(
                    'non-integer value of "InstallTime"' in processor_notes
                )
                ok_(
                    'non-integer value of "CrashTime"' in processor_notes
                )

Example 40

Project: py_smartyparse
Source File: trashtest.py
View license
def run():
    # Generic format
    tf_1 = SmartyParser()
    tf_1['magic'] = ParseHelper(Blob(length=4))
    tf_1['version'] = ParseHelper(Int32(signed=False))
    tf_1['cipher'] = ParseHelper(Int8(signed=False))
    tf_1['body1_length'] = ParseHelper(Int32(signed=False))
    tf_1['body1'] = ParseHelper(Blob())
    tf_1['body2_length'] = ParseHelper(Int32(signed=False))
    tf_1['body2'] = ParseHelper(Blob())
    tf_1.link_length('body1', 'body1_length')
    tf_1.link_length('body2', 'body2_length')
    
    # Nested formats
    tf_nest = SmartyParser()
    tf_nest['first'] = tf_1
    tf_nest['second'] = tf_1
    
    tf_nest2 = SmartyParser()
    tf_nest2['_0'] = ParseHelper(Int32())
    tf_nest2['_1'] = tf_1
    tf_nest2['_2'] = ParseHelper(Int32())
    
    # More exhaustive, mostly deterministic format
    tf_2 = SmartyParser()
    tf_2['_0'] = ParseHelper(parsers.Null())
    tf_2['_1'] = ParseHelper(parsers.Int8(signed=True))
    tf_2['_2'] = ParseHelper(parsers.Int8(signed=False))
    tf_2['_3'] = ParseHelper(parsers.Int16(signed=True))
    tf_2['_4'] = ParseHelper(parsers.Int16(signed=False))
    tf_2['_5'] = ParseHelper(parsers.Int32(signed=True))
    tf_2['_6'] = ParseHelper(parsers.Int32(signed=False))
    tf_2['_7'] = ParseHelper(parsers.Int64(signed=True))
    tf_2['_8'] = ParseHelper(parsers.Int64(signed=False))
    tf_2['_9'] = ParseHelper(parsers.Float(double=False))
    tf_2['_10'] = ParseHelper(parsers.Float())
    tf_2['_11'] = ParseHelper(parsers.ByteBool())
    tf_2['_12'] = ParseHelper(parsers.Padding(length=4))
    tf_2['_13'] = ParseHelper(parsers.String())
     
    tv1 = {}
    tv1['magic'] = b'[00]'
    tv1['version'] = 1
    tv1['cipher'] = 2
    tv1['body1'] = b'[tv1 byte string, first]'
    tv1['body2'] = b'[tv1 byte string, 2nd]'
     
    tv2 = {}
    tv2['magic'] = b'[aa]'
    tv2['version'] = 5
    tv2['cipher'] = 6
    tv2['body1'] = b'[new test byte string, first]'
    tv2['body2'] = b'[new test byte string, 2nd]'
    
    tv3 = {
            'first': copy.deepcopy(tv1), 
            'second': copy.deepcopy(tv2)
        }
    
    tv4 = {}
    tv4['_0'] = None
    tv4['_1'] = -10
    tv4['_2'] = 11
    tv4['_3'] = -300
    tv4['_4'] = 301
    tv4['_5'] = -100000
    tv4['_6'] = 100001
    tv4['_7'] = -10000000000
    tv4['_8'] = 10000000001
    tv4['_9'] = 11.11
    tv4['_10'] = 1e-50
    tv4['_11'] = True
    tv4['_12'] = None
    tv4['_13'] = 'EOF'
    
    tv5 = {}
    tv5['_0'] = 42
    tv5['_1'] = copy.deepcopy(tv1)
    tv5['_2'] = -42
    
    print('-----------------------------------------------')
    print('Testing all "other" parsers...')
    # print('    ', tv4)
    
    bites4 = tf_2.pack(tv4)
    
    # print('Successfully packed.')
    # print('    ', bytes(bites4))
    
    recycle4 = tf_2.unpack(bites4)
    # Note that numerical precision prevents us from easily:
    # assert recycle4 == tv4
    
    # print('Successfully reunpacked.')
    # print(recycle4)
    
    # print('    ', tv5)
    
    bites5 = tf_nest2.pack(tv5)
    
    # print('Successfully packed.')
    # print('    ', bytes(bites5))
    
    recycle5 = tf_nest2.unpack(bites5)
    assert recycle5 == tv5
    print('Successfully reunpacked.')
    
    # print(recycle5)
    # print('-----------------------------------------------')
    
    print('-----------------------------------------------')
    print('Starting TV1, serial...')
    # print('    ', tv1)
    
    bites1 = tf_1.pack(tv1)
    
    # print('Successfully packed.')
    # print('    ', bytes(bites1))
    
    recycle1 = tf_1.unpack(bites1)
    assert recycle1 == tv1
    print('Successfully reunpacked.')
    
    # print(recycle1)
    # print('-----------------------------------------------')
    
    print('-----------------------------------------------')
    print('Starting TV2, serial...')
    # print('    ', tv2)
    
    bites2 = tf_1.pack(tv2)
    
    # print('Successfully packed.')
    # print('    ', bytes(bites2))
    
    recycle2 = tf_1.unpack(bites2)
    assert recycle2 == tv2
    print('Successfully reunpacked.')
    
    # print(recycle2)
    # print('-----------------------------------------------')
    
    print('-----------------------------------------------')
    print('Starting TV1, TV2 parallel...')
    # print('    ', tv1)
    
    bites1 = tf_1.pack(tv1)
    
    # print('Successfully packed TV1.')
    # print('    ', bytes(bites1))
    
    # print('    ', tv2)
    
    bites2 = tf_1.pack(tv2)
    
    # print('Successfully packed TV2.')
    # print('    ', bytes(bites2))
    
    recycle1 = tf_1.unpack(bites1)
    assert recycle1 == tv1
    print('Successfully reunpacked TV1.')
    
    # print(recycle1)
    
    recycle2 = tf_1.unpack(bites2)
    assert recycle2 == tv2
    print('Successfully reunpacked TV2.')
    
    # print(recycle2)
    
    print('-----------------------------------------------')
    print('Starting (nested) TV3...')
    # print(tv3)
    
    bites3 = tf_nest.pack(tv3)
    
    # print('-----------------------------------------------')
    # print('Successfully packed.')
    # print(bytes(bites3))
    # print('-----------------------------------------------')
    
    recycle3 = tf_nest.unpack(bites3)
    assert recycle3 == tv3
    print('Successfully reunpacked.')
    
    # print(recycle3)
    print('-----------------------------------------------')
    print('Testing toggle...')
    
    parent = SmartyParser()
    parent['switch'] = ParseHelper(Int8(signed=False))
    parent['light'] = None
    
    @references(parent)
    def decide(self, switch):
        if switch == 1:
            self['light'] = ParseHelper(Int8())
        else:
            self['light'] = ParseHelper(Blob(length=11))
            
    parent['switch'].register_callback('prepack', decide)
    parent['switch'].register_callback('postunpack', decide)
            
    off = {'switch': 1, 'light': -55}
    on = {'switch': 0, 'light': b'Hello world'}
    
    o1 = parent.pack(off)
    o2 = parent.pack(on)
    assert parent.unpack(o1) == off
    assert parent.unpack(o2) == on
    print('Success.')
    
    # -----------------------------------------------------------------
    print('-----------------------------------------------')
    print('Testing listyparser...')
        
    pastr = SmartyParser()    
    pastr['length'] = ParseHelper(parsers.Int8(signed=False))
    pastr['body'] = ParseHelper(parsers.String())
    pastr.link_length('body', 'length')
    
    tag_typed = SmartyParser()
    tag_typed['tag'] = ParseHelper(parsers.Int8(signed=False))
    tag_typed['toggle'] = None
    
    @references(tag_typed)
    def switch(self, tag):
        if tag == 0:
            self['toggle'] = ParseHelper(parsers.Int8(signed=False))
        elif tag == 1:
            self['toggle'] = ParseHelper(parsers.Int16(signed=False))
        elif tag == 2:
            self['toggle'] = ParseHelper(parsers.Int32(signed=False))
        elif tag == 3:
            self['toggle'] = ParseHelper(parsers.Int64(signed=False))
        else:
            self['toggle'] = pastr
    tag_typed['tag'].register_callback('prepack', switch)
    tag_typed['tag'].register_callback('postunpack', switch)
    
    tf_list = ListyParser(parsers=[tag_typed])
    tv_list = [
            {'tag': 0, 'toggle': 5}, 
            {'tag': 1, 'toggle': 51}, 
            {'tag': 65, 'toggle': {'body': 'hello world'}}, 
            {'tag': 2, 'toggle': 3453}
        ]
    tv_list_pack = tf_list.pack(tv_list)
    for it1, it2 in zip(tv_list, tf_list.unpack(tv_list_pack)):
        assert it1 == it2
    print('Success.')
    # assert tf_list.unpack(tv_list_pack) == tv_list
    
    print('Testing nested explicit listyparser...')
    tf_list_nest = SmartyParser()
    tf_list_nest['_0'] = ParseHelper(parsers.Int8(signed=False))
    tf_list_nest['_1'] = ParseHelper(parsers.Int16(signed=False))
    tf_list_nest['_2'] = tf_list
    tf_list_nest['_3'] = ParseHelper(parsers.Int8(signed=False))
    tf_list_nest.link_length('_2', '_1')
    
    tv_list_nest = {
        '_0': 12,
        '_2': copy.deepcopy(tv_list),
        '_3': 14,
    }
    
    tv_list_nest_pack = tf_list_nest.pack(tv_list_nest)
    tv_list_nest_recycle = tf_list_nest.unpack(tv_list_nest_pack)
    print('No errors, but no test for equivalency yet.')
    
    print('Testing nested implicit listyparser.')
    
    terminant = ParseHelper(parsers.Literal(b'h', verify=False))
    tf_exlist = ListyParser(parsers=[tag_typed], terminant=terminant)
    
    tf_exlist_nest = SmartyParser()
    tf_exlist_nest['_0'] = ParseHelper(parsers.Int8(signed=False))
    tf_exlist_nest['_2'] = tf_exlist
    tf_exlist_nest['_3'] = ParseHelper(parsers.Int8(signed=False))
        
    tv_exlist_pack = tf_exlist_nest.pack(tv_list_nest)
    tv_exlist_recycle = tf_exlist_nest.unpack(tv_exlist_pack)
    print('No errors, but no test for equivalency yet.')
    
    # Can do some kind of check for len of self.obj to determine if there's 
    # only a single entry in the smartyparser, and thereby expand any 
    # objects to pack or objects unpacked.
    
    print('-----------------------------------------------')

Example 41

Project: nupic
Source File: EnsembleOnline.py
View license
def getStableVote(scores, stableSize, votes, currModel):
  scores = sorted(scores, key=lambda t: t[0])[:stableSize]
  median=True
  if not median:
    for s in scores:
      if s[3]==currModel:
        print [(score[0], score[3]) for score in scores]
      
        return s[1], currModel
    print [(s[0], s[3]) for s in scores], "switching voting Model!"
    return scores[0][1], scores[0][3]
  else:
    print [(s[0], s[3]) for s in scores]
    voters = sorted(scores, key=lambda t: t[1])
    for voter in voters:
      votes[voter[3]]=votes[voter[3]]+1
    vote=voters[int(stableSize/2)][1]
    return vote, currModel
  
        
def getFieldPermutations(config, predictedField):
  encoders=config['modelParams']['sensorParams']['encoders']
  encoderList=[]
  for encoder in encoders:
    if encoder==None:
      continue
    if encoder['name']==predictedField:
      encoderList.append([encoder])
      for e in encoders:
        if e==None:
          continue
        if e['name'] != predictedField:
          encoderList.append([encoder, e])
  return encoderList
              
        
def getModelDescriptionLists(numProcesses, experiment):
    config, control = opfhelpers.loadExperiment(experiment)
    encodersList=getFieldPermutations(config, 'pounds')
    ns=range(50, 140, 120)
    clAlphas=np.arange(0.01, 0.16, 0.104)
    synPermInactives=np.arange(0.01, 0.16, 0.105)
    tpPamLengths=range(5, 8, 2)
    tpSegmentActivations=range(13, 17, 12)
    
    if control['environment'] == 'opfExperiment':
      experimentTasks = control['tasks']
      task = experimentTasks[0]
      datasetURI = task['dataset']['streams'][0]['source']

    elif control['environment'] == 'nupic':
      datasetURI = control['dataset']['streams'][0]['source']

    metricSpecs = control['metrics']

    datasetPath = datasetURI[len("file://"):]
    ModelSetUpData=[]
    name=0
    
    for n in ns:
      for clAlpha in clAlphas:
        for synPermInactive in synPermInactives:
          for tpPamLength in tpPamLengths:
            for tpSegmentActivation in tpSegmentActivations:
              for encoders in encodersList:
                encodersmod=copy.deepcopy(encoders)
                configmod=copy.deepcopy(config)
                configmod['modelParams']['sensorParams']['encoders']=encodersmod
                configmod['modelParams']['clParams']['alpha']=clAlpha
                configmod['modelParams']['spParams']['synPermInactiveDec']=synPermInactive
                configmod['modelParams']['tpParams']['pamLength']=tpPamLength
                configmod['modelParams']['tpParams']['activationThreshold']=tpSegmentActivation
                for encoder in encodersmod:
                  if encoder['name']==predictedField:
                    encoder['n']=n
                
                ModelSetUpData.append((name,{'modelConfig':configmod, 'inferenceArgs':control['inferenceArgs'], 'metricSpecs':metricSpecs, 'sourceSpec':datasetPath,'sinkSpec':None,}))
                name=name+1
              #print modelInfo['modelConfig']['modelParams']['tpParams']
              #print modelInfo['modelConfig']['modelParams']['sensorParams']['encoders'][4]['n']
    print "num Models"+str( len(ModelSetUpData))
    
    shuffle(ModelSetUpData)
    #print [ (m[1]['modelConfig']['modelParams']['tpParams']['pamLength'], m[1]['modelConfig']['modelParams']['sensorParams']['encoders']) for m in ModelSetUpData]       
    return list(chunk(ModelSetUpData,numProcesses))

    
def chunk(l, n):
    """ Yield n successive chunks from l.
    """
    newn = int(1.0 * len(l) / n + 0.5)
    for i in xrange(0, n-1):
        yield l[i*newn:i*newn+newn]
    yield l[n*newn-newn:]

def command(command, work_queues, aux):
  for queue in work_queues:
    queue.put((command, aux))


def getDuplicateList(streams, delta):
  delList=[]
  keys=streams.keys()
  for key1 in keys:
    if key1 in streams:
      for key2 in streams.keys():
        if(key1 !=key2):
          print 'comparing model'+str(key1)+" to "+str(key2)
          dist=sum([(a-b)**2 for a, b in zip(streams[key1], streams[key2])])
          print dist
          if(dist<delta):
            delList.append(key2)
            del streams[key2]
  return delList
    
def slice_sampler(px, N = 1, x = None):
    """
    Provides samples from a user-defined distribution.
    
    slice_sampler(px, N = 1, x = None)
    
    Inputs:
    px = A discrete probability distribution.
    N  = Number of samples to return, default is 1
    x  = Optional list/array of observation values to return, where prob(x) = px.

    Outputs:
    If x=None (default) or if len(x) != len(px), it will return an array of integers
    between 0 and len(px)-1. If x is supplied, it will return the
    samples from x according to the distribution px.    
    """
    values = np.zeros(N, dtype=np.int)
    samples = np.arange(len(px))
    px = np.array(px) / (1.*sum(px))
    u = uniform(0, max(px))
    for n in xrange(N):
        included = px>=u
        choice = random.sample(range(np.sum(included)), 1)[0]
        values[n] = samples[included][choice]
        u = uniform(0, px[included][choice])
    if x:
        if len(x) == len(px):
            x=np.array(x)
            values = x[values]
        else:
            print "px and x are different lengths. Returning index locations for px."
    
    return values
    
    
def getPSOVariants(modelInfos, votes, n):
  # get x, px lists for sampling 
  norm=sum(votes.values())
  xpx =[(m, float(votes[m])/norm) for m in votes.keys()] 
  x,px = [[z[i] for z in xpx] for i in (0,1)]
  #sample form set of models
  variantIDs=slice_sampler(px, n, x)
  print "variant IDS"
  print variantIDs
  #best X
  x_best=modelInfos[0][2][0]
  # create PSO variates of models
  modelDescriptions=[]
  for variantID in variantIDs:
    t=modelInfos[[i for i, v in enumerate(modelInfos) if v[0] == variantID][0]]
    x=t[2][0]
    v=t[2][1]
    print "old x"
    print x
    modelDescriptionMod=copy.deepcopy(t[3])
    configmod=modelDescriptionMod['modelConfig']
    v=inertia*v+socRate*np.random.random_sample(len(v))*(x_best-x)
    x=x+v
    print "new x"
    print x    
    configmod['modelParams']['clParams']['alpha']=max(0.01, x[0])
    configmod['modelParams']['spParams']['synPermInactiveDec']=max(0.01, x[2])
    configmod['modelParams']['tpParams']['pamLength']=int(round(max(1, x[4])))
    configmod['modelParams']['tpParams']['activationThreshold']=int(round(max(1, x[3])))
    for encoder in configmod['modelParams']['sensorParams']['encoders']:
      if encoder['name']==predictedField:
        encoder['n']=int(round(max(encoder['w']+1, x[1]) ))
    modelDescriptions.append((modelDescriptionMod, x, v))
  return modelDescriptions 
            
    
def computeAAE(truth, predictions, windowSize):
  windowSize=min(windowSize, len(truth))
  zipped=zip(truth[-windowSize:], predictions[-windowSize-1:])
  AAE=sum([abs(a - b) for a, b in zipped])/windowSize
  return AAE
 
    
              
if __name__ == "__main__":
    cutPercentage=0.1
    currModel=0
    stableSize=3
    delta=1
    predictedField='pounds'
    truth=[]
    ensemblePredictions=[0,]
    divisor=4
    ModelSetUpData=getModelDescriptionLists(divisor, './')
    num_processes=len(ModelSetUpData)
    print num_processes 
    work_queues=[]
    votes={}
    votingParameterStats={"tpSegmentActivationThreshold":[], "tpPamLength":[], "synPermInactiveDec":[], "clAlpha":[], "numBuckets":[]}  
    # create a queue to pass to workers to store the results
    result_queue = multiprocessing.Queue(len(ModelSetUpData))
 
    # spawn workers
    workerName=0
    modelNameCount=0
    for modelData in ModelSetUpData:
        print len(modelData)
        modelNameCount+=len(modelData)
        work_queue= multiprocessing.Queue()
        work_queues.append(work_queue)
        worker = Worker(work_queue, result_queue, stableSize, windowSize, predictedField, modelData, workerName)
        worker.start()
        workerName=workerName+1
        
    #init votes dict
    for dataList in ModelSetUpData:
      for data in dataList:
        votes[data[0]]=0
      
    
    for i in range(2120):
      
      command('predict', work_queues, i)
      scores=[]
      for j in range(num_processes):
        subscore=result_queue.get()
        scores.extend(subscore)
      print ""
      print i
      ensemblePrediction, currModel=getStableVote(scores, stableSize, votes, currModel)
      ensemblePredictions.append(ensemblePrediction)
      truth.append(scores[0][2])
      print  computeAAE(truth,ensemblePredictions, windowSize), int(currModel)
      assert(result_queue.empty())
      if i%r==0 and i!=0: #refresh ensemble
        assert(result_queue.empty())
        #get AAES of models over last i records
        command('getAAEs', work_queues, None)
        AAEs=[]
        for j in range(num_processes):
          subAAEs=result_queue.get()
          AAEs.extend(subAAEs)
        AAEs=sorted(AAEs, key=lambda t: t[1])
        numToDelete=int(round(cutPercentage*len(AAEs)))
        print "Single Model AAES"
        print [(aae[0], aae[1]) for aae in AAEs]
        print "Ensemble AAE"
        print computeAAE(truth, ensemblePredictions, r)
        #add bottom models to delList
        print "Vote counts"
        print votes
        delList=[t[0] for t in AAEs[-numToDelete:]]   
        print "delList"     
        print delList
        #find duplicate models(now unnecessary)
        #command('getPredictionStreams', work_queues, None)
        #streams={}
        #for j in range(num_processes):
        #  subList=result_queue.get()
        #  streams.update(subList)
        #delList.extend(getDuplicateList(streams, delta))
        #print delList
        command('delete', work_queues, delList)
        for iden in delList:
          del votes[iden]
        print votes  
        #wait for deletion to finish and collect processIndices for addition
        processIndices=[]
        for j in range(num_processes):
          processIndices.append( result_queue.get())
        # pick new set of models for PSO variants
        newModelDescriptions=getPSOVariants(AAEs, votes, len(delList))
        assert(result_queue.empty())
        #send new model dscriptions to queue and have processess pick them up
        aux=[]
        for i in range(len(newModelDescriptions)):
          votes[modelNameCount]=0
          aux.append((processIndices[i],newModelDescriptions[i],modelNameCount) )
          modelNameCount=modelNameCount+1
        
        command('addPSOVariants', work_queues, aux)
        #set votes to 0
        for key in votes.keys():
          votes[key]=0
                
        
      
 
    print "AAE over full stream"
    print computeAAE(truth, ensemblePredictions, len(truth))
    print "AAE1000"
    print computeAAE(truth, ensemblePredictions, 1000)

Example 42

Project: nupic
Source File: EnsembleOnline.py
View license
def getStableVote(scores, stableSize, votes, currModel):
  scores = sorted(scores, key=lambda t: t[0])[:stableSize]
  median=True
  if not median:
    for s in scores:
      if s[3]==currModel:
        print [(score[0], score[3]) for score in scores]
      
        return s[1], currModel
    print [(s[0], s[3]) for s in scores], "switching voting Model!"
    return scores[0][1], scores[0][3]
  else:
    print [(s[0], s[3]) for s in scores]
    voters = sorted(scores, key=lambda t: t[1])
    for voter in voters:
      votes[voter[3]]=votes[voter[3]]+1
    vote=voters[int(stableSize/2)][1]
    return vote, currModel
  
        
def getFieldPermutations(config, predictedField):
  encoders=config['modelParams']['sensorParams']['encoders']
  encoderList=[]
  for encoder in encoders:
    if encoder==None:
      continue
    if encoder['name']==predictedField:
      encoderList.append([encoder])
      for e in encoders:
        if e==None:
          continue
        if e['name'] != predictedField:
          encoderList.append([encoder, e])
  return encoderList
              
        
def getModelDescriptionLists(numProcesses, experiment):
    config, control = opfhelpers.loadExperiment(experiment)
    encodersList=getFieldPermutations(config, 'pounds')
    ns=range(50, 140, 120)
    clAlphas=np.arange(0.01, 0.16, 0.104)
    synPermInactives=np.arange(0.01, 0.16, 0.105)
    tpPamLengths=range(5, 8, 2)
    tpSegmentActivations=range(13, 17, 12)
    
    if control['environment'] == 'opfExperiment':
      experimentTasks = control['tasks']
      task = experimentTasks[0]
      datasetURI = task['dataset']['streams'][0]['source']

    elif control['environment'] == 'nupic':
      datasetURI = control['dataset']['streams'][0]['source']

    metricSpecs = control['metrics']

    datasetPath = datasetURI[len("file://"):]
    ModelSetUpData=[]
    name=0
    
    for n in ns:
      for clAlpha in clAlphas:
        for synPermInactive in synPermInactives:
          for tpPamLength in tpPamLengths:
            for tpSegmentActivation in tpSegmentActivations:
              for encoders in encodersList:
                encodersmod=copy.deepcopy(encoders)
                configmod=copy.deepcopy(config)
                configmod['modelParams']['sensorParams']['encoders']=encodersmod
                configmod['modelParams']['clParams']['alpha']=clAlpha
                configmod['modelParams']['spParams']['synPermInactiveDec']=synPermInactive
                configmod['modelParams']['tpParams']['pamLength']=tpPamLength
                configmod['modelParams']['tpParams']['activationThreshold']=tpSegmentActivation
                for encoder in encodersmod:
                  if encoder['name']==predictedField:
                    encoder['n']=n
                
                ModelSetUpData.append((name,{'modelConfig':configmod, 'inferenceArgs':control['inferenceArgs'], 'metricSpecs':metricSpecs, 'sourceSpec':datasetPath,'sinkSpec':None,}))
                name=name+1
              #print modelInfo['modelConfig']['modelParams']['tpParams']
              #print modelInfo['modelConfig']['modelParams']['sensorParams']['encoders'][4]['n']
    print "num Models"+str( len(ModelSetUpData))
    
    shuffle(ModelSetUpData)
    #print [ (m[1]['modelConfig']['modelParams']['tpParams']['pamLength'], m[1]['modelConfig']['modelParams']['sensorParams']['encoders']) for m in ModelSetUpData]       
    return list(chunk(ModelSetUpData,numProcesses))

    
def chunk(l, n):
    """ Yield n successive chunks from l.
    """
    newn = int(1.0 * len(l) / n + 0.5)
    for i in xrange(0, n-1):
        yield l[i*newn:i*newn+newn]
    yield l[n*newn-newn:]

def command(command, work_queues, aux):
  for queue in work_queues:
    queue.put((command, aux))


def getDuplicateList(streams, delta):
  delList=[]
  keys=streams.keys()
  for key1 in keys:
    if key1 in streams:
      for key2 in streams.keys():
        if(key1 !=key2):
          print 'comparing model'+str(key1)+" to "+str(key2)
          dist=sum([(a-b)**2 for a, b in zip(streams[key1], streams[key2])])
          print dist
          if(dist<delta):
            delList.append(key2)
            del streams[key2]
  return delList
    
def slice_sampler(px, N = 1, x = None):
    """
    Provides samples from a user-defined distribution.
    
    slice_sampler(px, N = 1, x = None)
    
    Inputs:
    px = A discrete probability distribution.
    N  = Number of samples to return, default is 1
    x  = Optional list/array of observation values to return, where prob(x) = px.

    Outputs:
    If x=None (default) or if len(x) != len(px), it will return an array of integers
    between 0 and len(px)-1. If x is supplied, it will return the
    samples from x according to the distribution px.    
    """
    values = np.zeros(N, dtype=np.int)
    samples = np.arange(len(px))
    px = np.array(px) / (1.*sum(px))
    u = uniform(0, max(px))
    for n in xrange(N):
        included = px>=u
        choice = random.sample(range(np.sum(included)), 1)[0]
        values[n] = samples[included][choice]
        u = uniform(0, px[included][choice])
    if x:
        if len(x) == len(px):
            x=np.array(x)
            values = x[values]
        else:
            print "px and x are different lengths. Returning index locations for px."
    
    return values
    
    
def getPSOVariants(modelInfos, votes, n):
  # get x, px lists for sampling 
  norm=sum(votes.values())
  xpx =[(m, float(votes[m])/norm) for m in votes.keys()] 
  x,px = [[z[i] for z in xpx] for i in (0,1)]
  #sample form set of models
  variantIDs=slice_sampler(px, n, x)
  print "variant IDS"
  print variantIDs
  #best X
  x_best=modelInfos[0][2][0]
  # create PSO variates of models
  modelDescriptions=[]
  for variantID in variantIDs:
    t=modelInfos[[i for i, v in enumerate(modelInfos) if v[0] == variantID][0]]
    x=t[2][0]
    v=t[2][1]
    print "old x"
    print x
    modelDescriptionMod=copy.deepcopy(t[3])
    configmod=modelDescriptionMod['modelConfig']
    v=inertia*v+socRate*np.random.random_sample(len(v))*(x_best-x)
    x=x+v
    print "new x"
    print x    
    configmod['modelParams']['clParams']['alpha']=max(0.01, x[0])
    configmod['modelParams']['spParams']['synPermInactiveDec']=max(0.01, x[2])
    configmod['modelParams']['tpParams']['pamLength']=int(round(max(1, x[4])))
    configmod['modelParams']['tpParams']['activationThreshold']=int(round(max(1, x[3])))
    for encoder in configmod['modelParams']['sensorParams']['encoders']:
      if encoder['name']==predictedField:
        encoder['n']=int(round(max(encoder['w']+1, x[1]) ))
    modelDescriptions.append((modelDescriptionMod, x, v))
  return modelDescriptions 
            
    
def computeAAE(truth, predictions, windowSize):
  windowSize=min(windowSize, len(truth))
  zipped=zip(truth[-windowSize:], predictions[-windowSize-1:])
  AAE=sum([abs(a - b) for a, b in zipped])/windowSize
  return AAE
 
    
              
if __name__ == "__main__":
    cutPercentage=0.1
    currModel=0
    stableSize=3
    delta=1
    predictedField='pounds'
    truth=[]
    ensemblePredictions=[0,]
    divisor=4
    ModelSetUpData=getModelDescriptionLists(divisor, './')
    num_processes=len(ModelSetUpData)
    print num_processes 
    work_queues=[]
    votes={}
    votingParameterStats={"tpSegmentActivationThreshold":[], "tpPamLength":[], "synPermInactiveDec":[], "clAlpha":[], "numBuckets":[]}  
    # create a queue to pass to workers to store the results
    result_queue = multiprocessing.Queue(len(ModelSetUpData))
 
    # spawn workers
    workerName=0
    modelNameCount=0
    for modelData in ModelSetUpData:
        print len(modelData)
        modelNameCount+=len(modelData)
        work_queue= multiprocessing.Queue()
        work_queues.append(work_queue)
        worker = Worker(work_queue, result_queue, stableSize, windowSize, predictedField, modelData, workerName)
        worker.start()
        workerName=workerName+1
        
    #init votes dict
    for dataList in ModelSetUpData:
      for data in dataList:
        votes[data[0]]=0
      
    
    for i in range(2120):
      
      command('predict', work_queues, i)
      scores=[]
      for j in range(num_processes):
        subscore=result_queue.get()
        scores.extend(subscore)
      print ""
      print i
      ensemblePrediction, currModel=getStableVote(scores, stableSize, votes, currModel)
      ensemblePredictions.append(ensemblePrediction)
      truth.append(scores[0][2])
      print  computeAAE(truth,ensemblePredictions, windowSize), int(currModel)
      assert(result_queue.empty())
      if i%r==0 and i!=0: #refresh ensemble
        assert(result_queue.empty())
        #get AAES of models over last i records
        command('getAAEs', work_queues, None)
        AAEs=[]
        for j in range(num_processes):
          subAAEs=result_queue.get()
          AAEs.extend(subAAEs)
        AAEs=sorted(AAEs, key=lambda t: t[1])
        numToDelete=int(round(cutPercentage*len(AAEs)))
        print "Single Model AAES"
        print [(aae[0], aae[1]) for aae in AAEs]
        print "Ensemble AAE"
        print computeAAE(truth, ensemblePredictions, r)
        #add bottom models to delList
        print "Vote counts"
        print votes
        delList=[t[0] for t in AAEs[-numToDelete:]]   
        print "delList"     
        print delList
        #find duplicate models(now unnecessary)
        #command('getPredictionStreams', work_queues, None)
        #streams={}
        #for j in range(num_processes):
        #  subList=result_queue.get()
        #  streams.update(subList)
        #delList.extend(getDuplicateList(streams, delta))
        #print delList
        command('delete', work_queues, delList)
        for iden in delList:
          del votes[iden]
        print votes  
        #wait for deletion to finish and collect processIndices for addition
        processIndices=[]
        for j in range(num_processes):
          processIndices.append( result_queue.get())
        # pick new set of models for PSO variants
        newModelDescriptions=getPSOVariants(AAEs, votes, len(delList))
        assert(result_queue.empty())
        #send new model dscriptions to queue and have processess pick them up
        aux=[]
        for i in range(len(newModelDescriptions)):
          votes[modelNameCount]=0
          aux.append((processIndices[i],newModelDescriptions[i],modelNameCount) )
          modelNameCount=modelNameCount+1
        
        command('addPSOVariants', work_queues, aux)
        #set votes to 0
        for key in votes.keys():
          votes[key]=0
                
        
      
 
    print "AAE over full stream"
    print computeAAE(truth, ensemblePredictions, len(truth))
    print "AAE1000"
    print computeAAE(truth, ensemblePredictions, 1000)

Example 43

Project: heat
Source File: test_neutron_floating_ip.py
View license
    def test_floatip_association_port(self):
        t = template_format.parse(neutron_floating_template)
        stack = utils.parse_stack(t)

        neutronV20.find_resourceid_by_name_or_id(
            mox.IsA(neutronclient.Client),
            'network',
            'abcd1234',
            cmd_resource=None,
        ).MultipleTimes().AndReturn('abcd1234')
        neutronV20.find_resourceid_by_name_or_id(
            mox.IsA(neutronclient.Client),
            'subnet',
            'sub1234',
            cmd_resource=None,
        ).MultipleTimes().AndReturn('sub1234')
        neutronclient.Client.create_floatingip({
            'floatingip': {'floating_network_id': u'abcd1234'}
        }).AndReturn({'floatingip': {
            "status": "ACTIVE",
            "id": "fc68ea2c-b60b-4b4f-bd82-94ec81110766"
        }})

        neutronclient.Client.create_port({'port': {
            'network_id': u'abcd1234',
            'fixed_ips': [
                {'subnet_id': u'sub1234', 'ip_address': u'10.0.0.10'}
            ],
            'name': utils.PhysName(stack.name, 'port_floating'),
            'admin_state_up': True}}
        ).AndReturn({'port': {
            "status": "BUILD",
            "id": "fc68ea2c-b60b-4b4f-bd82-94ec81110766"
        }})
        neutronclient.Client.show_port(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766'
        ).AndReturn({'port': {
            "status": "ACTIVE",
            "id": "fc68ea2c-b60b-4b4f-bd82-94ec81110766"
        }})
        # create as
        neutronclient.Client.update_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766',
            {
                'floatingip': {
                    'port_id': u'fc68ea2c-b60b-4b4f-bd82-94ec81110766'}}
        ).AndReturn({'floatingip': {
            "status": "ACTIVE",
            "id": "fc68ea2c-b60b-4b4f-bd82-94ec81110766"
        }})
        # update as with port_id
        neutronclient.Client.update_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766',
            {
                'floatingip': {
                    'port_id': u'2146dfbf-ba77-4083-8e86-d052f671ece5',
                    'fixed_ip_address': None}}
        ).AndReturn({'floatingip': {
            "status": "ACTIVE",
            "id": "fc68ea2c-b60b-4b4f-bd82-94ec81110766"
        }})
        # update as with floatingip_id
        neutronclient.Client.update_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766',
            {'floatingip': {
                'port_id': None
            }}).AndReturn(None)
        neutronclient.Client.update_floatingip(
            '2146dfbf-ba77-4083-8e86-d052f671ece5',
            {
                'floatingip': {
                    'port_id': u'2146dfbf-ba77-4083-8e86-d052f671ece5',
                    'fixed_ip_address': None}}
        ).AndReturn({'floatingip': {
            "status": "ACTIVE",
            "id": "2146dfbf-ba77-4083-8e86-d052f671ece5"
        }})
        # update as with both
        neutronclient.Client.update_floatingip(
            '2146dfbf-ba77-4083-8e86-d052f671ece5',
            {'floatingip': {
                'port_id': None
            }}).AndReturn(None)
        neutronclient.Client.update_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766',
            {
                'floatingip': {
                    'port_id': u'ade6fcac-7d47-416e-a3d7-ad12efe445c1',
                    'fixed_ip_address': None}}
        ).AndReturn({'floatingip': {
            "status": "ACTIVE",
            "id": "fc68ea2c-b60b-4b4f-bd82-94ec81110766"
        }})
        # delete as
        neutronclient.Client.update_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766',
            {'floatingip': {
                'port_id': None
            }}).AndReturn(None)

        neutronclient.Client.delete_port(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766'
        ).AndReturn(None)

        neutronclient.Client.show_port(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766'
        ).AndRaise(qe.PortNotFoundClient(status_code=404))

        neutronclient.Client.delete_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766'
        ).AndReturn(None)
        neutronclient.Client.show_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766').AndRaise(
                qe.NeutronClientException(status_code=404))

        neutronclient.Client.delete_port(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766'
        ).AndRaise(qe.PortNotFoundClient(status_code=404))

        neutronclient.Client.delete_floatingip(
            'fc68ea2c-b60b-4b4f-bd82-94ec81110766'
        ).AndRaise(qe.NeutronClientException(status_code=404))
        self.stub_PortConstraint_validate()

        self.m.ReplayAll()

        fip = stack['floating_ip']
        scheduler.TaskRunner(fip.create)()
        self.assertEqual((fip.CREATE, fip.COMPLETE), fip.state)

        p = stack['port_floating']
        scheduler.TaskRunner(p.create)()
        self.assertEqual((p.CREATE, p.COMPLETE), p.state)

        fipa = stack['floating_ip_assoc']
        scheduler.TaskRunner(fipa.create)()
        self.assertEqual((fipa.CREATE, fipa.COMPLETE), fipa.state)
        self.assertIsNotNone(fipa.id)
        self.assertEqual(fipa.id, fipa.resource_id)

        fipa.validate()

        # test update FloatingIpAssociation with port_id
        props = copy.deepcopy(fipa.properties.data)
        update_port_id = '2146dfbf-ba77-4083-8e86-d052f671ece5'
        props['port_id'] = update_port_id
        update_snippet = rsrc_defn.ResourceDefinition(fipa.name, fipa.type(),
                                                      stack.t.parse(stack,
                                                                    props))

        scheduler.TaskRunner(fipa.update, update_snippet)()
        self.assertEqual((fipa.UPDATE, fipa.COMPLETE), fipa.state)

        # test update FloatingIpAssociation with floatingip_id
        props = copy.deepcopy(fipa.properties.data)
        update_flip_id = '2146dfbf-ba77-4083-8e86-d052f671ece5'
        props['floatingip_id'] = update_flip_id
        update_snippet = rsrc_defn.ResourceDefinition(fipa.name, fipa.type(),
                                                      props)

        scheduler.TaskRunner(fipa.update, update_snippet)()
        self.assertEqual((fipa.UPDATE, fipa.COMPLETE), fipa.state)

        # test update FloatingIpAssociation with port_id and floatingip_id
        props = copy.deepcopy(fipa.properties.data)
        update_flip_id = 'fc68ea2c-b60b-4b4f-bd82-94ec81110766'
        update_port_id = 'ade6fcac-7d47-416e-a3d7-ad12efe445c1'
        props['floatingip_id'] = update_flip_id
        props['port_id'] = update_port_id
        update_snippet = rsrc_defn.ResourceDefinition(fipa.name, fipa.type(),
                                                      props)

        scheduler.TaskRunner(fipa.update, update_snippet)()
        self.assertEqual((fipa.UPDATE, fipa.COMPLETE), fipa.state)

        scheduler.TaskRunner(fipa.delete)()
        scheduler.TaskRunner(p.delete)()
        scheduler.TaskRunner(fip.delete)()

        fip.state_set(fip.CREATE, fip.COMPLETE, 'to delete again')
        p.state_set(p.CREATE, p.COMPLETE, 'to delete again')

        self.assertIsNone(scheduler.TaskRunner(p.delete)())
        scheduler.TaskRunner(fip.delete)()

        self.m.VerifyAll()

Example 44

Project: framework
Source File: dataobject.py
View license
    def save(self, recursive=False, skip=None, _hook=None):
        """
        Save the object to the persistent backend and clear cache, making use
        of the specified conflict resolve settings.
        It will also invalidate certain caches if required. For example lists pointing towards this
        object
        :param recursive: Save related sub-objects recursively
        :param skip: Skip certain relations
        :param _hook:
        """
        if self.volatile is True:
            raise VolatileObjectException()

        tries = 0
        successful = False
        optimistic = True
        last_assert = None
        while successful is False:
            tries += 1
            if tries > 5:
                DataObject._logger.error('Raising RaceConditionException. Last AssertException: {0}'.format(last_assert))
                raise RaceConditionException()

            invalid_fields = []
            for prop in self._properties:
                if prop.mandatory is True and self._data[prop.name] is None:
                    invalid_fields.append(prop.name)
            for relation in self._relations:
                if relation.mandatory is True and self._data[relation.name]['guid'] is None:
                    invalid_fields.append(relation.name)
            if len(invalid_fields) > 0:
                raise MissingMandatoryFieldsException('Missing fields on {0}: {1}'.format(self._classname, ', '.join(invalid_fields)))

            if recursive:
                # Save objects that point to us (e.g. disk.vmachine - if this is disk)
                for relation in self._relations:
                    if relation.name != skip:  # disks will be skipped
                        item = getattr(self, relation.name)
                        if item is not None:
                            item.save(recursive=True, skip=relation.foreign_key)

                # Save object we point at (e.g. machine.vdisks - if this is machine)
                relations = RelationMapper.load_foreign_relations(self.__class__)
                if relations is not None:
                    for key, info in relations.iteritems():
                        if key != skip:  # machine will be skipped
                            if info['list'] is True:
                                for item in getattr(self, key).iterloaded():
                                    item.save(recursive=True, skip=info['key'])
                            else:
                                item = getattr(self, key)
                                if item is not None:
                                    item.save(recursive=True, skip=info['key'])

            validation_keys = []
            for relation in self._relations:
                if self._data[relation.name]['guid'] is not None:
                    if relation.foreign_type is None:
                        cls = self.__class__
                    else:
                        cls = relation.foreign_type
                    validation_keys.append('{0}_{1}_{2}'.format(DataObject.NAMESPACE, cls.__name__.lower(), self._data[relation.name]['guid']))
            try:
                [_ for _ in self._persistent.get_multi(validation_keys)]
            except KeyNotFoundException:
                raise ObjectNotFoundException('One of the relations specified in {0} with guid \'{1}\' was not found'.format(
                    self.__class__.__name__, self._guid
                ))

            transaction = self._persistent.begin_transaction()
            if self._new is True:
                data = {'_version': 0}
                store_data = {'_version': 0}
            elif optimistic is True:
                self._persistent.assert_value(self._key, self._original, transaction=transaction)
                data = copy.deepcopy(self._original)
                store_data = copy.deepcopy(self._original)
            else:
                try:
                    current_data = self._persistent.get(self._key)
                except KeyNotFoundException:
                    raise ObjectNotFoundException('{0} with guid \'{1}\' was deleted'.format(
                        self.__class__.__name__, self._guid
                    ))
                self._persistent.assert_value(self._key, current_data, transaction=transaction)
                data = copy.deepcopy(current_data)
                store_data = copy.deepcopy(current_data)

            changed_fields = []
            data_conflicts = []
            for attribute in self._data.keys():
                if attribute == '_version':
                    continue
                if self._data[attribute] != self._original[attribute]:
                    # We changed this value
                    changed_fields.append(attribute)
                    if attribute in data and self._original[attribute] != data[attribute]:
                        # Some other process also wrote to the database
                        if self._datastore_wins is None:
                            # In case we didn't set a policy, we raise the conflicts
                            data_conflicts.append(attribute)
                        elif self._datastore_wins is False:
                            # If the datastore should not win, we just overwrite the data
                            data[attribute] = self._data[attribute]
                        # If the datastore should win, we discard/ignore our change
                    else:
                        # Normal scenario, saving data
                        data[attribute] = self._data[attribute]
                elif attribute not in data:
                    data[attribute] = self._data[attribute]
            for attribute in data.keys():
                if attribute == '_version':
                    continue
                if attribute not in self._data:
                    del data[attribute]
            if data_conflicts:
                raise ConcurrencyException('Got field conflicts while saving {0}. Conflicts: {1}'.format(
                    self._classname, ', '.join(data_conflicts)
                ))

            # Refresh internal data structure
            self._data = copy.deepcopy(data)

            # First, update reverse index
            base_reverse_key = 'ovs_reverseindex_{0}_{1}|{2}|{3}'
            for relation in self._relations:
                key = relation.name
                original_guid = self._original[key]['guid']
                new_guid = self._data[key]['guid']
                if original_guid != new_guid:
                    if relation.foreign_type is None:
                        classname = self.__class__.__name__.lower()
                    else:
                        classname = relation.foreign_type.__name__.lower()
                    if original_guid is not None:
                        reverse_key = base_reverse_key.format(classname, original_guid, relation.foreign_key, self.guid)
                        self._persistent.delete(reverse_key, must_exist=False, transaction=transaction)
                    if new_guid is not None:
                        reverse_key = base_reverse_key.format(classname, new_guid, relation.foreign_key, self.guid)
                        self._persistent.assert_exists('{0}_{1}_{2}'.format(DataObject.NAMESPACE, classname, new_guid))
                        self._persistent.set(reverse_key, 0, transaction=transaction)

            # Second, invalidate property lists
            cache_key = '{0}_{1}|'.format(DataList.CACHELINK, self._classname)
            list_keys = set()
            cache_keys = {}
            for key in list(self._persistent.prefix(cache_key)):
                list_key, field = key.replace(cache_key, '').split('|')
                if list_key not in cache_keys:
                    cache_keys[list_key] = [False, []]
                cache_keys[list_key][1].append(key)
                if field in changed_fields or self._new is True:
                    list_keys.add(list_key)
                    cache_keys[list_key][0] = True
            for list_key in list_keys:
                self._volatile.delete(list_key)
                if cache_keys[list_key][0] is True:
                    for key in cache_keys[list_key][1]:
                        self._persistent.delete(key, must_exist=False, transaction=transaction)

            # Validate unique constraints
            unique_key = 'ovs_unique_{0}_{{0}}_{{1}}'.format(self._classname)
            for prop in self._properties:
                if prop.unique is True:
                    if prop.property_type not in [str, int, float, long]:
                        raise RuntimeError('A unique constraint can only be set on field of type str, int, float, or long')
                    if self._new is False and prop.name in changed_fields:
                        key = unique_key.format(prop.name, hashlib.sha1(str(store_data[prop.name])).hexdigest())
                        self._persistent.assert_value(key, self._key, transaction=transaction)
                        self._persistent.delete(key, transaction=transaction)
                    key = unique_key.format(prop.name, hashlib.sha1(str(self._data[prop.name])).hexdigest())
                    if self._new is True or prop.name in changed_fields:
                        self._persistent.assert_value(key, None, transaction=transaction)
                    self._persistent.set(key, self._key, transaction=transaction)

            if _hook is not None:
                _hook()

            # Save the data
            self._data['_version'] += 1
            try:
                self._mutex_version.acquire(30)
                self._persistent.set(self._key, self._data, transaction=transaction)
                self._persistent.apply_transaction(transaction)
                self._volatile.delete(self._key)
                successful = True
            except KeyNotFoundException as ex:
                if 'ovs_unique' in ex.message and tries == 1:
                    optimistic = False
                elif ex.message != self._key:
                    raise
                else:
                    raise ObjectNotFoundException('{0} with guid \'{1}\' was deleted'.format(
                        self.__class__.__name__, self._guid
                    ))
            except AssertException as ex:
                if 'ovs_unique' in str(ex.message):
                    field = str(ex.message).split('_', 3)[-1].rsplit('_', 1)[0]
                    raise UniqueConstraintViolationException('The unique constraint on {0}.{1} was violated'.format(
                        self.__class__.__name__, field
                    ))
                last_assert = ex
                optimistic = False
                self._mutex_version.release()  # Make sure it's released before a sleep
                time.sleep(randint(0, 25) / 100.0)
            finally:
                self._mutex_version.release()

        self.invalidate_dynamics()
        self._original = copy.deepcopy(self._data)

        self.dirty = False
        self._new = False

Example 45

Project: osmc
Source File: xmlfunctions.py
View license
    def writexml( self, profilelist, mainmenuID, groups, numLevels, buildMode, progress, options ): 
        # Reset the hashlist, add the profile list and script version
        hashlist.list = []
        hashlist.list.append( ["::PROFILELIST::", profilelist] )
        hashlist.list.append( ["::SCRIPTVER::", __addonversion__] )
        hashlist.list.append( ["::XBMCVER::", __xbmcversion__] )
        
        # Clear any skin settings for backgrounds and widgets
        DATA._reset_backgroundandwidgets()
        self.widgetCount = 1
        
        # Create a new tree and includes for the various groups
        tree = xmltree.ElementTree( xmltree.Element( "includes" ) )
        root = tree.getroot()
        
        # Create a Template object and pass it the root
        Template = template.Template()
        Template.includes = root
        
        # Get any shortcuts we're checking for
        self.checkForShortcuts = []
        overridestree = DATA._get_overrides_skin()
        if overridestree is not None:
            checkForShorctcutsOverrides = overridestree.getroot().findall( "checkforshortcut" )
            for checkForShortcutOverride in checkForShorctcutsOverrides:
                if "property" in checkForShortcutOverride.attrib:
                    # Add this to the list of shortcuts we'll check for
                    self.checkForShortcuts.append( ( checkForShortcutOverride.text.lower(), checkForShortcutOverride.attrib.get( "property" ), "False" ) )
        
        mainmenuTree = xmltree.SubElement( root, "include" )
        mainmenuTree.set( "name", "skinshortcuts-mainmenu" )
        
        submenuTrees = []
        for level in range( 0,  int( numLevels) + 1 ):
            subelement = xmltree.SubElement(root, "include")
            subtree = xmltree.SubElement( root, "include" )
            if level == 0:
                subtree.set( "name", "skinshortcuts-submenu" )
            else:
                subtree.set( "name", "skinshortcuts-submenu-" + str( level ) )
            if not subtree in submenuTrees:
                submenuTrees.append( subtree )
        
        if buildMode == "single":
            allmenuTree = xmltree.SubElement( root, "include" )
            allmenuTree.set( "name", "skinshortcuts-allmenus" )
        
        profilePercent = 100 / len( profilelist )
        profileCount = -1
        
        submenuNodes = {}
        
        for profile in profilelist:
            log( "Building menu for profile %s" %( profile[ 2 ] ) )
            # Load profile details
            profileDir = profile[0]
            profileVis = profile[1]
            profileCount += 1
            
            # Reset whether we have settings
            self.hasSettings = False
            
            # Reset any checkForShortcuts to say we haven't found them
            newCheckForShortcuts = []
            for checkforShortcut in self.checkForShortcuts:
                newCheckForShortcuts.append( ( checkforShortcut[ 0 ], checkforShortcut[ 1 ], "False" ) )
            self.checkForShortcuts = newCheckForShortcuts

            # Clear any previous labelID's
            DATA._clear_labelID()
            
            # Create objects to hold the items
            menuitems = []
            templateMainMenuItems = xmltree.Element( "includes" )
            
            # If building the main menu, split the mainmenu shortcut nodes into the menuitems list
            if groups == "" or groups.split( "|" )[0] == "mainmenu":
                # Set a skinstring that marks that we're providing the whole menu
                xbmc.executebuiltin( "Skin.SetBool(SkinShortcuts-FullMenu)" )
                for node in DATA._get_shortcuts( "mainmenu", None, True, profile[0] ).findall( "shortcut" ):
                    menuitems.append( node )
            else:
                # Clear any skinstring marking that we're providing the whole menu
                xbmc.executebuiltin( "Skin.Reset(SkinShortcuts-FullMenu)" )
                    
            # If building specific groups, split them into the menuitems list
            count = 0
            if groups != "":
                for group in groups.split( "|" ):
                    if count != 0 or group != "mainmenu":
                        menuitems.append( group )
                        
            if len( menuitems ) == 0:
                # No groups to build
                break
                
            itemidmainmenu = 0
            percent = profilePercent / len( menuitems )
            
            i = 0
            for item in menuitems:
                i += 1
                itemidmainmenu += 1
                progress.update( ( profilePercent * profileCount) + percent * i )
                submenuDefaultID = None

                if not isinstance( item, basestring ):
                    # This is a main menu item (we know this because it's an element, not a string)
                    submenu = item.find( "labelID" ).text

                    # Build the menu item
                    menuitem = self.buildElement( item, "mainmenu", None, profile[1], DATA.slugify( submenu, convertInteger=True ), itemid = itemidmainmenu, options = options )

                    # Add the menu item to the various includes, retaining a reference to them
                    mainmenuItemA = copy.deepcopy( menuitem )
                    mainmenuTree.append( mainmenuItemA )

                    if buildMode == "single":
                        mainmenuItemB = copy.deepcopy( menuitem )
                        allmenuTree.append( mainmenuItemB )

                    templateMainMenuItems.append( copy.deepcopy( menuitem ) )

                    # Get submenu defaultID
                    submenuDefaultID = item.find( "defaultID" ).text
                else:
                    # It's an additional menu, so get its labelID
                    submenu = DATA._get_labelID( item, None )
                    
                # Build the submenu
                count = 0 # Used to keep track of additional submenu
                for submenuTree in submenuTrees:
                    submenuVisibilityName = submenu
                    if count == 1:
                        submenu = submenu + "." + str( count )
                    elif count != 0:
                        submenu = submenu[:-1] + str( count )
                        submenuVisibilityName = submenu[:-2]
                        
                    # Get the tree's we're going to write the menu to
                    if submenu in submenuNodes:
                        justmenuTreeA = submenuNodes[ submenu ][ 0 ]
                        justmenuTreeB = submenuNodes[ submenu ][ 1 ]
                    else:
                        # Create these nodes
                        justmenuTreeA = xmltree.SubElement( root, "include" )
                        justmenuTreeB = xmltree.SubElement( root, "include" )
                        
                        justmenuTreeA.set( "name", "skinshortcuts-group-" + DATA.slugify( submenu ) )
                        justmenuTreeB.set( "name", "skinshortcuts-group-alt-" + DATA.slugify( submenu ) )
                        
                        submenuNodes[ submenu ] = [ justmenuTreeA, justmenuTreeB ]
                        
                    itemidsubmenu = 0
                    
                    # Get the shortcuts for the submenu
                    if count == 0:
                        submenudata = DATA._get_shortcuts( submenu, submenuDefaultID, True, profile[0] )
                    else:
                        submenudata = DATA._get_shortcuts( submenu, None, True, profile[0] )
                        
                    if type( submenudata ) == list:
                        submenuitems = submenudata
                    else:
                        submenuitems = submenudata.findall( "shortcut" )
                    
                    # Are there any submenu items for the main menu?
                    if count == 0:
                        if len( submenuitems ) != 0:
                            try:
                                hasSubMenu = xmltree.SubElement( mainmenuItemA, "property" )
                                hasSubMenu.set( "name", "hasSubmenu" )
                                hasSubMenu.text = "True"
                                if buildMode == "single":
                                    hasSubMenu = xmltree.SubElement( mainmenuItemB, "property" )
                                    hasSubMenu.set( "name", "hasSubmenu" )
                                    hasSubMenu.text = "True"
                            except:
                                # There probably isn't a main menu
                                pass
                        else:   
                            try:
                                hasSubMenu = xmltree.SubElement( mainmenuItemA, "property" )
                                hasSubMenu.set( "name", "hasSubmenu" )
                                hasSubMenu.text = "False"
                                if buildMode == "single":
                                    hasSubMenu = xmltree.SubElement( mainmenuItemB, "property" )
                                    hasSubMenu.set( "name", "hasSubmenu" )
                                    hasSubMenu.text = "False"
                            except:
                                # There probably isn't a main menu
                                pass
                
                    # If we're building a single menu, update the onclicks of the main menu
                    if buildMode == "single" and not len( submenuitems ) == 0:
                        for onclickelement in mainmenuItemB.findall( "onclick" ):
                            if "condition" in onclickelement.attrib:
                                onclickelement.set( "condition", "StringCompare(Window(10000).Property(submenuVisibility)," + DATA.slugify( submenuVisibilityName, convertInteger=True ) + ") + [" + onclickelement.attrib.get( "condition" ) + "]" )
                                newonclick = xmltree.SubElement( mainmenuItemB, "onclick" )
                                newonclick.text = "SetProperty(submenuVisibility," + DATA.slugify( submenuVisibilityName, convertInteger=True ) + ",10000)"
                                newonclick.set( "condition", onclickelement.attrib.get( "condition" ) )
                            else:
                                onclickelement.set( "condition", "StringCompare(Window(10000).Property(submenuVisibility)," + DATA.slugify( submenuVisibilityName, convertInteger=True ) + ")" )
                                newonclick = xmltree.SubElement( mainmenuItemB, "onclick" )
                                newonclick.text = "SetProperty(submenuVisibility," + DATA.slugify( submenuVisibilityName, convertInteger=True ) + ",10000)"
                    
                    # Build the submenu items
                    for submenuItem in submenuitems:
                        itemidsubmenu += 1
                        # Build the item without any visibility conditions
                        menuitem = self.buildElement( submenuItem, submenu, None, profile[1], itemid = itemidsubmenu, options = options )
                        isSubMenuElement = xmltree.SubElement( menuitem, "property" )
                        isSubMenuElement.set( "name", "isSubmenu" )
                        isSubMenuElement.text = "True"

                        # Add it, with appropriate visibility conditions, to the various submenu includes
                        justmenuTreeA.append( copy.deepcopy( menuitem ) )

                        menuitemCopy = copy.deepcopy( menuitem )
                        visibilityElement = menuitemCopy.find( "visible" )
                        visibilityElement.text = "[%s] + %s" %( visibilityElement.text, "StringCompare(Window(10000).Property(submenuVisibility)," + DATA.slugify( submenuVisibilityName, convertInteger=True ) + ")" )
                        justmenuTreeB.append( menuitemCopy )

                        if buildMode == "single":
                            allmenuTree.append( copy.deepcopy( menuitemCopy ) )

                        menuitemCopy = copy.deepcopy( menuitem )
                        visibilityElement = menuitemCopy.find( "visible" )
                        visibilityElement.text = "[%s] + %s" %( visibilityElement.text, "StringCompare(Container(" + mainmenuID + ").ListItem.Property(submenuVisibility)," + DATA.slugify( submenuVisibilityName, convertInteger=True ) + ")" )
                        submenuTree.append( menuitemCopy )
                            
                    # Build the template for the submenu
                    Template.parseItems( "submenu", count, justmenuTreeA, profile[ 2 ], profile[ 1 ], "StringCompare(Container(" + mainmenuID + ").ListItem.Property(submenuVisibility)," + DATA.slugify( submenuVisibilityName, convertInteger=True ) + ")", item )
                        
                    count += 1

            if self.hasSettings == False:
                # Check if the overrides asks for a forced settings...
                overridestree = DATA._get_overrides_skin()
                if overridestree is not None:
                    forceSettings = overridestree.getroot().find( "forcesettings" )
                    if forceSettings is not None:
                        # We want a settings option to be added
                        newelement = xmltree.SubElement( mainmenuTree, "item" )
                        xmltree.SubElement( newelement, "label" ).text = "$LOCALIZE[10004]"
                        xmltree.SubElement( newelement, "icon" ).text = "DefaultShortcut.png"
                        xmltree.SubElement( newelement, "onclick" ).text = "ActivateWindow(settings)" 
                        xmltree.SubElement( newelement, "visible" ).text = profile[1]
                        
                        if buildMode == "single":
                            newelement = xmltree.SubElement( mainmenuTree, "item" )
                            xmltree.SubElement( newelement, "label" ).text = "$LOCALIZE[10004]"
                            xmltree.SubElement( newelement, "icon" ).text = "DefaultShortcut.png"
                            xmltree.SubElement( newelement, "onclick" ).text = "ActivateWindow(settings)" 
                            xmltree.SubElement( newelement, "visible" ).text = profile[1]
                            
            if len( self.checkForShortcuts ) != 0:
                # Add a value to the variable for all checkForShortcuts
                for checkForShortcut in self.checkForShortcuts:
                    if profile[ 1 ] is not None and xbmc.getCondVisibility( profile[ 1 ] ):
                        # Current profile - set the skin bool
                        if checkForShortcut[ 2 ] == "True":
                            xbmc.executebuiltin( "Skin.SetBool(%s)" %( checkForShortcut[ 1 ] ) )
                        else:
                            xbmc.executebuiltin( "Skin.Reset(%s)" %( checkForShortcut[ 1 ] ) )
                    # Save this to the hashes file, so we can set it on profile changes
                    hashlist.list.append( [ "::SKINBOOL::", [ profile[ 1 ], checkForShortcut[ 1 ], checkForShortcut[ 2 ] ] ] )

            # Build the template for the main menu
            Template.parseItems( "mainmenu", 0, templateMainMenuItems, profile[ 2 ], profile[ 1 ], "", "", mainmenuID )
                    
        # Build any 'Other' templates
        Template.writeOthers()
        
        progress.update( 100, message = __language__( 32098 ) )
                
        # Get the skins addon.xml file
        addonpath = xbmc.translatePath( os.path.join( "special://skin/", 'addon.xml').encode("utf-8") ).decode("utf-8")
        addon = xmltree.parse( addonpath )
        extensionpoints = addon.findall( "extension" )
        paths = []
        for extensionpoint in extensionpoints:
            if extensionpoint.attrib.get( "point" ) == "xbmc.gui.skin":
                resolutions = extensionpoint.findall( "res" )
                for resolution in resolutions:
                    path = xbmc.translatePath( os.path.join( self.skinDir, resolution.attrib.get( "folder" ), "script-skinshortcuts-includes.xml").encode("utf-8") ).decode("utf-8")
                    paths.append( path )
        skinVersion = addon.getroot().attrib.get( "version" )
        
        # Save the tree
        DATA.indent( tree.getroot() )
        for path in paths:
            tree.write( path, encoding="UTF-8" )
            
            # Save the hash of the file we've just written
            with open(path, "r+") as f:
                DATA._save_hash( path, f.read() )
                f.close()
            
        # Save the hashes
        # Append the skin version to the hashlist
        hashlist.list.append( ["::SKINVER::", skinVersion] )

        # Save the hashes
        file = xbmcvfs.File( os.path.join( __masterpath__ , xbmc.getSkinDir() + ".hash" ), "w" )
        file.write( repr( hashlist.list ) )
        file.close

Example 46

Project: pygmi
Source File: igrf.py
View license
    def settings(self):
        """
        Settings Dialog. This is the main entrypoint into this routine. It also
        contains the main IGRF code.
        """
# Variable declaration
# Control variables
        self.proj.set_current(self.indata['Raster'][0].wkt)

        data = dp.merge(self.indata['Raster'])
        self.combobox_dtm.clear()
        self.combobox_mag.clear()
        for i in data:
            self.combobox_dtm.addItem(i.dataid)
            self.combobox_mag.addItem(i.dataid)

        tmp = self.exec_()

        if tmp == 1:
            self.acceptall()
            tmp = True
        else:
            return False

#        again = 1
        mdf = open(__file__.rpartition('\\')[0]+'\\IGRF11.cof')
        modbuff = mdf.readlines()
        fileline = -1                            # First line will be 1
        model = []
        epoch = []
        max1 = []
        max2 = []
        max3 = []
        yrmin = []
        yrmax = []
        altmin = []
        altmax = []
        irec_pos = []
# First model will be 0
        for i in modbuff:
            fileline += 1  # On new line
            if i[:3] == '   ':
                i2 = i.split()
                model.append(i2[0])
                epoch.append(float(i2[1]))
                max1.append(int(i2[2]))
                max2.append(int(i2[3]))
                max3.append(int(i2[4]))
                yrmin.append(float(i2[5]))
                yrmax.append(float(i2[6]))
                altmin.append(float(i2[7]))
                altmax.append(float(i2[8]))
                irec_pos.append(fileline)

        i = self.combobox_mag.currentIndex()
        maggrid = data[i]

        i = self.combobox_dtm.currentIndex()
        data = data[i]
        altgrid = data.data.flatten() * 0.001  # in km

        maxyr = max(yrmax)
        sdate = self.dateedit.date()
        sdate = sdate.year()+sdate.dayOfYear()/sdate.daysInYear()
        alt = self.dsb_alt.value()
        xrange = data.tlx+data.xdim/2.+np.arange(data.cols)*data.xdim
        yrange = data.tly-data.ydim/2.-np.arange(data.rows)*data.ydim
        xdat, ydat = np.meshgrid(xrange, yrange)
        xdat = xdat.flatten()
        ydat = ydat.flatten()

        igrf_F = altgrid * 0
        # Pick model
        yrmax = np.array(yrmax)
        modelI = sum(yrmax < sdate)
        igdgc = 1

        if (sdate > maxyr) and (sdate < maxyr+1):
            print("\nWarning: The date %4.2f is out of range,\n", sdate)
            print("but still within one year of model expiration date.\n")
            print("An updated model file is available before 1.1.%4.0f\n",
                  maxyr)

        if max2[modelI] == 0:
            self.getshc(modbuff, 1, irec_pos[modelI], max1[modelI], 0)
            self.getshc(modbuff, 1, irec_pos[modelI+1], max1[modelI+1], 1)
            nmax = self.interpsh(sdate, yrmin[modelI], max1[modelI],
                                 yrmin[modelI+1], max1[modelI+1], 2)
            nmax = self.interpsh(sdate+1, yrmin[modelI], max1[modelI],
                                 yrmin[modelI+1], max1[modelI+1], 3)
        else:
            self.getshc(modbuff, 1, irec_pos[modelI], max1[modelI], 0)
            self.getshc(modbuff, 0, irec_pos[modelI], max2[modelI], 1)
            nmax = self.extrapsh(sdate, epoch[modelI], max1[modelI],
                                 max2[modelI], 2)
            nmax = self.extrapsh(sdate+1, epoch[modelI], max1[modelI],
                                 max2[modelI], 3)

        progress = 0
        maxlen = xdat.size

        for i in self.pbar.iter(range(maxlen)):
            if igrf_F.mask[i] == True:
                continue

            tmp = int(i*100/maxlen)
            if tmp > progress:
                progress = tmp
#                self.reportback('Calculation: ' + str(progress) + '%', True)

            longitude, latitude, _ = self.ctrans.TransformPoint(xdat[i],
                                                                ydat[i])
            alt = altgrid[i]

# Do the first calculations
            self.shval3(igdgc, latitude, longitude, alt, nmax, 3)
            self.dihf(3)
#            self.shval3(igdgc, latitude, longitude, alt, nmax, 4)
#            self.dihf(4)
#
#            RAD2DEG = (180.0/np.pi)
#
#            self.ddot = ((self.dtemp - self.d)*RAD2DEG)
#            if self.ddot > 180.0:
#                self.ddot -= 360.0
#            if self.ddot <= -180.0:
#                self.ddot += 360.0
#            self.ddot *= 60.0
#
#            self.idot = ((self.itemp - self.i)*RAD2DEG)*60
#            self.d = self.d*(RAD2DEG)
#            self.i = self.i*(RAD2DEG)
#            self.hdot = self.htemp - self.h
#            self.xdot = self.xtemp - self.x
#            self.ydot = self.ytemp - self.y
#            self.zdot = self.ztemp - self.z
#            self.fdot = self.ftemp - self.f
#
#          # deal with geographic and magnetic poles
#
#            if self.h < 100.0:  # at magnetic poles
#                self.d = np.nan
#                self.ddot = np.nan
#              # while rest is ok
#
#            if 90.0-abs(latitude) <= 0.001:  # at geographic poles
#                self.x = np.nan
#                self.y = np.nan
#                self.d = np.nan
#                self.xdot = np.nan
#                self.ydot = np.nan
#                self.ddot = np.nan

#            print('Test Data')
#            print('==========')
#
#            print('# Date 2014.5', sdate)
#            print('# Coord-System D')
#            print('# Altitude K100', alt)
#            print('# Latitude 70.3', latitude)
#            print('# Longitude 30.8', longitude)
#            print('# D_deg 13d', int(self.d))
#            print('# D_min 51m ', int((self.d-int(self.d))*60))
#            print('# I_deg 78d', int(self.i))
#            print('# I_min 55m', int((self.i-int(self.i))*60))
#            print('# H_nT 9987.9 {0:.1f}'.format(self.h))
#            print('# X_nT 9697.4 {0:.1f}'.format(self.x))
#            print('# Y_nT 2391.4 {0:.1f}'.format(self.y))
#            print('# Z_nT 51022.3 {0:.1f}'.format(self.z))
#            print('# F_nT 51990.7 {0:.1f}'.format(self.f))
#            print('# dD_min 10.9 {0:.1f}'.format(self.ddot))
#            print('# dI_min 1.0 {0:.1f}'.format(self.idot))
#            print('# dH_nT -10.4 {0:.1f}'.format(self.hdot))
#            print('# dX_nT -17.7 {0:.1f}'.format(self.xdot))
#            print('# dY_nT 28.1 {0:.1f}'.format(self.ydot))
#            print('# dZ_nT 29.0 {0:.1f}'.format(self.zdot))
#            print('# dF_nT 26.5 {0:.1f}'.format(self.fdot))
            igrf_F[i] = self.f

        self.outdata['Raster'] = copy.deepcopy(self.indata['Raster'])
        igrf_F = np.ma.array(igrf_F)
        igrf_F.shape = data.data.shape
        igrf_F.mask = data.data.mask
        self.outdata['Raster'].append(copy.deepcopy(data))
        self.outdata['Raster'][-1].data = igrf_F
        self.outdata['Raster'][-1].dataid = 'IGRF'
        self.outdata['Raster'].append(copy.deepcopy(maggrid))
        self.outdata['Raster'][-1].data -= igrf_F
        self.outdata['Raster'][-1].dataid = 'Magnetic Data: IGRF Corrected'

        self.reportback('Calculation: Completed', True)

        return True

Example 47

Project: GAE-Bulk-Mailer
Source File: base.py
View license
    def __new__(cls, name, bases, attrs):
        super_new = super(ModelBase, cls).__new__

        # six.with_metaclass() inserts an extra class called 'NewBase' in the
        # inheritance tree: Model -> NewBase -> object. But the initialization
        # should be executed only once for a given model class.

        # attrs will never be empty for classes declared in the standard way
        # (ie. with the `class` keyword). This is quite robust.
        if name == 'NewBase' and attrs == {}:
            return super_new(cls, name, bases, attrs)

        # Also ensure initialization is only performed for subclasses of Model
        # (excluding Model class itself).
        parents = [b for b in bases if isinstance(b, ModelBase) and
                not (b.__name__ == 'NewBase' and b.__mro__ == (b, object))]
        if not parents:
            return super_new(cls, name, bases, attrs)

        # Create the class.
        module = attrs.pop('__module__')
        new_class = super_new(cls, name, bases, {'__module__': module})
        attr_meta = attrs.pop('Meta', None)
        abstract = getattr(attr_meta, 'abstract', False)
        if not attr_meta:
            meta = getattr(new_class, 'Meta', None)
        else:
            meta = attr_meta
        base_meta = getattr(new_class, '_meta', None)

        if getattr(meta, 'app_label', None) is None:
            # Figure out the app_label by looking one level up.
            # For 'django.contrib.sites.models', this would be 'sites'.
            model_module = sys.modules[new_class.__module__]
            kwargs = {"app_label": model_module.__name__.split('.')[-2]}
        else:
            kwargs = {}

        new_class.add_to_class('_meta', Options(meta, **kwargs))
        if not abstract:
            new_class.add_to_class('DoesNotExist', subclass_exception(str('DoesNotExist'),
                    tuple(x.DoesNotExist
                          for x in parents if hasattr(x, '_meta') and not x._meta.abstract)
                    or (ObjectDoesNotExist,),
                    module, attached_to=new_class))
            new_class.add_to_class('MultipleObjectsReturned', subclass_exception(str('MultipleObjectsReturned'),
                    tuple(x.MultipleObjectsReturned
                          for x in parents if hasattr(x, '_meta') and not x._meta.abstract)
                    or (MultipleObjectsReturned,),
                    module, attached_to=new_class))
            if base_meta and not base_meta.abstract:
                # Non-abstract child classes inherit some attributes from their
                # non-abstract parent (unless an ABC comes before it in the
                # method resolution order).
                if not hasattr(meta, 'ordering'):
                    new_class._meta.ordering = base_meta.ordering
                if not hasattr(meta, 'get_latest_by'):
                    new_class._meta.get_latest_by = base_meta.get_latest_by

        is_proxy = new_class._meta.proxy

        # If the model is a proxy, ensure that the base class
        # hasn't been swapped out.
        if is_proxy and base_meta and base_meta.swapped:
            raise TypeError("%s cannot proxy the swapped model '%s'." % (name, base_meta.swapped))

        if getattr(new_class, '_default_manager', None):
            if not is_proxy:
                # Multi-table inheritance doesn't inherit default manager from
                # parents.
                new_class._default_manager = None
                new_class._base_manager = None
            else:
                # Proxy classes do inherit parent's default manager, if none is
                # set explicitly.
                new_class._default_manager = new_class._default_manager._copy_to_model(new_class)
                new_class._base_manager = new_class._base_manager._copy_to_model(new_class)

        # Bail out early if we have already created this class.
        m = get_model(new_class._meta.app_label, name,
                      seed_cache=False, only_installed=False)
        if m is not None:
            return m

        # Add all attributes to the class.
        for obj_name, obj in attrs.items():
            new_class.add_to_class(obj_name, obj)

        # All the fields of any type declared on this model
        new_fields = new_class._meta.local_fields + \
                     new_class._meta.local_many_to_many + \
                     new_class._meta.virtual_fields
        field_names = set([f.name for f in new_fields])

        # Basic setup for proxy models.
        if is_proxy:
            base = None
            for parent in [cls for cls in parents if hasattr(cls, '_meta')]:
                if parent._meta.abstract:
                    if parent._meta.fields:
                        raise TypeError("Abstract base class containing model fields not permitted for proxy model '%s'." % name)
                    else:
                        continue
                if base is not None:
                    raise TypeError("Proxy model '%s' has more than one non-abstract model base class." % name)
                else:
                    base = parent
            if base is None:
                    raise TypeError("Proxy model '%s' has no non-abstract model base class." % name)
            if (new_class._meta.local_fields or
                    new_class._meta.local_many_to_many):
                raise FieldError("Proxy model '%s' contains model fields." % name)
            new_class._meta.setup_proxy(base)
            new_class._meta.concrete_model = base._meta.concrete_model
        else:
            new_class._meta.concrete_model = new_class

        # Do the appropriate setup for any model parents.
        o2o_map = dict([(f.rel.to, f) for f in new_class._meta.local_fields
                if isinstance(f, OneToOneField)])

        for base in parents:
            original_base = base
            if not hasattr(base, '_meta'):
                # Things without _meta aren't functional models, so they're
                # uninteresting parents.
                continue

            parent_fields = base._meta.local_fields + base._meta.local_many_to_many
            # Check for clashes between locally declared fields and those
            # on the base classes (we cannot handle shadowed fields at the
            # moment).
            for field in parent_fields:
                if field.name in field_names:
                    raise FieldError('Local field %r in class %r clashes '
                                     'with field of similar name from '
                                     'base class %r' %
                                        (field.name, name, base.__name__))
            if not base._meta.abstract:
                # Concrete classes...
                base = base._meta.concrete_model
                if base in o2o_map:
                    field = o2o_map[base]
                elif not is_proxy:
                    attr_name = '%s_ptr' % base._meta.module_name
                    field = OneToOneField(base, name=attr_name,
                            auto_created=True, parent_link=True)
                    new_class.add_to_class(attr_name, field)
                else:
                    field = None
                new_class._meta.parents[base] = field
            else:
                # .. and abstract ones.
                for field in parent_fields:
                    new_class.add_to_class(field.name, copy.deepcopy(field))

                # Pass any non-abstract parent classes onto child.
                new_class._meta.parents.update(base._meta.parents)

            # Inherit managers from the abstract base classes.
            new_class.copy_managers(base._meta.abstract_managers)

            # Proxy models inherit the non-abstract managers from their base,
            # unless they have redefined any of them.
            if is_proxy:
                new_class.copy_managers(original_base._meta.concrete_managers)

            # Inherit virtual fields (like GenericForeignKey) from the parent
            # class
            for field in base._meta.virtual_fields:
                if base._meta.abstract and field.name in field_names:
                    raise FieldError('Local field %r in class %r clashes '\
                                     'with field of similar name from '\
                                     'abstract base class %r' % \
                                        (field.name, name, base.__name__))
                new_class.add_to_class(field.name, copy.deepcopy(field))

        if abstract:
            # Abstract base models can't be instantiated and don't appear in
            # the list of models for an app. We do the final setup for them a
            # little differently from normal models.
            attr_meta.abstract = False
            new_class.Meta = attr_meta
            return new_class

        new_class._prepare()
        register_models(new_class._meta.app_label, new_class)

        # Because of the way imports happen (recursively), we may or may not be
        # the first time this model tries to register with the framework. There
        # should only be one class for each model, so we always return the
        # registered version.
        return get_model(new_class._meta.app_label, name,
                         seed_cache=False, only_installed=False)

Example 48

View license
    def run_cgi(self):
        """Execute a CGI script."""
        dir, rest = self.cgi_info
        path = dir + '/' + rest
        i = path.find('/', len(dir)+1)
        while i >= 0:
            nextdir = path[:i]
            nextrest = path[i+1:]

            scriptdir = self.translate_path(nextdir)
            if os.path.isdir(scriptdir):
                dir, rest = nextdir, nextrest
                i = path.find('/', len(dir)+1)
            else:
                break

        # find an explicit query string, if present.
        rest, _, query = rest.partition('?')

        # dissect the part after the directory name into a script name &
        # a possible additional path, to be stored in PATH_INFO.
        i = rest.find('/')
        if i >= 0:
            script, rest = rest[:i], rest[i:]
        else:
            script, rest = rest, ''

        scriptname = dir + '/' + script
        scriptfile = self.translate_path(scriptname)
        if not os.path.exists(scriptfile):
            self.send_error(404, "No such CGI script (%r)" % scriptname)
            return
        if not os.path.isfile(scriptfile):
            self.send_error(403, "CGI script is not a plain file (%r)" %
                            scriptname)
            return
        ispy = self.is_python(scriptname)
        if not ispy:
            if not (self.have_fork or self.have_popen2 or self.have_popen3):
                self.send_error(403, "CGI script is not a Python script (%r)" %
                                scriptname)
                return
            if not self.is_executable(scriptfile):
                self.send_error(403, "CGI script is not executable (%r)" %
                                scriptname)
                return

        # Reference: http://hoohoo.ncsa.uiuc.edu/cgi/env.html
        # XXX Much of the following could be prepared ahead of time!
        env = copy.deepcopy(os.environ)
        env['SERVER_SOFTWARE'] = self.version_string()
        env['SERVER_NAME'] = self.server.server_name
        env['GATEWAY_INTERFACE'] = 'CGI/1.1'
        env['SERVER_PROTOCOL'] = self.protocol_version
        env['SERVER_PORT'] = str(self.server.server_port)
        env['REQUEST_METHOD'] = self.command
        uqrest = urllib.unquote(rest)
        env['PATH_INFO'] = uqrest
        env['PATH_TRANSLATED'] = self.translate_path(uqrest)
        env['SCRIPT_NAME'] = scriptname
        if query:
            env['QUERY_STRING'] = query
        host = self.address_string()
        if host != self.client_address[0]:
            env['REMOTE_HOST'] = host
        env['REMOTE_ADDR'] = self.client_address[0]
        authorization = self.headers.getheader("authorization")
        if authorization:
            authorization = authorization.split()
            if len(authorization) == 2:
                import base64, binascii
                env['AUTH_TYPE'] = authorization[0]
                if authorization[0].lower() == "basic":
                    try:
                        authorization = base64.decodestring(authorization[1])
                    except binascii.Error:
                        pass
                    else:
                        authorization = authorization.split(':')
                        if len(authorization) == 2:
                            env['REMOTE_USER'] = authorization[0]
        # XXX REMOTE_IDENT
        if self.headers.typeheader is None:
            env['CONTENT_TYPE'] = self.headers.type
        else:
            env['CONTENT_TYPE'] = self.headers.typeheader
        length = self.headers.getheader('content-length')
        if length:
            env['CONTENT_LENGTH'] = length
        referer = self.headers.getheader('referer')
        if referer:
            env['HTTP_REFERER'] = referer
        accept = []
        for line in self.headers.getallmatchingheaders('accept'):
            if line[:1] in "\t\n\r ":
                accept.append(line.strip())
            else:
                accept = accept + line[7:].split(',')
        env['HTTP_ACCEPT'] = ','.join(accept)
        ua = self.headers.getheader('user-agent')
        if ua:
            env['HTTP_USER_AGENT'] = ua
        co = filter(None, self.headers.getheaders('cookie'))
        if co:
            env['HTTP_COOKIE'] = ', '.join(co)
        # XXX Other HTTP_* headers
        # Since we're setting the env in the parent, provide empty
        # values to override previously set values
        for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
                  'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
            env.setdefault(k, "")

        self.send_response(200, "Script output follows")

        decoded_query = query.replace('+', ' ')

        if self.have_fork:
            # Unix -- fork as we should
            args = [script]
            if '=' not in decoded_query:
                args.append(decoded_query)
            nobody = nobody_uid()
            self.wfile.flush() # Always flush before forking
            pid = os.fork()
            if pid != 0:
                # Parent
                pid, sts = os.waitpid(pid, 0)
                # throw away additional data [see bug #427345]
                while select.select([self.rfile], [], [], 0)[0]:
                    if not self.rfile.read(1):
                        break
                if sts:
                    self.log_error("CGI script exit status %#x", sts)
                return
            # Child
            try:
                try:
                    os.setuid(nobody)
                except os.error:
                    pass
                os.dup2(self.rfile.fileno(), 0)
                os.dup2(self.wfile.fileno(), 1)
                os.execve(scriptfile, args, env)
            except:
                self.server.handle_error(self.request, self.client_address)
                os._exit(127)

        else:
            # Non Unix - use subprocess
            import subprocess
            cmdline = [scriptfile]
            if self.is_python(scriptfile):
                interp = sys.executable
                if interp.lower().endswith("w.exe"):
                    # On Windows, use python.exe, not pythonw.exe
                    interp = interp[:-5] + interp[-4:]
                cmdline = [interp, '-u'] + cmdline
            if '=' not in query:
                cmdline.append(query)

            self.log_message("command: %s", subprocess.list2cmdline(cmdline))
            try:
                nbytes = int(length)
            except (TypeError, ValueError):
                nbytes = 0
            p = subprocess.Popen(cmdline,
                                 stdin = subprocess.PIPE,
                                 stdout = subprocess.PIPE,
                                 stderr = subprocess.PIPE,
                                 env = env
                                )
            if self.command.lower() == "post" and nbytes > 0:
                data = self.rfile.read(nbytes)
            else:
                data = None
            # throw away additional data [see bug #427345]
            while select.select([self.rfile._sock], [], [], 0)[0]:
                if not self.rfile._sock.recv(1):
                    break
            stdout, stderr = p.communicate(data)
            self.wfile.write(stdout)
            if stderr:
                self.log_error('%s', stderr)
            p.stderr.close()
            p.stdout.close()
            status = p.returncode
            if status:
                self.log_error("CGI script exit status %#x", status)
            else:
                self.log_message("CGI script exited OK")

Example 49

Project: ochopod
Source File: marathon.py
View license
    def boot(self, lifecycle, model=Reactive, tools=None, local=False):

        #
        # - quick check to make sure we get the right implementations
        #
        assert issubclass(model, Model), 'model must derive from ochopod.api.Model'
        assert issubclass(lifecycle, LifeCycle), 'lifecycle must derive from ochopod.api.LifeCycle'

        #
        # - instantiate our flask endpoint
        # - default to a json handler for all HTTP errors (including an unexpected 500)
        #
        def _handler(error):
            http = error.code if isinstance(error, HTTPException) else 500
            return '{}', http, {'Content-Type': 'application/json; charset=utf-8'}

        web = Flask(__name__)
        for code in default_exceptions.iterkeys():
            web.error_handler_spec[None][code] = _handler

        #
        # - default presets in case we run outside of marathon (local vm testing)
        # - any environment variable prefixed with "ochopod." is of interest for us (e.g this is what the user puts
        #   in the marathon application configuration for instance)
        # - the other settings come from marathon (namely the port bindings & application/task identifiers)
        # - the MESOS_TASK_ID is important to keep around to enable task deletion via the marathon REST API
        #
        env = \
            {
                'ochopod_application':  '',
                'ochopod_cluster':      'default',
                'ochopod_debug':        'true',
                'ochopod_local':        'false',
                'ochopod_namespace':    'marathon',
                'ochopod_port':         '8080',
                'ochopod_start':        'true',
                'ochopod_task':         '',
                'ochopod_zk':           '',
                'PORT_8080':            '8080'
            }

        env.update(os.environ)
        ochopod.enable_cli_log(debug=env['ochopod_debug'] == 'true')
        try:

            #
            # - grab our environment variables (which are set by the marathon executor)
            # - extract the mesos PORT_* bindings and construct a small remapping dict
            #
            ports = {}
            logger.debug('environment ->\n%s' % '\n'.join(['\t%s -> %s' % (k, v) for k, v in env.items()]))
            for key, val in env.items():
                if key.startswith('PORT_'):
                    ports[key[5:]] = int(val)

            #
            # - keep any "ochopod_" environment variable & trim its prefix
            # - default all our settings, especially the mandatory ones
            # - the ip and zookeeper are defaulted to localhost to enable easy testing
            #
            hints = {k[8:]: v for k, v in env.items() if k.startswith('ochopod_')}
            if local or hints['local'] == 'true':

                #
                # - we are running in local mode (e.g on a dev workstation)
                # - default everything to localhost
                #
                logger.info('running in local mode (make sure you run a standalone zookeeper)')
                hints.update(
                    {
                        'fwk':          'marathon (debug)',
                        'ip':           '127.0.0.1',
                        'node':         'local',
                        'ports':        ports,
                        'public':       '127.0.0.1',
                        'zk':           '127.0.0.1:2181'
                    })
            else:

                #
                # - extend our hints
                # - add the application + task
                #
                hints.update(
                    {
                        'application':  env['MARATHON_APP_ID'][1:],
                        'fwk':          'marathon',
                        'ip':           '',
                        'node':         '',
                        'ports':        ports,
                        'public':       '',
                        'task':         env['MESOS_TASK_ID'],
                        'zk':           ''
                    })

                #
                # - use whatever subclass is implementing us to infer 'ip', 'node' and 'public'
                #
                hints.update(self.get_node_details())

                #
                # - lookup for the zookeeper connection string from environment variable or on disk
                # - we have to look into different places depending on how mesos was installed
                #
                def _1():

                    #
                    # - most recent DCOS release
                    # - $MESOS_MASTER is located in /opt/mesosphere/etc/mesos-slave-common
                    # - the snippet in there is prefixed by MESOS_MASTER=zk://<ip:port>/mesos
                    #
                    logger.debug('checking /opt/mesosphere/etc/mesos-slave-common...')
                    _, lines = shell("grep MESOS_MASTER /opt/mesosphere/etc/mesos-slave-common")
                    return lines[0][13:]

                def _2():

                    #
                    # - same as above except for slightly older DCOS releases
                    # - $MESOS_MASTER is located in /opt/mesosphere/etc/mesos-slave
                    #
                    logger.debug('checking /opt/mesosphere/etc/mesos-slave...')
                    _, lines = shell("grep MESOS_MASTER /opt/mesosphere/etc/mesos-slave")
                    return lines[0][13:]

                def _3():

                    #
                    # - a regular package install will write the slave settings under /etc/mesos/zk (the snippet in
                    #   there looks like zk://10.0.0.56:2181/mesos)
                    #
                    logger.debug('checking /etc/mesos/zk...')
                    _, lines = shell("cat /etc/mesos/zk")
                    return lines[0]

                def _4():

                    #
                    # - look for ZK from environment variables
                    # - user can pass down ZK using $ochopod_zk
                    # - this last-resort situation is used mostly for debugging
                    #
                    logger.debug('checking $ochopod_zk environment variable...')
                    return env['ochopod_zk']

                #
                # - depending on how the slave has been installed we might have to look in various places
                #   to find out what our zookeeper connection string is
                # - use urlparse to keep the host:port part of the URL (possibly including a login+password)
                #
                for method in [_1, _2, _3, _4]:
                    try:
                        hints['zk'] = urlparse(method()).netloc
                        break

                    except Exception:
                        pass

            #
            # - the cluster must be fully qualified with a namespace (which is defaulted anyway)
            #
            assert hints['zk'], 'unable to determine where zookeeper is located (unsupported/bogus mesos setup ?)'
            assert hints['cluster'] and hints['namespace'], 'no cluster and/or namespace defined (user error ?)'

            #
            # - load the tools
            #
            if tools:
                tools = {tool.tag: tool for tool in [clz() for clz in tools if issubclass(clz, Tool)] if tool.tag}
                logger.info('supporting tools %s' % ', '.join(tools.keys()))

            #
            # - start the life-cycle actor which will pass our hints (as a json object) to its underlying sub-process
            # - start our coordinator which will connect to zookeeper and attempt to lead the cluster
            # - upon grabbing the lock the model actor will start and implement the configuration process
            # - the hints are a convenient bag for any data that may change at runtime and needs to be returned (via
            #   the HTTP POST /info request)
            # - what's being registered in zookeeper is immutable though and decorated with additional details by
            #   the coordinator (especially the pod index which is derived from zookeeper)
            #
            latch = ThreadingFuture()
            logger.info('starting %s.%s (marathon) @ %s' % (hints['namespace'], hints['cluster'], hints['node']))
            breadcrumbs = deepcopy(hints)
            hints['metrics'] = {}
            hints['dependencies'] = model.depends_on
            env.update({'ochopod': json.dumps(hints)})
            executor = lifecycle.start(env, latch, hints)
            coordinator = Coordinator.start(
                hints['zk'].split(','),
                hints['namespace'],
                hints['cluster'],
                int(hints['port']),
                breadcrumbs,
                model,
                hints)

            #
            # - external hook forcing a coordinator reset
            # - this will force a re-connection to zookeeper and pod registration
            # - please note this will not impact the pod lifecycle (e.g the underlying sub-process will be
            #   left running)
            #
            @web.route('/reset', methods=['POST'])
            def _reset():

                logger.debug('http in -> /reset')
                coordinator.tell({'request': 'reset'})
                return '{}', 200, {'Content-Type': 'application/json; charset=utf-8'}

            #
            # - external hook exposing information about our pod
            # - this is a subset of what's registered in zookeeper at boot-time
            # - the data is dynamic and updated from time to time by the model and executor actors
            # - from @pferro -> the pod's dependencies defined in the model are now added as well
            #
            @web.route('/info', methods=['POST'])
            def _info():

                logger.debug('http in -> /info')
                keys = \
                    [
                        'application',
                        'dependencies',
                        'ip',
                        'metrics',
                        'node',
                        'port',
                        'ports',
                        'process',
                        'public',
                        'state',
                        'status',
                        'task'
                    ]

                subset = dict(filter(lambda i: i[0] in keys, hints.iteritems()))
                return json.dumps(subset), 200, {'Content-Type': 'application/json; charset=utf-8'}

            #
            # - external hook exposing our circular log
            # - reverse and dump ochopod.log as a json array
            #
            @web.route('/log', methods=['POST'])
            def _log():

                logger.debug('http in -> /log')
                with open(ochopod.LOG, 'r+') as log:
                    lines = [line for line in log]
                    return json.dumps(lines), 200, {'Content-Type': 'application/json; charset=utf-8'}

            #
            # - RPC call to run a custom tool within the pod
            #
            @web.route('/exec', methods=['POST'])
            def _exec():

                logger.debug('http in -> /exec')

                #
                # - make sure the command (first token in the X-Shell header) maps to a tool
                # - if no match abort on a 404
                #
                line = request.headers['X-Shell']
                tokens = line.split(' ')
                cmd = tokens[0]
                if not tools or cmd not in tools:
                    return '{}', 404, {'Content-Type': 'application/json; charset=utf-8'}

                code = 1
                tool = tools[cmd]

                #
                # - make sure the parser does not sys.exit()
                #
                class _Parser(ArgumentParser):
                    def exit(self, status=0, message=None):
                        raise ValueError(message)

                #
                # - prep a temporary directory
                # - invoke define_cmdline_parsing()
                # - switch off parsing if NotImplementedError is raised
                #
                use_parser = 1
                parser = _Parser(prog=tool.tag)
                try:
                    tool.define_cmdline_parsing(parser)

                except NotImplementedError:
                    use_parser = 0

                tmp = tempfile.mkdtemp()
                try:

                    #
                    # - parse the command line
                    # - upload any attachment
                    #
                    args = parser.parse_args(tokens[1:]) if use_parser else ' '.join(tokens[1:])
                    for tag, upload in request.files.items():
                        where = path.join(tmp, tag)
                        logger.debug('uploading %s @ %s' % (tag, tmp))
                        upload.save(where)

                    #
                    # - run the tool method
                    # - pass the temporary directory as well
                    #
                    logger.info('invoking "%s"' % line)
                    code, lines = tool.body(args, tmp)

                except ValueError as failure:

                    lines = [parser.format_help() if failure.message is None else failure.message]

                except Exception as failure:

                    lines = ['unexpected failure -> %s' % failure]

                finally:

                    #
                    # - make sure to cleanup our temporary directory
                    #
                    shutil.rmtree(tmp)

                out = \
                    {
                        'code': code,
                        'stdout': lines
                    }

                return json.dumps(out), 200, {'Content-Type': 'application/json; charset=utf-8'}

            #
            # - web-hook used to receive requests from the leader or the CLI tools
            # - those requests are passed down to the executor actor
            # - any non HTTP 200 response is a failure
            # - failure to acknowledge within the specified timeout will result in a HTTP 408 (REQUEST TIMEOUT)
            # - attempting to send a control request to a dead pod will result in a HTTP 410 (GONE)
            #
            @web.route('/control/<task>', methods=['POST'])
            @web.route('/control/<task>/<timeout>', methods=['POST'])
            def _control(task, timeout='60'):

                logger.debug('http in -> /control/%s' % task)
                if task not in ['check', 'on', 'off', 'ok', 'kill', 'signal']:

                    #
                    # - fail on a HTTP 400 if the request is not supported
                    #
                    return '{}', 400, {'Content-Type': 'application/json; charset=utf-8'}

                try:

                    ts = time.time()
                    latch = ThreadingFuture()
                    executor.tell({'request': task, 'latch': latch, 'data': request.data})
                    js, code = latch.get(timeout=int(timeout))
                    ms = time.time() - ts
                    logger.debug('http out -> HTTP %s (%d ms)' % (code, ms))
                    return json.dumps(js), code, {'Content-Type': 'application/json; charset=utf-8'}

                except Timeout:

                    #
                    # - we failed to match the specified timeout
                    # - gracefully fail on a HTTP 408
                    #
                    return '{}', 408, {'Content-Type': 'application/json; charset=utf-8'}

                except ActorDeadError:

                    #
                    # - the executor has been shutdown (probably after a /control/kill)
                    # - gracefully fail on a HTTP 410
                    #
                    return '{}', 410, {'Content-Type': 'application/json; charset=utf-8'}

            #
            # - internal hook required to shutdown the web-server
            # - it's not possible to do it outside of a request handler
            # - make sure this calls only comes from localhost (todo)
            #
            @web.route('/terminate', methods=['POST'])
            def _terminate():

                request.environ.get('werkzeug.server.shutdown')()
                return '{}', 200, {'Content-Type': 'application/json; charset=utf-8'}

            #
            # - run werkzeug from a separate thread to avoid blocking the main one
            # - we'll have to shut it down using a dedicated HTTP POST
            #
            class _Runner(threading.Thread):

                def run(self):
                    web.run(host='0.0.0.0', port=int(hints['port']), threaded=True)

            try:

                #
                # - block on the lifecycle actor until it goes down (usually after a /control/kill request)
                #
                _Runner().start()
                spin_lock(latch)
                logger.debug('pod is dead, idling')
                while 1:

                    #
                    # - simply idle forever (since the framework would restart any container that terminates)
                    # - /log and /hints HTTP requests will succeed (and show the pod as being killed)
                    # - any control request will now fail
                    #
                    time.sleep(60.0)

            finally:

                #
                # - when we exit the block first shutdown our executor (which may probably be already down)
                # - then shutdown the coordinator to un-register from zookeeper
                # - finally ask werkzeug to shutdown via a REST call
                #
                shutdown(executor)
                shutdown(coordinator)
                post('http://127.0.0.1:%s/terminate' % env['ochopod_port'])

        except KeyboardInterrupt:

            logger.fatal('CTRL-C pressed')

        except Exception as failure:

            logger.fatal('unexpected condition -> %s' % diagnostic(failure))

Example 50

Project: pupil
Source File: main.py
View license
def session(rec_dir):

    system_plugins = [Log_Display,Seek_Bar,Trim_Marks]
    vis_plugins = sorted([Vis_Circle,Vis_Polyline,Vis_Light_Points,Vis_Cross,Vis_Watermark,Eye_Video_Overlay,Scan_Path], key=lambda x: x.__name__)
    analysis_plugins = sorted([Gaze_Position_2D_Fixation_Detector,Pupil_Angle_3D_Fixation_Detector,Pupil_Angle_3D_Fixation_Detector,Manual_Gaze_Correction,Video_Export_Launcher,Offline_Surface_Tracker,Raw_Data_Exporter,Batch_Exporter,Annotation_Player], key=lambda x: x.__name__)
    other_plugins = sorted([Show_Calibration,Log_History], key=lambda x: x.__name__)
    user_plugins = sorted(import_runtime_plugins(os.path.join(user_dir,'plugins')), key=lambda x: x.__name__)
    user_launchable_plugins = vis_plugins + analysis_plugins + other_plugins + user_plugins
    available_plugins = system_plugins + user_launchable_plugins
    name_by_index = [p.__name__ for p in available_plugins]
    index_by_name = dict(zip(name_by_index,range(len(name_by_index))))
    plugin_by_name = dict(zip(name_by_index,available_plugins))


    # Callback functions
    def on_resize(window,w, h):
        g_pool.gui.update_window(w,h)
        g_pool.gui.collect_menus()
        graph.adjust_size(w,h)
        adjust_gl_view(w,h)
        for p in g_pool.plugins:
            p.on_window_resize(window,w,h)

    def on_key(window, key, scancode, action, mods):
        g_pool.gui.update_key(key,scancode,action,mods)

    def on_char(window,char):
        g_pool.gui.update_char(char)

    def on_button(window,button, action, mods):
        g_pool.gui.update_button(button,action,mods)
        pos = glfwGetCursorPos(window)
        pos = normalize(pos,glfwGetWindowSize(window))
        pos = denormalize(pos,(frame.img.shape[1],frame.img.shape[0]) ) # Position in img pixels
        for p in g_pool.plugins:
            p.on_click(pos,button,action)

    def on_pos(window,x, y):
        hdpi_factor = float(glfwGetFramebufferSize(window)[0]/glfwGetWindowSize(window)[0])
        g_pool.gui.update_mouse(x*hdpi_factor,y*hdpi_factor)

    def on_scroll(window,x,y):
        g_pool.gui.update_scroll(x,y*y_scroll_factor)


    def on_drop(window,count,paths):
        for x in range(count):
            new_rec_dir =  paths[x]
            if is_pupil_rec_dir(new_rec_dir):
                logger.debug("Starting new session with '%s'"%new_rec_dir)
                global rec_dir
                rec_dir = new_rec_dir
                glfwSetWindowShouldClose(window,True)
            else:
                logger.error("'%s' is not a valid pupil recording"%new_rec_dir)




    tick = delta_t()
    def get_dt():
        return next(tick)

    update_recording_to_recent(rec_dir)

    video_path = [f for f in glob(os.path.join(rec_dir,"world.*")) if f[-3:] in ('mp4','mkv','avi')][0]
    timestamps_path = os.path.join(rec_dir, "world_timestamps.npy")
    pupil_data_path = os.path.join(rec_dir, "pupil_data")

    meta_info = load_meta_info(rec_dir)
    rec_version = read_rec_version(meta_info)
    app_version = get_version(version_file)

    # log info about Pupil Platform and Platform in player.log
    logger.info('Application Version: %s'%app_version)
    logger.info('System Info: %s'%get_system_info())

    timestamps = np.load(timestamps_path)

    # create container for globally scoped vars
    g_pool = Global_Container()
    g_pool.app = 'player'

    # Initialize capture
    cap = File_Source(g_pool,video_path,timestamps=list(timestamps))

    # load session persistent settings
    session_settings = Persistent_Dict(os.path.join(user_dir,"user_settings"))
    if session_settings.get("version",VersionFormat('0.0')) < get_version(version_file):
        logger.info("Session setting are from older version of this app. I will not use those.")
        session_settings.clear()

    width,height = session_settings.get('window_size',cap.frame_size)
    window_pos = session_settings.get('window_position',(0,0))
    main_window = glfwCreateWindow(width, height, "Pupil Player: "+meta_info["Recording Name"]+" - "+ rec_dir.split(os.path.sep)[-1], None, None)
    glfwSetWindowPos(main_window,window_pos[0],window_pos[1])
    glfwMakeContextCurrent(main_window)
    cygl.utils.init()

    # load pupil_positions, gaze_positions
    pupil_data = load_object(pupil_data_path)
    pupil_list = pupil_data['pupil_positions']
    gaze_list = pupil_data['gaze_positions']

    g_pool.binocular = meta_info.get('Eye Mode','monocular') == 'binocular'
    g_pool.version = app_version
    g_pool.capture = cap
    g_pool.timestamps = timestamps
    g_pool.play = False
    g_pool.new_seek = True
    g_pool.user_dir = user_dir
    g_pool.rec_dir = rec_dir
    g_pool.rec_version = rec_version
    g_pool.meta_info = meta_info
    g_pool.min_data_confidence = session_settings.get('min_data_confidence',0.6)
    g_pool.pupil_positions_by_frame = correlate_data(pupil_list,g_pool.timestamps)
    g_pool.gaze_positions_by_frame = correlate_data(gaze_list,g_pool.timestamps)
    g_pool.fixations_by_frame = [[] for x in g_pool.timestamps] #populated by the fixation detector plugin

    def next_frame(_):
        try:
            cap.seek_to_frame(cap.get_frame_index())
        except FileSeekError:
            logger.warning("Could not seek to next frame.")
        else:
            g_pool.new_seek = True

    def prev_frame(_):
        try:
            cap.seek_to_frame(cap.get_frame_index()-2)
        except FileSeekError:
            logger.warning("Could not seek to previous frame.")
        else:
            g_pool.new_seek = True

    def toggle_play(new_state):
        if cap.get_frame_index() >= cap.get_frame_count()-5:
            cap.seek_to_frame(1) #avoid pause set by hitting trimmark pause.
            logger.warning("End of video - restart at beginning.")
        g_pool.play = new_state

    def set_scale(new_scale):
        g_pool.gui.scale = new_scale
        g_pool.gui.collect_menus()

    def set_data_confidence(new_confidence):
        g_pool.min_data_confidence = new_confidence
        notification = {'subject':'min_data_confidence_changed'}
        notification['_notify_time_'] = time()+.8
        g_pool.delayed_notifications[notification['subject']] = notification

    def open_plugin(plugin):
        if plugin ==  "Select to load":
            return
        g_pool.plugins.add(plugin)

    def purge_plugins():
        for p in g_pool.plugins:
            if p.__class__ in user_launchable_plugins:
                p.alive = False
        g_pool.plugins.clean()

    def do_export(_):
        export_range = slice(g_pool.trim_marks.in_mark,g_pool.trim_marks.out_mark)
        export_dir = os.path.join(g_pool.rec_dir,'exports','%s-%s'%(export_range.start,export_range.stop))
        try:
            os.makedirs(export_dir)
        except OSError as e:
            if e.errno != errno.EEXIST:
                logger.error("Could not create export dir")
                raise e
            else:
                logger.warning("Previous export for range [%s-%s] already exsits - overwriting."%(export_range.start,export_range.stop))
        else:
            logger.info('Created export dir at "%s"'%export_dir)

        notification = {'subject':'should_export','range':export_range,'export_dir':export_dir}
        g_pool.notifications.append(notification)

    g_pool.gui = ui.UI()
    g_pool.gui.scale = session_settings.get('gui_scale',1)
    g_pool.main_menu = ui.Scrolling_Menu("Settings",pos=(-350,20),size=(300,500))
    g_pool.main_menu.append(ui.Button("Close Pupil Player",lambda:glfwSetWindowShouldClose(main_window,True)))
    g_pool.main_menu.append(ui.Slider('scale',g_pool.gui, setter=set_scale,step = .05,min=0.75,max=2.5,label='Interface Size'))
    g_pool.main_menu.append(ui.Info_Text('Player Version: %s'%g_pool.version))
    g_pool.main_menu.append(ui.Info_Text('Recording Version: %s'%rec_version))
    g_pool.main_menu.append(ui.Slider('min_data_confidence',g_pool, setter=set_data_confidence,step=.05 ,min=0.0,max=1.0,label='Confidence threshold'))

    selector_label = "Select to load"

    vis_labels = ["   " + p.__name__.replace('_',' ') for p in vis_plugins]
    analysis_labels = ["   " + p.__name__.replace('_',' ') for p in analysis_plugins]
    other_labels = ["   " + p.__name__.replace('_',' ') for p in other_plugins]
    user_labels = ["   " + p.__name__.replace('_',' ') for p in user_plugins]

    plugins = [selector_label, selector_label] + vis_plugins + [selector_label] + analysis_plugins + [selector_label] + other_plugins + [selector_label] + user_plugins
    labels = [selector_label, "Visualization"] + vis_labels + ["Analysis"] + analysis_labels + ["Other"] + other_labels + ["User added"] + user_labels

    g_pool.main_menu.append(ui.Selector('Open plugin:',
                                        selection = plugins,
                                        labels    = labels,
                                        setter    = open_plugin,
                                        getter    = lambda: selector_label))

    g_pool.main_menu.append(ui.Button('Close all plugins',purge_plugins))
    g_pool.main_menu.append(ui.Button('Reset window size',lambda: glfwSetWindowSize(main_window,cap.frame_size[0],cap.frame_size[1])) )
    g_pool.quickbar = ui.Stretching_Menu('Quick Bar',(0,100),(120,-100))
    g_pool.play_button = ui.Thumb('play',g_pool,label=unichr(0xf04b).encode('utf-8'),setter=toggle_play,hotkey=GLFW_KEY_SPACE,label_font='fontawesome',label_offset_x=5,label_offset_y=0,label_offset_size=-24)
    g_pool.play_button.on_color[:] = (0,1.,.0,.8)
    g_pool.forward_button = ui.Thumb('forward',label=unichr(0xf04e).encode('utf-8'),getter = lambda: False,setter= next_frame, hotkey=GLFW_KEY_RIGHT,label_font='fontawesome',label_offset_x=5,label_offset_y=0,label_offset_size=-24)
    g_pool.backward_button = ui.Thumb('backward',label=unichr(0xf04a).encode('utf-8'),getter = lambda: False, setter = prev_frame, hotkey=GLFW_KEY_LEFT,label_font='fontawesome',label_offset_x=-5,label_offset_y=0,label_offset_size=-24)
    g_pool.export_button = ui.Thumb('export',label=unichr(0xf063).encode('utf-8'),getter = lambda: False, setter = do_export, hotkey='e',label_font='fontawesome',label_offset_x=0,label_offset_y=2,label_offset_size=-24)
    g_pool.quickbar.extend([g_pool.play_button,g_pool.forward_button,g_pool.backward_button,g_pool.export_button])
    g_pool.gui.append(g_pool.quickbar)
    g_pool.gui.append(g_pool.main_menu)


    #we always load these plugins
    system_plugins = [('Trim_Marks',{}),('Seek_Bar',{})]
    default_plugins = [('Log_Display',{}),('Scan_Path',{}),('Vis_Polyline',{}),('Vis_Circle',{}),('Video_Export_Launcher',{})]
    previous_plugins = session_settings.get('loaded_plugins',default_plugins)
    g_pool.notifications = []
    g_pool.delayed_notifications = {}
    g_pool.plugins = Plugin_List(g_pool,plugin_by_name,system_plugins+previous_plugins)


    # Register callbacks main_window
    glfwSetFramebufferSizeCallback(main_window,on_resize)
    glfwSetKeyCallback(main_window,on_key)
    glfwSetCharCallback(main_window,on_char)
    glfwSetMouseButtonCallback(main_window,on_button)
    glfwSetCursorPosCallback(main_window,on_pos)
    glfwSetScrollCallback(main_window,on_scroll)
    glfwSetDropCallback(main_window,on_drop)
    #trigger on_resize
    on_resize(main_window, *glfwGetFramebufferSize(main_window))

    g_pool.gui.configuration = session_settings.get('ui_config',{})

    # gl_state settings
    basic_gl_setup()
    g_pool.image_tex = Named_Texture()

    #set up performace graphs:
    pid = os.getpid()
    ps = psutil.Process(pid)
    ts = None

    cpu_graph = graph.Bar_Graph()
    cpu_graph.pos = (20,110)
    cpu_graph.update_fn = ps.cpu_percent
    cpu_graph.update_rate = 5
    cpu_graph.label = 'CPU %0.1f'

    fps_graph = graph.Bar_Graph()
    fps_graph.pos = (140,110)
    fps_graph.update_rate = 5
    fps_graph.label = "%0.0f REC FPS"

    pupil_graph = graph.Bar_Graph(max_val=1.0)
    pupil_graph.pos = (260,110)
    pupil_graph.update_rate = 5
    pupil_graph.label = "Confidence: %0.2f"

    while not glfwWindowShouldClose(main_window):


        #grab new frame
        if g_pool.play or g_pool.new_seek:
            g_pool.new_seek = False
            try:
                new_frame = cap.get_frame_nowait()
            except EndofVideoFileError:
                #end of video logic: pause at last frame.
                g_pool.play=False
                logger.warning("end of video")
            update_graph = True
        else:
            update_graph = False


        frame = new_frame.copy()
        events = {}
        #report time between now and the last loop interation
        events['dt'] = get_dt()
        #new positons we make a deepcopy just like the image is a copy.
        events['gaze_positions'] = deepcopy(g_pool.gaze_positions_by_frame[frame.index])
        events['pupil_positions'] = deepcopy(g_pool.pupil_positions_by_frame[frame.index])

        if update_graph:
            #update performace graphs
            for p in  events['pupil_positions']:
                pupil_graph.add(p['confidence'])

            t = new_frame.timestamp
            if ts and ts != t:
                dt,ts = t-ts,t
                fps_graph.add(1./dt)

            g_pool.play_button.status_text = str(frame.index)
        #always update the CPU graph
        cpu_graph.update()


        # publish delayed notifiactions when their time has come.
        for n in g_pool.delayed_notifications.values():
            if n['_notify_time_'] < time():
                del n['_notify_time_']
                del g_pool.delayed_notifications[n['subject']]
                g_pool.notifications.append(n)

        # notify each plugin if there are new notifactions:
        while g_pool.notifications:
            n = g_pool.notifications.pop(0)
            for p in g_pool.plugins:
                p.on_notify(n)

        # allow each Plugin to do its work.
        for p in g_pool.plugins:
            p.update(frame,events)

        #check if a plugin need to be destroyed
        g_pool.plugins.clean()

        # render camera image
        glfwMakeContextCurrent(main_window)
        make_coord_system_norm_based()
        g_pool.image_tex.update_from_frame(frame)
        g_pool.image_tex.draw()
        make_coord_system_pixel_based(frame.img.shape)
        # render visual feedback from loaded plugins
        for p in g_pool.plugins:
            p.gl_display()

        graph.push_view()
        fps_graph.draw()
        cpu_graph.draw()
        pupil_graph.draw()
        graph.pop_view()
        g_pool.gui.update()

        #present frames at appropriate speed
        cap.wait(frame)

        glfwSwapBuffers(main_window)
        glfwPollEvents()

    session_settings['loaded_plugins'] = g_pool.plugins.get_initializers()
    session_settings['min_data_confidence'] = g_pool.min_data_confidence
    session_settings['gui_scale'] = g_pool.gui.scale
    session_settings['ui_config'] = g_pool.gui.configuration
    session_settings['window_size'] = glfwGetWindowSize(main_window)
    session_settings['window_position'] = glfwGetWindowPos(main_window)
    session_settings['version'] = g_pool.version
    session_settings.close()

    # de-init all running plugins
    for p in g_pool.plugins:
        p.alive = False
    g_pool.plugins.clean()

    cap.cleanup()
    g_pool.gui.terminate()
    glfwDestroyWindow(main_window)