os.path.exists

Here are the examples of the python api os.path.exists taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: firewalld
Source File: zone.py
View license
def zone_writer(zone, path=None):
    _path = path if path else zone.path

    if zone.filename:
        name = "%s/%s" % (_path, zone.filename)
    else:
        name = "%s/%s.xml" % (_path, zone.name)

    if os.path.exists(name):
        try:
            shutil.copy2(name, "%s.old" % name)
        except Exception as msg:
            log.error("Backup of file '%s' failed: %s", name, msg)

    dirpath = os.path.dirname(name)
    if dirpath.startswith(ETC_FIREWALLD) and not os.path.exists(dirpath):
        if not os.path.exists(ETC_FIREWALLD):
            os.mkdir(ETC_FIREWALLD, 0o750)
        os.mkdir(dirpath, 0o750)

    f = io.open(name, mode='wt', encoding='UTF-8')
    handler = IO_Object_XMLGenerator(f)
    handler.startDocument()

    # start zone element
    attrs = {}
    if zone.version and zone.version != "":
        attrs["version"] = zone.version
    if zone.target != DEFAULT_ZONE_TARGET:
        attrs["target"] = zone.target
    handler.startElement("zone", attrs)
    handler.ignorableWhitespace("\n")

    # short
    if zone.short and zone.short != "":
        handler.ignorableWhitespace("  ")
        handler.startElement("short", { })
        handler.characters(zone.short)
        handler.endElement("short")
        handler.ignorableWhitespace("\n")

    # description
    if zone.description and zone.description != "":
        handler.ignorableWhitespace("  ")
        handler.startElement("description", { })
        handler.characters(zone.description)
        handler.endElement("description")
        handler.ignorableWhitespace("\n")

    # interfaces
    for interface in uniqify(zone.interfaces):
        handler.ignorableWhitespace("  ")
        handler.simpleElement("interface", { "name": interface })
        handler.ignorableWhitespace("\n")

    # source
    for source in uniqify(zone.sources):
        handler.ignorableWhitespace("  ")
        if "ipset:" in source:
            handler.simpleElement("source", { "ipset": source[6:] })
        else:
            handler.simpleElement("source", { "address": source })
        handler.ignorableWhitespace("\n")

    # services
    for service in uniqify(zone.services):
        handler.ignorableWhitespace("  ")
        handler.simpleElement("service", { "name": service })
        handler.ignorableWhitespace("\n")

    # ports
    for port in uniqify(zone.ports):
        handler.ignorableWhitespace("  ")
        handler.simpleElement("port", { "port": port[0], "protocol": port[1] })
        handler.ignorableWhitespace("\n")

    # protocols
    for protocol in uniqify(zone.protocols):
        handler.ignorableWhitespace("  ")
        handler.simpleElement("protocol", { "value": protocol })
        handler.ignorableWhitespace("\n")

    # icmp-block-inversion
    if zone.icmp_block_inversion:
        handler.ignorableWhitespace("  ")
        handler.simpleElement("icmp-block-inversion", { })
        handler.ignorableWhitespace("\n")

    # icmp-blocks
    for icmp in uniqify(zone.icmp_blocks):
        handler.ignorableWhitespace("  ")
        handler.simpleElement("icmp-block", { "name": icmp })
        handler.ignorableWhitespace("\n")

    # masquerade
    if zone.masquerade:
        handler.ignorableWhitespace("  ")
        handler.simpleElement("masquerade", { })
        handler.ignorableWhitespace("\n")

    # forward-ports
    for forward in uniqify(zone.forward_ports):
        handler.ignorableWhitespace("  ")
        attrs = { "port": forward[0], "protocol": forward[1] }
        if forward[2] and forward[2] != "" :
            attrs["to-port"] = forward[2]
        if forward[3] and forward[3] != "" :
            attrs["to-addr"] = forward[3]
        handler.simpleElement("forward-port", attrs)
        handler.ignorableWhitespace("\n")

    # source-ports
    for port in uniqify(zone.source_ports):
        handler.ignorableWhitespace("  ")
        handler.simpleElement("source-port", { "port": port[0],
                                               "protocol": port[1] })
        handler.ignorableWhitespace("\n")

    # rules
    for rule in zone.rules:
        attrs = { }
        if rule.family:
            attrs["family"] = rule.family
        handler.ignorableWhitespace("  ")
        handler.startElement("rule", attrs)
        handler.ignorableWhitespace("\n")

        # source
        if rule.source:
            attrs = { }
            if rule.source.addr:
                attrs["address"] = rule.source.addr
            if rule.source.mac:
                attrs["mac"] = rule.source.mac
            if rule.source.ipset:
                attrs["ipset"] = rule.source.ipset
            if rule.source.invert:
                attrs["invert"] = "True"
            handler.ignorableWhitespace("    ")
            handler.simpleElement("source", attrs)
            handler.ignorableWhitespace("\n")

        # destination
        if rule.destination:
            attrs = { "address": rule.destination.addr }
            if rule.destination.invert:
                attrs["invert"] = "True"
            handler.ignorableWhitespace("    ")
            handler.simpleElement("destination", attrs)
            handler.ignorableWhitespace("\n")

        # element
        if rule.element:
            element = ""
            attrs = { }

            if type(rule.element) == rich.Rich_Service:
                element = "service"
                attrs["name"] = rule.element.name
            elif type(rule.element) == rich.Rich_Port:
                element = "port"
                attrs["port"] = rule.element.port
                attrs["protocol"] = rule.element.protocol
            elif type(rule.element) == rich.Rich_Protocol:
                element = "protocol"
                attrs["value"] = rule.element.value
            elif type(rule.element) == rich.Rich_Masquerade:
                element = "masquerade"
            elif type(rule.element) == rich.Rich_IcmpBlock:
                element = "icmp-block"
                attrs["name"] = rule.element.name
            elif type(rule.element) == rich.Rich_ForwardPort:
                element = "forward-port"
                attrs["port"] = rule.element.port
                attrs["protocol"] = rule.element.protocol
                if rule.element.to_port != "":
                    attrs["to-port"] = rule.element.to_port
                if rule.element.to_address != "":
                    attrs["to-addr"] = rule.element.to_address
            elif type(rule.element) == rich.Rich_SourcePort:
                element = "source-port"
                attrs["port"] = rule.element.port
                attrs["protocol"] = rule.element.protocol
            else:
                log.warning("Unknown element '%s'", type(rule.element))

            handler.ignorableWhitespace("    ")
            handler.simpleElement(element, attrs)
            handler.ignorableWhitespace("\n")

        # rule.element

        # log
        if rule.log:
            attrs = { }
            if rule.log.prefix:
                attrs["prefix"] = rule.log.prefix
            if rule.log.level:
                attrs["level"] = rule.log.level
            if rule.log.limit:
                handler.ignorableWhitespace("    ")
                handler.startElement("log", attrs)
                handler.ignorableWhitespace("\n      ")
                handler.simpleElement("limit",
                                      { "value": rule.log.limit.value })
                handler.ignorableWhitespace("\n    ")
                handler.endElement("log")
            else:
                handler.ignorableWhitespace("    ")
                handler.simpleElement("log", attrs)
            handler.ignorableWhitespace("\n")

        # audit
        if rule.audit:
            attrs = {}
            if rule.audit.limit:
                handler.ignorableWhitespace("    ")
                handler.startElement("audit", { })
                handler.ignorableWhitespace("\n      ")
                handler.simpleElement("limit",
                                      { "value": rule.audit.limit.value })
                handler.ignorableWhitespace("\n    ")
                handler.endElement("audit")
            else:
                handler.ignorableWhitespace("    ")
                handler.simpleElement("audit", attrs)
            handler.ignorableWhitespace("\n")

        # action
        if rule.action:
            action = ""
            attrs = { }
            if type(rule.action) == rich.Rich_Accept:
                action = "accept"
            elif type(rule.action) == rich.Rich_Reject:
                action = "reject"
                if rule.action.type:
                    attrs["type"] = rule.action.type
            elif type(rule.action) == rich.Rich_Drop:
                action = "drop"
            elif type(rule.action) == rich.Rich_Mark:
                action = "mark"
                attrs["set"] = rule.action.set
            else:
                log.warning("Unknown action '%s'", type(rule.action))
            if rule.action.limit:
                handler.ignorableWhitespace("    ")
                handler.startElement(action, attrs)
                handler.ignorableWhitespace("\n      ")
                handler.simpleElement("limit",
                                      { "value": rule.action.limit.value })
                handler.ignorableWhitespace("\n    ")
                handler.endElement(action)
            else:
                handler.ignorableWhitespace("    ")
                handler.simpleElement(action, attrs)
            handler.ignorableWhitespace("\n")

        handler.ignorableWhitespace("  ")
        handler.endElement("rule")
        handler.ignorableWhitespace("\n")

    # end zone element
    handler.endElement("zone")
    handler.ignorableWhitespace("\n")
    handler.endDocument()
    f.close()
    del handler

Example 2

Project: pyriscope
Source File: processor.py
View license
def process(args):
    # Make sure there are args, do a primary check for help.
    if len(args) == 0 or args[0] in ARGLIST_HELP:
        show_help()

    # Defaults arg flag settings.
    url_parts_list = []
    ffmpeg = True
    convert = False
    clean = False
    rotate = False
    agent_mocking = False
    name = ""
    live_duration = ""
    req_headers = {}

    # Check for ffmpeg.
    if shutil.which("ffmpeg") is None:
        ffmpeg = False

    # Read in args and set appropriate flags.
    cont = None
    for i in range(len(args)):
        if cont == ARGLIST_NAME:
            if args[i][0] in ('\'', '\"'):
                if args[i][-1:] == args[i][0]:
                    cont = None
                    name = args[i][1:-1]
                else:
                    cont = args[i][0]
                    name = args[i][1:]
            else:
                cont = None
                name = args[i]
            continue
        if cont in ('\'', '\"'):
            if args[i][-1:] == cont:
                cont = None
                name += " {}".format(args[i][:-1])
            else:
                name += " {}".format(args[i])
            continue
        if cont == ARGLIST_TIME:
            cont = None
            live_duration = args[i]

        if re.search(URL_PATTERN, args[i]) is not None:
            url_parts_list.append(dissect_url(args[i]))
        if args[i] in ARGLIST_HELP:
            show_help()
        if args[i] in ARGLIST_CONVERT:
            convert = True
        if args[i] in ARGLIST_CLEAN:
            convert = True
            clean = True
        if args[i] in ARGLIST_ROTATE:
            convert = True
            rotate = True
        if args[i] in ARGLIST_AGENTMOCK:
            agent_mocking = True
        if args[i] in ARGLIST_NAME:
            cont = ARGLIST_NAME
        if args[i] in ARGLIST_TIME:
            cont = ARGLIST_TIME


    # Check for URLs found.
    if len(url_parts_list) < 1:
        print("\nError: No valid URLs entered.")
        sys.exit(1)

    # Disable conversion/rotation if ffmpeg is not found.
    if convert and not ffmpeg:
        print("ffmpeg not found: Disabling conversion/rotation.")
        convert = False
        clean = False
        rotate = False

    # Set a mocked user agent.
    if agent_mocking:
        stdout("Getting mocked User-Agent.")
        req_headers['User-Agent'] = get_mocked_user_agent()
    else:
        req_headers['User-Agent'] = DEFAULT_UA


    url_count = 0
    for url_parts in url_parts_list:
        url_count += 1

        # Disable custom naming for multiple URLs.
        if len(url_parts_list) > 1:
            name = ""

        # Public Periscope API call to get information about the broadcast.
        if url_parts['token'] == "":
            req_url = PERISCOPE_GETBROADCAST.format("broadcast_id", url_parts['broadcast_id'])
        else:
            req_url = PERISCOPE_GETBROADCAST.format("token", url_parts['token'])

        stdout("Downloading broadcast information.")
        response = requests.get(req_url, headers=req_headers)
        broadcast_public = json.loads(response.text)

        if 'success' in broadcast_public and broadcast_public['success'] == False:
            print("\nError: Video expired/deleted/wasn't found: {}".format(url_parts['url']))
            continue

        # Loaded the correct JSON. Create file name.
        if name[-3:] == ".ts":
            name = name[:-3]
        if name[-4:] == ".mp4":
            name = name[:-4]
        if name == "":
            broadcast_start_time_end = broadcast_public['broadcast']['start'].rfind('.')
            timezone = broadcast_public['broadcast']['start'][broadcast_start_time_end:]
            timezone_start = timezone.rfind('-') if timezone.rfind('-') != -1 else timezone.rfind('+')
            timezone = timezone[timezone_start:].replace(':', '')
            to_zone = tz.tzlocal()
            broadcast_start_time = broadcast_public['broadcast']['start'][:broadcast_start_time_end]
            broadcast_start_time = "{}{}".format(broadcast_start_time, timezone)
            broadcast_start_time_dt = datetime.strptime(broadcast_start_time, '%Y-%m-%dT%H:%M:%S%z')
            broadcast_start_time_dt = broadcast_start_time_dt.astimezone(to_zone)
            broadcast_start_time = "{}-{:02d}-{:02d} {:02d}-{:02d}-{:02d}".format(
                broadcast_start_time_dt.year, broadcast_start_time_dt.month, broadcast_start_time_dt.day,
                broadcast_start_time_dt.hour, broadcast_start_time_dt.minute, broadcast_start_time_dt.second)
            name = "{} ({})".format(broadcast_public['broadcast']['username'], broadcast_start_time)

        name = sanitize(name)

        # Get ready to start capturing.
        if broadcast_public['broadcast']['state'] == 'RUNNING':
            # Cannot record live stream without ffmpeg.
            if not ffmpeg:
                print("\nError: Cannot record live stream without ffmpeg: {}".format(url_parts['url']))
                continue

            # The stream is live, start live capture.
            name = "{}.live".format(name)

            if url_parts['token'] == "":
                req_url = PERISCOPE_GETACCESS.format("broadcast_id", url_parts['broadcast_id'])
            else:
                req_url = PERISCOPE_GETACCESS.format("token", url_parts['token'])

            stdout("Downloading live stream information.")
            response = requests.get(req_url, headers=req_headers)
            access_public = json.loads(response.text)

            if 'success' in access_public and access_public['success'] == False:
                print("\nError: Video expired/deleted/wasn't found: {}".format(url_parts['url']))
                continue

            time_argument = ""
            if not live_duration == "":
                time_argument = " -t {}".format(live_duration)

            live_url = FFMPEG_LIVE.format(
                url_parts['url'],
                req_headers['User-Agent'],
                access_public['hls_url'],
                time_argument,
                name)

            # Start downloading live stream.
            stdout("Recording stream to {}.ts".format(name))

            Popen(live_url, shell=True, stdout=PIPE).stdout.read()

            stdoutnl("{}.ts Downloaded!".format(name))

            # Convert video to .mp4.
            if convert:
                stdout("Converting to {}.mp4".format(name))

                if rotate:
                    Popen(FFMPEG_ROT.format(name), shell=True, stdout=PIPE).stdout.read()
                else:
                    Popen(FFMPEG_NOROT.format(name), shell=True, stdout=PIPE).stdout.read()

                stdoutnl("Converted to {}.mp4!".format(name))

                if clean and os.path.exists("{}.ts".format(name)):
                    os.remove("{}.ts".format(name))
            continue

        else:
            if not broadcast_public['broadcast']['available_for_replay']:
                print("\nError: Replay unavailable: {}".format(url_parts['url']))
                continue

            # Broadcast replay is available.
            if url_parts['token'] == "":
                req_url = PERISCOPE_GETACCESS.format("broadcast_id", url_parts['broadcast_id'])
            else:
                req_url = PERISCOPE_GETACCESS.format("token", url_parts['token'])

            stdout("Downloading replay information.")
            response = requests.get(req_url, headers=req_headers)
            access_public = json.loads(response.text)

            if 'success' in access_public and access_public['success'] == False:
                print("\nError: Video expired/deleted/wasn't found: {}".format(url_parts['url']))
                continue

            base_url = access_public['replay_url']
            base_url_parts = dissect_replay_url(base_url)

            req_headers['Cookie'] = "{}={};{}={};{}={}".format(access_public['cookies'][0]['Name'],
                                                               access_public['cookies'][0]['Value'],
                                                               access_public['cookies'][1]['Name'],
                                                               access_public['cookies'][1]['Value'],
                                                               access_public['cookies'][2]['Name'],
                                                               access_public['cookies'][2]['Value'])
            req_headers['Host'] = "replay.periscope.tv"

            # Get the list of chunks to download.
            stdout("Downloading chunk list.")
            response = requests.get(access_public['replay_url'], headers=req_headers)
            chunks = response.text
            chunk_pattern = re.compile(r'chunk_\d+\.ts')

            download_list = []
            for chunk in re.findall(chunk_pattern, chunks):
                download_list.append(
                    {
                        'url': REPLAY_URL.format(base_url_parts['key'], chunk),
                        'file_name': chunk
                    }
                )

            # Download chunk .ts files and append them.
            pool = ThreadPool(name, DEFAULT_DL_THREADS, len(download_list))

            temp_dir_name = ".pyriscope.{}".format(name)
            if not os.path.exists(temp_dir_name):
                os.makedirs(temp_dir_name)

            stdout("Downloading replay {}.ts.".format(name))

            for chunk_info in download_list:
                temp_file_path = "{}/{}".format(temp_dir_name, chunk_info['file_name'])
                chunk_info['file_path'] = temp_file_path
                pool.add_task(download_chunk, chunk_info['url'], req_headers, temp_file_path)

            pool.wait_completion()

            if os.path.exists("{}.ts".format(name)):
                try:
                    os.remove("{}.ts".format(name))
                except:
                    stdoutnl("Failed to delete preexisting {}.ts.".format(name))

            with open("{}.ts".format(name), 'wb') as handle:
                for chunk_info in download_list:
                    file_path = chunk_info['file_path']
                    if not os.path.exists(file_path) or os.path.getsize(file_path) == 0:
                        break
                    with open(file_path, 'rb') as ts_file:
                        handle.write(ts_file.read())

            # don't delete temp if the download had missing chunks, just in case
            if pool.is_complete() and os.path.exists(temp_dir_name):
                try:
                    shutil.rmtree(temp_dir_name)
                except:
                    stdoutnl("Failed to delete temp folder: {}.".format(temp_dir_name))

            if pool.is_complete():
                stdoutnl("{}.ts Downloaded!".format(name))
            else:
                stdoutnl("{}.ts partially Downloaded!".format(name))

            # Convert video to .mp4.
            if convert:
                stdout("Converting to {}.mp4".format(name))

                if rotate:
                    Popen(FFMPEG_ROT.format(name), shell=True, stdout=PIPE).stdout.read()
                else:
                    Popen(FFMPEG_NOROT.format(name), shell=True, stdout=PIPE).stdout.read()

                stdoutnl("Converted to {}.mp4!".format(name))

                if clean and os.path.exists("{}.ts".format(name)):
                    try:
                        os.remove("{}.ts".format(name))
                    except:
                        stdout("Failed to delete {}.ts.".format(name))

    sys.exit(0)

Example 3

Project: deepnl
Source File: dl-ner.py
View license
def main():

    # set the seed for replicability
    np.random.seed(89) #(42)

    defaults = {}
    
    parser = argparse.ArgumentParser(description="Train or use a Named Entity tagger.")
    
    parser.add_argument('-c', '--config', dest='config_file',
                        help='Specify config file', metavar='FILE')

    # args, remaining_argv = parser.parse_known_args()

    # if args.config_file:
    #     config = ConfigParser.SafeConfigParser()
    #     config.read([args.config_file])
    #     defaults = dict(config.items('Defaults'))

    # parser.set_defaults(**defaults)

    parser.add_argument('model', type=str,
                        help='Model file to train/use.')

    # training options
    train = parser.add_argument_group('Train')
    train.add_argument('-t', '--train', type=str, default='',
                        help='File with annotated data for training.')
    train.add_argument('-w', '--window', type=int, default=2,
                        help='Size of the word window (default %(default)s)')
    train.add_argument('-s', '--embeddings-size', type=int, default=50,
                        help='Number of features per word (default %(default)s)',
                        dest='embeddings_size')
    train.add_argument('-e', '--epochs', type=int, default=100,
                        help='Number of training epochs (default %(default)s)',
                        dest='iterations')
    train.add_argument('-l', '--learning_rate', type=float, default=0.001,
                        help='Learning rate for network weights (default %(default)s)',
                        dest='learning_rate')
    train.add_argument('-n', '--hidden', type=int, default=200,
                        help='Number of hidden neurons (default %(default)s)',
                        dest='hidden')
    train.add_argument('--eps', type=float, default=1e-6,
                        help='Epsilon value for AdaGrad (default %(default)s)')
    train.add_argument('--ro', type=float, default=0.95,
                        help='Ro value for AdaDelta (default %(default)s)')
    train.add_argument('-o', '--output', type=str, default='',
                        help='File where to save embeddings')

    # Embeddings
    embeddings = parser.add_argument_group('Embeddings')
    embeddings.add_argument('--vocab', type=str, default='',
                        help='Vocabulary file, either read or created')
    embeddings.add_argument('--vocab-size', type=int, default=0,
                            help='Maximum size of vocabulary from corpus (default %(default)s)')
    embeddings.add_argument('--vectors', type=str, default='',
                        help='Embeddings file, either read or created')
    embeddings.add_argument('--min-occurr', type=int, default=3,
                        help='Minimum occurrences for inclusion in vocabulary (default %(default)s',
                        dest='minOccurr')
    embeddings.add_argument('--load', type=str, default='',
                        help='Load previously saved model')
    embeddings.add_argument('--variant', type=str, default='',
                        help='Either "senna" (default), "polyglot" or "word2vec".')

    # Extractors:
    extractors = parser.add_argument_group('Extractors')
    extractors.add_argument('--caps', const=5, nargs='?', type=int, default=None,
                        help='Include capitalization features. Optionally, supply the number of features (default %(default)s)')
    extractors.add_argument('--pos', const=1, type=int, nargs='?', default=None,
                        help='Use POS tag. Optionally supply the POS token field index (default %(default)s)')
    extractors.add_argument('--suffix', const=5, nargs='?', type=int, default=None,
                            help='Include suffix features. Optionally, supply the number of features (default %(default)s)')
    extractors.add_argument('--suffixes', type=str, default='',
                        help='Load suffixes from this file')
    extractors.add_argument('--prefix', const=5, nargs='?', type=int, default=None,
                            help='Include prefix features. Optionally, '\
                            'supply the number of features (default %(default)s)')
    extractors.add_argument('--prefixes', type=str, default='',
                        help='Load prefixes from this file')
    extractors.add_argument('--gazetteer', type=str,
                        help='Load gazetteer from this file')
    extractors.add_argument('--gsize', type=int, default=5,
                        help='Size of gazetteer features (default %(default)s)')

    # reader
    parser.add_argument('--form-field', type=int, default=0,
                        help='Token field containing form (default %(default)s)',
                        dest='formField')

    # common
    parser.add_argument('--threads', type=int, default=1,
                        help='Number of threads (default %(default)s)')
    parser.add_argument('-v', '--verbose', help='Verbose mode',
                        action='store_true')

    # Use this for obtaining defaults from config file:
    #args = arguments.get_args()
    args = parser.parse_args()

    log_format = '%(message)s'
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(format=log_format, level=log_level)
    logger = logging.getLogger("Logger")

    config = ConfigParser()
    if args.config_file:
        config.read(args.config_file)

    # merge args with config

    if args.train:
        reader = NerReader(args.formField)

        # a generator (can be iterated several times)
        sentence_iter = reader.read(args.train)

        if args.vocab and os.path.exists(args.vocab):
            if args.vectors and os.path.exists(args.vectors):
                # use supplied embeddings
                embeddings = Embeddings(vectors=args.vectors, vocab_file=args.vocab,
                                        variant=args.variant)
            else:
                # create random embeddings
                embeddings = Embeddings(args.embeddings_size, vocab_file=args.vocab,
                                        variant=args.variant)
            # add the ngrams from the corpus
            # build vocabulary and tag set
            if args.vocab_size:
                vocab, tagset = reader.create_vocabulary(sentence_iter,
                                                         args.vocab_size,
                                                         args.minOccurr)
                embeddings.merge(vocab)
                logger.info("Overriding vocabulary in %s" % args.vocab)
                embeddings.save_vocabulary(args.vocab)
            else:
                tagset = reader.create_tagset(sentence_iter)

        elif args.variant == 'word2vec':
            if os.path.exists(args.vectors):
                embeddings = Embeddings(vectors=args.vectors,
                                        variant=args.variant)
                vocab, tagset = reader.create_vocabulary(sentence_iter,
                                                         args.vocab_size,
                                                         args.minOccurr)
                embeddings.merge(vocab)
            else:
                vocab, tagset = reader.create_vocabulary(sentence_iter,
                                                         args.vocab_size,
                                                         args.minOccurr)
                embeddings = Embeddings(vocab=vocab,
                                        variant=args.variant)
            if args.vocab:
                logger.info("Saving vocabulary in %s" % args.vocab)
                embeddings.save_vocabulary(args.vocab)
        elif not args.vocab_size:
            logger.error("Missing parameter --vocab-size")
            return
        else:
            # build vocabulary and tag set
            vocab, tagset = reader.create_vocabulary(sentence_iter,
                                                     args.vocab_size,
                                                     args.minOccurr)
            logger.info("Creating word embeddings")
            embeddings = Embeddings(args.embeddings_size, vocab=vocab,
                                    variant=args.variant)
            if args.vocab:
                logger.info("Saving vocabulary in %s" % args.vocab)
                embeddings.save_vocabulary(args.vocab)

        converter = Converter()
        # pass just the formField from tokens to the extractor
        converter.add(embeddings, reader.formField)
        
        if args.caps:
            logger.info("Creating capitalization features...")
            converter.add(CapsExtractor(args.caps), reader.formField)

        if args.pos:
            logger.info("Creating POS features...")
            postags = frozenset((token[args.pos] for sent in sentence_iter for token in sent))
            # tell the extractor which field to use 
            converter.add(AttributeExtractor(postags), args.pos) # no variant, preserve case

        if ((args.suffixes and not os.path.exists(args.suffixes)) or
            (args.prefixes and not os.path.exists(args.prefixes))):
            # collect the forms once
            words = (tok[reader.formField] for sent in sentence_iter for tok in sent)

        if args.suffix:
            if os.path.exists(args.suffixes):
                logger.info("Loading suffix list...")
                extractor = SuffixExtractor(args.suffix, args.suffixes)
                converter.add(extractor, reader.formField)
            else:
                logger.info("Creating suffix list...")
                extractor = SuffixExtractor(args.suffix, None, words)
                converter.add(extractor, reader.formField)
                if args.suffixes:
                    logger.info("Saving suffix list to: %s", args.suffixes)
                    extractor.write(args.suffixes)

        if args.prefix:
            if os.path.exists(args.prefixes):
                logger.info("Loading prefix list...")
                extractor = PrefixExtractor(args.prefix, args.prefixes)
                converter.add(extractor, reader.formField)
            else:
                logger.info("Creating prefix list...")
                extractor = PrefixExtractor(args.prefix, None, words)
                converter.add(extractor, reader.formField)
                if args.prefixes:
                    logger.info("Saving prefix list to: %s", args.prefixes)
                    extractor.write(args.prefixes)

        if args.gazetteer:
            if os.path.exists(args.gazetteer):
                logger.info("Loading gazetteers")
                for extractor in GazetteerExtractor.create(args.gazetteer, args.gsize):
                    # tell the extractor which field to use 
                    converter.add(extractor, reader.formField)
            else:
                logger.info("Creating gazetteer")
                tries = GazetteerExtractor.build(sentence_iter, reader.formField, reader.tagField)
                for tag, trie in tries.items():
                    # tell the extractor which field to use 
                    converter.add(GazetteerExtractor(trie, args.gsize), reader.formField)
                logger.info("Saving gazetteer list to: %s", args.gazetteer)
                with open(args.gazetteer, 'wb') as file:
                    for tag, trie in tries.iteritems():
                        for ngram in trie:
                            print(('%s\t%s' % (tag, ' '.join(ngram))).encode('UTF-8'), file=file)

        # if args.pos:
        #     converter.add(POS(arg.pos))

        # obtain the tags for each sentence
        tag_index = { t:i for i,t in enumerate(tagset) }
        sentences = []
        tags = []
        for sent in sentence_iter:
            sentences.append(converter.convert(sent))
            tags.append(np.array([tag_index[token[reader.tagField]] for token in sent],
                                 np.int32))
        logger.info("Vocabulary size: %d" % embeddings.dict.size())
        logger.info("Tagset size: %d" % len(tagset))
        trainer = create_trainer(args, converter, tag_index)
        logger.info("Starting training with %d sentences" % len(sentences))

        report_frequency = max(args.iterations / 200, 1)
        report_frequency = 1    # DEBUG
        trainer.train(sentences, tags, args.iterations, report_frequency,
                      args.threads)
    
        logger.info("Saving trained model ...")
        trainer.saver(trainer)
        logger.info("... to %s" % args.model)

    else:
        with open(args.model) as file:
            tagger = NerTagger.load(file)
        reader = TaggerReader()
        for sent in reader:
            ConllWriter.write(tagger.tag(sent, reader.tagField))

Example 4

Project: python-mode
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    import pickle
    import marshal
    import inspect
    import types
    import threading

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('wb')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()


    def _cached(func):
        cache = {}
        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    args.append(self._object_to_persisted_form(frame.f_locals[argname]))
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)

            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                   _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, str):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list', self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    keys = list(object_.keys())[0]
                    values = object_[keys]
                    if values == object_ and len(object_) > 1:
                        keys = list(object_.keys())[1]
                        values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set', self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.__code__)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.__func__.__code__)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, (str, list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, type):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()
            sys.settrace(None)

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        'builtins': __builtins__,
                        '__file__': file_to_run})
    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    with open(file_to_run) as file:
        exec(compile(file.read(), file_to_run, 'exec'), run_globals)
    if send_info != '-':
        data_sender.close()

Example 5

Project: catkin
Source File: builder.py
View license
def build_workspace_isolated(
    workspace='.',
    sourcespace=None,
    buildspace=None,
    develspace=None,
    installspace=None,
    merge=False,
    install=False,
    force_cmake=False,
    colorize=True,
    build_packages=None,
    quiet=False,
    cmake_args=None,
    make_args=None,
    catkin_make_args=None,
    continue_from_pkg=False,
    only_pkg_with_deps=None,
    destdir=None,
    use_ninja=False,
    use_nmake=False,
    override_build_tool_check=False
):
    '''
    Runs ``cmake``, ``make`` and optionally ``make install`` for all
    catkin packages in sourcespace_dir.  It creates several folders
    in the current working directory. For non-catkin packages it runs
    ``cmake``, ``make`` and ``make install`` for each, installing it to
    the devel space or install space if the ``install`` option is specified.

    :param workspace: path to the current workspace, ``str``
    :param sourcespace: workspace folder containing catkin packages, ``str``
    :param buildspace: path to build space location, ``str``
    :param develspace: path to devel space location, ``str``
    :param installspace: path to install space (CMAKE_INSTALL_PREFIX), ``str``
    :param merge: if True, build each catkin package into the same
        devel space (not affecting plain cmake packages), ``bool``
    :param install: if True, install all packages to the install space,
        ``bool``
    :param force_cmake: (optional), if True calls cmake explicitly for each
        package, ``bool``
    :param colorize: if True, colorize cmake output and other messages,
        ``bool``
    :param build_packages: specific packages to build (all parent packages
        in the topological order must have been built before), ``str``
    :param quiet: if True, hides some build output, ``bool``
    :param cmake_args: additional arguments for cmake, ``[str]``
    :param make_args: additional arguments for make, ``[str]``
    :param catkin_make_args: additional arguments for make but only for catkin
        packages, ``[str]``
    :param continue_from_pkg: indicates whether or not cmi should continue
        when a package is reached, ``bool``
    :param only_pkg_with_deps: only consider the specific packages and their
        recursive dependencies and ignore all other packages in the workspace,
        ``[str]``
    :param destdir: define DESTDIR for cmake/invocation, ``string``
    :param use_ninja: if True, use ninja instead of make, ``bool``
    :param use_nmake: if True, use nmake instead of make, ``bool``
    :param override_build_tool_check: if True, build even if a space was built
        by another tool previously.
    '''
    if not colorize:
        disable_ANSI_colors()

    # Check workspace existance
    if not os.path.exists(workspace):
        sys.exit("Workspace path '{0}' does not exist.".format(workspace))
    workspace = os.path.abspath(workspace)

    # Check source space existance
    if sourcespace is None:
        sourcespace = os.path.join(workspace, 'src')
    if not os.path.exists(sourcespace):
        sys.exit('Could not find source space: {0}'.format(sourcespace))
    print('Base path: ' + str(workspace))
    print('Source space: ' + str(sourcespace))

    # Check build space
    if buildspace is None:
        buildspace = os.path.join(workspace, 'build_isolated')
    if not os.path.exists(buildspace):
        os.mkdir(buildspace)
    print('Build space: ' + str(buildspace))

    # ensure the build space was previously built by catkin_make_isolated
    previous_tool = get_previous_tool_used_on_the_space(buildspace)
    if previous_tool is not None and previous_tool != 'catkin_make_isolated':
        if override_build_tool_check:
            print(fmt(
                "@{yf}Warning: build space at '%s' was previously built by '%s', "
                "but --override-build-tool-check was passed so continuing anyways."
                % (buildspace, previous_tool)))
        else:
            sys.exit(fmt(
                "@{rf}The build space at '%s' was previously built by '%s'. "
                "Please remove the build space or pick a different build space."
                % (buildspace, previous_tool)))
    mark_space_as_built_by(buildspace, 'catkin_make_isolated')

    # Check devel space
    if develspace is None:
        develspace = os.path.join(workspace, 'devel_isolated')
    print('Devel space: ' + str(develspace))

    # ensure the devel space was previously built by catkin_make_isolated
    previous_tool = get_previous_tool_used_on_the_space(develspace)
    if previous_tool is not None and previous_tool != 'catkin_make_isolated':
        if override_build_tool_check:
            print(fmt(
                "@{yf}Warning: devel space at '%s' was previously built by '%s', "
                "but --override-build-tool-check was passed so continuing anyways."
                % (develspace, previous_tool)))
        else:
            sys.exit(fmt(
                "@{rf}The devel space at '%s' was previously built by '%s'. "
                "Please remove the devel space or pick a different devel space."
                % (develspace, previous_tool)))
    mark_space_as_built_by(develspace, 'catkin_make_isolated')

    # Check install space
    if installspace is None:
        installspace = os.path.join(workspace, 'install_isolated')
    print('Install space: ' + str(installspace))

    if cmake_args:
        print("Additional CMake Arguments: " + " ".join(cmake_args))
    else:
        cmake_args = []

    if not [arg for arg in cmake_args if arg.startswith('-G')]:
        if use_ninja:
            cmake_args += ['-G', 'Ninja']
        elif use_nmake:
            cmake_args += ['-G', 'NMake Makefiles']
        else:
            cmake_args += ['-G', 'Unix Makefiles']
    elif use_ninja or use_nmake:
        print(colorize_line("Error: either specify a generator using '-G...' or '--use-[ninja|nmake]' but not both"))
        sys.exit(1)

    if make_args:
        print("Additional make Arguments: " + " ".join(make_args))
    else:
        make_args = []

    if catkin_make_args:
        print("Additional make Arguments for catkin packages: " + " ".join(catkin_make_args))
    else:
        catkin_make_args = []

    # Find packages
    packages = find_packages(sourcespace, exclude_subspaces=True)
    if not packages:
        print(fmt("@{yf}No packages found in source space: %[email protected]|" % sourcespace))

    # whitelist packages and their dependencies in workspace
    if only_pkg_with_deps:
        package_names = [p.name for p in packages.values()]
        unknown_packages = [name for name in only_pkg_with_deps if name not in package_names]
        if unknown_packages:
            sys.exit('Packages not found in the workspace: %s' % ', '.join(unknown_packages))

        whitelist_pkg_names = get_package_names_with_recursive_dependencies(packages, only_pkg_with_deps)
        print('Whitelisted packages: %s' % ', '.join(sorted(whitelist_pkg_names)))
        packages = {path: p for path, p in packages.items() if p.name in whitelist_pkg_names}

    # verify that specified package exists in workspace
    if build_packages:
        packages_by_name = {p.name: path for path, p in packages.items()}
        unknown_packages = [p for p in build_packages if p not in packages_by_name]
        if unknown_packages:
            sys.exit('Packages not found in the workspace: %s' % ', '.join(unknown_packages))

    # Report topological ordering
    ordered_packages = topological_order_packages(packages)
    unknown_build_types = []
    msg = []
    msg.append('@{pf}~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' + ('~' * len(str(len(ordered_packages)))))
    msg.append('@{pf}[email protected]|  traversing %d packages in topological order:' % len(ordered_packages))
    for path, package in ordered_packages:
        if path is None:
            print(fmt('@{rf}Error: Circular dependency in subset of packages: @!%[email protected]|' % package))
            sys.exit('Can not build workspace with circular dependency')

        export_tags = [e.tagname for e in package.exports]
        if 'build_type' in export_tags:
            build_type_tag = [e.content for e in package.exports if e.tagname == 'build_type'][0]
        else:
            build_type_tag = 'catkin'
        if build_type_tag == 'catkin':
            msg.append('@{pf}[email protected]|  - @[email protected]{bf}' + package.name + '@|')
        elif build_type_tag == 'cmake':
            msg.append(
                '@{pf}[email protected]|  - @[email protected]{bf}' + package.name + '@|' +
                ' (@[email protected]{cf}plain [email protected]|)'
            )
        else:
            msg.append(
                '@{pf}[email protected]|  - @[email protected]{bf}' + package.name + '@|' +
                ' (@{rf}[email protected]|)'
            )
            unknown_build_types.append(package)
    msg.append('@{pf}~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' + ('~' * len(str(len(ordered_packages)))))
    for index in range(len(msg)):
        msg[index] = fmt(msg[index])
    print('\n'.join(msg))

    # Error if there are packages with unknown build_types
    if unknown_build_types:
        print(colorize_line('Error: Packages with unknown build types exist'))
        sys.exit('Can not build workspace with packages of unknown build_type')

    # Check to see if the workspace has changed
    cmake_args_with_spaces = list(cmake_args)
    if develspace:
        cmake_args_with_spaces.append('-DCATKIN_DEVEL_PREFIX=' + develspace)
    if installspace:
        cmake_args_with_spaces.append('-DCMAKE_INSTALL_PREFIX=' + installspace)
    if (
        not force_cmake and
        cmake_input_changed(packages, buildspace, cmake_args=cmake_args_with_spaces, filename='catkin_make_isolated')
    ):
        print('The packages or cmake arguments have changed, forcing cmake invocation')
        force_cmake = True

    ensure_workspace_marker(workspace)

    # Build packages
    pkg_develspace = None
    last_env = None
    for index, path_package in enumerate(ordered_packages):
        path, package = path_package
        if merge:
            pkg_develspace = develspace
        else:
            pkg_develspace = os.path.join(develspace, package.name)
        if not build_packages or package.name in build_packages:
            if continue_from_pkg and build_packages and package.name in build_packages:
                build_packages = None
            try:
                print()
                last_env = build_package(
                    path, package,
                    workspace, buildspace, pkg_develspace, installspace,
                    install, force_cmake,
                    quiet, last_env, cmake_args, make_args, catkin_make_args,
                    destdir=destdir, use_ninja=use_ninja,
                    number=index + 1, of=len(ordered_packages)
                )
            except subprocess.CalledProcessError as e:
                _print_build_error(package, e)
                # Let users know how to reproduce
                # First add the cd to the build folder of the package
                cmd = 'cd ' + quote(os.path.join(buildspace, package.name)) + ' && '
                # Then reproduce the command called
                if isinstance(e.cmd, list):
                    # quote arguments to allow copy-n-paste of command
                    cmd += ' '.join([quote(arg) for arg in e.cmd])
                else:
                    cmd += e.cmd
                print(fmt("\[email protected]{rf}Reproduce this error by running:"))
                print(fmt("@{gf}@!==> @|") + cmd + "\n")
                sys.exit('Command failed, exiting.')
            except Exception as e:
                print("Unhandled exception of type '{0}':".format(type(e).__name__))
                import traceback
                traceback.print_exc()
                _print_build_error(package, e)
                sys.exit('Command failed, exiting.')
        else:
            cprint("Skipping package: '@[email protected]{bf}" + package.name + "@|'")
            last_env = get_new_env(package, pkg_develspace, installspace, install, last_env, destdir)

    # Provide a top level devel space environment setup script
    if not os.path.exists(develspace):
        os.makedirs(develspace)
    if not build_packages:
        generated_env_sh = os.path.join(develspace, 'env.sh')
        generated_setup_util_py = os.path.join(develspace, '_setup_util.py')
        if not merge and pkg_develspace:
            # generate env.sh and setup.sh|bash|zsh which relay to last devel space
            with open(generated_env_sh, 'w') as f:
                f.write("""\
#!/usr/bin/env sh
# generated from catkin.builder module

{0} "[email protected]"
""".format(os.path.join(pkg_develspace, 'env.sh')))
            os.chmod(generated_env_sh, stat.S_IXUSR | stat.S_IWUSR | stat.S_IRUSR)

            for shell in ['sh', 'bash', 'zsh']:
                with open(os.path.join(develspace, 'setup.%s' % shell), 'w') as f:
                    f.write("""\
#!/usr/bin/env {1}
# generated from catkin.builder module

. "{0}/setup.{1}"
""".format(pkg_develspace, shell))

            # remove _setup_util.py file which might have been generated for an empty devel space before
            if os.path.exists(generated_setup_util_py):
                os.remove(generated_setup_util_py)

        elif not pkg_develspace:
            # generate env.sh and setup.sh|bash|zsh for an empty devel space
            if 'CMAKE_PREFIX_PATH' in os.environ.keys():
                variables = {
                    'CATKIN_GLOBAL_BIN_DESTINATION': 'bin',
                    'CATKIN_LIB_ENVIRONMENT_PATHS': "'lib'",
                    'CATKIN_PKGCONFIG_ENVIRONMENT_PATHS': "os.path.join('lib', 'pkgconfig')",
                    'CMAKE_PREFIX_PATH_AS_IS': ';'.join(os.environ['CMAKE_PREFIX_PATH'].split(os.pathsep)),
                    'PYTHON_EXECUTABLE': sys.executable,
                    'PYTHON_INSTALL_DIR': get_python_install_dir(),
                }
                with open(generated_setup_util_py, 'w') as f:
                    f.write(configure_file(
                        os.path.join(get_cmake_path(), 'templates', '_setup_util.py.in'),
                        variables))
                os.chmod(generated_setup_util_py, stat.S_IXUSR | stat.S_IWUSR | stat.S_IRUSR)
            else:
                sys.exit("Unable to process CMAKE_PREFIX_PATH from environment. Cannot generate environment files.")

            variables = {'SETUP_FILENAME': 'setup'}
            with open(generated_env_sh, 'w') as f:
                f.write(configure_file(os.path.join(get_cmake_path(), 'templates', 'env.sh.in'), variables))
            os.chmod(generated_env_sh, stat.S_IXUSR | stat.S_IWUSR | stat.S_IRUSR)

            variables = {'SETUP_DIR': develspace}
            for shell in ['sh', 'bash', 'zsh']:
                with open(os.path.join(develspace, 'setup.%s' % shell), 'w') as f:
                    f.write(configure_file(
                        os.path.join(get_cmake_path(), 'templates', 'setup.%s.in' % shell),
                        variables))

Example 6

Project: deepnl
Source File: dl-conv.py
View license
def main():

    # set the seed for replicability
    np.random.seed(42)          # DEBUG

    defaults = {}
    
    parser = argparse.ArgumentParser(description="Convolutional network classifier.")
    
    parser.add_argument('-c', '--config', dest='config_file',
                        help='Specify config file', metavar='FILE')

    # args, remaining_argv = parser.parse_known_args()

    # if args.config_file:
    #     config = ConfigParser.SafeConfigParser()
    #     config.read([args.config_file])
    #     defaults = dict(config.items('Defaults'))

    # parser.set_defaults(**defaults)

    parser.add_argument('model', type=str,
                        help='Model file to train/use.')

    # input format
    format = parser.add_argument_group('Format')

    format.add_argument('--label-field', type=int, default=1,
                        help='Field containing label (default %(default)s).')
    format.add_argument('--text-field', type=int, default=2,
                        help='Field containing text (default %(default)s).')

    # training options
    train = parser.add_argument_group('Train')

    train.add_argument('-t', '--train', type=str, default=None,
                       help='File with annotated data for training.')

    train.add_argument('-w', '--window', type=int, default=5,
                       help='Size of the word window (default %(default)s)')
    train.add_argument('-s', '--embeddings-size', type=int, default=50,
                       help='Number of features per word (default %(default)s)',
                       dest='embeddings_size')
    train.add_argument('-e', '--epochs', type=int, default=100,
                       help='Number of training epochs (default %(default)s)',
                       dest='iterations')
    train.add_argument('-l', '--learning_rate', type=float, default=0.001,
                       help='Learning rate for network weights (default %(default)s)',
                       dest='learning_rate')
    train.add_argument('--eps', type=float, default=1e-6,
                        help='Epsilon value for AdaGrad (default %(default)s)')
    train.add_argument('-n', '--hidden', type=int, default=200,
                       help='Number of hidden neurons (default %(default)s)')
    train.add_argument('-n2', '--hidden2', type=int, default=200,
                       help='Number of hidden neurons (default %(default)s)')

    # Extractors:
    extractors = parser.add_argument_group('Extractors')
    extractors.add_argument('--caps', const=5, nargs='?', type=int, default=None,
                            help='Include capitalization features. Optionally, supply the number of features (default %(default)s)')
    extractors.add_argument('--suffix', const=5, nargs='?', type=int, default=None,
                            help='Include suffix features. Optionally, supply the number of features (default %(default)s)')
    extractors.add_argument('--suffixes', type=str, default='',
                        help='Load suffixes from this file')
    extractors.add_argument('--prefix', const=0, nargs='?', type=int, default=None,
                        help='Include prefix features. Optionally, '\
                            'supply the number of features (default %(default)s)')
    extractors.add_argument('--prefixes', type=str, default='',
                        help='Load prefixes from this file')
    # Embeddings
    embeddings = parser.add_argument_group('Embeddings')
    embeddings.add_argument('--vocab', type=str, default=None,
                        help='Vocabulary file, either read or created')
    embeddings.add_argument('--vectors', type=str, default=None,
                        help='Embeddings file, either read or created')
    embeddings.add_argument('--min-occurr', type=int, default=3,
                        help='Minimum occurrences for inclusion in vocabulary',
                        dest='minOccurr')
    embeddings.add_argument('--load', type=str, default=None,
                        help='Load previously saved model')
    embeddings.add_argument('--variant', type=str, default=None,
                        help='Either "senna" (default), "polyglot" or "word2vec".')

    # common
    parser.add_argument('--threads', type=int, default=1,
                        help='Number of threads (default %(default)s)')
    parser.add_argument('-v', '--verbose', help='Verbose mode',
                        action='store_true')

    # Use this for obtaining defaults from config file:
    #args = arguments.get_args()
    args = parser.parse_args()

    log_format = '%(message)s'
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(format=log_format, level=log_level)
    logger = logging.getLogger("Logger")

    config = ConfigParser()
    if args.config_file:
        config.read(args.config_file)

    # merge args with config

    if args.train:
        reader = ClassifyReader(text_field=args.text_field, label_field=args.label_field)
        # a generator (can be iterated several times)
        sentences = reader.read(args.train)

        if args.vocab and os.path.exists(args.vocab):
            if args.vectors and os.path.exists(args.vectors):
                # use supplied embeddings
                embeddings = Embeddings(vectors=args.vectors, vocab_file=args.vocab,
                                        variant=args.variant)
            else:
                # create random embeddings
                embeddings = Embeddings(args.embeddings_size, vocab_file=args.vocab,
                                        variant=args.variant)
            # collect words from the corpus
            # build vocabulary
            vocab, bigrams, trigrams = reader.create_vocabulary(sentences,
                                                                #size=args.vocab_size,
                                                                min_occurrences=args.minOccurr)
            # add them to the given vocabulary
            embeddings.merge(vocab)
            logger.info("Overriding vocabulary in %s" % args.vocab)
            embeddings.save_vocabulary(args.vocab)

        elif args.variant == 'word2vec':
            if os.path.exists(args.vectors):
                embeddings = Embeddings(vectors=args.vectors,
                                        variant=args.variant)
                vocab, bigrams, trigrams = reader.create_vocabulary(sentences,
                                                                    #args.vocab_size,
                                                                    min_occurrences=args.minOccurr)
                embeddings.merge(vocab)
            else:
                vocab, bigrams, trigrams = reader.create_vocabulary(sentences,
                                                                    #args.vocab_size,
                                                                    min_occurrences=args.minOccurr)
                embeddings = Embeddings(vocab=vocab,
                                        variant=args.variant)
            if args.vocab:
                logger.info("Saving vocabulary in %s" % args.vocab)
                embeddings.save_vocabulary(args.vocab)

        elif not args.vocab_size:
            logger.error("Missing parameter --vocab-size")
            return
        else:
            # build vocabulary and tag set
            vocab, bigrams, trigrams = reader.create_vocabulary(sentences,
                                                                #args.vocab_size,
                                                                min_occurrences=args.minOccurr)
            logger.info("Creating word embeddings")
            embeddings = Embeddings(args.embeddings_size, vocab=vocab,
                                    variant=args.variant)
            if args.vocab:
                logger.info("Saving vocabulary in %s" % args.vocab)
                embeddings.save_vocabulary(args.vocab)

        converter = Converter()
        converter.add(embeddings)

        if args.caps:
            logger.info("Creating capitalization features...")
            converter.add(CapsExtractor(args.caps))

        if ((args.suffixes and not os.path.exists(args.suffixes)) or
            (args.prefixes and not os.path.exists(args.prefixes))):
            # collect the forms once
            words = (tok for sent in sentences for tok in sent)

        if args.suffix:
            if os.path.exists(args.suffixes):
                logger.info("Loading suffix list...")
                extractor = SuffixExtractor(args.suffix, args.suffixes)
                converter.add(extractor)
            else:
                logger.info("Creating suffix list...")
                extractor = SuffixExtractor(args.suffix, None, words)
                converter.add(extractor)
                if args.suffixes:
                    logger.info("Saving suffix list to: %s", args.suffixes)
                    extractor.write(args.suffixes)

        if args.prefix:
            if os.path.exists(args.prefixes):
                logger.info("Loading prefix list...")
                extractor = PrefixExtractor(args.prefix, args.prefixes)
                converter.add(extractor)
            else:
                logger.info("Creating prefix list...")
                extractor = PrefixExtractor(args.prefix, None, words)
                converter.add(extractor)
                if args.prefixes:
                    logger.info("Saving prefix list to: %s", args.prefixes)
                    extractor.write(args.prefixes)

        # labels from all examples
        examples = [converter.convert(example) for example in sentences]
        # assign index to labels
        sent_labels = reader.polarities
        labels_index = {}
        labels = []
        for i,c in enumerate(set(sent_labels)):
            labels_index[c] = i
            labels.append(c)
        trainer = create_trainer(args, converter, labels)
        logger.info("Starting training with %d examples" % len(examples))

        report_frequency = max(args.iterations / 200, 1)
        report_frequency = 1    # DEBUG
        labels_ids = [labels_index[label] for label in sent_labels]
        trainer.train(examples, labels_ids, args.iterations, report_frequency,
                      args.threads)
    
        logger.info("Saving trained model ...")
        trainer.saver(trainer)
        logger.info("... to %s" % args.model)

    else:
        # predict
        with open(args.model) as file:
            classifier = Classifier.load(file)
        reader = ClassifyReader(text_field=args.text_field, label_field=args.label_field)
        
        for example in reader:
            words = example[reader.text_field].split()
            example[reader.label_field] = classifier.predict(words)
            print('\t'.join(example).encode('utf-8'))

Example 7

Project: morituri
Source File: readdisc.py
View license
def main(argv):
    parser = optparse.OptionParser()

    default = 'cli'
    parser.add_option('-r', '--runner',
        action="store", dest="runner",
        help="runner ('cli' or 'gtk', defaults to %s)" % default,
        default=default)
    default = 0
    parser.add_option('-o', '--offset',
        action="store", dest="offset",
        help="sample offset (defaults to %d)" % default,
        default=default)
    parser.add_option('-t', '--table-pickle',
        action="store", dest="table_pickle",
        help="pickle to use for reading and writing the table",
        default=default)
    parser.add_option('-T', '--toc-pickle',
        action="store", dest="toc_pickle",
        help="pickle to use for reading and writing the TOC",
        default=default)
    default = '%A - %d/%t. %a - %n'
    parser.add_option('', '--track-template',
        action="store", dest="track_template",
        help="template for track file naming (default %s)" % default,
        default=default)
    default = '%A - %d/%A - %d'
    parser.add_option('', '--disc-template',
        action="store", dest="disc_template",
        help="template for disc file naming (default %s)" % default,
        default=default)


    options, args = parser.parse_args(argv[1:])

    if options.runner == 'cli':
        runner = task.SyncRunner()
        function = climain
    elif options.runner == 'gtk':
        from morituri.common import taskgtk
        runner = taskgtk.GtkProgressRunner()
        function = gtkmain

    # first, read the normal TOC, which is fast
    ptoc = common.Persister(options.toc_pickle or None)
    if not ptoc.object:
        t = cdrdao.ReadTOCTask()
        function(runner, t)
        ptoc.persist(t.table)
    ittoc = ptoc.object
    assert ittoc.hasTOC()

    # already show us some info based on this
    print "CDDB disc id", ittoc.getCDDBDiscId()
    metadata = musicbrainz(ittoc.getMusicBrainzDiscId())

    # now, read the complete index table, which is slower
    ptable = common.Persister(options.table_pickle or None)
    if not ptable.object:
        t = cdrdao.ReadTableTask()
        function(runner, t)
        ptable.persist(t.table)
    itable = ptable.object

    assert itable.hasTOC()

    assert itable.getCDDBDiscId() == ittoc.getCDDBDiscId(), \
        "full table's id %s differs from toc id %s" % (
            itable.getCDDBDiscId(), ittoc.getCDDBDiscId())
    assert itable.getMusicBrainzDiscId() == ittoc.getMusicBrainzDiscId()

    lastTrackStart = 0

    # check for hidden track one audio
    htoapath = None
    index = None
    track = itable.tracks[0]
    try:
        index = track.getIndex(0)
    except KeyError:
        pass

    if index:
        start = index.absolute
        stop = track.getIndex(1).absolute
        print 'Found Hidden Track One Audio from frame %d to %d' % (start, stop)
            
        # rip it
        htoapath = getPath(options.track_template, metadata, -1) + '.wav'
        htoalength = stop - start
        if not os.path.exists(htoapath):
            print 'Ripping track %d: %s' % (0, os.path.basename(htoapath))
            t = cdparanoia.ReadVerifyTrackTask(htoapath, ittoc,
                start, stop - 1,
                offset=int(options.offset))
            function(runner, t)
            if t.checksum:
                print 'Checksums match for track %d' % 0
            else:
                print 'ERROR: checksums did not match for track %d' % 0
            # overlay this rip onto the Table
        itable.setFile(1, 0, htoapath, htoalength, 0)


    for i, track in enumerate(itable.tracks):
        path = getPath(options.track_template, metadata, i) + '.wav'
        dirname = os.path.dirname(path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        # FIXME: optionally allow overriding reripping
        if not os.path.exists(path):
            print 'Ripping track %d: %s' % (i + 1, os.path.basename(path))
            t = cdparanoia.ReadVerifyTrackTask(path, ittoc,
                ittoc.getTrackStart(i + 1),
                ittoc.getTrackEnd(i + 1),
                offset=int(options.offset))
            t.description = 'Reading Track %d' % (i + 1)
            function(runner, t)
            if t.checksum:
                print 'Checksums match for track %d' % (i + 1)
            else:
                print 'ERROR: checksums did not match for track %d' % (i + 1)

        # overlay this rip onto the Table
        itable.setFile(i + 1, 1, path, ittoc.getTrackLength(i + 1), i + 1)


    ### write disc files
    discName = getPath(options.disc_template, metadata, i)
    dirname = os.path.dirname(discName)
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    # write .cue file
    cuePath = '%s.cue' % discName
    handle = open(cuePath, 'w')
    handle.write(itable.cue())
    handle.close()

    # write .m3u file
    m3uPath = '%s.m3u' % discName
    handle = open(m3uPath, 'w')
    handle.write('#EXTM3U\n')
    if htoapath:
        handle.write('#EXTINF:%d,%s\n' % (
            htoalength / common.FRAMES_PER_SECOND,
                os.path.basename(htoapath[:-4])))
        handle.write('%s\n' % os.path.basename(htoapath))

    for i, track in enumerate(itable.tracks):
        path = getPath(options.track_template, metadata, i) + '.wav'
        handle.write('#EXTINF:%d,%s\n' % (
            itable.getTrackLength(i + 1) / common.FRAMES_PER_SECOND,
            os.path.basename(path)))
        handle.write('%s\n' % os.path.basename(path))
    handle.close()

    # verify using accuraterip
    print "CDDB disc id", itable.getCDDBDiscId()
    print "MusicBrainz disc id", itable.getMusicBrainzDiscId()
    url = itable.getAccurateRipURL()
    print "AccurateRip URL", url

    # FIXME: download url as a task too
    responses = []
    import urllib2
    try:
        handle = urllib2.urlopen(url)
        data = handle.read()
        responses = image.getAccurateRipResponses(data)
    except urllib2.HTTPError, e:
        if e.code == 404:
            print 'Album not found in AccurateRip database'
        else:
            raise

    if responses:
        print '%d AccurateRip responses found' % len(responses)

        if responses[0].cddbDiscId != itable.getCDDBDiscId():
            print "AccurateRip response discid different: %s" % \
                responses[0].cddbDiscId

       
    # FIXME: put accuraterip verification into a separate task/function
    # and apply here
    cueImage = image.Image(cuePath)
    verifytask = image.ImageVerifyTask(cueImage)
    cuetask = image.AccurateRipChecksumTask(cueImage)
    function(runner, verifytask)
    function(runner, cuetask)

    response = None # track which response matches, for all tracks

    # loop over tracks
    for i, sum in enumerate(cuetask.checksums):
        status = 'rip NOT accurate'

        confidence = None
        arsum = None

        # match against each response's checksum
        for j, r in enumerate(responses):
            if "%08x" % sum == r.checksums[i]:
                if not response:
                    response = r
                else:
                    assert r == response, \
                        "checksum %s for %d matches wrong response %d, "\
                        "checksum %s" % (
                            sum, i + 1, j + 1, response.checksums[i])
                status = 'rip accurate    '
                arsum = sum
                confidence = response.confidences[i]

        c = "(not found)"
        ar = "(not in database)"
        if responses:
            if not response:
                print 'ERROR: none of the responses matched.'
            else:
                maxConfidence = max(r.confidences[i] for r in responses)
                     
                c = "(max confidence %3d)" % maxConfidence
                if confidence is not None:
                    if confidence < maxConfidence:
                        c = "(confidence %3d of %3d)" % (confidence, maxConfidence)

                ar = ", AR [%s]" % response.checksums[i]
        print "Track %2d: %s %s [%08x]%s" % (
            i + 1, status, c, sum, ar)

Example 8

View license
def import_photos(iphoto_dir, shotwell_db, photos_dir):
    # Sanity check the iPhoto dir and Shotwell DB.
    _log.debug("Performing sanity checks on iPhoto and Shotwell DBs.")
    now = int(time.time())
    album_data_filename = join_path(iphoto_dir, "AlbumData.xml")
    if not os.path.exists(album_data_filename):
        _log.error("Failed to find expected file inside iPhoto library: %s", 
                   album_data_filename)
        sys.exit(1)
    if not os.path.exists(shotwell_db):
        _log.error("Shotwell DB not found at %s", shotwell_db)
        sys.exit(2)
    db = sqlite3.connect(shotwell_db) #@UndefinedVariable
    with db:
        cursor = db.execute("SELECT schema_version from VersionTable;")
        schema_version = cursor.fetchone()[0]
        if schema_version not in SUPPORTED_SHOTWELL_SCHEMAS:
            _log.error("Shotwell DB uses unsupported schema version %s. "
                       "Giving up, just to be safe.", schema_version)
            sys.exit(3)
        _log.debug("Sanity checks passed.")
        
        # Back up the Shotwell DB.
        fmt_now = time.strftime('%Y-%m-%d_%H%M%S')
        db_backup = "%s.iphotobak_%s" % (shotwell_db, fmt_now)
        _log.debug("Backing up shotwell DB to %s", db_backup)
        shutil.copy(shotwell_db, db_backup)
        _log.debug("Backup complete")
        
        # Load and parse the iPhoto DB.
        _log.debug("Loading the iPhoto library file. Might take a while for a large DB!")
        album_data = plistlib.readPlist(album_data_filename)
        _log.debug("Finished loading the iPhoto library.")
        path_prefix = album_data["Archive Path"]
        def fix_prefix(path, new_prefix=iphoto_dir):
            if path:
                if path[:len(path_prefix)] != path_prefix:
                    raise AssertionError("Path %s didn't begin with %s" % (path, path_prefix))
                path = path[len(path_prefix):]
                path = join_path(new_prefix, path.strip(os.path.sep)) 
            return path
        photos = {} # Map from photo ID to photo info.
        copy_queue = []
        
#                  id = 224
#            filename = /home/shaun/Pictures/Photos/2008/03/24/DSCN2416 (Modified (2)).JPG
#               width = 1600
#              height = 1200
#            filesize = 480914
#           timestamp = 1348718403
#       exposure_time = 1206392706
#         orientation = 1
#original_orientation = 1
#           import_id = 1348941635
#            event_id = 3
#     transformations = 
#                 md5 = 3ca3cf05312d0c1a4c141bb582fc43d0
#       thumbnail_md5 = 
#            exif_md5 = cec27a47c34c89f571c0fd4e9eb4a9fe
#        time_created = 1348941635
#               flags = 0
#              rating = 0
#         file_format = 0
#               title = 
#           backlinks = 
#     time_reimported = 
#         editable_id = 1
#      metadata_dirty = 1
#           developer = SHOTWELL
# develop_shotwell_id = -1
#   develop_camera_id = -1
# develop_embedded_id = -1
        skipped = []
        for key, i_photo in album_data["Master Image List"].items():
            mod_image_path = fix_prefix(i_photo.get("ImagePath", None))
            orig_image_path = fix_prefix(i_photo.get("OriginalPath", None))
            
            new_mod_path = fix_prefix(i_photo.get("ImagePath"), new_prefix=photos_dir)
            new_orig_path = fix_prefix(i_photo.get("OriginalPath", None), 
                                       new_prefix=photos_dir)
            
            if not orig_image_path or not os.path.exists(mod_image_path):
                orig_image_path = mod_image_path
                new_orig_path = new_mod_path
                new_mod_path = None
                mod_image_path = None
                mod_file_size = None
            else:
                mod_file_size = os.path.getsize(mod_image_path)
                
            if not os.path.exists(orig_image_path):
                _log.error("Original file not found %s", orig_image_path)
                skipped.append(orig_image_path)
                continue
            
            copy_queue.append((orig_image_path, new_orig_path))
            if mod_image_path: copy_queue.append((mod_image_path, new_mod_path))
                
            mime, _ = mimetypes.guess_type(orig_image_path)
                
            sys.stdout.write('.')
            sys.stdout.flush()
            if mime not in ("image/jpeg", "image/png", "image/x-ms-bmp", "image/tiff"):
                print
                _log.error("Skipping %s, it's not an image %s", orig_image_path, mime)
                skipped.append(orig_image_path)
                continue
            
            caption = i_photo.get("Caption", "")
            
            img = Image.open(orig_image_path)
            w, h = img.size
            
            md5 = md5_for_file(orig_image_path)
            orig_timestamp = int(os.path.getmtime(orig_image_path))
            
            mod_w, mod_h, mod_md5, mod_timestamp = None, None, None, None
            if mod_image_path:
                try:
                    mod_img = Image.open(mod_image_path)
                except Exception:
                    _log.error("Failed to open modified image %s, skipping", mod_image_path)
                    orig_image_path = mod_image_path
                    new_orig_path = new_mod_path
                    new_mod_path = None
                    mod_image_path = None
                    mod_file_size = None
                else:
                    mod_w, mod_h = mod_img.size
                    mod_md5 = md5_for_file(mod_image_path)
                    mod_timestamp = int(os.path.getmtime(mod_image_path))
            
            file_format = FILE_FORMAT.get(mime, -1)
            if file_format == -1:
                raise Exception("Unknown image type %s" % mime)
            
            photo = {"orig_image_path": orig_image_path,
                           "mod_image_path": mod_image_path,
                           "new_mod_path": new_mod_path,
                           "new_orig_path": new_orig_path,
                           "orig_file_size": os.path.getsize(orig_image_path),
                           "mod_file_size": mod_file_size,
                           "mod_timestamp": mod_timestamp,
                           "orig_timestamp": orig_timestamp,
                           "caption": caption,
                           "rating": i_photo["Rating"],
                           "event": i_photo["Roll"],
                           "orig_exposure_time": int(parse_date(i_photo["DateAsTimerInterval"])),
                           "width": w,
                           "height": h,
                           "mod_width": mod_w,
                           "mod_height": mod_h,
                           "orig_md5": md5,
                           "mod_md5": md5,
                           "file_format": file_format,
                           "time_created": now,
                           "import_id": now,
                           }
            def read_metadata(path, photo, prefix="orig_"):
                photo[prefix + "orientation"] = 1
                photo[prefix + "original_orientation"] = 1
                try:
                    meta = ImageMetadata(path)
                    meta.read()
                    try:
                        photo[prefix + "orientation"] = meta["Exif.Image.Orientation"].value
                        photo[prefix + "original_orientation"] = meta["Exif.Image.Orientation"].value
                    except KeyError:
                        print
                        _log.debug("Failed to read the orientation from %s" % path)
                    exposure_dt = meta["Exif.Image.DateTime"].value
                    photo[prefix + "exposure_time"] = exif_datetime_to_time(exposure_dt)
                except KeyError:
                    pass
                except Exception:
                    print
                    _log.exception("Failed to read date from %s", path)
                    raise
                    
            try:
                read_metadata(orig_image_path, photo, "orig_")
                photo["orientation"] = photo["orig_orientation"]
                if mod_image_path:
                    read_metadata(mod_image_path, photo, "mod_")
                    photo["orientation"] = photo["mod_orientation"]
            except Exception:
                _log.error("**** Skipping %s" % orig_image_path)
                skipped.append(orig_image_path)
                continue
            
            photos[key] = photo
        
        events = {}
        for event in album_data["List of Rolls"]:
            key = event["RollID"]
            events[key] = {
                "date": parse_date(event["RollDateAsTimerInterval"]),
                "key_photo": event["KeyPhotoKey"], 
                "photos": event["KeyList"],
                "name": event["RollName"]
            }
            for photo_key in event["KeyList"]:
                assert photo_key not in photos or photos[photo_key]["event"] == key
        
        # Insert into the Shotwell DB.
        for _, event in events.items():
            c = db.execute("""
                INSERT INTO EventTable (time_created, name) 
                VALUES (?, ?)
            """, (event["date"], event["name"]))
            assert c.lastrowid is not None
            event["row_id"] = c.lastrowid
            for photo_key in event["photos"]:
                if photo_key in photos:
                    photos[photo_key]["event_id"] = event["row_id"]

        
            
        for key, photo in photos.items():
            # The BackingPhotoTable
#                  id = 1
#            filepath = /home/shaun/Pictures/Photos/2008/03/24/DSCN2416 (Modified (2))_modified.JPG
#           timestamp = 1348968706
#            filesize = 1064375
#               width = 1600
#              height = 1200
#original_orientation = 1
#         file_format = 0
#        time_created = 1348945103
            if "event_id" not in photo:
                _log.error("Photo didn't have an event: %s", photo)
                skipped.append(photo["orig_image_path"])
                continue
            editable_id = -1
            if photo["mod_image_path"] is not None:
                # This photo has a backing image
                c = db.execute("""
                    INSERT INTO BackingPhotoTable (filepath,
                                                   timestamp,
                                                   filesize,
                                                   width,
                                                   height,
                                                   original_orientation,
                                                   file_format,
                                                   time_created)
                    VALUES (:new_mod_path,
                            :mod_timestamp,
                            :mod_file_size,
                            :mod_width,
                            :mod_height,
                            :mod_original_orientation,
                            :file_format,
                            :time_created)
                """, photo)
                editable_id = c.lastrowid
            
            photo["editable_id"] = editable_id
            try:
                c = db.execute("""
                    INSERT INTO PhotoTable (filename,
                                            width,
                                            height,
                                            filesize,
                                            timestamp,
                                            exposure_time,
                                            orientation,
                                            original_orientation,
                                            import_id,
                                            event_id,
                                            md5,
                                            time_created,
                                            flags,
                                            rating,
                                            file_format,
                                            title,
                                            editable_id,
                                            metadata_dirty,
                                            developer,
                                            develop_shotwell_id,
                                            develop_camera_id,
                                            develop_embedded_id)
                    VALUES (:new_orig_path,
                            :width,
                            :height,
                            :orig_file_size,
                            :orig_timestamp,
                            :orig_exposure_time,
                            :orientation,
                            :orig_original_orientation,
                            :import_id,
                            :event_id,
                            :orig_md5,
                            :time_created,
                            0,
                            :rating,
                            :file_format,
                            :caption,
                            :editable_id,
                            1,
                            'SHOTWELL',
                            -1,
                            -1,
                            -1);
                """, photo)
            except Exception:
                _log.exception("Failed to insert photo %s" % photo)
                raise
            
        print >> sys.stderr, "Skipped importing these files:\n", "\n".join(skipped)
        print >> sys.stderr, "%s file skipped (they will still be copied)" % len(skipped)
        
        for src, dst in copy_queue:
            safe_link_file(src, dst)
        
        db.commit()

Example 9

Project: stonix
Source File: ConfigureLDAPServer.py
View license
    def fix(self):
        try:
            if not self.ci.getcurrvalue():
                return
            success = True
            self.iditerator = 0
            eventlist = self.statechglogger.findrulechanges(self.rulenumber)
            for event in eventlist:
                self.statechglogger.deleteentry(event)
            if self.ph.check(self.ldap):
                #is ldap configured for tls?

                #do the ldap files have the correct permissions?
                slapd = "/etc/openldap/slapd.conf"
                if os.path.exists(slapd):
                    statdata = os.stat(slapd)
                    mode = stat.S_IMODE(statdata.st_mode)
                    ownergrp = getUserGroupName(slapd)
                    owner = ownergrp[0]
                    group = ownergrp[1]
                    if mode != 416 or owner != "root" or group != "ldap":
                        origuid = statdata.st_uid
                        origgid = statdata.st_gid
                        if grp.getgrnam("ldap")[2]:
                            gid = grp.getgrnam("ldap")[2]
                            self.iditerator += 1
                            myid = iterate(self.iditerator, self.rulenumber)
                            event = {"eventtype": "perm",
                                     "startstate": [origuid, origgid, mode],
                                     "endstate": [0, gid, 416],
                                     "filepath": slapd}
                            self.statechglogger.recordchgevent(myid, event)
                            os.chmod(slapd, 416)
                            os.chown(slapd, 0, gid)
                            resetsecon(slapd)
                        else:
                            success = False
                            debug = "Unable to determine the id " + \
                                "number of ldap group.  Will not change " + \
                                "permissions on " + slapd + " file\n"
                            self.logger.log(LogPriority.DEBUG, debug)
                #apt-get systems
                slapd = "/etc/ldap/ldap.conf"
                if os.path.exists(slapd):
                    statdata = os.stat(slapd)
                    mode = stat.S_IMODE(statdata.st_mode)
                    ownergrp = getUserGroupName(slapd)
                    owner = ownergrp[0]
                    group = ownergrp[1]
                    if mode != 420 or owner != "root" or group != "root":
                        origuid = statdata.st_uid
                        origgid = statdata.st_gid
                        if grp.getgrnam("root")[2] != "":
                            gid = grp.getgrnam("root")[2]
                            self.iditerator += 1
                            myid = iterate(self.iditerator, self.rulenumber)
                            event = {"eventtype": "perm",
                                     "startstate": [origuid, origgid, mode],
                                     "endstate": [0, gid, 420],
                                     "filepath": slapd}
                            self.statechglogger.recordchgevent(myid, event)
                            os.chmod(slapd, 420)
                            os.chown(slapd, 0, gid)
                            resetsecon(slapd)
                        else:
                            success = False
                            debug = "Unable to determine the id " + \
                                "number of root group.  Will not change " + \
                                "permissions on " + slapd + " file\n"
                            self.logger.log(LogPriority.DEBUG, debug)
                slapdd = "/etc/openldap/slapd.d/"
                if os.path.exists(slapdd):
                    dirs = glob.glob(slapdd + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416 or owner != "ldap" or group \
                                    != "ldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("ldap")[2] != "":
                                    if pwd.getpwnam("ldap")[2] != "":
                                        gid = grp.getgrnam("ldap")[2]
                                        uid = pwd.getpwnam("ldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 416],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 416)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                #apt-get systems
                slapdd = "/etc/ldap/slapd.d/"
                if os.path.exists(slapdd):
                    dirs = glob.glob(slapdd + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 384 or owner != "openldap" or group \
                                    != "openldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("openldap")[2] != "":
                                    if pwd.getpwnam("openldap")[2] != "":
                                        gid = grp.getgrnam("openldap")[2]
                                        uid = pwd.getpwnam("openldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 384],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 384)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                cnconfig = "/etc/openldap/slapd.d/cn=config/"
                if os.path.exists(cnconfig):
                    dirs = glob.glob(cnconfig + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416 or owner != "ldap" or group != "ldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("ldap")[2] != "":
                                    if pwd.getpwnam("ldap")[2] != "":
                                        gid = grp.getgrnam("ldap")[2]
                                        uid = pwd.getpwnam("ldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 416],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 416)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                #apt-get systems
                cnconfig = "/etc/ldap/slapd.d/cn=config/"
                if os.path.exists(cnconfig):
                    dirs = glob.glob(cnconfig + "*")
                    for loc in dirs:
                        if not os.path.isdir(loc):
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 384 or owner != "openldap" or group != "openldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("openldap")[2] != "":
                                    if pwd.getpwnam("openldap")[2] != "":
                                        gid = grp.getgrnam("openldap")[2]
                                        uid = pwd.getpwnam("openldap")[2]
                                        self.iditerator += 1
                                        myid = iterate(self.iditerator,
                                                       self.rulenumber)
                                        event = {"eventtype": "perm",
                                                 "startstate": [origuid,
                                                                origgid, mode],
                                                 "endstate": [uid, gid, 384],
                                                 "filepath": loc}
                                        self.statechglogger.recordchgevent(myid, event)
                                        os.chmod(loc, 384)
                                        os.chown(loc, uid, gid)
                                        resetsecon(loc)
                                    else:
                                        debug = "Unable to determine the " + \
                                            "id number of ldap user.  " + \
                                            "Will not change permissions " + \
                                            "on " + loc + " file\n"
                                        self.logger.log(LogPriority.DEBUG, debug)
                                        success = False
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not " + \
                                        "change permissions on " + loc + \
                                        " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                pki = "/etc/pki/tls/ldap/"
                if os.path.exists(pki):
                    dirs = glob.glob(pki + "*")
                    for loc in dirs:
                        if not os.path.isdir():
                            statdata = os.stat(loc)
                            mode = stat.S_IMODE(statdata.st_mode)
                            ownergrp = getUserGroupName(loc)
                            owner = ownergrp[0]
                            group = ownergrp[1]
                            if mode != 416 or owner != "root" or group != "ldap":
                                origuid = statdata.st_uid
                                origgid = statdata.st_gid
                                if grp.getgrnam("ldap")[2] != "":
                                    gid = grp.getgrnam("ldap")[2]
                                    self.iditerator += 1
                                    myid = iterate(self.iditerator, self.rulenumber)
                                    event = {"eventtype": "perm",
                                             "startstate": [origuid, origgid, mode],
                                             "endstate": [0, gid, 416],
                                             "filepath": loc}
                                    self.statechglogger.recordchgevent(myid, event)
                                    os.chmod(slapd, 416)
                                    os.chown(slapd, 0, gid)
                                    resetsecon(slapd)
                                else:
                                    success = False
                                    debug = "Unable to determine the id " + \
                                        "number of ldap group.  Will not change " + \
                                        "permissions on " + loc + " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                if os.path.exists("/etc/pki/tls/CA/"):
                    dirs = glob.glob("/etc/pki/tls/CA/*")
                    for loc in dirs:
                        if not os.path.isdir():
                            if not checkPerms(loc, [0, 0, 420], self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, self.rulenumber)
                                if not setPerms(loc, [0, 0, 420], self.logger,
                                                self.statechglogger, myid):
                                    debug = "Unable to set permissions on " + \
                                        loc + " file\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                                    success = False
                                else:
                                    resetsecon(loc)
            self.rulesuccess = success
        except (KeyboardInterrupt, SystemExit):
            # User initiated exit
            raise
        except Exception:
            self.rulesuccess = False
            self.detailedresults += "\n" + traceback.format_exc()
            self.logdispatch.log(LogPriority.ERROR, self.detailedresults)
        self.formatDetailedResults("fix", self.rulesuccess,
                                   self.detailedresults)
        self.logdispatch.log(LogPriority.INFO, self.detailedresults)
        return self.rulesuccess

Example 10

Project: deepnl
Source File: dl-pos.py
View license
def main():

    # set the seed for replicability
    np.random.seed(42)

    defaults = {}
    
    parser = argparse.ArgumentParser(description="POS tagger using word embeddings.")
    
    parser.add_argument('-c', '--config', dest='config_file',
                        help='Specify config file', metavar='FILE')

    # args, remaining_argv = parser.parse_known_args()

    # if args.config_file:
    #     config = ConfigParser.SafeConfigParser()
    #     config.read([args.config_file])
    #     defaults = dict(config.items('Defaults'))

    # parser.set_defaults(**defaults)

    parser.add_argument('model', type=str,
                        help='Model file to train/use.')
    parser.add_argument('--threads', type=int, default=1,
                        help='Number of threads (default %(default)s)')
    parser.add_argument('-v', '--verbose', help='Verbose mode',
                        action='store_true')

    # training options
    train = parser.add_argument_group('Train')
    train.add_argument('-t', '--train', type=str, default=None,
                        help='File with annotated data for training.')

    train.add_argument('-w', '--window', type=int, default=2,
                        help='Size of the word window (default %(default)s)')
    train.add_argument('-s', '--embeddings-size', type=int, default=50,
                        help='Number of features per word (default %(default)s)',
                        dest='embeddings_size')
    train.add_argument('-e', '--epochs', type=int, default=100,
                        help='Number of training epochs (default %(default)s)',
                        dest='iterations')
    train.add_argument('-l', '--learning_rate', type=float, default=0.001,
                        help='Learning rate for network weights (default %(default)s)',
                        dest='learning_rate')
    train.add_argument('-n', '--hidden', type=int, default=200,
                        help='Number of hidden neurons (default %(default)s)',
                        dest='hidden')
    train.add_argument('--eps', type=float, default=1e-8,
                        help='Epsilon value for AdaGrad (default %(default)s)')
    train.add_argument('--ro', type=float, default=0.95,
                        help='Ro value for AdaDelta (default %(default)s)')
    train.add_argument('-o', '--output', type=str, default='',
                        help='File where to save embeddings')

    # Embeddings
    embeddings = parser.add_argument_group('Embeddings')
    embeddings.add_argument('--vocab', type=str, default='',
                        help='Vocabulary file, either read or created')
    embeddings.add_argument('--vocab-size', type=int, default=0,
                            help='Maximum size of vocabulary (default %(default)s)')
    embeddings.add_argument('--vectors', type=str, default='',
                        help='Embeddings file, either read or created')
    embeddings.add_argument('--min-occurr', type=int, default=3,
                        help='Minimum occurrences for inclusion in vocabulary',
                        dest='minOccurr')
    embeddings.add_argument('--load', type=str, default='',
                        help='Load previously saved model')
    embeddings.add_argument('--variant', type=str, default='',
                        help='Either "senna" (default), "polyglot" or "word2vec".')

    # Extractors:
    extractors = parser.add_argument_group('Extractors')
    extractors.add_argument('--caps', const=5, nargs='?', type=int, default=None,
                            help='Include capitalization features. Optionally, supply the number of features (default %(default)s)')
    extractors.add_argument('--suffix', const=5, nargs='?', type=int, default=None,
                            help='Include suffix features. Optionally, supply the number of features (default %(default)s)')
    extractors.add_argument('--suffixes', type=str, default='',
                        help='Load suffixes from this file')
    extractors.add_argument('--prefix', const=5, nargs='?', type=int, default=None,
                            help='Include prefix features. Optionally, '\
                            'supply the number of features (default %(default)s)')
    extractors.add_argument('--prefixes', type=str, default='',
                        help='Load prefixes from this file')

    # reader
    parser.add_argument('--form-field', type=int, default=0,
                        help='Token field containing form (default %(default)s)',
                        dest='formField')

    # Use this for obtaining defaults from config file:
    #args = arguments.get_args()
    args = parser.parse_args()

    log_format = '%(message)s'
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(format=log_format, level=log_level)
    logger = logging.getLogger("Logger")

    config = ConfigParser()
    if args.config_file:
        config.read(args.config_file)

    # merge args with config

    if args.train:
        reader = PosReader(args.formField)
        # a generator (can be iterated several times)
        sentence_iter = reader.read(args.train)

        if args.vocab and os.path.exists(args.vocab):
            # start with the given vocabulary
            base_vocab = reader.load_vocabulary(args.vocab)
            if args.vectors and os.path.exists(args.vectors):
                embeddings = Embeddings(vectors=args.vectors, vocab=base_vocab,
                                        variant=args.variant)
            else:
                # create random embeddings
                embeddings = Embeddings(args.embeddings_size, vocab=base_vocab,
                                        variant=args.variant)
            # add the ngrams from the corpus
            # build vocabulary and tag set
            if args.vocab_size:
                vocab, tagset = reader.create_vocabulary(sentence_iter,
                                                         args.vocab_size,
                                                         args.minOccurr)
                embeddings.merge(vocab)
                logger.info("Overriding vocabulary in %s" % args.vocab)
                embeddings.save_vocabulary(args.vocab)
            else:
                vocab = base_vocab
                tagset = reader.create_tagset(sentence_iter)

        elif args.vocab:
            if not args.vectors:
                logger.error("No --vectors specified")
                return
            embeddings = Embeddings(args.embeddings_size, args.vocab,
                                    args.vectors, variant=args.variant)
            tagset = reader.create_tagset(sentence_iter)
            logger.info("Creating vocabulary in %s" % args.vocab)
            embeddings.save_vocabulary(args.vocab)

        elif args.variant == 'word2vec':
            if os.path.exists(args.vectors):
                embeddings = Embeddings(vectors=args.vectors,
                                        variant=args.variant)
                vocab, tagset = reader.create_vocabulary(sentence_iter,
                                                         args.vocab_size,
                                                         args.minOccurr)
                embeddings.merge(vocab)
            else:
                embeddings = Embeddings(vectors=args.vectors,
                                        variant=args.variant)
                tagset = reader.create_tagset(sentence_iter)
            if args.vocab:
                logger.info("Creating vocabulary in %s" % args.vocab)
                embeddings.save_vocabulary(args.vocab)
        else:
            # build vocabulary and tag set
            vocab, tagset = reader.create_vocabulary(sentence_iter,
                                                     args.vocab_size,
                                                     args.minOccurr)
            logger.info("Creating vocabulary in %s" % args.vocab)
            embeddings.save_vocabulary(args.vocab)
            logger.info("Creating word embeddings")
            embeddings = Embeddings(args.embeddings_size, vocab=vocab,
                                    variant=args.variant)

        converter = Converter()
        converter.add(embeddings)

        if args.caps:
            logger.info("Creating capitalization features...")
            converter.add(CapsExtractor(args.caps))

        if ((args.suffixes and not os.path.exists(args.suffixes)) or
            (args.prefixes and not os.path.exists(args.prefixes))):
            # collect the forms once
            words = (tok[reader.formField] for sent in sentence_iter for tok in sent)

        if args.suffix:
            if os.path.exists(args.suffixes):
                logger.info("Loading suffix list...")
                extractor = SuffixExtractor(args.suffix, args.suffixes)
                converter.add(extractor)
            else:
                logger.info("Creating suffix list...")
                extractor = SuffixExtractor(args.suffix, None, words)
                converter.add(extractor)
                if args.suffixes:
                    logger.info("Saving suffix list to: %s", args.suffixes)
                    extractor.write(args.suffixes)

        if args.prefix:
            if os.path.exists(args.prefixes):
                logger.info("Loading prefix list...")
                extractor = PrefixExtractor(args.prefix, args.prefixes)
                converter.add(extractor)
            else:
                logger.info("Creating prefix list...")
                extractor = PrefixExtractor(args.prefix, None, words)
                converter.add(extractor)
                if args.prefixes:
                    logger.info("Saving prefix list to: %s", args.prefixes)
                    extractor.write(args.prefixes)

        # obtain the tags for each sentence
        tag_index = { t:i for i,t in enumerate(tagset) }
        sentences = []
        tags = []
        for sent in sentence_iter:
            sentences.append(converter.convert([token[reader.formField] for token in sent]))
            tags.append(np.array([tag_index[token[reader.tagField]] for token in sent]))
    
        trainer = create_trainer(args, converter, tag_index)
        logger.info("Starting training with %d sentences" % len(sentences))

        report_frequency = max(args.iterations / 200, 1)
        report_frequency = 1    # DEBUG
        trainer.train(sentences, tags, args.iterations, report_frequency,
                      args.threads)
    
        logger.info("Saving trained model ...")
        trainer.saver(trainer)
        logger.info("... to %s" % args.model)

    else:
        with open(args.model) as file:
            tagger = Tagger.load(file)
        reader = PosReader()
        for sent in reader:
            for tok, tag in tagger.tag(sent):
                tok[reader.tagField] = tag
            ConllWriter.write(sent)

Example 11

Project: laikaboss
Source File: laika.py
View license
def main():
    # Define default configuration location

    parser = OptionParser(usage="usage: %prog [options] /path/to/file")
    parser.add_option("-d", "--debug",
                      action="store_true",
                      dest="debug",
                      help="enable debug messages to the console.")
    parser.add_option("-c", "--config-path",
                      action="store", type="string",
                      dest="config_path",
                      help="path to configuration for laikaboss framework.")
    parser.add_option("-o", "--out-path",
                      action="store", type="string",
                      dest="save_path",
                      help="Write all results to the specified path")
    parser.add_option("-s", "--source",
                      action="store", type="string",
                      dest="source",
                      help="Set the source (may affect dispatching) [default:laika]")
    parser.add_option("-p", "--num_procs",
                      action="store", type="int",
                      dest="num_procs",
                      default=8,
                      help="Specify the number of CPU's to use for a recursive scan. [default:8]")
    parser.add_option("-l", "--log",
                      action="store_true",
                      dest="log_result",
                      help="enable logging to syslog")
    parser.add_option("-j", "--log-json",
                      action="store", type="string",
                      dest="log_json",
                      help="enable logging JSON results to file")
    parser.add_option("-m", "--module",
                      action="store", type="string",
                      dest="scan_modules",
                      help="Specify individual module(s) to run and their arguments. If multiple, must be a space-separated list.")
    parser.add_option("--parent",
                      action="store", type="string",
                      dest="parent", default="",
                      help="Define the parent of the root object")
    parser.add_option("-e", "--ephID",
                      action="store", type="string",
                      dest="ephID", default="",
                      help="Specify an ephemeralID to send with the object")
    parser.add_option("--metadata",
                      action="store",
                      dest="ext_metadata",
                      help="Define metadata to add to the scan or specify a file containing the metadata.")
    parser.add_option("--size-limit",
                      action="store", type="int", default=10,
                      dest="sizeLimit",
                      help="Specify a size limit in MB (default: 10)")
    parser.add_option("--file-limit",
                      action="store", type="int", default=0,
                      dest="fileLimit",
                      help="Specify a limited number of files to scan (default: off)")
    parser.add_option("--progress",
                      action="store_true",
                      dest="progress",
                      default=False,
                      help="enable the progress bar")
    (options, args) = parser.parse_args()
    
    logger = logging.getLogger()

    if options.debug:
        # stdout is added by default, we'll capture this object here
        #lhStdout = logger.handlers[0]
        fileHandler = logging.FileHandler('laika-debug.log', 'w')
        formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
        fileHandler.setFormatter(formatter)
        logger.addHandler(fileHandler)
        # remove stdout from handlers so that debug info is only written to the file
        #logger.removeHandler(lhStdout)
        logging.basicConfig(level=logging.DEBUG)
        logger.setLevel(logging.DEBUG)

    global EXT_METADATA
    if options.ext_metadata:
        if os.path.exists(options.ext_metadata):
            with open(options.ext_metadata) as metafile:
                EXT_METADATA = json.loads(metafile.read())
        else:
            EXT_METADATA = json.loads(options.ext_metadata)
    else:
        EXT_METADATA = getConfig("ext_metadata")
    
    global EPHID
    if options.ephID:
        EPHID = options.ephID
    else:
        EPHID = getConfig("ephID")

    global SCAN_MODULES
    if options.scan_modules:
        SCAN_MODULES = options.scan_modules.split()
    else:
        SCAN_MODULES = None
    logging.debug("SCAN_MODULES: %s"  % (SCAN_MODULES))

    global PROGRESS_BAR
    if options.progress:
        PROGRESS_BAR = 1
    else:
        PROGRESS_BAR = strtobool(getConfig('progress_bar'))
    logging.debug("PROGRESS_BAR: %s"  % (PROGRESS_BAR))

    global LOG_RESULT
    if options.log_result:
        LOG_RESULT = 1
    else:
        LOG_RESULT = strtobool(getConfig('log_result'))
    logging.debug("LOG_RESULT: %s" % (LOG_RESULT))

    global LOG_JSON
    if options.log_json:
        LOG_JSON = options.log_json
    else:
        LOG_JSON = getConfig('log_json')

    global NUM_PROCS
    if options.num_procs:
        NUM_PROCS = options.num_procs
    else:
        NUM_PROCS = int(getConfig('num_procs'))
    logging.debug("NUM_PROCS: %s"  % (NUM_PROCS))

    global MAX_BYTES
    if options.sizeLimit:
        MAX_BYTES = options.sizeLimit * 1024 * 1024
    else:
        MAX_BYTES = int(getConfig('max_bytes'))
    logging.debug("MAX_BYTES: %s"  % (MAX_BYTES))

    global MAX_FILES
    if options.fileLimit:
        MAX_FILES = options.fileLimit
    else:
        MAX_FILES = int(getConfig('max_files'))
    logging.debug("MAX_FILES: %s"  % (MAX_FILES))

    global SOURCE
    if options.source:
        SOURCE = options.source
    else:
        SOURCE = getConfig('source')

    global SAVE_PATH
    if options.save_path:
        SAVE_PATH = options.save_path
    else:
        SAVE_PATH = getConfig('save_path')

    global CONFIG_PATH
    # Highest priority configuration is via argument
    if options.config_path:
        CONFIG_PATH = options.config_path
        logging.debug("using alternative config path: %s" % options.config_path)
        if not os.path.exists(options.config_path):
            error("the provided config path is not valid, exiting")
            return 1
    # Next, check to see if we're in the top level source directory (dev environment)
    elif os.path.exists(default_configs['dev_config_path']):
        CONFIG_PATH = default_configs['dev_config_path']
    # Next, check for an installed copy of the default configuration
    elif os.path.exists(default_configs['sys_config_path']):
        CONFIG_PATH = default_configs['sys_config_path']
    # Exit
    else:
        error('A valid framework configuration was not found in either of the following locations:\
\n%s\n%s' % (default_configs['dev_config_path'],default_configs['sys_config_path']))
        return 1
       

    # Check for stdin in no arguments were provided
    if len(args) == 0:

        DATA_PATH = []

        if not sys.stdin.isatty():
            while True:
                f = sys.stdin.readline().strip()
                if not f:
                    break
                else:
                    if not os.path.isfile(f):
                        error("One of the specified files does not exist: %s" % (f))
                        return 1
                    if os.path.isdir(f):
                        error("One of the files you specified is actually a directory: %s" % (f))
                        return 1
                    DATA_PATH.append(f)

        if not DATA_PATH:
            error("You must provide files via stdin when no arguments are provided")
            return 1
        logging.debug("Loaded %s files from stdin" % (len(DATA_PATH)))
    elif len(args) == 1:
        if os.path.isdir(args[0]):
            DATA_PATH = args[0]
        elif os.path.isfile(args[0]):
            DATA_PATH = [args[0]]
        else:
            error("File or directory does not exist: %s" % (args[0]))
            return 1
    else:
        for f in args:
            if not os.path.isfile(f):
                error("One of the specified files does not exist: %s" % (f))
                return 1
            if os.path.isdir(f):
                error("One of the files you specified is actually a directory: %s" % (f))
                return 1
        
        DATA_PATH = args

   
    tasks = multiprocessing.JoinableQueue()
    results = multiprocessing.Queue()
    
    fileList = []
    if type(DATA_PATH) is str:
        for root, dirs, files in os.walk(DATA_PATH):
            files = [f for f in files if not f[0] == '.']
            dirs[:] = [d for d in dirs if not d[0] == '.']
            for fname in files:
                fullpath = os.path.join(root, fname)
                if not os.path.islink(fullpath) and os.path.isfile(fullpath):
                    fileList.append(fullpath)
    else:
        fileList = DATA_PATH

    if MAX_FILES:
        fileList = fileList[:MAX_FILES]

    num_jobs = len(fileList)
    logging.debug("Loaded %s files for scanning" % (num_jobs))
    
    # Start consumers
    # If there's less files to process than processes, reduce the number of processes
    if num_jobs < NUM_PROCS:
        NUM_PROCS = num_jobs
    logging.debug("Starting %s processes" % (NUM_PROCS))
    consumers = [ Consumer(tasks, results)
                  for i in xrange(NUM_PROCS) ]
    try:
        
        for w in consumers:
            w.start()

        # Enqueue jobs
        for fname in fileList:
            tasks.put(fname)
        
        # Add a poison pill for each consumer
        for i in xrange(NUM_PROCS):
            tasks.put(None)

        if PROGRESS_BAR:
            monitor = QueueMonitor(tasks, num_jobs)
            monitor.start()

        # Wait for all of the tasks to finish
        tasks.join()
        if PROGRESS_BAR:
            monitor.join()

        while num_jobs:
            answer = zlib.decompress(results.get())
            print(answer)
            num_jobs -= 1

    except KeyboardInterrupt:
        error("Cancelled by user.. Shutting down.")
        for w in consumers:
            w.terminate()
            w.join()
        return None
    except:
        raise

Example 12

Project: fb2mobi
Source File: fb2mobi.py
View license
def process_file(config, infile, outfile=None):
    critical_error = False

    start_time = time.clock()
    temp_dir = tempfile.mkdtemp()

    if not os.path.exists(infile):
        config.log.critical('File {0} not found'.format(infile))
        return

    config.log.info('Converting "{0}"...'.format(os.path.split(infile)[1]))
    config.log.info('Using profile "{0}".'.format(config.current_profile['name']))

    # Проверка корректности параметров
    if infile:
        if not infile.lower().endswith(('.fb2', '.fb2.zip', '.zip', '.epub')):
            config.log.critical('"{0}" not *.fb2, *.fb2.zip, *.zip or *.epub'.format(infile))
            return

    if not config.current_profile['css'] and not infile.lower().endswith(('.epub')):
        config.log.warning('Profile does not have link to css file.')

    if 'xslt' in config.current_profile and not os.path.exists(config.current_profile['xslt']):
        config.log.critical('Transformation file {0} not found'.format(config.current_profile['xslt']))
        return

    if config.kindle_compression_level < 0 or config.kindle_compression_level > 2:
        config.log.warning('Parameter kindleCompressionLevel should be between 0 and 2, using default value (1).')
        config.kindle_compression_level = 1

    # Если не задано имя выходного файла - вычислим
    if not outfile:

        outdir, outputfile = os.path.split(infile)
        outputfile = get_mobi_filename(outputfile, config.transliterate)

        if config.output_dir:
            if not os.path.exists(config.output_dir):
                os.makedirs(config.output_dir)
            if config.input_dir and config.save_structure:
                rel_path = os.path.join(config.output_dir, os.path.split(os.path.relpath(infile, config.input_dir))[0])
                if not os.path.exists(rel_path):
                    os.makedirs(rel_path)
                outfile = os.path.join(rel_path, outputfile)
            else:
                outfile = os.path.join(config.output_dir, outputfile)
        else:
            outfile = os.path.join(outdir, outputfile)
    else:
        _output_format = os.path.splitext(outfile)[1].lower()[1:]
        if _output_format not in ('mobi', 'azw3', 'epub'):
            config.log.critical('Unknown output format: {0}'.format(_output_format))
            return -1
        else:
            if not config.mhl:
                config.output_format = _output_format
            outfile = '{0}.{1}'.format(os.path.splitext(outfile)[0], config.output_format)

    if config.output_format.lower() == 'epub':
        # Для epub всегда разбиваем по главам
        config.current_profile['chapterOnNewPage'] = True

    debug_dir = os.path.abspath(os.path.splitext(infile)[0])
    if os.path.splitext(debug_dir)[1].lower() == '.fb2':
        debug_dir = os.path.splitext(debug_dir)[0]

    input_epub = False

    if os.path.splitext(infile)[1].lower() == '.zip':
        config.log.info('Unpacking...')
        tmp_infile = infile
        try:
            infile = unzip(infile, temp_dir)
        except:
            config.log.critical('Error unpacking file "{0}".'.format(tmp_infile))
            return

        if not infile:
            config.log.critical('Error unpacking file "{0}".'.format(tmp_infile))
            return

    elif os.path.splitext(infile)[1].lower() == '.epub':
        config.log.info('Unpacking epub...')
        tmp_infile = infile
        try:
            infile = unzip_epub(infile, temp_dir)
        except:
            config.log.critical('Error unpacking file "{0}".'.format(tmp_infile))
            return

        if not infile:
            config.log.critical('Error unpacking file "{0}".'.format(tmp_infile))
            return

        input_epub = True

    if input_epub:
        # Let's see what we could do
        config.log.info('Processing epub...')
        epubparser = EpubProc(infile, config)
        epubparser.process()
        document_id = epubparser.book_uuid
    else:
        # Конвертируем в html
        config.log.info('Converting fb2 to html...')
        try:
            fb2parser = Fb2XHTML(infile, outfile, temp_dir, config)
            fb2parser.generate()
            document_id = fb2parser.book_uuid
            infile = os.path.join(temp_dir, 'OEBPS', 'content.opf')
        except:
            config.log.critical('Error while converting file "{0}"'.format(infile))
            config.log.debug('Getting details', exc_info=True)
            return

    config.log.info('Processing took {0} sec.'.format(round(time.clock() - start_time, 2)))

    if config.output_format.lower() in ('mobi', 'azw3'):
        # Запускаем kindlegen
        application_path = get_executable_path()
        if sys.platform == 'win32':
            if os.path.exists(os.path.join(application_path, 'kindlegen.exe')):
                kindlegen_cmd = os.path.join(application_path, 'kindlegen.exe')
            else:
                kindlegen_cmd = 'kindlegen.exe'
        else:
            if os.path.exists(os.path.join(application_path, 'kindlegen')):
                kindlegen_cmd = os.path.join(application_path, 'kindlegen')
            else:
                kindlegen_cmd = 'kindlegen'

        try:
            config.log.info('Running kindlegen...')
            kindlegen_cmd_pars = '-c{0}'.format(config.kindle_compression_level)

            startupinfo = None
            if os.name == 'nt':
                startupinfo = subprocess.STARTUPINFO()
                startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW

            with subprocess.Popen([kindlegen_cmd, infile, kindlegen_cmd_pars, '-locale', 'en'], stdout=subprocess.PIPE,
                                  stderr=subprocess.STDOUT, startupinfo=startupinfo) as result:
                config.log.debug(str(result.stdout.read(), 'utf-8', errors='replace'))

        except OSError as e:
            if e.errno == os.errno.ENOENT:
                config.log.critical('{0} not found'.format(kindlegen_cmd))
                critical_error = True
            else:
                config.log.critical(e.winerror)
                config.log.critical(e.strerror)
                config.log.debug('Getting details', exc_info=True, stack_info=True)
                raise e

    elif config.output_format.lower() == 'epub':
        # Собираем epub
        outfile = os.path.splitext(outfile)[0] + '.epub'
        config.log.info('Creating epub...')
        create_epub(temp_dir, outfile)

    if config.debug:
        # В режиме отладки копируем получившиеся файлы в выходной каталог
        config.log.info('Copying intermediate files to {0}...'.format(debug_dir))
        if os.path.exists(debug_dir):
            rm_tmp_files(debug_dir)
        shutil.copytree(temp_dir, debug_dir)

    # Копируем mobi(azw3) из временного в выходной каталог
    if not critical_error:
        ext = config.output_format.lower()
        if ext in ('mobi', 'azw3'):
            result_book = infile.replace('.opf', '.mobi')
            if not os.path.isfile(result_book):
                config.log.critical('kindlegen error, conversion interrupted.')
                critical_error = True
            else:
                try:
                    remove_personal = config.current_profile['kindleRemovePersonalLabel'] if not ext in ('mobi') or not config.send_to_kindle['send'] else False
                    if ext in ('mobi') and config.noMOBIoptimization:
                        config.log.info('Copying resulting file...')
                        shutil.copyfile(result_book, outfile)
                    else:
                        config.log.info('Optimizing resulting file...')
                        splitter = mobi_split(result_book, document_id, remove_personal, ext)
                        open(os.path.splitext(outfile)[0] + '.' + ext, 'wb').write(splitter.getResult() if ext == 'mobi' else splitter.getResult8())
                except:
                    config.log.critical('Error optimizing file, conversion interrupted.')
                    config.log.debug('Getting details', exc_info=True, stack_info=True)
                    critical_error = True

    if not critical_error:
        config.log.info('Book conversion completed in {0} sec.\n'.format(round(time.clock() - start_time, 2)))

        if config.send_to_kindle['send']:
            if config.output_format.lower() != 'mobi':
                config.log.warning('Kindle Personal Documents Service only accepts personal mobi files')
            else:
                config.log.info('Sending book...')
                try:
                    kindle = SendToKindle()
                    kindle.smtp_server = config.send_to_kindle['smtpServer']
                    kindle.smtp_port = config.send_to_kindle['smtpPort']
                    kindle.smtp_login = config.send_to_kindle['smtpLogin']
                    kindle.smtp_password = config.send_to_kindle['smtpPassword']
                    kindle.user_email = config.send_to_kindle['fromUserEmail']
                    kindle.kindle_email = config.send_to_kindle['toKindleEmail']
                    kindle.convert = False
                    kindle.send_mail([outfile])

                    config.log.info('Book has been sent to "{0}"'.format(config.send_to_kindle['toKindleEmail']))

                    if config.send_to_kindle['deleteSendedBook']:
                        try:
                            os.remove(outfile)
                        except:
                            config.log.error('Unable to remove file "{0}".'.format(outfile))
                            return -1

                except KeyboardInterrupt:
                    print('User interrupt. Exiting...')
                    sys.exit(-1)

                except:
                    config.log.error('Error sending file')
                    config.log.debug('Getting details', exc_info=True, stack_info=True)

    # Чистим временные файлы
    rm_tmp_files(temp_dir)

Example 13

Project: tracpy
Source File: plotting.py
View license
def hist(lonp, latp, fname, tind='final', which='contour', vmax=None,
         fig=None, ax=None, bins=(40, 40), N=10, grid=None, xlims=None,
         ylims=None, C=None, Title=None, weights=None,
         Label='Final drifter location (%)', isll=True, binscale=None):
    """
    Plot histogram of given track data at time index tind.

    Args:
        lonp,latp: Drifter track positions in lon/lat [time x ndrifters]
        fname: Plot name to save
        tind (Optional): Default is 'final', in which case the final
         position of each drifter in the array is found and plotted.
         Alternatively, a time index can be input and drifters at that time
         will be plotted. Note that once drifters hit the outer numerical
         boundary, they are nan'ed out so this may miss some drifters.
        which (Optional[str]): 'contour', 'pcolor', 'hexbin', 'hist2d' for
         type of plot used. Default 'hexbin'.
        bins (Optional): Number of bins used in histogram. Default (15,25).
        N (Optional[int]): Number of contours to make. Default 10.
        grid (Optional): grid as read in by inout.readgrid()
        xlims (Optional): value limits on the x axis
        ylims (Optional): value limits on the y axis
        isll: Default True. Inputs are in lon/lat. If False, assume they
         are in projected coords.

    Note:
        Currently assuming we are plotting the final location of each drifter
        regardless of tind.
    """

    if grid is None:
        loc = 'http://barataria.tamu.edu:8080/thredds/dodsC/NcML/txla_nesting6.nc'
        grid = inout.readgrid(loc)

    if isll:  # if inputs are in lon/lat, change to projected x/y
        # Change positions from lon/lat to x/y
        xp, yp = grid.proj(lonp, latp)
        # Need to retain nan's since basemap changes them to values
        ind = np.isnan(lonp)
        xp[ind] = np.nan
        yp[ind] = np.nan
    else:
        xp = lonp
        yp = latp

    if fig is None:
        fig = plt.figure(figsize=(11, 10))
    else:
        fig = fig
    background(grid)  # Plot coastline and such

    if tind == 'final':
        # Find final positions of drifters
        xpc, ypc = tools.find_final(xp, yp)
    elif isinstance(tind, int):
        xpc = xp[:, tind]
        ypc = yp[:, tind]
    else:  # just plot what is input if some other string
        xpc = xp.flatten()
        ypc = yp.flatten()

    if which == 'contour':

        # Info for 2d histogram
        H, xedges, yedges = np.histogram2d(xpc, ypc,
                                           range=[[grid.x_rho.min(),
                                                   grid.x_rho.max()],
                                                  [grid.y_rho.min(),
                                                   grid.y_rho.max()]],
                                           bins=bins)

        # Contour Plot
        XE, YE = np.meshgrid(op.resize(xedges, 0), op.resize(yedges, 0))
        d = (H/H.sum())*100
        # # from http://matplotlib.1069221.n5.nabble.com/question-about-contours-and-clim-td21111.html
        # locator = ticker.MaxNLocator(50) # if you want no more than 10 contours
        # locator.create_dummy_axis()
        # locator.set_bounds(0,1)#d.min(),d.max())
        # levs = locator()
        con = fig.contourf(XE, YE, d.T, N)  # ,levels=levs)#(0,15,30,45,60,75,90,105,120))
        con.set_cmap('YlOrRd')

        if Title is not None:
            plt.set_title(Title)

        # Horizontal colorbar below plot
        cax = fig.add_axes([0.3725, 0.25, 0.48, 0.02])  # colorbar axes
        cb = fig.colorbar(con, cax=cax, orientation='horizontal')
        cb.set_label('Final drifter location (percent)')

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'histcon.png', bbox_inches='tight')

    elif which == 'pcolor':

        # Info for 2d histogram
        H, xedges, yedges = np.histogram2d(xpc, ypc,
                                           range=[[grid.x_rho.min(),
                                                   grid.x_rho.max()],
                                                  [grid.y_rho.min(),
                                                   grid.y_rho.max()]],
                                           bins=bins, weights=weights)

        # Pcolor plot
        # C is the z value plotted, and is normalized by the total number of
        # drifters
        if C is None:
            C = (H.T/H.sum())*100
        else:
            # or, provide some other weighting
            C = (H.T/C)*100

        p = plt.pcolor(xedges, yedges, C, cmap='YlOrRd')

        if Title is not None:
            plt.set_title(Title)

        # Set x and y limits
        if xlims is not None:
            plt.xlim(xlims)
        if ylims is not None:
            plt.ylim(ylims)

        # Horizontal colorbar below plot
        cax = fig.add_axes([0.3775, 0.25, 0.48, 0.02])  # colorbar axes
        cb = fig.colorbar(p, cax=cax, orientation='horizontal')
        cb.set_label('Final drifter location (percent)')

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'histpcolor.png', bbox_inches='tight')
        # savefig('figures/' + fname + 'histpcolor.pdf',bbox_inches='tight')

    elif which == 'hexbin':

        if ax is None:
            ax = plt.gca()
        else:
            ax = ax

        if C is None:
            # C with the reduce_C_function as sum is what makes it a percent
            C = np.ones(len(xpc))*(1./len(xpc))*100
        else:
            C = C*np.ones(len(xpc))*100
        hb = plt.hexbin(xpc, ypc, C=C, cmap='YlOrRd', gridsize=bins[0],
                    extent=(grid.x_psi.min(), grid.x_psi.max(),
                            grid.y_psi.min(), grid.y_psi.max()),
                    reduce_C_function=sum, vmax=vmax, axes=ax, bins=binscale)

        # Set x and y limits
        if xlims is not None:
            plt.xlim(xlims)
        if ylims is not None:
            plt.ylim(ylims)

        if Title is not None:
            ax.set_title(Title)

        # Want colorbar at the given location relative to axis so this works
        # regardless of # of subplots, so convert from axis to figure
        # coordinates. To do this, first convert from axis to display coords
        # transformations:
        # http://matplotlib.org/users/transforms_tutorial.html
        # axis: [x_left, y_bottom, width, height]
        ax_coords = [0.35, 0.25, 0.6, 0.02]
        # display: [x_left,y_bottom,x_right,y_top]
        disp_coords = ax.transAxes.transform([(ax_coords[0], ax_coords[1]),
                                              (ax_coords[0]+ax_coords[2],
                                               ax_coords[1]+ax_coords[3])])
        # inverter object to go from display coords to figure coords
        inv = fig.transFigure.inverted()
        # figure: [x_left,y_bottom,x_right,y_top]
        fig_coords = inv.transform(disp_coords)
        # actual desired figure coords. figure:
        # [x_left, y_bottom, width, height]
        fig_coords = [fig_coords[0, 0], fig_coords[0, 1], fig_coords[1, 0] -
                      fig_coords[0, 0], fig_coords[1, 1] - fig_coords[0, 1]]
        # Inlaid colorbar
        cax = fig.add_axes(fig_coords)

        # # Horizontal colorbar below plot
        # cax = fig.add_axes([0.3775, 0.25, 0.48, 0.02]) # colorbar axes
        cb = fig.colorbar(hb, cax=cax, orientation='horizontal')
        cb.set_label(Label)

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'histhexbin.png', bbox_inches='tight')
        # savefig('figures/' + fname + 'histhexbin.pdf',bbox_inches='tight')

    elif which == 'hist2d':

        plt.hist2d(xpc, ypc, bins=40, range=[[grid.x_rho.min(),
                                          grid.x_rho.max()],
                                         [grid.y_rho.min(),
                                          grid.y_rho.max()]], normed=True)
        plt.set_cmap('YlOrRd')
        # Set x and y limits
        if xlims is not None:
            xlim(xlims)
        if ylims is not None:
            ylim(ylims)

        # Horizontal colorbar below plot
        cax = fig.add_axes([0.3775, 0.25, 0.48, 0.02])  # colorbar axes
        cb = fig.colorbar(cax=cax, orientation='horizontal')
        cb.set_label('Final drifter location (percent)')

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'hist2d.png', bbox_inches='tight')

Example 14

Project: RopeMate.tmbundle
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    import cPickle as pickle
    import marshal
    import inspect
    import types
    import threading

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('w')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()


    def _cached(func):
        cache = {}
        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    args.append(self._object_to_persisted_form(frame.f_locals[argname]))
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)

            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                   _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, (str, unicode)):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list', self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    keys = object_.keys()[0]
                    values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set', self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.func_code)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.im_func.func_code)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, (str, unicode, list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, (types.TypeType, types.ClassType)):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        '__builtins__': __builtins__,
                        '__file__': file_to_run})
    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    execfile(file_to_run, run_globals)
    if send_info != '-':
        data_sender.close()

Example 15

Project: python-mode
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    import cPickle as pickle
    import marshal
    import inspect
    import types
    import threading

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('w')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    def _cached(func):
        cache = {}

        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    args.append(self._object_to_persisted_form(
                        frame.f_locals[argname]))
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or
            #    not self._is_code_inside_project(frame.f_back.f_code)

            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or
                    not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, (str, unicode)):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list',
                        self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    keys = object_.keys()[0]
                    values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set',
                        self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.func_code)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.im_func.func_code)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, (str, unicode, list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, (types.TypeType, types.ClassType)):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()
            sys.settrace(None)

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        '__builtins__': __builtins__,
                        '__file__': file_to_run})
    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    execfile(file_to_run, run_globals)
    if send_info != '-':
        data_sender.close()

Example 16

Project: rope
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    try:
        import pickle
    except ImportError:
        import cPickle as pickle
    import marshal
    import inspect
    import types
    import threading
    import rope.base.utils.pycompat as pycompat

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('wb')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    def _cached(func):
        cache = {}

        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    argvalue = self._object_to_persisted_form(
                        frame.f_locals[argname])
                    args.append(argvalue)
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or
            #    not self._is_code_inside_project(frame.f_back.f_code)
            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or
                    not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, pycompat.string_types):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list',
                        self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    # @todo - fix it properly, why is __locals__ being
                    # duplicated ?
                    keys = [key for key in object_.keys() if key != '__locals__'][0]
                    values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set',
                        self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.__code__)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.__func__.__code__)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, pycompat.string_types + (list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, type):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()
            sys.settrace(None)

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        '__builtins__': __builtins__,
                        '__file__': file_to_run})

    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    pycompat.execfile(file_to_run, run_globals)
    if send_info != '-':
        data_sender.close()

Example 17

Project: tp-libvirt
Source File: virsh_blockcopy.py
View license
def run(test, params, env):
    """
    Test command: virsh blockcopy.

    This command can copy a disk backing image chain to dest.
    1. Positive testing
        1.1 Copy a disk to a new image file.
        1.2 Reuse existing destination copy.
        1.3 Valid blockcopy timeout and bandwidth test.
    2. Negative testing
        2.1 Copy a disk to a non-exist directory.
        2.2 Copy a disk with invalid options.
        2.3 Do block copy for a persistent domain.
    """

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    target = params.get("target_disk", "")
    replace_vm_disk = "yes" == params.get("replace_vm_disk", "no")
    disk_source_protocol = params.get("disk_source_protocol")
    disk_type = params.get("disk_type")
    pool_name = params.get("pool_name")
    image_size = params.get("image_size")
    emu_image = params.get("emulated_image")
    copy_to_nfs = "yes" == params.get("copy_to_nfs", "no")
    mnt_path_name = params.get("mnt_path_name")
    options = params.get("blockcopy_options", "")
    bandwidth = params.get("blockcopy_bandwidth", "")
    bandwidth_byte = "yes" == params.get("bandwidth_byte", "no")
    reuse_external = "yes" == params.get("reuse_external", "no")
    persistent_vm = params.get("persistent_vm", "no")
    status_error = "yes" == params.get("status_error", "no")
    active_error = "yes" == params.get("active_error", "no")
    active_snap = "yes" == params.get("active_snap", "no")
    active_save = "yes" == params.get("active_save", "no")
    check_state_lock = "yes" == params.get("check_state_lock", "no")
    with_shallow = "yes" == params.get("with_shallow", "no")
    with_blockdev = "yes" == params.get("with_blockdev", "no")
    setup_libvirt_polkit = "yes" == params.get('setup_libvirt_polkit')
    bug_url = params.get("bug_url", "")
    timeout = int(params.get("timeout", 1200))
    rerun_flag = 0
    blkdev_n = None
    back_n = 'blockdev-backing-iscsi'
    snapshot_external_disks = []
    # Skip/Fail early
    if with_blockdev and not libvirt_version.version_compare(1, 2, 13):
        raise exceptions.TestSkipError("--blockdev option not supported in "
                                       "current version")
    if not target:
        raise exceptions.TestSkipError("Require target disk to copy")
    if setup_libvirt_polkit and not libvirt_version.version_compare(1, 1, 1):
        raise exceptions.TestSkipError("API acl test not supported in current"
                                       " libvirt version")
    if copy_to_nfs and not libvirt_version.version_compare(1, 1, 1):
        raise exceptions.TestSkipError("Bug will not fix: %s" % bug_url)
    if bandwidth_byte and not libvirt_version.version_compare(1, 3, 3):
        raise exceptions.TestSkipError("--bytes option not supported in "
                                       "current version")

    # Check the source disk
    if vm_xml.VMXML.check_disk_exist(vm_name, target):
        logging.debug("Find %s in domain %s", target, vm_name)
    else:
        raise exceptions.TestFail("Can't find %s in domain %s" % (target,
                                                                  vm_name))

    original_xml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
    tmp_dir = data_dir.get_tmp_dir()

    # Prepare dest path params
    dest_path = params.get("dest_path", "")
    dest_format = params.get("dest_format", "")
    # Ugh... this piece of chicanery brought to you by the QemuImg which
    # will "add" the 'dest_format' extension during the check_format code.
    # So if we create the file with the extension and then remove it when
    # doing the check_format later, then we avoid erroneous failures.
    dest_extension = ""
    if dest_format != "":
        dest_extension = ".%s" % dest_format

    # Prepare for --reuse-external option
    if reuse_external:
        options += "--reuse-external --wait"
        # Set rerun_flag=1 to do blockcopy twice, and the first time created
        # file can be reused in the second time if no dest_path given
        # This will make sure the image size equal to original disk size
        if dest_path == "/path/non-exist":
            if os.path.exists(dest_path) and not os.path.isdir(dest_path):
                os.remove(dest_path)
        else:
            rerun_flag = 1

    # Prepare other options
    if dest_format == "raw":
        options += " --raw"
    if with_blockdev:
        options += " --blockdev"
    if len(bandwidth):
        options += " --bandwidth %s" % bandwidth
    if bandwidth_byte:
        options += " --bytes"
    if with_shallow:
        options += " --shallow"

    # Prepare acl options
    uri = params.get("virsh_uri")
    unprivileged_user = params.get('unprivileged_user')
    if unprivileged_user:
        if unprivileged_user.count('EXAMPLE'):
            unprivileged_user = 'testacl'

    extra_dict = {'uri': uri, 'unprivileged_user': unprivileged_user,
                  'debug': True, 'ignore_status': True, 'timeout': timeout}

    libvirtd_utl = utils_libvirtd.Libvirtd()
    libvirtd_conf = utils_config.LibvirtdConfig()
    libvirtd_conf["log_filters"] = '"3:json 1:libvirt 1:qemu"'
    libvirtd_log_path = os.path.join(test.tmpdir, "libvirtd.log")
    libvirtd_conf["log_outputs"] = '"1:file:%s"' % libvirtd_log_path
    logging.debug("the libvirtd config file content is:\n %s" %
                  libvirtd_conf)
    libvirtd_utl.restart()

    def check_format(dest_path, dest_extension, expect):
        """
        Check the image format

        :param dest_path: Path of the copy to create
        :param expect: Expect image format
        """
        # And now because the QemuImg will add the extension for us
        # we have to remove it here.
        path_noext = dest_path.strip(dest_extension)
        params['image_name'] = path_noext
        params['image_format'] = expect
        image = qemu_storage.QemuImg(params, "/", path_noext)
        if image.get_format() == expect:
            logging.debug("%s format is %s", dest_path, expect)
        else:
            raise exceptions.TestFail("%s format is not %s" % (dest_path,
                                                               expect))

    def _blockjob_and_libvirtd_chk(cmd_result):
        """
        Raise TestFail when blockcopy fail with block-job-complete error or
        blockcopy hang with state change lock.
        This is a specific bug verify, so ignore status_error here.
        """
        bug_url_ = "https://bugzilla.redhat.com/show_bug.cgi?id=1197592"
        err_msg = "internal error: unable to execute QEMU command"
        err_msg += " 'block-job-complete'"
        if err_msg in cmd_result.stderr:
            raise exceptions.TestFail("Hit on bug: %s" % bug_url_)

        err_pattern = "Timed out during operation: cannot acquire"
        err_pattern += " state change lock"
        ret = chk_libvirtd_log(libvirtd_log_path, err_pattern, "error")
        if ret:
            raise exceptions.TestFail("Hit on bug: %s" % bug_url_)

    def _make_snapshot():
        """
        Make external disk snapshot
        """
        snap_xml = snapshot_xml.SnapshotXML()
        snapshot_name = "blockcopy_snap"
        snap_xml.snap_name = snapshot_name
        snap_xml.description = "blockcopy snapshot"

        # Add all disks into xml file.
        vmxml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
        disks = vmxml.devices.by_device_tag('disk')
        new_disks = []
        src_disk_xml = disks[0]
        disk_xml = snap_xml.SnapDiskXML()
        disk_xml.xmltreefile = src_disk_xml.xmltreefile
        del disk_xml.device
        del disk_xml.address
        disk_xml.snapshot = "external"
        disk_xml.disk_name = disk_xml.target['dev']

        # Only qcow2 works as external snapshot file format, update it
        # here
        driver_attr = disk_xml.driver
        driver_attr.update({'type': 'qcow2'})
        disk_xml.driver = driver_attr

        new_attrs = disk_xml.source.attrs
        if disk_xml.source.attrs.has_key('file'):
            new_file = os.path.join(tmp_dir, "blockcopy_shallow.snap")
            snapshot_external_disks.append(new_file)
            new_attrs.update({'file': new_file})
            hosts = None
        elif (disk_xml.source.attrs.has_key('dev') or
              disk_xml.source.attrs.has_key('name') or
              disk_xml.source.attrs.has_key('pool')):
            if (disk_xml.type_name == 'block' or
                    disk_source_protocol == 'iscsi'):
                disk_xml.type_name = 'block'
                if new_attrs.has_key('name'):
                    del new_attrs['name']
                    del new_attrs['protocol']
                elif new_attrs.has_key('pool'):
                    del new_attrs['pool']
                    del new_attrs['volume']
                    del new_attrs['mode']
                back_path = utl.setup_or_cleanup_iscsi(is_setup=True,
                                                       is_login=True,
                                                       image_size="1G",
                                                       emulated_image=back_n)
                emulated_iscsi.append(back_n)
                cmd = "qemu-img create -f qcow2 %s 1G" % back_path
                process.run(cmd, shell=True)
                new_attrs.update({'dev': back_path})
                hosts = None

        new_src_dict = {"attrs": new_attrs}
        if hosts:
            new_src_dict.update({"hosts": hosts})
        disk_xml.source = disk_xml.new_disk_source(**new_src_dict)

        new_disks.append(disk_xml)

        snap_xml.set_disks(new_disks)
        snapshot_xml_path = snap_xml.xml
        logging.debug("The snapshot xml is: %s" % snap_xml.xmltreefile)

        options = "--disk-only --xmlfile %s " % snapshot_xml_path

        snapshot_result = virsh.snapshot_create(
            vm_name, options, debug=True)

        if snapshot_result.exit_status != 0:
            raise exceptions.TestFail(snapshot_result.stderr)

    snap_path = ''
    save_path = ''
    emulated_iscsi = []
    nfs_cleanup = False
    try:
        # Prepare dest_path
        tmp_file = time.strftime("%Y-%m-%d-%H.%M.%S.img")
        tmp_file += dest_extension
        if not dest_path:
            if with_blockdev:
                blkdev_n = 'blockdev-iscsi'
                dest_path = utl.setup_or_cleanup_iscsi(is_setup=True,
                                                       is_login=True,
                                                       image_size=image_size,
                                                       emulated_image=blkdev_n)
                emulated_iscsi.append(blkdev_n)
                # Make sure the new disk show up
                utils_misc.wait_for(lambda: os.path.exists(dest_path), 5)
            else:
                if copy_to_nfs:
                    tmp_dir = "%s/%s" % (tmp_dir, mnt_path_name)
                dest_path = os.path.join(tmp_dir, tmp_file)

        # Domain disk replacement with desire type
        if replace_vm_disk:
            # Calling 'set_vm_disk' is bad idea as it left lots of cleanup jobs
            # after test, such as pool, volume, nfs, iscsi and so on
            # TODO: remove this function in the future
            if disk_source_protocol == 'iscsi':
                emulated_iscsi.append(emu_image)
            if disk_source_protocol == 'netfs':
                nfs_cleanup = True
            utl.set_vm_disk(vm, params, tmp_dir, test)
            new_xml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

        if with_shallow:
            _make_snapshot()

        # Prepare transient/persistent vm
        if persistent_vm == "no" and vm.is_persistent():
            vm.undefine()
        elif persistent_vm == "yes" and not vm.is_persistent():
            new_xml.define()

        # Run blockcopy command to create destination file
        if rerun_flag == 1:
            options1 = "--wait %s --finish --verbose" % dest_format
            if with_blockdev:
                options1 += " --blockdev"
            if with_shallow:
                options1 += " --shallow"
            cmd_result = virsh.blockcopy(vm_name, target,
                                         dest_path, options1,
                                         **extra_dict)
            status = cmd_result.exit_status
            if status != 0:
                raise exceptions.TestFail("Run blockcopy command fail: %s" %
                                          cmd_result.stdout + cmd_result.stderr)
            elif not os.path.exists(dest_path):
                raise exceptions.TestFail("Cannot find the created copy")

        # Run the real testing command
        cmd_result = virsh.blockcopy(vm_name, target, dest_path,
                                     options, **extra_dict)

        # check BZ#1197592
        _blockjob_and_libvirtd_chk(cmd_result)
        status = cmd_result.exit_status

        if not libvirtd_utl.is_running():
            raise exceptions.TestFail("Libvirtd service is dead")

        if not status_error:
            if status == 0:
                ret = utils_misc.wait_for(
                    lambda: check_xml(vm_name, target, dest_path, options), 5)
                if not ret:
                    raise exceptions.TestFail("Domain xml not expected after"
                                              " blockcopy")
                if options.count("--bandwidth"):
                    if options.count('--bytes'):
                        bandwidth += 'B'
                    else:
                        bandwidth += 'M'
                    if not utl.check_blockjob(vm_name, target, "bandwidth",
                                              bandwidth):
                        raise exceptions.TestFail("Check bandwidth failed")
                val = options.count("--pivot") + options.count("--finish")
                # Don't wait for job finish when using --byte option
                val += options.count('--bytes')
                if val == 0:
                    try:
                        finish_job(vm_name, target, timeout)
                    except JobTimeout, excpt:
                        raise exceptions.TestFail("Run command failed: %s" %
                                                  excpt)
                if options.count("--raw") and not with_blockdev:
                    check_format(dest_path, dest_extension, dest_format)
                if active_snap:
                    snap_path = "%s/%s.snap" % (tmp_dir, vm_name)
                    snap_opt = "--disk-only --atomic --no-metadata "
                    snap_opt += "vda,snapshot=external,file=%s" % snap_path
                    ret = virsh.snapshot_create_as(vm_name, snap_opt,
                                                   ignore_status=True,
                                                   debug=True)
                    utl.check_exit_status(ret, active_error)
                if active_save:
                    save_path = "%s/%s.save" % (tmp_dir, vm_name)
                    ret = virsh.save(vm_name, save_path,
                                     ignore_status=True,
                                     debug=True)
                    utl.check_exit_status(ret, active_error)
                if check_state_lock:
                    # Run blockjob pivot in subprocess as it will hang
                    # for a while, run blockjob info again to check
                    # job state
                    command = "virsh blockjob %s %s --pivot" % (vm_name,
                                                                target)
                    session = aexpect.ShellSession(command)
                    ret = virsh.blockjob(vm_name, target, "--info")
                    err_info = "cannot acquire state change lock"
                    if err_info in ret.stderr:
                        raise exceptions.TestFail("Hit on bug: %s" % bug_url)
                    utl.check_exit_status(ret, status_error)
                    session.close()
            else:
                raise exceptions.TestFail(cmd_result.stdout + cmd_result.stderr)
        else:
            if status:
                logging.debug("Expect error: %s", cmd_result.stderr)
            else:
                # Commit id '4c297728' changed how virsh exits when
                # unexpectedly failing due to timeout from a fail (1)
                # to a success(0), so we need to look for a different
                # marker to indicate the copy aborted. As "stdout: Now
                # in mirroring phase" could be in stdout which fail the
                # check, so also do check in libvirtd log to confirm.
                if options.count("--timeout") and options.count("--wait"):
                    log_pattern = "Copy aborted"
                    if (re.search(log_pattern, cmd_result.stdout) or
                            chk_libvirtd_log(libvirtd_log_path,
                                             log_pattern, "debug")):
                        logging.debug("Found success a timed out block copy")
                else:
                    raise exceptions.TestFail("Expect fail, but run "
                                              "successfully: %s" % bug_url)
    finally:
        # Recover VM may fail unexpectedly, we need using try/except to
        # proceed the following cleanup steps
        try:
            # Abort exist blockjob to avoid any possible lock error
            virsh.blockjob(vm_name, target, '--abort', ignore_status=True)
            vm.destroy(gracefully=False)
            # It may take a long time to shutdown the VM which has
            # blockjob running
            utils_misc.wait_for(
                lambda: virsh.domstate(vm_name,
                                       ignore_status=True).exit_status, 180)
            if virsh.domain_exists(vm_name):
                if active_snap or with_shallow:
                    option = "--snapshots-metadata"
                else:
                    option = None
                original_xml.sync(option)
            else:
                original_xml.define()
        except Exception, e:
            logging.error(e)
        for disk in snapshot_external_disks:
            if os.path.exists(disk):
                os.remove(disk)
        # Clean up libvirt pool, which may be created by 'set_vm_disk'
        if disk_type == 'volume':
            virsh.pool_destroy(pool_name, ignore_status=True, debug=True)
        # Restore libvirtd conf and restart libvirtd
        libvirtd_conf.restore()
        libvirtd_utl.restart()
        if libvirtd_log_path and os.path.exists(libvirtd_log_path):
            os.unlink(libvirtd_log_path)
        # Clean up NFS
        try:
            if nfs_cleanup:
                utl.setup_or_cleanup_nfs(is_setup=False)
        except Exception, e:
            logging.error(e)
        # Clean up iSCSI
        try:
            for iscsi_n in list(set(emulated_iscsi)):
                utl.setup_or_cleanup_iscsi(is_setup=False, emulated_image=iscsi_n)
                # iscsid will be restarted, so give it a break before next loop
                time.sleep(5)
        except Exception, e:
            logging.error(e)
        if os.path.exists(dest_path):
            os.remove(dest_path)
        if os.path.exists(snap_path):
            os.remove(snap_path)
        if os.path.exists(save_path):
            os.remove(save_path)
        # Restart virtlogd service to release VM log file lock
        try:
            path.find_command('virtlogd')
            process.run('systemctl reset-failed virtlogd')
            process.run('systemctl restart virtlogd ')
        except path.CmdNotFoundError:
            pass

Example 18

Project: tractor
Source File: wise2.py
View license
def main(opt, ps):
	#ralo = 36
	#rahi = 42
	#declo = -1.25
	#dechi = 1.25
	#width = 7
	  
	ralo = 37.5
	rahi = 41.5
	declo = -1.5
	dechi = 2.5
	width = 2.5

	rl,rh = 39,40
	dl,dh = 0,1
	roipoly = np.array([(rl,dl),(rl,dh),(rh,dh),(rh,dl)])

	ra  = (ralo  + rahi ) / 2.
	dec = (declo + dechi) / 2.

	bandnum = 1
	band = 'w%i' % bandnum
	plt.figure(figsize=(12,12))

	#basedir = '/project/projectdirs/bigboss'
	#wisedatadir = os.path.join(basedir, 'data', 'wise')

	wisedatadirs = ['/clusterfs/riemann/raid007/bosswork/boss/wise_level1b',
					'/clusterfs/riemann/raid000/bosswork/boss/wise1ext']

	wisecatdir = '/home/boss/products/NULL/wise/trunk/fits/'

	ofn = 'wise-images-overlapping.fits'

	if os.path.exists(ofn):
		print('File exists:', ofn)
		T = fits_table(ofn)
		print('Found', len(T), 'images overlapping')

		print('Reading WCS headers...')
		wcses = []
		T.filename = [fn.strip() for fn in T.filename]
		for fn in T.filename:
			wcs = anwcs(fn, 0)
			wcses.append(wcs)

	else:
		TT = []
		for d in wisedatadirs:
			ifn = os.path.join(d, 'WISE-index-L1b.fits') #'index-allsky-astr-L1b.fits')
			T = fits_table(ifn, columns=['ra','dec','scan_id','frame_num'])
			print('Read', len(T), 'from WISE index', ifn)
			I = np.flatnonzero((T.ra > ralo) * (T.ra < rahi) * (T.dec > declo) * (T.dec < dechi))
			print(len(I), 'overlap RA,Dec box')
			T.cut(I)

			fns = []
			for sid,fnum in zip(T.scan_id, T.frame_num):
				print('scan,frame', sid, fnum)
				fn = get_l1b_file(d, sid, fnum, bandnum)
				print('-->', fn)
				assert(os.path.exists(fn))
				fns.append(fn)
			T.filename = np.array(fns)
			TT.append(T)
		T = merge_tables(TT)

		wcses = []
		corners = []
		ii = []
		for i in range(len(T)):
			wcs = anwcs(T.filename[i], 0)
			W,H = wcs.get_width(), wcs.get_height()
			rd = []
			for x,y in [(1,1),(1,H),(W,H),(W,1)]:
				rd.append(wcs.pixelxy2radec(x,y))
			rd = np.array(rd)
			if polygons_intersect(roipoly, rd):
				wcses.append(wcs)
				corners.append(rd)
				ii.append(i)

		print('Found', len(wcses), 'overlapping')
		I = np.array(ii)
		T.cut(I)

		outlines = corners
		corners = np.vstack(corners)

		nin = sum([1 if point_in_poly(ra,dec,ol) else 0 for ol in outlines])
		print('Number of images containing RA,Dec,', ra,dec, 'is', nin)

		r0,r1 = corners[:,0].min(), corners[:,0].max()
		d0,d1 = corners[:,1].min(), corners[:,1].max()
		print('RA,Dec extent', r0,r1, d0,d1)

		T.writeto(ofn)
		print('Wrote', ofn)


	# MAGIC 2.75: approximate pixel scale, "/pix
	S = int(3600. / 2.75)
	print('Coadd size', S)
	cowcs = anwcs_create_box(ra, dec, 1., S, S)

	if False:
		print('Plotting map...')
		plot = Plotstuff(outformat='png', ra=ra, dec=dec, width=width, size=(800,800))
		out = plot.outline
		plot.color = 'white'
		plot.alpha = 0.07
		plot.apply_settings()

		for wcs in wcses:
			out.wcs = wcs
			out.fill = False
			plot.plot('outline')
			out.fill = True
			plot.plot('outline')

		plot.color = 'gray'
		plot.alpha = 1.0
		plot.lw = 1
		plot.plot_grid(1, 1, 1, 1)

		plot.color = 'red'
		plot.lw = 3
		plot.alpha = 0.75
		out.wcs = cowcs
		out.fill = False
		plot.plot('outline')

		if opt.sources:
			rd = plot.radec
			plot_radec_set_filename(rd, opt.sources)
			plot.plot('radec')

		pfn = ps.getnext()
		plot.write(pfn)
		print('Wrote', pfn)


	# Re-sort by distance to RA,Dec center...
	#I = np.argsort(np.hypot(T.ra - ra, T.dec - dec))
	#T.cut(I)
	# IF YOU DO THIS, MUST ALSO RE-SORT 'wcses'!

	
	if opt.sources:

		# Look at a radius this big, in arcsec, around each source position.
		# 15" = about 6 WISE pixels
		Wrad = 15. / 3600.

		# Look for SDSS objects within this radius; Wrad + a margin
		Srad = Wrad + 5./3600.


		S = fits_table(opt.sources)
		print('Read', len(S), 'sources from', opt.sources)

		groups,singles = cluster_radec(S.ra, S.dec, Wrad, singles=True)
		print('Source clusters:', groups)
		print('Singletons:', singles)

		tractors = []

		sdss = DR9(basedir='data-dr9')
		sband = 'r'

		for i in singles:
			r,d = S.ra[i],S.dec[i]
			print('Source', i, 'at', r,d)
			fn = sdss.retrieve('photoObj', S.run[i], S.camcol[i], S.field[i], band=sband)
			print('Reading', fn)
			oo = fits_table(fn)
			print('Got', len(oo))
			cat1,obj1,I = get_tractor_sources_dr9(None, None, None, bandname=sband,
												  objs=oo, radecrad=(r,d,Srad), bands=[],
												  nanomaggies=True, extrabands=[band],
												  fixedComposites=True,
												  getobjs=True, getobjinds=True)
			print('Got', len(cat1), 'SDSS sources nearby')

			# Find images that overlap?

			ims = []
			for j,wcs in enumerate(wcses):

				print('Filename', T.filename[j])
				ok,x,y = wcs.radec2pixelxy(r,d)
				print('WCS', j, '-> x,y:', x,y)

				if not anwcs_radec_is_inside_image(wcs, r, d):
					continue

				tim = wise.read_wise_level1b(
					T.filename[j].replace('-int-1b.fits',''),
					nanomaggies=True, mask_gz=True, unc_gz=True,
					sipwcs=True, constantInvvar=True, radecrad=(r,d,Wrad))
				ims.append(tim)
			print('Found', len(ims), 'images containing this source')

			tr = Tractor(ims, cat1)
			tractors.append(tr)
			

		if len(groups):
			# TODO!
			assert(False)

		sys.exit(0)



		# Find additional SDSS sources nearby = within R pixels radius.
		R = 30.
		#R = 50.
		rad = R * 0.396 / 3600.

		cats = []
		objs = []
		for run,camcol,field,r,d in zip(S.run, S.camcol, S.field, S.ra, S.dec):
			fn = sdss.retrieve('photoObj', run, camcol, field, band=sband)
			print('Reading', fn)
			oo = fits_table(fn)
			print('Got', len(oo))
			cat1,obj1,I = get_tractor_sources_dr9(None, None, None, bandname=sband,
												  objs=oo, radecrad=(r,d,rad), bands=[],
												  nanomaggies=True, extrabands=[band],
												  fixedComposites=True,
												  getobjs=True, getobjinds=True)
			print('Got', len(cat1), 'SDSS sources nearby')
			cats.append(cat1)
			objs.append(obj1[I])

		# Merge into one big catalog.
		cat = Catalog()
		for c in cats:
			for src in c:
				cat.append(src)
		S = merge_tables(objs)

		print('Merged catalog has', len(cat), 'entries')
		print('S table has', len(S))
		assert(len(S) == len(cat))

		if opt.ptsrc:
			print('Converting all sources to PointSources')
			pcat = Catalog()
			for src in cat:
				ps = PointSource(src.getPosition(), src.getBrightness())
				pcat.append(ps)
			print('PointSource catalog:', pcat)
			cat = pcat

		# ??
		WW = S
		#WW = tabledata()

		# cat = get_tractor_sources_dr9(None, None, None, bandname=sband,
		# 							  objs=S, bands=[], nanomaggies=True,
		# 							  extrabands=[band])

		print('Got', len(cat), 'tractor sources')
		#cat = Catalog(*cat)
		print(cat)
		for src in cat:
			print('  ', src)

		### FIXME -- match to WISE catalog to initialize mags?

		# Initialize WISE mags to be at least detectable
		# so that we identify the right pixel ROIs below.

		#minbright = NanoMaggies.magToNanomaggies()
		#minbright = 50.
		minbright = 250.

		cat.freezeParamsRecursive('*')
		cat.thawPathsTo(band)
		p0 = cat.getParams()
		cat.setParams(np.maximum(minbright, p0))

		print('Set minimum W1 brightness:')
		for src in cat:
			print('  ', src)

		# Cut images that don't overlap.
		ii = []
		for i,wcs in enumerate(wcses):
			isin = False
			for r,d in zip(S.ra, S.dec):
				if anwcs_radec_is_inside_image(wcs, r, d):
					isin = True
					break
			if isin:
				ii.append(i)
		T.cut(np.array(ii))
		print('Cut to', len(T), 'images containing sources')


		
	else:
		wfn = 'wise-sources-nearby.fits'
		if os.path.exists(wfn):
			print('Reading existing file', wfn)
			W = fits_table(wfn)
			print('Got', len(W), 'with range RA', W.ra.min(), W.ra.max(), ', Dec', W.dec.min(), W.dec.max())
		else:
			# Range of WISE slices (inclusive) containing this Dec range.
			ws0, ws1 = 26,27
			WW = []
			for w in range(ws0, ws1+1):
				fn = os.path.join(wisecatdir, 'wise-allsky-cat-part%02i-radec.fits' % w)
				print('Searching for sources in', fn)
				W = fits_table(fn)
				I = np.flatnonzero((W.ra >= r0) * (W.ra <= r1) * (W.dec >= d0) * (W.dec <= d1))
				fn = os.path.join(wisecatdir, 'wise-allsky-cat-part%02i.fits' % w)
				print('Reading', len(I), 'rows from', fn)
				W = fits_table(fn, rows=I)
				print('Cut to', len(W), 'sources in range')
				WW.append(W)
			W = merge_tables(WW)
			del WW
			print('Total of', len(W))
			W.writeto(wfn)
			print('wrote', wfn)
	
		# DEBUG
		W.cut((W.ra >= rl) * (W.ra <= rh) * (W.dec >= dl) * (W.dec <= dh))
		print('Cut to', len(W), 'in the central region')
	
		print('Creating', len(W), 'Tractor sources')
		cat = Catalog()
		for i in range(len(W)):
			w1 = W.w1mpro[i]
			nm = NanoMaggies.magToNanomaggies(w1)
			cat.append(PointSource(RaDecPos(W.ra[i], W.dec[i]), NanoMaggies(w1=nm)))

		WW = W

	cat.freezeParamsRecursive('*')
	cat.thawPathsTo(band)

	cat0 = cat.getParams()
	br0 = [src.getBrightness().copy() for src in cat]
	nm0 = np.array([b.getBand(band) for b in br0])

	WW.nm0 = nm0

	w1psf = wise.get_psf_model(bandnum, opt.pixpsf)

	# Create fake image in the "coadd" footprint in order to find overlapping
	# sources.
	H,W = int(cowcs.imageh), int(cowcs.imagew)
	# MAGIC -- sigma a bit smaller than typical images (4.0-ish)
	sig = 3.5
	# typical zeropoint
	zp = 20.752
	
	faketim = Image(data=np.zeros((H,W), np.float32),
					invvar=np.zeros((H,W), np.float32) + (1./sig**2),
					psf=w1psf, wcs=ConstantFitsWcs(cowcs), sky=ConstantSky(0.),
					photocal = LinearPhotoCal(NanoMaggies.zeropointToScale(zp),
											  band=band),
					#photocal=LinearPhotoCal(1., band=band),
					name='fake')
	minsb = 0.1 * sig
	#minsb = 0.

	# pc = faketim.getPhotoCal()
	# print 'Source counts:'
	# for src in cat:
	# 	print '  ', src
	# 	print '-->', pc.brightnessToCounts(src.getBrightness())
	# 	print '  -->', [pc.brightnessToCounts(br) for br in src.getBrightnesses()]
	# print 'Source pixel positions:'
	# wcs = faketim.getWcs()
	# for src in cat:
	# 	print '  ', src
	# 	print '--> x,y', wcs.positionToPixel(src.getPosition())
	

	print('Finding overlapping sources...')
	t0 = Time()
	tractor = Tractor([faketim], cat)
	groups,L,fakemod = tractor.getOverlappingSources(0, minsb=minsb)
	print('Overlapping sources took', Time()-t0)
	print('Got', len(groups), 'groups of sources')
	nl = L.max()
	gslices = find_objects(L, nl)

	print('unique labels:', np.unique(L))

	# plt.clf()
	# plt.imshow(fakemod, interpolation='nearest', origin='lower',
	# 		   vmin=0, vmax=sig*3.)
	# plt.title('Fakemod')
	# ps.savefig()
	# 
	# for IM in [L, (L>0)]:
	# 	plt.clf()
	# 	plt.imshow(IM, interpolation='nearest', origin='lower')
	# 	plt.gray()
	# 	wcs = faketim.getWcs()
	# 	xy = []
	# 	for src in cat:
	# 		x,y = wcs.positionToPixel(src.getPosition())
	# 		xy.append((x,y))
	# 	xy = np.array(xy)
	# 	ax = plt.axis()
	# 	plt.plot(xy[:,0], xy[:,1], 'r+')
	# 	plt.title('Source groups')
	# 	ps.savefig()

	# Find sources touching each group's (rectangular) ROI
	tgroups = {}
	for i,gslice in enumerate(gslices):
		gl = i+1
		tg = np.unique(L[gslice])
		tsrcs = []
		for g in tg:
			if not g in [gl,0]:
				if g in groups:
					tsrcs.extend(groups[g])
		tgroups[gl] = tsrcs


	# for i,gslice in enumerate(gslices):
	# 	if not (i+1) in groups:
	# 		continue
	# 
	# 	plt.clf()
	# 	plt.imshow(IM[gslice], interpolation='nearest', origin='lower')
	# 	plt.gray()
	# 	wcs = faketim.getWcs()
	# 	xy = []
	# 	y0,x0 = gslice[0].start, gslice[1].start
	# 	for src in cat:
	# 		x,y = wcs.positionToPixel(src.getPosition())
	# 		xy.append((x-x0,y-y0))
	# 	xy = np.array(xy)
	# 
	# 	ax = plt.axis()
	# 
	# 	plt.plot(xy[:,0], xy[:,1], 'r+')
	# 
	# 	I = np.array(groups[i+1])
	# 	if len(I):
	# 		plt.plot(xy[I,0], xy[I,1], 'g.')
	# 
	# 	I = np.array(tgroups[i+1])
	# 	if len(I):
	# 		plt.plot(xy[I,0], xy[I,1], 'gx')
	# 
	# 	ps.savefig()
	# 
	# 	plt.axis(ax)
	# 	ps.savefig()



	print('Group size histogram:')
	ng = Counter()
	for g in groups.values():
		ng[len(g)] += 1
	kk = ng.keys()
	kk.sort()
	for k in kk:
		print('  ', k, 'sources:', ng[k], 'groups')

	nms = []
	tims = []
	allrois = {}
	badrois = {}

	if opt.threads:
		mp = multiproc(opt.threads)
	else:
		mp = multiproc(1)

	tims = mp.map(_read_l1b, T.filename)

	for imi,tim in enumerate(tims):
		tim.psf = w1psf
		H,W = tim.shape
		nin = 0
		for src in cat:
			x,y = tim.getWcs().positionToPixel(src.getPosition())
			if x >= 0 and y >= 0 and x < W and y < H:
				nin += 1
		print('Number of sources inside image:', nin)

		tractor = Tractor([tim], cat)
		tractor.freezeParam('images')
		### ??
		cat.setParams(cat0)

		pgroups = 0
		pobjs = 0

		for gi in range(len(gslices)):
			gl = gi
			# note, gslices is zero-indexed
			gslice = gslices[gl]
			gl += 1
			if not gl in groups:
				print('Group', gl, 'not in groups array; skipping')
				continue
			gsrcs = groups[gl]
			tsrcs = tgroups[gl]

			# print 'Group number', (gi+1), 'of', len(Gorder), ', id', gl, ': sources', gsrcs
			# print 'sources in groups touching slice:', tsrcs

			# Convert from 'canonical' ROI to this image.
			yl,yh = gslice[0].start, gslice[0].stop
			xl,xh = gslice[1].start, gslice[1].stop
			x0,y0 = W-1,H-1
			x1,y1 = 0,0
			for x,y in [(xl,yl),(xh-1,yl),(xh-1,yh-1),(xl,yh-1)]:
				r,d = cowcs.pixelxy2radec(x+1, y+1)
				x,y = tim.getWcs().positionToPixel(RaDecPos(r,d))
				x = int(np.round(x))
				y = int(np.round(y))

				x = np.clip(x, 0, W-1)
				y = np.clip(y, 0, H-1)
				x0 = min(x0, x)
				y0 = min(y0, y)
				x1 = max(x1, x)
				y1 = max(y1, y)
			if x1 == x0 or y1 == y0:
				print('Gslice', gslice, 'is completely outside this image')
				continue
			
			gslice = (slice(y0,y1+1), slice(x0, x1+1))

			if np.all(tim.getInvError()[gslice] == 0):
				print('This whole object group has invvar = 0.')

				if not gl in badrois:
					badrois[gl] = {}
				badrois[gl][imi] = gslice

				continue

			if not gl in allrois:
				allrois[gl] = {}
			allrois[gl][imi] = gslice

			if not opt.individual:
				continue

			fullcat = tractor.catalog
			subcat = Catalog(*[fullcat[i] for i in gsrcs + tsrcs])
			for i in range(len(tsrcs)):
				subcat.freezeParam(len(gsrcs) + i)
			tractor.catalog = subcat

			print(len(gsrcs), 'sources unfrozen; total', len(subcat))

			pgroups += 1
			pobjs += len(gsrcs)
			
			t0 = Time()
			tractor.optimize_forced_photometry(minsb=minsb, mindlnp=1.,
											   rois=[gslice])
			print('optimize_forced_photometry took', Time()-t0)

			tractor.catalog = fullcat

		# mod = tractor.getModelImage(0, minsb=minsb)
		# noise = np.random.normal(size=mod.shape)
		# noise[tim.getInvError() == 0] = 0.
		# nz = (tim.getInvError() > 0)
		# noise[nz] *= (1./tim.getInvError()[nz])
		# ima = dict(interpolation='nearest', origin='lower',
		# 		   vmin=tim.zr[0], vmax=tim.zr[1])
		# imchi = dict(interpolation='nearest', origin='lower',
		# 		   vmin=-5, vmax=5)
		# plt.clf()
		# plt.subplot(2,2,1)
		# plt.imshow(tim.getImage(), **ima)
		# plt.gray()
		# plt.subplot(2,2,2)
		# plt.imshow(mod, **ima)
		# plt.gray()
		# plt.subplot(2,2,3)
		# plt.imshow((tim.getImage() - mod) * tim.getInvError(), **imchi)
		# plt.gray()
		# plt.subplot(2,2,4)
		# plt.imshow(mod + noise, **ima)
		# plt.gray()
		# plt.suptitle('W1, scan %s, frame %i' % (sid, fnum))
		# ps.savefig()

		if opt.individual:
			print('Photometered', pgroups, 'groups containing', pobjs, 'objects')
	
			cat.thawPathsTo(band)
			nm1 = np.array([src.getBrightness().getBand(band) for src in cat])
			nms.append(nm1)
	
			WW.nms = np.array(nms).T
			fn = opt.output % imi
			WW.writeto(fn)
			print('Wrote', fn)

	return dict(cat0=cat0, WW=WW, band=band, tims=tims,
				allrois=allrois, badrois=badrois, groups=groups,
				tgroups=tgroups, minsb=minsb,
				gslices=gslices, cat=cat)

Example 19

Project: plexpy
Source File: webstart.py
View license
def initialize(options):

    # HTTPS stuff stolen from sickbeard
    enable_https = options['enable_https']
    https_cert = options['https_cert']
    https_key = options['https_key']

    if enable_https:
        # If either the HTTPS certificate or key do not exist, try to make self-signed ones.
        if plexpy.CONFIG.HTTPS_CREATE_CERT and \
            (not (https_cert and os.path.exists(https_cert)) or not (https_key and os.path.exists(https_key))):
            if not create_https_certificates(https_cert, https_key):
                logger.warn(u"PlexPy WebStart :: Unable to create certificate and key. Disabling HTTPS")
                enable_https = False

        if not (os.path.exists(https_cert) and os.path.exists(https_key)):
            logger.warn(u"PlexPy WebStart :: Disabled HTTPS because of missing certificate and key.")
            enable_https = False

    options_dict = {
        'server.socket_port': options['http_port'],
        'server.socket_host': options['http_host'],
        'environment': options['http_environment'],
        'server.thread_pool': 10,
        'tools.encode.on': True,
        'tools.encode.encoding': 'utf-8',
        'tools.decode.on': True
    }

    if plexpy.DEV:
        options_dict['environment'] = "test_suite"
        options_dict['engine.autoreload.on'] = True

    if enable_https:
        options_dict['server.ssl_certificate'] = https_cert
        options_dict['server.ssl_private_key'] = https_key
        protocol = "https"
    else:
        protocol = "http"

    if options['http_password']:
        logger.info(u"PlexPy WebStart :: Web server authentication is enabled, username is '%s'", options['http_username'])
        if options['http_basic_auth']:
            auth_enabled = session_enabled = False
            basic_auth_enabled = True
        else:
            options_dict['tools.sessions.on'] = auth_enabled = session_enabled = True
            basic_auth_enabled = False
            cherrypy.tools.auth = cherrypy.Tool('before_handler', webauth.check_auth)
    else:
        auth_enabled = session_enabled = basic_auth_enabled = False

    if not options['http_root'] or options['http_root'] == '/':
        plexpy.HTTP_ROOT = options['http_root'] = '/'
    else:
        plexpy.HTTP_ROOT = options['http_root'] = '/' + options['http_root'].strip('/') + '/'

    cherrypy.config.update(options_dict)

    conf = {
        '/': {
            'tools.staticdir.root': os.path.join(plexpy.PROG_DIR, 'data'),
            'tools.proxy.on': options['http_proxy'],  # pay attention to X-Forwarded-Proto header
            'tools.gzip.on': True,
            'tools.gzip.mime_types': ['text/html', 'text/plain', 'text/css',
                                      'text/javascript', 'application/json',
                                      'application/javascript'],
            'tools.auth.on': auth_enabled,
            'tools.sessions.on': session_enabled,
            'tools.sessions.timeout': 30 * 24 * 60,  # 30 days
            'tools.auth_basic.on': basic_auth_enabled,
            'tools.auth_basic.realm': 'PlexPy web server',
            'tools.auth_basic.checkpassword': cherrypy.lib.auth_basic.checkpassword_dict({
                options['http_username']: options['http_password']})
        },
        '/api': {
            'tools.auth_basic.on': False
        },
        '/interfaces': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': "interfaces",
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        '/images': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': "interfaces/default/images",
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        '/css': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': "interfaces/default/css",
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        '/fonts': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': "interfaces/default/fonts",
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        '/js': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': "interfaces/default/js",
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        '/json': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': "interfaces/default/json",
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        '/xml': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': "interfaces/default/xml",
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        '/cache': {
            'tools.staticdir.on': True,
            'tools.staticdir.dir': plexpy.CONFIG.CACHE_DIR,
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        },
        #'/pms_image_proxy': {
        #    'tools.staticdir.on': True,
        #    'tools.staticdir.dir': os.path.join(plexpy.CONFIG.CACHE_DIR, 'images'),
        #    'tools.caching.on': True,
        #    'tools.caching.force': True,
        #    'tools.caching.delay': 0,
        #    'tools.expires.on': True,
        #    'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
        #    'tools.auth.on': False,
        #    'tools.sessions.on': False
        #},
        '/favicon.ico': {
            'tools.staticfile.on': True,
            'tools.staticfile.filename': os.path.abspath(os.path.join(plexpy.PROG_DIR, 'data/interfaces/default/images/favicon.ico')),
            'tools.caching.on': True,
            'tools.caching.force': True,
            'tools.caching.delay': 0,
            'tools.expires.on': True,
            'tools.expires.secs': 60 * 60 * 24 * 30,  # 30 days
            'tools.auth.on': False,
            'tools.sessions.on': False
        }
    }

    # Prevent time-outs
    cherrypy.engine.timeout_monitor.unsubscribe()
    cherrypy.tree.mount(WebInterface(), options['http_root'], config=conf)

    try:
        logger.info(u"PlexPy WebStart :: Starting PlexPy web server on %s://%s:%d%s", protocol,
                    options['http_host'], options['http_port'], options['http_root'])
        cherrypy.process.servers.check_port(str(options['http_host']), options['http_port'])
        if not plexpy.DEV:
            cherrypy.server.start()
        else:
            cherrypy.engine.signals.subscribe()
            cherrypy.engine.start()
            cherrypy.engine.block()
    except IOError:
        sys.stderr.write('Failed to start on port: %i. Is something else running?\n' % (options['http_port']))
        sys.exit(1)

    cherrypy.server.wait()

Example 20

Project: stonix
Source File: DisableRemoveableStorage.py
View license
    def fixMac(self):
        '''This method will attempt to disable certain storage ports by moving
        certain kernel extensions.  If the check box is checked we will
        move the kernel (if present) associated with that storage port/device
        into a folder designated for those disabled extensions.  If the
        check box is unchecked, we will assume the user doesn't want this
        disabled and if the kernel is no longer where it should be, we will
        check the disabled extensions folder to see if it was previously
        disabled.  If it's in that folder, we will move it back.
        @author: bemalmbe
        @return: bool
        @change: dwalker 8/19/2014
        '''
        debug = ""
        check = "/usr/sbin/kextstat "
        unload = "/sbin/kextunload "
        load = "/sbin/kextload "
        filepath = "/System/Library/Extensions/"
        success = True
        #created1 = False
        created2 = False
        if not os.path.exists(self.plistpath):
            createFile(self.plistpath, self.logger)
        self.iditerator += 1
        myid = iterate(self.iditerator, self.rulenumber)
        cmd = "/bin/launchctl unload " + self.plistpath
        event = {"eventtype": "commandstring",
                 "command": cmd}
        self.statechglogger.recordchgevent(myid, event)
        #created1 = True
        self.iditerator += 1
        myid = iterate(self.iditerator, self.rulenumber)
        event = {"eventtype": "creation",
                 "filepath": self.plistpath}
        self.statechglogger.recordchgevent(myid, event)
        if os.path.exists(self.plistpath):
            uid, gid = "", ""
            statdata = os.stat(self.plistpath)
            mode = stat.S_IMODE(statdata.st_mode)
            ownergrp = getUserGroupName(self.plistpath)
            owner = ownergrp[0]
            group = ownergrp[1]
            if grp.getgrnam("wheel")[2] != "":
                gid = grp.getgrnam("wheel")[2]
            if pwd.getpwnam("root")[2] != "":
                uid = pwd.getpwnam("root")[2]
#             if not created1:
#                 if mode != 420 or owner != "root" or group != "wheel":
#                     origuid = statdata.st_uid
#                     origgid = statdata.st_gid
#                     if gid:
#                         if uid:
#                             self.iditerator += 1
#                             myid = iterate(self.iditerator,
#                                            self.rulenumber)
#                             event = {"eventtype": "perm",
#                                      "startstate": [origuid,
#                                                     origgid, mode],
#                                      "endstate": [uid, gid, 420],
#                                      "filepath": self.plistpath}
            contents = readFile(self.plistpath, self.logger)
            contentstring = ""
            for line in contents:
                contentstring += line.strip()
            if not re.search(self.plistregex, contentstring):
                tmpfile = self.plistpath + ".tmp"
                if not writeFile(tmpfile, self.plistcontents, self.logger):
                    success = False
#                 elif not created1:
#                     self.iditerator += 1
#                     myid = iterate(self.iditerator, self.rulenumber)
#                     event = {"eventtype": "conf",
#                              "filepath": self.plistpath}
#                     self.statechglogger.recordchgevent(myid, event)
#                     self.statechglogger.recordfilechange(self.plistpath,
#                                                          tmpfile, myid)
#                     os.rename(tmpfile, self.plistpath)
#                     if uid and gid:
#                         os.chown(self.plistpath, uid, gid)
#                     os.chmod(self.plistpath, 420)
                else:
                    os.rename(tmpfile, self.plistpath)
                    if uid and gid:
                        os.chown(self.plistpath, uid, gid)
                    os.chmod(self.plistpath, 420)
        if not os.path.exists(self.daemonpath):
            if not createFile(self.daemonpath, self.logger):
                success = False
                self.detailedresults += "Unable to create the disablestorage python file\n"
        self.iditerator += 1
        myid = iterate(self.iditerator, self.rulenumber)
        event = {"eventtype": "creation",
                 "filepath": self.daemonpath}
        self.statechglogger.recordchgevent(myid, event)
        if os.path.exists(self.daemonpath):
            uid, gid = "", ""
            statdata = os.stat(self.daemonpath)
            mode = stat.S_IMODE(statdata.st_mode)
            ownergrp = getUserGroupName(self.daemonpath)
            owner = ownergrp[0]
            group = ownergrp[1]
            if grp.getgrnam("admin")[2] != "":
                gid = grp.getgrnam("admin")[2]
            if pwd.getpwnam("root")[2] != "":
                uid = pwd.getpwnam("root")[2]
            #if we didn't have to create the file then we want to record
            #incorrect permissions as state event
            if not created2:
                if mode != 509 or owner != "root" or group != "admin":
                    origuid = statdata.st_uid
                    origgid = statdata.st_gid
                    if gid:
                        if uid:
                            self.iditerator += 1
                            myid = iterate(self.iditerator,
                                           self.rulenumber)
                            event = {"eventtype": "perm",
                                     "startstate": [origuid,
                                                    origgid, mode],
                                     "endstate": [uid, gid, 509],
                                     "filepath": self.daemonpath}
            contents = readFile(self.daemonpath, self.logger)
            contentstring = ""
            for line in contents:
                contentstring += line
            if contentstring != self.daemoncontents:
                tmpfile = self.daemonpath + ".tmp"
                if writeFile(tmpfile, self.daemoncontents, self.logger):
                    if not created2:
                        self.iditerator += 1
                        myid = iterate(self.iditerator, self.rulenumber)
                        event = {"eventtype": "conf",
                                 "filepath": self.daemonpath}
                        self.statechglogger.recordchgevent(myid, event)
                        self.statechglogger.recordfilechange(self.daemonpath,
                                                             tmpfile, myid)
                        os.rename(tmpfile, self.daemonpath)
                        if uid and gid:
                            os.chown(self.daemonpath, uid, gid)
                        os.chmod(self.daemonpath, 509)
                    else:
                        os.rename(tmpfile, self.daemonpath)
                        if uid and gid:
                            os.chown(self.daemonpath, uid, gid)
                        os.chmod(self.daemonpath, 509)
                else:
                    success = False
            elif not checkPerms(self.daemonpath, [0, 0, 509], self.logger):
                if not setPerms(self.daemonpath, [0, 0, 509], self.logger):
                    success = False
        if re.search("^10.11", self.environ.getosver()):
            usb = "IOUSBMassStorageDriver"
        else:
            usb = "IOUSBMassStorageClass"
        cmd = check + "| grep " + usb
        self.ch.executeCommand(cmd)

        # if return code is 0, the kernel module is loaded, thus we need
        # to disable it
        if self.ch.getReturnCode() == 0:
            cmd = unload + filepath + usb + ".kext/"
            if not self.ch.executeCommand(cmd):
                debug += "Unable to disable USB\n"
                success = False
            else:
                self.iditerator += 1
                myid = iterate(self.iditerator, self.rulenumber)
                undo = load + filepath + usb + ".kext/"
                event = {"eventtype": "comm",
                         "command": undo}
                self.statechglogger.recordchgevent(myid, event)
        fw = "IOFireWireSerialBusProtocolTransport"
        cmd = check + "| grep " + fw
        self.ch.executeCommand(cmd)

        # if return code is 0, the kernel module is loaded, thus we need
        # to disable it
        if self.ch.getReturnCode() == 0:
            cmd = unload + filepath + fw + ".kext/"
            if not self.ch.executeCommand(cmd):
                debug += "Unable to disable Firewire\n"
                success = False
            else:
                self.iditerator += 1
                myid = iterate(self.iditerator, self.rulenumber)
                undo = load + filepath + fw + ".kext/"
                event = {"eventtype": "comm",
                         "command": undo}
                self.statechglogger.recordchgevent(myid, event)
        tb = "AppleThunderboltUTDM"
        cmd = check + "| grep " + tb
        self.ch.executeCommand(cmd)

        # if return code is 0, the kernel module is loaded, thus we need
        # to disable it
        if self.ch.getReturnCode() == 0:
            cmd = unload + "/System/Library/Extensions/" + tb + ".kext/"
            if not self.ch.executeCommand(cmd):
                debug += "Unable to disable Thunderbolt\n"
                success = False
            else:
                self.iditerator += 1
                myid = iterate(self.iditerator, self.rulenumber)
                undo = load + filepath + tb + ".kext/"
                event = {"eventtype": "comm",
                         "command": undo}
                self.statechglogger.recordchgevent(myid, event)
        sd = "AppleSDXC"
        cmd = check + "| grep " + sd
        self.ch.executeCommand(cmd)

        # if return code is 0, the kernel module is loaded, thus we need
        # to disable it
        if self.ch.getReturnCode() == 0:
            cmd = unload + "/System/Library/Extensions/" + sd + ".kext/"
            if not self.ch.executeCommand(cmd):
                debug += "Unable to disable SD Card functionality\n"
                success = False
            else:
                self.iditerator += 1
                myid = iterate(self.iditerator, self.rulenumber)
                undo = load + filepath + sd + ".kext/"
                event = {"eventtype": "comm",
                         "command": undo}
                self.statechglogger.recordchgevent(myid, event)
        cmd = ["/bin/launchctl", "load", self.plistpath]
        if not self.ch.executeCommand(cmd):
            debug += "Unable to load the launchctl job to regularly " + \
                "disable removeable storage.  May need to be done manually\n"
            success = False
        if debug:
            self.logger.log(LogPriority.DEBUG, debug)
        return success

Example 21

Project: dcos
Source File: __init__.py
View license
def build(package_store, name, variant, clean_after_build, recursive=False):
    assert isinstance(package_store, PackageStore)
    print("Building package {} variant {}".format(name, pkgpanda.util.variant_str(variant)))
    tmpdir = tempfile.TemporaryDirectory(prefix="pkgpanda_repo")
    repository = Repository(tmpdir.name)

    package_dir = package_store.get_package_folder(name)

    def src_abs(name):
        return package_dir + '/' + name

    def cache_abs(filename):
        return package_store.get_package_cache_folder(name) + '/' + filename

    # Build pkginfo over time, translating fields from buildinfo.
    pkginfo = {}

    # Build up the docker command arguments over time, translating fields as needed.
    cmd = DockerCmd()

    assert (name, variant) in package_store.packages, \
        "Programming error: name, variant should have been validated to be valid before calling build()."

    builder = IdBuilder(package_store.get_buildinfo(name, variant))
    final_buildinfo = dict()

    builder.add('name', name)
    builder.add('variant', pkgpanda.util.variant_str(variant))

    # Convert single_source -> sources
    if builder.has('sources'):
        if builder.has('single_source'):
            raise BuildError('Both sources and single_source cannot be specified at the same time')
        sources = builder.take('sources')
    elif builder.has('single_source'):
        sources = {name: builder.take('single_source')}
        builder.replace('single_source', 'sources', sources)
    else:
        builder.add('sources', {})
        sources = dict()
        print("NOTICE: No sources specified")

    final_buildinfo['sources'] = sources

    # Construct the source fetchers, gather the checkout ids from them
    checkout_ids = dict()
    fetchers = dict()
    try:
        for src_name, src_info in sorted(sources.items()):
            # TODO(cmaloney): Switch to a unified top level cache directory shared by all packages
            cache_dir = package_store.get_package_cache_folder(name) + '/' + src_name
            check_call(['mkdir', '-p', cache_dir])
            fetcher = get_src_fetcher(src_info, cache_dir, package_dir)
            fetchers[src_name] = fetcher
            checkout_ids[src_name] = fetcher.get_id()
    except ValidationError as ex:
        raise BuildError("Validation error when fetching sources for package: {}".format(ex))

    for src_name, checkout_id in checkout_ids.items():
        # NOTE: single_source buildinfo was expanded above so the src_name is
        # always correct here.
        # Make sure we never accidentally overwrite something which might be
        # important. Fields should match if specified (And that should be
        # tested at some point). For now disallowing identical saves hassle.
        assert_no_duplicate_keys(checkout_id, final_buildinfo['sources'][src_name])
        final_buildinfo['sources'][src_name].update(checkout_id)

    # Add the sha1 of the buildinfo.json + build file to the build ids
    builder.update('sources', checkout_ids)
    build_script = src_abs(builder.take('build_script'))
    # TODO(cmaloney): Change dest name to build_script_sha1
    builder.replace('build_script', 'build', pkgpanda.util.sha1(build_script))
    builder.add('pkgpanda_version', pkgpanda.build.constants.version)

    extra_dir = src_abs("extra")
    # Add the "extra" folder inside the package as an additional source if it
    # exists
    if os.path.exists(extra_dir):
        extra_id = hash_folder(extra_dir)
        builder.add('extra_source', extra_id)
        final_buildinfo['extra_source'] = extra_id

    # Figure out the docker name.
    docker_name = builder.take('docker')
    cmd.container = docker_name

    # Add the id of the docker build environment to the build_ids.
    try:
        docker_id = get_docker_id(docker_name)
    except CalledProcessError:
        # docker pull the container and try again
        check_call(['docker', 'pull', docker_name])
        docker_id = get_docker_id(docker_name)

    builder.update('docker', docker_id)

    # TODO(cmaloney): The environment variables should be generated during build
    # not live in buildinfo.json.
    pkginfo['environment'] = builder.take('environment')

    # Whether pkgpanda should on the host make sure a `/var/lib` state directory is available
    pkginfo['state_directory'] = builder.take('state_directory')
    if pkginfo['state_directory'] not in [True, False]:
        raise BuildError("state_directory in buildinfo.json must be a boolean `true` or `false`")

    username = None
    if builder.has('username'):
        username = builder.take('username')
        if not isinstance(username, str):
            raise BuildError("username in buildinfo.json must be either not set (no user for this"
                             " package), or a user name string")
        try:
            pkgpanda.UserManagement.validate_username(username)
        except ValidationError as ex:
            raise BuildError("username in buildinfo.json didn't meet the validation rules. {}".format(ex))
        pkginfo['username'] = username

    group = None
    if builder.has('group'):
        group = builder.take('group')
        if not isinstance(group, str):
            raise BuildError("group in buildinfo.json must be either not set (use default group for this user)"
                             ", or group must be a string")
        try:
            pkgpanda.UserManagement.validate_group_name(group)
        except ValidationError as ex:
            raise BuildError("group in buildinfo.json didn't meet the validation rules. {}".format(ex))
        pkginfo['group'] = group

    # Packages need directories inside the fake install root (otherwise docker
    # will try making the directories on a readonly filesystem), so build the
    # install root now, and make the package directories in it as we go.
    install_dir = tempfile.mkdtemp(prefix="pkgpanda-")

    active_packages = list()
    active_package_ids = set()
    active_package_variants = dict()
    auto_deps = set()

    # Final package has the same requires as the build.
    requires = builder.take('requires')
    pkginfo['requires'] = requires

    if builder.has("sysctl"):
        pkginfo["sysctl"] = builder.take("sysctl")

    # TODO(cmaloney): Pull generating the full set of requires a function.
    to_check = copy.deepcopy(requires)
    if type(to_check) != list:
        raise BuildError("`requires` in buildinfo.json must be an array of dependencies.")
    while to_check:
        requires_info = to_check.pop(0)
        requires_name, requires_variant = expand_require(requires_info)

        if requires_name in active_package_variants:
            # TODO(cmaloney): If one package depends on the <default>
            # variant of a package and 1+ others depends on a non-<default>
            # variant then update the dependency to the non-default variant
            # rather than erroring.
            if requires_variant != active_package_variants[requires_name]:
                # TODO(cmaloney): Make this contain the chains of
                # dependencies which contain the conflicting packages.
                # a -> b -> c -> d {foo}
                # e {bar} -> d {baz}
                raise BuildError(
                    "Dependncy on multiple variants of the same package {}. variants: {} {}".format(
                        requires_name,
                        requires_variant,
                        active_package_variants[requires_name]))

            # The variant has package {requires_name, variant} already is a
            # dependency, don't process it again / move on to the next.
            continue

        active_package_variants[requires_name] = requires_variant

        # Figure out the last build of the dependency, add that as the
        # fully expanded dependency.
        requires_last_build = package_store.get_last_build_filename(requires_name, requires_variant)
        if not os.path.exists(requires_last_build):
            if recursive:
                # Build the dependency
                build(package_store, requires_name, requires_variant, clean_after_build, recursive)
            else:
                raise BuildError("No last build file found for dependency {} variant {}. Rebuild "
                                 "the dependency".format(requires_name, requires_variant))

        try:
            pkg_id_str = load_string(requires_last_build)
            auto_deps.add(pkg_id_str)
            pkg_buildinfo = package_store.get_buildinfo(requires_name, requires_variant)
            pkg_requires = pkg_buildinfo['requires']
            pkg_path = repository.package_path(pkg_id_str)
            pkg_tar = pkg_id_str + '.tar.xz'
            if not os.path.exists(package_store.get_package_cache_folder(requires_name) + '/' + pkg_tar):
                raise BuildError(
                    "The build tarball {} refered to by the last_build file of the dependency {} "
                    "variant {} doesn't exist. Rebuild the dependency.".format(
                        pkg_tar,
                        requires_name,
                        requires_variant))

            active_package_ids.add(pkg_id_str)

            # Mount the package into the docker container.
            cmd.volumes[pkg_path] = "/opt/mesosphere/packages/{}:ro".format(pkg_id_str)
            os.makedirs(os.path.join(install_dir, "packages/{}".format(pkg_id_str)))

            # Add the dependencies of the package to the set which will be
            # activated.
            # TODO(cmaloney): All these 'transitive' dependencies shouldn't
            # be available to the package being built, only what depends on
            # them directly.
            to_check += pkg_requires
        except ValidationError as ex:
            raise BuildError("validating package needed as dependency {0}: {1}".format(requires_name, ex)) from ex
        except PackageError as ex:
            raise BuildError("loading package needed as dependency {0}: {1}".format(requires_name, ex)) from ex

    # Add requires to the package id, calculate the final package id.
    # NOTE: active_packages isn't fully constructed here since we lazily load
    # packages not already in the repository.
    builder.update('requires', list(active_package_ids))
    version_extra = None
    if builder.has('version_extra'):
        version_extra = builder.take('version_extra')

    build_ids = builder.get_build_ids()
    version_base = hash_checkout(build_ids)
    version = None
    if builder.has('version_extra'):
        version = "{0}-{1}".format(version_extra, version_base)
    else:
        version = version_base
    pkg_id = PackageId.from_parts(name, version)

    # Everything must have been extracted by now. If it wasn't, then we just
    # had a hard error that it was set but not used, as well as didn't include
    # it in the caluclation of the PackageId.
    builder = None

    # Save the build_ids. Useful for verify exactly what went into the
    # package build hash.
    final_buildinfo['build_ids'] = build_ids
    final_buildinfo['package_version'] = version

    # Save the package name and variant. The variant is used when installing
    # packages to validate dependencies.
    final_buildinfo['name'] = name
    final_buildinfo['variant'] = variant

    # If the package is already built, don't do anything.
    pkg_path = package_store.get_package_cache_folder(name) + '/{}.tar.xz'.format(pkg_id)

    # Done if it exists locally
    if exists(pkg_path):
        print("Package up to date. Not re-building.")

        # TODO(cmaloney): Updating / filling last_build should be moved out of
        # the build function.
        write_string(package_store.get_last_build_filename(name, variant), str(pkg_id))

        return pkg_path

    # Try downloading.
    dl_path = package_store.try_fetch_by_id(pkg_id)
    if dl_path:
        print("Package up to date. Not re-building. Downloaded from repository-url.")
        # TODO(cmaloney): Updating / filling last_build should be moved out of
        # the build function.
        write_string(package_store.get_last_build_filename(name, variant), str(pkg_id))
        print(dl_path, pkg_path)
        assert dl_path == pkg_path
        return pkg_path

    # Fall out and do the build since it couldn't be downloaded
    print("Unable to download from cache. Proceeding to build")

    print("Building package {} with buildinfo: {}".format(
        pkg_id,
        json.dumps(final_buildinfo, indent=2, sort_keys=True)))

    # Clean out src, result so later steps can use them freely for building.
    def clean():
        # Run a docker container to remove src/ and result/
        cmd = DockerCmd()
        cmd.volumes = {
            package_store.get_package_cache_folder(name): "/pkg/:rw",
        }
        cmd.container = "ubuntu:14.04.4"
        cmd.run("package-cleaner", ["rm", "-rf", "/pkg/src", "/pkg/result"])

    clean()

    # Only fresh builds are allowed which don't overlap existing artifacts.
    result_dir = cache_abs("result")
    if exists(result_dir):
        raise BuildError("result folder must not exist. It will be made when the package is "
                         "built. {}".format(result_dir))

    # 'mkpanda add' all implicit dependencies since we actually need to build.
    for dep in auto_deps:
        print("Auto-adding dependency: {}".format(dep))
        # NOTE: Not using the name pkg_id because that overrides the outer one.
        id_obj = PackageId(dep)
        add_package_file(repository, package_store.get_package_path(id_obj))
        package = repository.load(dep)
        active_packages.append(package)

    # Checkout all the sources int their respective 'src/' folders.
    try:
        src_dir = cache_abs('src')
        if os.path.exists(src_dir):
            raise ValidationError(
                "'src' directory already exists, did you have a previous build? " +
                "Currently all builds must be from scratch. Support should be " +
                "added for re-using a src directory when possible. src={}".format(src_dir))
        os.mkdir(src_dir)
        for src_name, fetcher in sorted(fetchers.items()):
            root = cache_abs('src/' + src_name)
            os.mkdir(root)

            fetcher.checkout_to(root)
    except ValidationError as ex:
        raise BuildError("Validation error when fetching sources for package: {}".format(ex))

    # Activate the packages so that we have a proper path, environment
    # variables.
    # TODO(cmaloney): RAII type thing for temproary directory so if we
    # don't get all the way through things will be cleaned up?
    install = Install(
        root=install_dir,
        config_dir=None,
        rooted_systemd=True,
        manage_systemd=False,
        block_systemd=True,
        fake_path=True,
        manage_users=False,
        manage_state_dir=False)
    install.activate(active_packages)
    # Rewrite all the symlinks inside the active path because we will
    # be mounting the folder into a docker container, and the absolute
    # paths to the packages will change.
    # TODO(cmaloney): This isn't very clean, it would be much nicer to
    # just run pkgpanda inside the package.
    rewrite_symlinks(install_dir, repository.path, "/opt/mesosphere/packages/")

    print("Building package in docker")

    # TODO(cmaloney): Run as a specific non-root user, make it possible
    # for non-root to cleanup afterwards.
    # Run the build, prepping the environment as necessary.
    mkdir(cache_abs("result"))

    # Copy the build info to the resulting tarball
    write_json(cache_abs("src/buildinfo.full.json"), final_buildinfo)
    write_json(cache_abs("result/buildinfo.full.json"), final_buildinfo)

    write_json(cache_abs("result/pkginfo.json"), pkginfo)

    # Make the folder for the package we are building. If docker does it, it
    # gets auto-created with root permissions and we can't actually delete it.
    os.makedirs(os.path.join(install_dir, "packages", str(pkg_id)))

    # TOOD(cmaloney): Disallow writing to well known files and directories?
    # Source we checked out
    cmd.volumes.update({
        # TODO(cmaloney): src should be read only...
        cache_abs("src"): "/pkg/src:rw",
        # The build script
        build_script: "/pkg/build:ro",
        # Getting the result out
        cache_abs("result"): "/opt/mesosphere/packages/{}:rw".format(pkg_id),
        install_dir: "/opt/mesosphere:ro"
    })

    if os.path.exists(extra_dir):
        cmd.volumes[extra_dir] = "/pkg/extra:ro"

    cmd.environment = {
        "PKG_VERSION": version,
        "PKG_NAME": name,
        "PKG_ID": pkg_id,
        "PKG_PATH": "/opt/mesosphere/packages/{}".format(pkg_id),
        "PKG_VARIANT": variant if variant is not None else "<default>",
        "NUM_CORES": multiprocessing.cpu_count()
    }

    try:
        # TODO(cmaloney): Run a wrapper which sources
        # /opt/mesosphere/environment then runs a build. Also should fix
        # ownership of /opt/mesosphere/packages/{pkg_id} post build.
        cmd.run("package-builder", [
            "/bin/bash",
            "-o", "nounset",
            "-o", "pipefail",
            "-o", "errexit",
            "/pkg/build"])
    except CalledProcessError as ex:
        raise BuildError("docker exited non-zero: {}\nCommand: {}".format(ex.returncode, ' '.join(ex.cmd)))

    # Clean up the temporary install dir used for dependencies.
    # TODO(cmaloney): Move to an RAII wrapper.
    check_call(['rm', '-rf', install_dir])

    print("Building package tarball")

    # Check for forbidden services before packaging the tarball:
    try:
        check_forbidden_services(cache_abs("result"), RESERVED_UNIT_NAMES)
    except ValidationError as ex:
        raise BuildError("Package validation failed: {}".format(ex))

    # TODO(cmaloney): Updating / filling last_build should be moved out of
    # the build function.
    write_string(package_store.get_last_build_filename(name, variant), str(pkg_id))

    # Bundle the artifacts into the pkgpanda package
    tmp_name = pkg_path + "-tmp.tar.xz"
    make_tar(tmp_name, cache_abs("result"))
    os.rename(tmp_name, pkg_path)
    print("Package built.")
    if clean_after_build:
        clean()
    return pkg_path

Example 22

Project: tp-libvirt
Source File: iface_options.py
View license
def run(test, params, env):
    """
    Test interafce xml options.

    1.Prepare test environment,destroy or suspend a VM.
    2.Edit xml and start the domain.
    3.Perform test operation.
    4.Recover test environment.
    5.Confirm the test result.
    """
    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    virsh_dargs = {'debug': True, 'ignore_status': False}

    def create_iface_xml(iface_mac):
        """
        Create interface xml file
        """
        iface = Interface(type_name=iface_type)
        source = ast.literal_eval(iface_source)
        if source:
            iface.source = source
        iface.model = iface_model if iface_model else "virtio"
        iface.mac_address = iface_mac
        driver_dict = {}
        driver_host = {}
        driver_guest = {}
        if iface_driver:
            driver_dict = ast.literal_eval(iface_driver)
        if iface_driver_host:
            driver_host = ast.literal_eval(iface_driver_host)
        if iface_driver_guest:
            driver_guest = ast.literal_eval(iface_driver_guest)
        iface.driver = iface.new_driver(driver_attr=driver_dict,
                                        driver_host=driver_host,
                                        driver_guest=driver_guest)
        logging.debug("Create new interface xml: %s", iface)
        return iface

    def modify_iface_xml(update, status_error=False):
        """
        Modify interface xml options
        """
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        xml_devices = vmxml.devices
        iface_index = xml_devices.index(
            xml_devices.by_device_tag("interface")[0])
        iface = xml_devices[iface_index]
        if iface_model:
            iface.model = iface_model
        else:
            del iface.model
        if iface_type:
            iface.type_name = iface_type
        del iface.source
        source = ast.literal_eval(iface_source)
        if source:
            net_ifs = utils_net.get_net_if(state="UP")
            # Check source device is valid or not,
            # if it's not in host interface list, try to set
            # source device to first active interface of host
            if (iface.type_name == "direct" and
                    source.has_key('dev') and
                    source['dev'] not in net_ifs):
                logging.warn("Source device %s is not a interface"
                             " of host, reset to %s",
                             source['dev'], net_ifs[0])
                source['dev'] = net_ifs[0]
            iface.source = source
        backend = ast.literal_eval(iface_backend)
        if backend:
            iface.backend = backend
        driver_dict = {}
        driver_host = {}
        driver_guest = {}
        if iface_driver:
            driver_dict = ast.literal_eval(iface_driver)
        if iface_driver_host:
            driver_host = ast.literal_eval(iface_driver_host)
        if iface_driver_guest:
            driver_guest = ast.literal_eval(iface_driver_guest)
        iface.driver = iface.new_driver(driver_attr=driver_dict,
                                        driver_host=driver_host,
                                        driver_guest=driver_guest)
        if iface.address:
            del iface.address

        logging.debug("New interface xml file: %s", iface)
        if unprivileged_user:
            # Create disk image for unprivileged user
            disk_index = xml_devices.index(
                xml_devices.by_device_tag("disk")[0])
            disk_xml = xml_devices[disk_index]
            logging.debug("source: %s", disk_xml.source)
            disk_source = disk_xml.source.attrs["file"]
            cmd = ("cp -fZ {0} {1} && chown {2}:{2} {1}"
                   "".format(disk_source, dst_disk, unprivileged_user))
            utils.run(cmd)
            disk_xml.source = disk_xml.new_disk_source(
                attrs={"file": dst_disk})
            vmxml.devices = xml_devices
            # Remove all channels to avoid of permission problem
            channels = vmxml.get_devices(device_type="channel")
            for channel in channels:
                vmxml.del_device(channel)

            vmxml.xmltreefile.write()
            logging.debug("New VM xml: %s", vmxml)
            utils.run("chmod a+rw %s" % vmxml.xml)
            virsh.define(vmxml.xml, **virsh_dargs)
        # Try to modify interface xml by update-device or edit xml
        elif update:
            iface.xmltreefile.write()
            ret = virsh.update_device(vm_name, iface.xml,
                                      ignore_status=True)
            libvirt.check_exit_status(ret, status_error)
        else:
            vmxml.devices = xml_devices
            vmxml.xmltreefile.write()
            vmxml.sync()

    def check_offloads_option(if_name, driver_options, session=None):
        """
        Check interface offloads by ethtool output
        """
        offloads = {"csum": "tx-checksumming",
                    "gso": "generic-segmentation-offload",
                    "tso4": "tcp-segmentation-offload",
                    "tso6": "tx-tcp6-segmentation",
                    "ecn": "tx-tcp-ecn-segmentation",
                    "ufo": "udp-fragmentation-offload"}
        if session:
            ret, output = session.cmd_status_output("ethtool -k %s | head"
                                                    " -18" % if_name)
        else:
            out = utils.run("ethtool -k %s | head -18" % if_name)
            ret, output = out.exit_status, out.stdout
        if ret:
            raise error.TestFail("ethtool return error code")
        logging.debug("ethtool output: %s", output)
        for offload in driver_options.keys():
            if offloads.has_key(offload):
                if (output.count(offloads[offload]) and
                    not output.count("%s: %s" % (
                        offloads[offload], driver_options[offload]))):
                    raise error.TestFail("offloads option %s: %s isn't"
                                         " correct in ethtool output" %
                                         (offloads[offload],
                                          driver_options[offload]))

    def run_xml_test(iface_mac):
        """
        Test for interface options in vm xml
        """
        # Get the interface object according the mac address
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        iface_devices = vmxml.get_devices(device_type="interface")
        iface = None
        for iface_dev in iface_devices:
            if iface_dev.mac_address == iface_mac:
                iface = iface_dev
        if not iface:
            raise error.TestFail("Can't find interface with mac"
                                 " '%s' in vm xml" % iface_mac)
        driver_dict = {}
        if iface_driver:
            driver_dict = ast.literal_eval(iface_driver)
        for driver_opt in driver_dict.keys():
            if not driver_dict[driver_opt] == iface.driver.driver_attr[driver_opt]:
                raise error.TestFail("Can't see driver option %s=%s in vm xml"
                                     % (driver_opt, driver_dict[driver_opt]))
        if iface_target:
            if (not iface.target.has_key("dev") or
                    not iface.target["dev"].startswith(iface_target)):
                raise error.TestFail("Can't see device target dev in vm xml")
            # Check macvtap mode by ip link command
            if iface_target == "macvtap" and iface.source.has_key("mode"):
                cmd = "ip -d link show %s" % iface.target["dev"]
                output = utils.run(cmd).stdout
                logging.debug("ip link output: %s", output)
                mode = iface.source["mode"]
                if mode == "passthrough":
                    mode = "passthru"
                if not output.count("macvtap  mode %s" % mode):
                    raise error.TestFail("Failed to verify macvtap mode")

    def run_cmdline_test(iface_mac):
        """
        Test for qemu-kvm command line options
        """
        cmd = ("ps -ef | grep %s | grep -v grep " % vm_name)
        ret = utils.run(cmd)
        logging.debug("Command line %s", ret.stdout)
        if test_vhost_net:
            if not ret.stdout.count("vhost=on") and not rm_vhost_driver:
                raise error.TestFail("Can't see vhost options in"
                                     " qemu-kvm command line")

        if iface_model == "virtio":
            model_option = "device virtio-net-pci"
        else:
            model_option = "device rtl8139"
        iface_cmdline = re.findall(r"%s,(.+),mac=%s" %
                                   (model_option, iface_mac), ret.stdout)
        if not iface_cmdline:
            raise error.TestFail("Can't see %s with mac %s in command"
                                 " line" % (model_option, iface_mac))

        cmd_opt = {}
        for opt in iface_cmdline[0].split(','):
            tmp = opt.rsplit("=")
            cmd_opt[tmp[0]] = tmp[1]
        logging.debug("Command line options %s", cmd_opt)

        driver_dict = {}
        # Test <driver> xml options.
        if iface_driver:
            iface_driver_dict = ast.literal_eval(iface_driver)
            for driver_opt in iface_driver_dict.keys():
                if driver_opt == "name":
                    continue
                elif driver_opt == "txmode":
                    if iface_driver_dict["txmode"] == "iothread":
                        driver_dict["tx"] = "bh"
                    else:
                        driver_dict["tx"] = iface_driver_dict["txmode"]
                elif driver_opt == "queues":
                    driver_dict["mq"] = "on"
                    driver_dict["vectors"] = str(int(
                        iface_driver_dict["queues"]) * 2 + 2)
                else:
                    driver_dict[driver_opt] = iface_driver_dict[driver_opt]
        # Test <driver><host/><driver> xml options.
        if iface_driver_host:
            driver_dict.update(ast.literal_eval(iface_driver_host))
        # Test <driver><guest/><driver> xml options.
        if iface_driver_guest:
            driver_dict.update(ast.literal_eval(iface_driver_guest))

        for driver_opt in driver_dict.keys():
            if (not cmd_opt.has_key(driver_opt) or
                    not cmd_opt[driver_opt] == driver_dict[driver_opt]):
                raise error.TestFail("Can't see option '%s=%s' in qemu-kvm "
                                     " command line" %
                                     (driver_opt, driver_dict[driver_opt]))
        if test_backend:
            guest_pid = ret.stdout.rsplit()[1]
            cmd = "lsof %s | grep %s" % (backend["tap"], guest_pid)
            if utils.system(cmd, ignore_status=True):
                raise error.TestFail("Guest process didn't open backend file"
                                     " %s" % backend["tap"])
            cmd = "lsof %s | grep %s" % (backend["vhost"], guest_pid)
            if utils.system(cmd, ignore_status=True):
                raise error.TestFail("Guest process didn't open backend file"
                                     " %s" % backend["tap"])

    def get_guest_ip(session, mac):
        """
        Wrapper function to get guest ip address
        """
        utils_net.restart_guest_network(session, mac)
        # Wait for IP address is ready
        utils_misc.wait_for(
            lambda: utils_net.get_guest_ip_addr(session, mac), 10)
        return utils_net.get_guest_ip_addr(session, mac)

    def check_user_network(session):
        """
        Check user network ip address on guest
        """
        vm_ips = []
        vm_ips.append(get_guest_ip(session, iface_mac_old))
        if attach_device:
            vm_ips.append(get_guest_ip(session, iface_mac))
        logging.debug("IP address on guest: %s", vm_ips)
        if len(vm_ips) != len(set(vm_ips)):
            raise error.TestFail("Duplicated IP address on guest. "
                                 "Check bug: https://bugzilla.redhat."
                                 "com/show_bug.cgi?id=1147238")

        for vm_ip in vm_ips:
            if vm_ip is None or not vm_ip.startswith("10.0.2."):
                raise error.TestFail("Found wrong IP address"
                                     " on guest")
        # Check gateway address
        gateway = utils_net.get_net_gateway(session.cmd_output)
        if gateway != "10.0.2.2":
            raise error.TestFail("The gateway on guest is not"
                                 " right")
        # Check dns server address
        ns_list = utils_net.get_net_nameserver(session.cmd_output)
        if "10.0.2.3" not in ns_list:
            raise error.TestFail("The dns server can't be found"
                                 " on guest")

    def check_mcast_network(session):
        """
        Check multicast ip address on guests
        """
        username = params.get("username")
        password = params.get("password")
        src_addr = ast.literal_eval(iface_source)['address']
        add_session = additional_vm.wait_for_serial_login(username=username,
                                                          password=password)
        vms_sess_dict = {vm_name: session,
                         additional_vm.name: add_session}

        # Check mcast address on host
        cmd = "netstat -g | grep %s" % src_addr
        if utils.run(cmd, ignore_status=True).exit_status:
            raise error.TestFail("Can't find multicast ip address"
                                 " on host")
        vms_ip_dict = {}
        # Get ip address on each guest
        for vms in vms_sess_dict.keys():
            vm_mac = vm_xml.VMXML.get_first_mac_by_name(vms)
            vm_ip = get_guest_ip(vms_sess_dict[vms], vm_mac)
            if not vm_ip:
                raise error.TestFail("Can't get multicast ip"
                                     " address on guest")
            vms_ip_dict.update({vms: vm_ip})
        if len(set(vms_ip_dict.values())) != len(vms_sess_dict):
            raise error.TestFail("Got duplicated multicast ip address")
        logging.debug("Found ips on guest: %s", vms_ip_dict)

        # Run omping server on host
        if not utils_misc.yum_install(["omping"]):
            raise error.TestError("Failed to install omping"
                                  " on host")
        cmd = ("iptables -F;omping -m %s %s" %
               (src_addr, "192.168.122.1 %s" %
                ' '.join(vms_ip_dict.values())))
        # Run a backgroup job waiting for connection of client
        bgjob = utils.AsyncJob(cmd)

        # Run omping client on guests
        for vms in vms_sess_dict.keys():
            # omping should be installed first
            if not utils_misc.yum_install(["omping"], vms_sess_dict[vms]):
                raise error.TestError("Failed to install omping"
                                      " on guest")
            cmd = ("iptables -F; omping -c 5 -T 5 -m %s %s" %
                   (src_addr, "192.168.122.1 %s" %
                    vms_ip_dict[vms]))
            ret, output = vms_sess_dict[vms].cmd_status_output(cmd)
            logging.debug("omping ret: %s, output: %s", ret, output)
            if (not output.count('multicast, xmt/rcv/%loss = 5/5/0%') or
                    not output.count('unicast, xmt/rcv/%loss = 5/5/0%')):
                raise error.TestFail("omping failed on guest")
        # Kill the backgroup job
        bgjob.kill_func()

    status_error = "yes" == params.get("status_error", "no")
    start_error = "yes" == params.get("start_error", "no")
    unprivileged_user = params.get("unprivileged_user")

    # Interface specific attributes.
    iface_type = params.get("iface_type", "network")
    iface_source = params.get("iface_source", "{}")
    iface_driver = params.get("iface_driver")
    iface_model = params.get("iface_model", "virtio")
    iface_target = params.get("iface_target")
    iface_backend = params.get("iface_backend", "{}")
    iface_driver_host = params.get("iface_driver_host")
    iface_driver_guest = params.get("iface_driver_guest")
    attach_device = params.get("attach_iface_device")
    change_option = "yes" == params.get("change_iface_options", "no")
    update_device = "yes" == params.get("update_iface_device", "no")
    additional_guest = "yes" == params.get("additional_guest", "no")
    serial_login = "yes" == params.get("serial_login", "no")
    rm_vhost_driver = "yes" == params.get("rm_vhost_driver", "no")
    test_option_cmd = "yes" == params.get(
                      "test_iface_option_cmd", "no")
    test_option_xml = "yes" == params.get(
                      "test_iface_option_xml", "no")
    test_vhost_net = "yes" == params.get(
                     "test_vhost_net", "no")
    test_option_offloads = "yes" == params.get(
                           "test_option_offloads", "no")
    test_iface_user = "yes" == params.get(
                      "test_iface_user", "no")
    test_iface_mcast = "yes" == params.get(
                       "test_iface_mcast", "no")
    test_libvirtd = "yes" == params.get("test_libvirtd", "no")
    test_guest_ip = "yes" == params.get("test_guest_ip", "no")
    test_backend = "yes" == params.get("test_backend", "no")

    if iface_driver_host or iface_driver_guest or test_backend:
        if not libvirt_version.version_compare(1, 2, 8):
            raise error.TestNAError("Offloading/backend options not "
                                    "supported in this libvirt version")
    if iface_driver and "queues" in ast.literal_eval(iface_driver):
        if not libvirt_version.version_compare(1, 0, 6):
            raise error.TestNAError("Queues options not supported"
                                    " in this libvirt version")

    if unprivileged_user:
        if not libvirt_version.version_compare(1, 1, 1):
            raise error.TestNAError("qemu-bridge-helper not supported"
                                    " on this host")
        virsh_dargs["unprivileged_user"] = unprivileged_user
        # Create unprivileged user if needed
        cmd = ("grep {0} /etc/passwd || "
               "useradd {0}".format(unprivileged_user))
        utils.run(cmd)
        # Need another disk image for unprivileged user to access
        dst_disk = "/tmp/%s.img" % unprivileged_user

    # Destroy VM first
    if vm.is_alive():
        vm.destroy(gracefully=False)

    # Back up xml file.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
    iface_mac_old = vm_xml.VMXML.get_first_mac_by_name(vm_name)
    # iface_mac will update if attach a new interface
    iface_mac = iface_mac_old
    # Additional vm for test
    additional_vm = None
    libvirtd = utils_libvirtd.Libvirtd()

    try:
        # Build the xml and run test.
        try:
            # Prepare interface backend files
            if test_backend:
                if not os.path.exists("/dev/vhost-net"):
                    utils.run("modprobe vhost-net")
                backend = ast.literal_eval(iface_backend)
                backend_tap = "/dev/net/tun"
                backend_vhost = "/dev/vhost-net"
                if not backend:
                    backend["tap"] = backend_tap
                    backend["vhost"] = backend_vhost
                if not start_error:
                    # Create backend files for normal test
                    if not os.path.exists(backend["tap"]):
                        os.rename(backend_tap, backend["tap"])
                    if not os.path.exists(backend["vhost"]):
                        os.rename(backend_vhost, backend["vhost"])
            # Edit the interface xml.
            if change_option:
                modify_iface_xml(update=False)

            if rm_vhost_driver:
                # Check vhost driver.
                kvm_version = os.uname()[2]
                driver_path = ("/lib/modules/%s/kernel/drivers/vhost/"
                               "vhost_net.ko" % kvm_version)
                driver_backup = driver_path + ".bak"
                cmd = ("modprobe -r {0}; lsmod | "
                       "grep {0}".format("vhost_net"))
                if not utils.system(cmd, ignore_status=True):
                    raise error.TestError("Failed to remove vhost_net driver")
                # Move the vhost_net driver
                if os.path.exists(driver_path):
                    os.rename(driver_path, driver_backup)
            else:
                # Load vhost_net driver by default
                cmd = "modprobe vhost_net"
                utils.system(cmd)

            # Attach a interface when vm is shutoff
            if attach_device == 'config':
                iface_mac = utils_net.generate_mac_address_simple()
                iface_xml_obj = create_iface_xml(iface_mac)
                iface_xml_obj.xmltreefile.write()
                ret = virsh.attach_device(vm_name, iface_xml_obj.xml,
                                          flagstr="--config",
                                          ignore_status=True)
                libvirt.check_exit_status(ret)

            # Clone additional vm
            if additional_guest:
                guest_name = "%s_%s" % (vm_name, '1')
                # Clone additional guest
                timeout = params.get("clone_timeout", 360)
                utils_libguestfs.virt_clone_cmd(vm_name, guest_name,
                                                True, timeout=timeout)
                additional_vm = vm.clone(guest_name)
                additional_vm.start()
                # additional_vm.wait_for_login()

            # Start the VM.
            if unprivileged_user:
                virsh.start(vm_name, **virsh_dargs)
                cmd = ("su - %s -c 'virsh console %s'"
                       % (unprivileged_user, vm_name))
                session = aexpect.ShellSession(cmd)
                session.sendline()
                remote.handle_prompts(session, params.get("username"),
                                      params.get("password"), "[\#\$]", 30)
                # Get ip address on guest
                if not get_guest_ip(session, iface_mac):
                    raise error.TestError("Can't get ip address on guest")
            else:
                # Will raise VMStartError exception if start fails
                vm.start()
                if serial_login:
                    session = vm.wait_for_serial_login()
                else:
                    session = vm.wait_for_login()
            if start_error:
                raise error.TestFail("VM started unexpectedly")

            # Attach a interface when vm is running
            if attach_device == 'live':
                iface_mac = utils_net.generate_mac_address_simple()
                iface_xml_obj = create_iface_xml(iface_mac)
                iface_xml_obj.xmltreefile.write()
                ret = virsh.attach_device(vm_name, iface_xml_obj.xml,
                                          flagstr="--live",
                                          ignore_status=True)
                libvirt.check_exit_status(ret)
                # Need sleep here for attachment take effect
                time.sleep(5)

            # Update a interface options
            if update_device:
                modify_iface_xml(update=True, status_error=status_error)

            # Run tests for qemu-kvm command line options
            if test_option_cmd:
                run_cmdline_test(iface_mac)
            # Run tests for vm xml
            if test_option_xml:
                run_xml_test(iface_mac)
            # Run tests for offloads options
            if test_option_offloads:
                if iface_driver_host:
                    ifname_guest = utils_net.get_linux_ifname(
                        session, iface_mac)
                    check_offloads_option(
                        ifname_guest, ast.literal_eval(
                            iface_driver_host), session)
                if iface_driver_guest:
                    ifname_host = libvirt.get_ifname_host(vm_name,
                                                          iface_mac)
                    check_offloads_option(
                        ifname_host, ast.literal_eval(iface_driver_guest))

            if test_iface_user:
                # Test user type network
                check_user_network(session)
            if test_iface_mcast:
                # Test mcast type network
                check_mcast_network(session)
            # Check guest ip address
            if test_guest_ip:
                if not get_guest_ip(session, iface_mac):
                    raise error.TestFail("Guest can't get a"
                                         " valid ip address")

            session.close()
            # Restart libvirtd and guest, then test again
            if test_libvirtd:
                libvirtd.restart()
                vm.destroy()
                vm.start()
                if test_option_xml:
                    run_xml_test(iface_mac)

            # Detach hot/cold-plugged interface at last
            if attach_device:
                ret = virsh.detach_device(vm_name, iface_xml_obj.xml,
                                          flagstr="", ignore_status=True)
                libvirt.check_exit_status(ret)

        except virt_vm.VMStartError as e:
            logging.info(str(e))
            if not start_error:
                raise error.TestFail('VM failed to start\n%s' % e)

    finally:
        # Recover VM.
        logging.info("Restoring vm...")
        # Restore interface backend files
        if test_backend:
            if not os.path.exists(backend_tap):
                os.rename(backend["tap"], backend_tap)
            if not os.path.exists(backend_vhost):
                os.rename(backend["vhost"], backend_vhost)
        if rm_vhost_driver:
            # Restore vhost_net driver
            if os.path.exists(driver_backup):
                os.rename(driver_backup, driver_path)
        if unprivileged_user:
            virsh.remove_domain(vm_name, "--remove-all-storage",
                                **virsh_dargs)
        if additional_vm:
            virsh.remove_domain(additional_vm.name,
                                "--remove-all-storage")
            # Kill all omping server process on host
            utils.system("pidof omping && killall omping",
                         ignore_status=True)
        if vm.is_alive():
            vm.destroy(gracefully=False)
        vmxml_backup.sync()

Example 23

Project: pythonVSCode
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    try:
        import pickle
    except ImportError:
        import cPickle as pickle
    import marshal
    import inspect
    import types
    import threading
    import rope.base.utils.pycompat as pycompat

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('wb')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    def _cached(func):
        cache = {}

        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    argvalue = self._object_to_persisted_form(
                        frame.f_locals[argname])
                    args.append(argvalue)
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or
            #    not self._is_code_inside_project(frame.f_back.f_code)
            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or
                    not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, pycompat.string_types):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list',
                        self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    # @todo - fix it properly, why is __locals__ being
                    # duplicated ?
                    keys = [key for key in object_.keys() if key != '__locals__'][0]
                    values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set',
                        self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.__code__)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.__func__.__code__)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, pycompat.string_types + (list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, type):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()
            sys.settrace(None)

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        '__builtins__': __builtins__,
                        '__file__': file_to_run})

    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    pycompat.execfile(file_to_run, run_globals)
    if send_info != '-':
        data_sender.close()

Example 24

Project: tp-qemu
Source File: ksm_overcommit.py
View license
def run(test, params, env):
    """
    Tests KSM (Kernel Shared Memory) capability by allocating and filling
    KVM guests memory using various values. KVM sets the memory as
    MADV_MERGEABLE so all VM's memory can be merged. The workers in
    guest writes to tmpfs filesystem thus allocations are not limited
    by process max memory, only by VM's memory. Two test modes are supported -
    serial and parallel.

    Serial mode - uses multiple VMs, allocates memory per guest and always
                  verifies the correct number of shared memory.
                  0) Prints out the setup and initialize guest(s)
                  1) Fills guest with the same number (S1)
                  2) Random fill on the first guest
                  3) Random fill of the remaining VMs one by one until the
                     memory is completely filled (KVM stops machines which
                     asks for additional memory until there is available
                     memory) (S2, shouldn't finish)
                  4) Destroy all VMs but the last one
                  5) Checks the last VMs memory for corruption
    Parallel mode - uses one VM with multiple allocator workers. Executes
                   scenarios in parallel to put more stress on the KVM.
                   0) Prints out the setup and initialize guest(s)
                   1) Fills memory with the same number (S1)
                   2) Fills memory with random numbers (S2)
                   3) Verifies all pages
                   4) Fills memory with the same number (S2)
                   5) Changes the last 96B (S3)

    Scenarios:
    S1) Fill all vms with the same value (all pages should be merged into 1)
    S2) Random fill (all pages should be splitted)
    S3) Fill last 96B (change only last 96B of each page; some pages will be
                      merged; there was a bug with data corruption)
    Every worker has unique random key so we are able to verify the filled
    values.

    :param test: kvm test object.
    :param params: Dictionary with test parameters.
    :param env: Dictionary with the test environment.

    :param cfg: ksm_swap - use swap?
    :param cfg: ksm_overcommit_ratio - memory overcommit (serial mode only)
    :param cfg: ksm_parallel_ratio - number of workers (parallel mode only)
    :param cfg: ksm_host_reserve - override memory reserve on host in MB
    :param cfg: ksm_guest_reserve - override memory reserve on guests in MB
    :param cfg: ksm_mode - test mode {serial, parallel}
    :param cfg: ksm_perf_ratio - performance ratio, increase it when your
                                 machine is too slow
    """
    def _start_allocator(vm, session, timeout):
        """
        Execute ksm_overcommit_guest.py on guest, wait until it's initialized.

        :param vm: VM object.
        :param session: Remote session to a VM object.
        :param timeout: Timeout that will be used to verify if
                ksm_overcommit_guest.py started properly.
        """
        logging.debug("Starting ksm_overcommit_guest.py on guest %s", vm.name)
        session.sendline("python /tmp/ksm_overcommit_guest.py")
        try:
            session.read_until_last_line_matches(["PASS:", "FAIL:"], timeout)
        except aexpect.ExpectProcessTerminatedError, details:
            e_msg = ("Command ksm_overcommit_guest.py on vm '%s' failed: %s" %
                     (vm.name, str(details)))
            raise error.TestFail(e_msg)

    def _execute_allocator(command, vm, session, timeout):
        """
        Execute a given command on ksm_overcommit_guest.py main loop,
        indicating the vm the command was executed on.

        :param command: Command that will be executed.
        :param vm: VM object.
        :param session: Remote session to VM object.
        :param timeout: Timeout used to verify expected output.

        :return: Tuple (match index, data)
        """
        logging.debug("Executing '%s' on ksm_overcommit_guest.py loop, "
                      "vm: %s, timeout: %s", command, vm.name, timeout)
        session.sendline(command)
        try:
            (match, data) = session.read_until_last_line_matches(
                ["PASS:", "FAIL:"],
                timeout)
        except aexpect.ExpectProcessTerminatedError, details:
            e_msg = ("Failed to execute command '%s' on "
                     "ksm_overcommit_guest.py, vm '%s': %s" %
                     (command, vm.name, str(details)))
            raise error.TestFail(e_msg)
        return (match, data)

    def get_ksmstat():
        """
        Return sharing memory by ksm in MB

        :return: memory in MB
        """
        fpages = open('/sys/kernel/mm/ksm/pages_sharing')
        ksm_pages = int(fpages.read())
        fpages.close()
        return ((ksm_pages * 4096) / 1e6)

    def initialize_guests():
        """
        Initialize guests (fill their memories with specified patterns).
        """
        logging.info("Phase 1: filling guest memory pages")
        for session in lsessions:
            vm = lvms[lsessions.index(session)]

            logging.debug("Turning off swap on vm %s", vm.name)
            session.cmd("swapoff -a", timeout=300)

            # Start the allocator
            _start_allocator(vm, session, 60 * perf_ratio)

        # Execute allocator on guests
        for i in range(0, vmsc):
            vm = lvms[i]

            cmd = "mem = MemFill(%d, %s, %s)" % (ksm_size, skeys[i], dkeys[i])
            _execute_allocator(cmd, vm, lsessions[i], 60 * perf_ratio)

            cmd = "mem.value_fill(%d)" % skeys[0]
            _execute_allocator(cmd, vm, lsessions[i],
                               fill_base_timeout * 2 * perf_ratio)

            # Let ksm_overcommit_guest.py do its job
            # (until shared mem reaches expected value)
            shm = 0
            j = 0
            logging.debug("Target shared meminfo for guest %s: %s", vm.name,
                          ksm_size)
            while ((new_ksm and (shm < (ksm_size * (i + 1)))) or
                    (not new_ksm and (shm < (ksm_size)))):
                if j > 64:
                    logging.debug(utils_test.get_memory_info(lvms))
                    raise error.TestError("SHM didn't merge the memory until "
                                          "the DL on guest: %s" % vm.name)
                pause = ksm_size / 200 * perf_ratio
                logging.debug("Waiting %ds before proceeding...", pause)
                time.sleep(pause)
                if (new_ksm):
                    shm = get_ksmstat()
                else:
                    shm = vm.get_shared_meminfo()
                logging.debug("Shared meminfo for guest %s after "
                              "iteration %s: %s", vm.name, j, shm)
                j += 1

        # Keep some reserve
        pause = ksm_size / 200 * perf_ratio
        logging.debug("Waiting %ds before proceeding...", pause)
        time.sleep(pause)

        logging.debug(utils_test.get_memory_info(lvms))
        logging.info("Phase 1: PASS")

    def separate_first_guest():
        """
        Separate memory of the first guest by generating special random series
        """
        logging.info("Phase 2: Split the pages on the first guest")

        cmd = "mem.static_random_fill()"
        data = _execute_allocator(cmd, lvms[0], lsessions[0],
                                  fill_base_timeout * 2 * perf_ratio)[1]

        r_msg = data.splitlines()[-1]
        logging.debug("Return message of static_random_fill: %s", r_msg)
        out = int(r_msg.split()[4])
        logging.debug("Performance: %dMB * 1000 / %dms = %dMB/s", ksm_size,
                      out, (ksm_size * 1000 / out))
        logging.debug(utils_test.get_memory_info(lvms))
        logging.debug("Phase 2: PASS")

    def split_guest():
        """
        Sequential split of pages on guests up to memory limit
        """
        logging.info("Phase 3a: Sequential split of pages on guests up to "
                     "memory limit")
        last_vm = 0
        session = None
        vm = None
        for i in range(1, vmsc):
            # Check VMs
            for j in range(0, vmsc):
                if not lvms[j].is_alive:
                    e_msg = ("VM %d died while executing static_random_fill on"
                             " VM %d in allocator loop" % (j, i))
                    raise error.TestFail(e_msg)
            vm = lvms[i]
            session = lsessions[i]
            cmd = "mem.static_random_fill()"
            logging.debug("Executing %s on ksm_overcommit_guest.py loop, "
                          "vm: %s", cmd, vm.name)
            session.sendline(cmd)

            out = ""
            try:
                logging.debug("Watching host mem while filling vm %s memory",
                              vm.name)
                while (not out.startswith("PASS") and
                       not out.startswith("FAIL")):
                    if not vm.is_alive():
                        e_msg = ("VM %d died while executing "
                                 "static_random_fill on allocator loop" % i)
                        raise error.TestFail(e_msg)
                    free_mem = int(utils_memory.read_from_meminfo("MemFree"))
                    if (ksm_swap):
                        free_mem = (free_mem +
                                    int(utils_memory.read_from_meminfo("SwapFree")))
                    logging.debug("Free memory on host: %d", free_mem)

                    # We need to keep some memory for python to run.
                    if (free_mem < 64000) or (ksm_swap and
                                              free_mem < (450000 * perf_ratio)):
                        vm.pause()
                        for j in range(0, i):
                            lvms[j].destroy(gracefully=False)
                        time.sleep(20)
                        vm.resume()
                        logging.debug("Only %s free memory, killing %d guests",
                                      free_mem, (i - 1))
                        last_vm = i
                    out = session.read_nonblocking(0.1, 1)
                    time.sleep(2)
            except OSError:
                logging.debug("Only %s host free memory, killing %d guests",
                              free_mem, (i - 1))
                logging.debug("Stopping %s", vm.name)
                vm.pause()
                for j in range(0, i):
                    logging.debug("Destroying %s", lvms[j].name)
                    lvms[j].destroy(gracefully=False)
                time.sleep(20)
                vm.resume()
                last_vm = i

            if last_vm != 0:
                break
            logging.debug("Memory filled for guest %s", vm.name)

        logging.info("Phase 3a: PASS")

        logging.info("Phase 3b: Verify memory of the max stressed VM")
        for i in range(last_vm + 1, vmsc):
            lsessions[i].close()
            if i == (vmsc - 1):
                logging.debug(utils_test.get_memory_info([lvms[i]]))
            logging.debug("Destroying guest %s", lvms[i].name)
            lvms[i].destroy(gracefully=False)

        # Verify last machine with randomly generated memory
        cmd = "mem.static_random_verify()"
        _execute_allocator(cmd, lvms[last_vm], lsessions[last_vm],
                           (mem / 200 * 50 * perf_ratio))
        logging.debug(utils_test.get_memory_info([lvms[last_vm]]))

        lsessions[last_vm].cmd_output("die()", 20)
        lvms[last_vm].destroy(gracefully=False)
        logging.info("Phase 3b: PASS")

    def split_parallel():
        """
        Parallel page spliting
        """
        logging.info("Phase 1: parallel page spliting")
        # We have to wait until allocator is finished (it waits 5 seconds to
        # clean the socket

        session = lsessions[0]
        vm = lvms[0]
        for i in range(1, max_alloc):
            lsessions.append(vm.wait_for_login(timeout=360))

        session.cmd("swapoff -a", timeout=300)

        for i in range(0, max_alloc):
            # Start the allocator
            _start_allocator(vm, lsessions[i], 60 * perf_ratio)

        logging.info("Phase 1: PASS")

        logging.info("Phase 2a: Simultaneous merging")
        logging.debug("Memory used by allocator on guests = %dMB",
                      (ksm_size / max_alloc))

        for i in range(0, max_alloc):
            cmd = "mem = MemFill(%d, %s, %s)" % ((ksm_size / max_alloc),
                                                 skeys[i], dkeys[i])
            _execute_allocator(cmd, vm, lsessions[i], 60 * perf_ratio)

            cmd = "mem.value_fill(%d)" % (skeys[0])
            _execute_allocator(cmd, vm, lsessions[i],
                               fill_base_timeout * perf_ratio)

        # Wait until ksm_overcommit_guest.py merges pages (3 * ksm_size / 3)
        shm = 0
        i = 0
        logging.debug("Target shared memory size: %s", ksm_size)
        while (shm < ksm_size):
            if i > 64:
                logging.debug(utils_test.get_memory_info(lvms))
                raise error.TestError("SHM didn't merge the memory until DL")
            pause = ksm_size / 200 * perf_ratio
            logging.debug("Waiting %ds before proceed...", pause)
            time.sleep(pause)
            if (new_ksm):
                shm = get_ksmstat()
            else:
                shm = vm.get_shared_meminfo()
            logging.debug("Shared meminfo after attempt %s: %s", i, shm)
            i += 1

        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2a: PASS")

        logging.info("Phase 2b: Simultaneous spliting")
        # Actual splitting
        for i in range(0, max_alloc):
            cmd = "mem.static_random_fill()"
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      fill_base_timeout * perf_ratio)[1]

            data = data.splitlines()[-1]
            logging.debug(data)
            out = int(data.split()[4])
            logging.debug("Performance: %dMB * 1000 / %dms = %dMB/s",
                          (ksm_size / max_alloc), out,
                          (ksm_size * 1000 / out / max_alloc))
        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2b: PASS")

        logging.info("Phase 2c: Simultaneous verification")
        for i in range(0, max_alloc):
            cmd = "mem.static_random_verify()"
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      (mem / 200 * 50 * perf_ratio))[1]
        logging.info("Phase 2c: PASS")

        logging.info("Phase 2d: Simultaneous merging")
        # Actual splitting
        for i in range(0, max_alloc):
            cmd = "mem.value_fill(%d)" % skeys[0]
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      fill_base_timeout * 2 * perf_ratio)[1]
        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2d: PASS")

        logging.info("Phase 2e: Simultaneous verification")
        for i in range(0, max_alloc):
            cmd = "mem.value_check(%d)" % skeys[0]
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      (mem / 200 * 50 * perf_ratio))[1]
        logging.info("Phase 2e: PASS")

        logging.info("Phase 2f: Simultaneous spliting last 96B")
        for i in range(0, max_alloc):
            cmd = "mem.static_random_fill(96)"
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      fill_base_timeout * perf_ratio)[1]

            data = data.splitlines()[-1]
            out = int(data.split()[4])
            logging.debug("Performance: %dMB * 1000 / %dms = %dMB/s",
                          ksm_size / max_alloc, out,
                          (ksm_size * 1000 / out / max_alloc))

        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2f: PASS")

        logging.info("Phase 2g: Simultaneous verification last 96B")
        for i in range(0, max_alloc):
            cmd = "mem.static_random_verify(96)"
            _, data = _execute_allocator(cmd, vm, lsessions[i],
                                         (mem / 200 * 50 * perf_ratio))
        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2g: PASS")

        logging.debug("Cleaning up...")
        for i in range(0, max_alloc):
            lsessions[i].cmd_output("die()", 20)
        session.close()
        vm.destroy(gracefully=False)

    # Main test code
    logging.info("Starting phase 0: Initialization")
    if utils.run("ps -C ksmtuned", ignore_status=True).exit_status == 0:
        logging.info("Killing ksmtuned...")
        utils.run("killall ksmtuned")
    new_ksm = False
    if (os.path.exists("/sys/kernel/mm/ksm/run")):
        utils.run("echo 50 > /sys/kernel/mm/ksm/sleep_millisecs")
        utils.run("echo 5000 > /sys/kernel/mm/ksm/pages_to_scan")
        utils.run("echo 1 > /sys/kernel/mm/ksm/run")

        e_up = "/sys/kernel/mm/transparent_hugepage/enabled"
        e_rh = "/sys/kernel/mm/redhat_transparent_hugepage/enabled"
        if os.path.exists(e_up):
            utils.run("echo 'never' > %s" % e_up)
        if os.path.exists(e_rh):
            utils.run("echo 'never' > %s" % e_rh)
        new_ksm = True
    else:
        try:
            utils.run("modprobe ksm")
            utils.run("ksmctl start 5000 100")
        except error.CmdError, details:
            raise error.TestFail("Failed to load KSM: %s" % details)

    # host_reserve: mem reserve kept for the host system to run
    host_reserve = int(params.get("ksm_host_reserve", -1))
    if (host_reserve == -1):
        # default host_reserve = MemAvailable + one_minimal_guest(128MB)
        # later we add 64MB per additional guest
        host_reserve = ((utils_memory.memtotal() -
                         utils_memory.read_from_meminfo("MemFree")) /
                        1024 + 128)
        # using default reserve
        _host_reserve = True
    else:
        _host_reserve = False

    # guest_reserve: mem reserve kept to avoid guest OS to kill processes
    guest_reserve = int(params.get("ksm_guest_reserve", -1))
    if (guest_reserve == -1):
        # default guest_reserve = minimal_system_mem(256MB)
        # later we add tmpfs overhead
        guest_reserve = 256
        # using default reserve
        _guest_reserve = True
    else:
        _guest_reserve = False

    max_vms = int(params.get("max_vms", 2))
    overcommit = float(params.get("ksm_overcommit_ratio", 2.0))
    max_alloc = int(params.get("ksm_parallel_ratio", 1))

    # vmsc: count of all used VMs
    vmsc = int(overcommit) + 1
    vmsc = max(vmsc, max_vms)

    if (params['ksm_mode'] == "serial"):
        max_alloc = vmsc
        if _host_reserve:
            # First round of additional guest reserves
            host_reserve += vmsc * 64
            _host_reserve = vmsc

    host_mem = (int(utils_memory.memtotal()) / 1024 - host_reserve)

    ksm_swap = False
    if params.get("ksm_swap") == "yes":
        ksm_swap = True

    # Performance ratio
    perf_ratio = params.get("ksm_perf_ratio")
    if perf_ratio:
        perf_ratio = float(perf_ratio)
    else:
        perf_ratio = 1

    if (params['ksm_mode'] == "parallel"):
        vmsc = 1
        overcommit = 1
        mem = host_mem
        # 32bit system adjustment
        if "64" not in params.get("vm_arch_name"):
            logging.debug("Probably i386 guest architecture, "
                          "max allocator mem = 2G")
            # Guest can have more than 2G but
            # kvm mem + 1MB (allocator itself) can't
            if (host_mem > 3100):
                mem = 3100

        if os.popen("uname -i").readline().startswith("i386"):
            logging.debug("Host is i386 architecture, max guest mem is 2G")
            # Guest system with qemu overhead (64M) can't have more than 2G
            if mem > 3100 - 64:
                mem = 3100 - 64

    else:
        # mem: Memory of the guest systems. Maximum must be less than
        # host's physical ram
        mem = int(overcommit * host_mem / vmsc)

        # 32bit system adjustment
        if not params['image_name'].endswith("64"):
            logging.debug("Probably i386 guest architecture, "
                          "max allocator mem = 2G")
            # Guest can have more than 2G but
            # kvm mem + 1MB (allocator itself) can't
            if mem - guest_reserve - 1 > 3100:
                vmsc = int(math.ceil((host_mem * overcommit) /
                                     (3100 + guest_reserve)))
                if _host_reserve:
                    host_reserve += (vmsc - _host_reserve) * 64
                    host_mem -= (vmsc - _host_reserve) * 64
                    _host_reserve = vmsc
                mem = int(math.floor(host_mem * overcommit / vmsc))

        if os.popen("uname -i").readline().startswith("i386"):
            logging.debug("Host is i386 architecture, max guest mem is 2G")
            # Guest system with qemu overhead (64M) can't have more than 2G
            if mem > 3100 - 64:
                vmsc = int(math.ceil((host_mem * overcommit) /
                                     (3100 - 64.0)))
                if _host_reserve:
                    host_reserve += (vmsc - _host_reserve) * 64
                    host_mem -= (vmsc - _host_reserve) * 64
                    _host_reserve = vmsc
                mem = int(math.floor(host_mem * overcommit / vmsc))

    # 0.055 represents OS + TMPFS additional reserve per guest ram MB
    if _guest_reserve:
        guest_reserve += math.ceil(mem * 0.055)

    swap = int(utils_memory.read_from_meminfo("SwapTotal")) / 1024

    logging.debug("Overcommit = %f", overcommit)
    logging.debug("True overcommit = %f ", (float(vmsc * mem) /
                                            float(host_mem)))
    logging.debug("Host memory = %dM", host_mem)
    logging.debug("Guest memory = %dM", mem)
    logging.debug("Using swap = %s", ksm_swap)
    logging.debug("Swap = %dM", swap)
    logging.debug("max_vms = %d", max_vms)
    logging.debug("Count of all used VMs = %d", vmsc)
    logging.debug("Performance_ratio = %f", perf_ratio)

    # Generate unique keys for random series
    skeys = []
    dkeys = []
    for i in range(0, max(vmsc, max_alloc)):
        key = random.randrange(0, 255)
        while key in skeys:
            key = random.randrange(0, 255)
        skeys.append(key)

        key = random.randrange(0, 999)
        while key in dkeys:
            key = random.randrange(0, 999)
        dkeys.append(key)

    logging.debug("skeys: %s", skeys)
    logging.debug("dkeys: %s", dkeys)

    lvms = []
    lsessions = []

    # As we don't know the number and memory amount of VMs in advance,
    # we need to specify and create them here
    vm_name = params["main_vm"]
    params['mem'] = mem
    params['vms'] = vm_name
    # Associate pidfile name
    params['pid_' + vm_name] = utils_misc.generate_tmp_file_name(vm_name,
                                                                 'pid')
    if not params.get('extra_params'):
        params['extra_params'] = ' '
    params['extra_params_' + vm_name] = params.get('extra_params')
    params['extra_params_' + vm_name] += (" -pidfile %s" %
                                          (params.get('pid_' + vm_name)))
    params['extra_params'] = params.get('extra_params_' + vm_name)

    # ksm_size: amount of memory used by allocator
    ksm_size = mem - guest_reserve
    logging.debug("Memory used by allocator on guests = %dM", ksm_size)
    fill_base_timeout = ksm_size / 10

    # Creating the first guest
    env_process.preprocess_vm(test, params, env, vm_name)
    lvms.append(env.get_vm(vm_name))
    if not lvms[0]:
        raise error.TestError("VM object not found in environment")
    if not lvms[0].is_alive():
        raise error.TestError("VM seems to be dead; Test requires a living "
                              "VM")

    logging.debug("Booting first guest %s", lvms[0].name)

    lsessions.append(lvms[0].wait_for_login(timeout=360))
    # Associate vm PID
    try:
        tmp = open(params.get('pid_' + vm_name), 'r')
        params['pid_' + vm_name] = int(tmp.readline())
    except Exception:
        raise error.TestFail("Could not get PID of %s" % (vm_name))

    # Creating other guest systems
    for i in range(1, vmsc):
        vm_name = "vm" + str(i + 1)
        params['pid_' + vm_name] = utils_misc.generate_tmp_file_name(vm_name,
                                                                     'pid')
        params['extra_params_' + vm_name] = params.get('extra_params')
        params['extra_params_' + vm_name] += (" -pidfile %s" %
                                              (params.get('pid_' + vm_name)))
        params['extra_params'] = params.get('extra_params_' + vm_name)

        # Last VM is later used to run more allocators simultaneously
        lvms.append(lvms[0].clone(vm_name, params))
        env.register_vm(vm_name, lvms[i])
        params['vms'] += " " + vm_name

        logging.debug("Booting guest %s", lvms[i].name)
        lvms[i].create()
        if not lvms[i].is_alive():
            raise error.TestError("VM %s seems to be dead; Test requires a"
                                  "living VM" % lvms[i].name)

        lsessions.append(lvms[i].wait_for_login(timeout=360))
        try:
            tmp = open(params.get('pid_' + vm_name), 'r')
            params['pid_' + vm_name] = int(tmp.readline())
        except Exception:
            raise error.TestFail("Could not get PID of %s" % (vm_name))

    # Let guests rest a little bit :-)
    pause = vmsc * 2 * perf_ratio
    logging.debug("Waiting %ds before proceed", pause)
    time.sleep(vmsc * 2 * perf_ratio)
    logging.debug(utils_test.get_memory_info(lvms))

    # Copy ksm_overcommit_guest.py into guests
    shared_dir = os.path.dirname(data_dir.get_data_dir())
    vksmd_src = os.path.join(shared_dir, "scripts", "ksm_overcommit_guest.py")
    dst_dir = "/tmp"
    for vm in lvms:
        vm.copy_files_to(vksmd_src, dst_dir)
    logging.info("Phase 0: PASS")

    if params['ksm_mode'] == "parallel":
        logging.info("Starting KSM test parallel mode")
        split_parallel()
        logging.info("KSM test parallel mode: PASS")
    elif params['ksm_mode'] == "serial":
        logging.info("Starting KSM test serial mode")
        initialize_guests()
        separate_first_guest()
        split_guest()
        logging.info("KSM test serial mode: PASS")

Example 25

Project: cgat
Source File: runExpression.py
View license
def main(argv=None):
    """script main.

    parses command line options in sys.argv, unless *argv* is given.
    """

    if not argv:
        argv = sys.argv

    # setup command line parser
    parser = E.OptionParser(version="%prog version: $Id$",
                            usage=globals()["__doc__"])

    parser.add_option("-t", "--tags-tsv-file", dest="input_filename_tags",
                      type="string",
                      help="input file with tag counts [default=%default].")

    parser.add_option(
        "--result-tsv-file", dest="input_filename_result",
        type="string",
        help="input file with results (for plotdetagstats) "
        "[default=%default].")

    parser.add_option("-d", "--design-tsv-file", dest="input_filename_design",
                      type="string",
                      help="input file with experimental design "
                      "[default=%default].")

    parser.add_option("-o", "--outfile", dest="output_filename", type="string",
                      help="output filename [default=%default].")

    parser.add_option("-m", "--method", dest="method", type="choice",
                      choices=(
                          "deseq", "edger", "deseq2",
                          "ttest",
                          "mock", "summary",
                          "dump", "spike",
                          "plottagstats",
                          "plotdetagstats"),
                      help="differential expression method to apply "
                      "[default=%default].")

    parser.add_option("--deseq-dispersion-method",
                      dest="deseq_dispersion_method",
                      type="choice",
                      choices=("pooled", "per-condition", "blind"),
                      help="dispersion method for deseq [default=%default].")

    parser.add_option("--deseq-fit-type", dest="deseq_fit_type", type="choice",
                      choices=("parametric", "local"),
                      help="fit type for deseq [default=%default].")

    parser.add_option("--deseq-sharing-mode",
                      dest="deseq_sharing_mode",
                      type="choice",
                      choices=("maximum", "fit-only", "gene-est-only"),
                      help="deseq sharing mode [default=%default].")

    parser.add_option(
        "--edger-dispersion",
        dest="edger_dispersion", type="float",
        help="dispersion value for edgeR if there are no replicates "
        "[default=%default].")

    parser.add_option("-f", "--fdr", dest="fdr", type="float",
                      help="fdr to apply [default=%default].")

    parser.add_option("-p", "--pseudocounts", dest="pseudo_counts",
                      type="float",
                      help="pseudocounts to add for mock analyis "
                      "[default=%default].")

    parser.add_option("-R", "--output-R-code", dest="save_r_environment",
                      type="string",
                      help="save R environment [default=%default].")

    parser.add_option("-r", "--reference-group", dest="ref_group",
                      type="string",
                      help="Group to use as reference to compute "
                      "fold changes against [default=$default]")

    parser.add_option("--filter-min-counts-per-row",
                      dest="filter_min_counts_per_row",
                      type="int",
                      help="remove rows with less than this "
                      "number of counts in total [default=%default].")

    parser.add_option("--filter-min-counts-per-sample",
                      dest="filter_min_counts_per_sample",
                      type="int",
                      help="remove samples with a maximum count per sample of "
                      "less than this number   [default=%default].")

    parser.add_option("--filter-percentile-rowsums",
                      dest="filter_percentile_rowsums",
                      type="int",
                      help="remove percent of rows with "
                      "lowest total counts [default=%default].")
    parser.add_option("--deseq2-design-formula",
                      dest="model",
                      type="string",
                      help="Design formula for DESeq2")
    parser.add_option("--deseq2-contrasts",
                      dest="contrasts",
                      type="string",
                      help=("contrasts for post-hoc testing writen"
                            " variable:control:treatment,..."))
    parser.add_option("--deseq2-plot",
                      dest="plot",
                      type="int",
                      help=("draw plots during deseq2 analysis"))

    parser.set_defaults(
        input_filename_tags=None,
        input_filename_result=None,
        input_filename_design=None,
        output_filename=sys.stdout,
        method="deseq",
        fdr=0.1,
        deseq_dispersion_method="pooled",
        deseq_fit_type="parametric",
        deseq_sharing_mode="maximum",
        edger_dispersion=0.4,
        ref_group=None,
        save_r_environment=None,
        filter_min_counts_per_row=1,
        filter_min_counts_per_sample=10,
        filter_percentile_rowsums=0,
        pseudo_counts=0,
        spike_foldchange_max=4.0,
        spike_expression_max=5.0,
        spike_expression_bin_width=0.5,
        spike_foldchange_bin_width=0.5,
        spike_max_counts_per_bin=50,
        model=None,
        plot=1
    )

    # add common options (-h/--help, ...) and parse command line
    (options, args) = E.Start(parser, argv=argv, add_output_options=True)

    if options.input_filename_tags == "-":
        fh = tempfile.NamedTemporaryFile(delete=False)
        fh.write("".join([x for x in options.stdin]))
        fh.close()
        options.input_filename_tags = fh.name
    else:
        fh = None

    # load tag data and filter
    if options.method in ("deseq2", "deseq", "edger", "mock", "ttest"):
        assert options.input_filename_tags and os.path.exists(
            options.input_filename_tags)
        assert options.input_filename_design and os.path.exists(
            options.input_filename_design)

        Expression.loadTagData(options.input_filename_tags,
                               options.input_filename_design)

        nobservations, nsamples = Expression.filterTagData(
            filter_min_counts_per_row=options.filter_min_counts_per_row,
            filter_min_counts_per_sample=options.filter_min_counts_per_sample,
            filter_percentile_rowsums=options.filter_percentile_rowsums)

        if nobservations == 0:
            E.warn("no observations - no output")
            return

        if nsamples == 0:
            E.warn("no samples remain after filtering - no output")
            return

        sample_names = R('''colnames(countsTable)''')
        E.info("%i samples to test at %i observations: %s" %
               (nsamples, nobservations,
                ",".join(sample_names)))

    try:
        if options.method == "deseq2":
            Expression.runDESeq2(
                outfile=options.output_filename,
                outfile_prefix=options.output_filename_pattern,
                fdr=options.fdr,
                ref_group=options.ref_group,
                model=options.model,
                contrasts=options.contrasts,
                plot=options.plot
            )

        elif options.method == "deseq":
            Expression.runDESeq(
                outfile=options.output_filename,
                outfile_prefix=options.output_filename_pattern,
                fdr=options.fdr,
                dispersion_method=options.deseq_dispersion_method,
                fit_type=options.deseq_fit_type,
                sharing_mode=options.deseq_sharing_mode,
                ref_group=options.ref_group,
            )

        elif options.method == "edger":
            Expression.runEdgeR(
                outfile=options.output_filename,
                outfile_prefix=options.output_filename_pattern,
                fdr=options.fdr,
                ref_group=options.ref_group,
                dispersion=options.edger_dispersion)

        elif options.method == "mock":
            Expression.runMockAnalysis(
                outfile=options.output_filename,
                outfile_prefix=options.output_filename_pattern,
                ref_group=options.ref_group,
                pseudo_counts=options.pseudo_counts,
            )

        elif options.method == "summary":
            Expression.outputTagSummary(
                options.input_filename_tags,
                options.stdout,
                options.output_filename_pattern,
                filename_design=options.input_filename_design
            )

        elif options.method == "dump":
            assert options.input_filename_tags and os.path.exists(
                options.input_filename_tags)
            Expression.dumpTagData(options.input_filename_tags,
                                   options.input_filename_design,
                                   outfile=options.stdout)

        elif options.method == "plottagstats":
            assert options.input_filename_tags and os.path.exists(
                options.input_filename_tags)
            Expression.plotTagStats(
                options.input_filename_tags,
                options.input_filename_design,
                outfile_prefix=options.output_filename_pattern)

        elif options.method == "plotdetagstats":
            assert options.input_filename_result and os.path.exists(
                options.input_filename_result)
            Expression.plotDETagStats(
                options.input_filename_result,
                outfile_prefix=options.output_filename_pattern)

        elif options.method == "spike":
            Expression.outputSpikeIns(
                options.input_filename_tags,
                options.stdout,
                options.output_filename_pattern,
                filename_design=options.input_filename_design,
                foldchange_max=options.spike_foldchange_max,
                expression_max=options.spike_expression_max,
                max_counts_per_bin=options.spike_max_counts_per_bin,
                expression_bin_width=options.spike_expression_bin_width,
                foldchange_bin_width=options.spike_foldchange_bin_width,
            )

        elif options.method == "ttest":
            Expression.runTTest(
                outfile=options.output_filename,
                outfile_prefix=options.output_filename_pattern,
                fdr=options.fdr)

    except rpy2.rinterface.RRuntimeError:
        if options.save_r_environment:
            E.info("saving R image to %s" % options.save_r_environment)
            R['save.image'](options.save_r_environment)
        raise

    if fh and os.path.exists(fh.name):
        os.unlink(fh.name)

    if options.save_r_environment:
        R['save.image'](options.save_r_environment)

    E.Stop()

Example 26

Project: laikaboss
Source File: laika.py
View license
def main():
    # Define default configuration location

    parser = OptionParser(usage="usage: %prog [options] /path/to/file")
    parser.add_option("-d", "--debug",
                      action="store_true",
                      dest="debug",
                      help="enable debug messages to the console.")
    parser.add_option("-c", "--config-path",
                      action="store", type="string",
                      dest="config_path",
                      help="path to configuration for laikaboss framework.")
    parser.add_option("-o", "--out-path",
                      action="store", type="string",
                      dest="save_path",
                      help="Write all results to the specified path")
    parser.add_option("-s", "--source",
                      action="store", type="string",
                      dest="source",
                      help="Set the source (may affect dispatching) [default:laika]")
    parser.add_option("-p", "--num_procs",
                      action="store", type="int",
                      dest="num_procs",
                      default=8,
                      help="Specify the number of CPU's to use for a recursive scan. [default:8]")
    parser.add_option("-l", "--log",
                      action="store_true",
                      dest="log_result",
                      help="enable logging to syslog")
    parser.add_option("-j", "--log-json",
                      action="store", type="string",
                      dest="log_json",
                      help="enable logging JSON results to file")
    parser.add_option("-m", "--module",
                      action="store", type="string",
                      dest="scan_modules",
                      help="Specify individual module(s) to run and their arguments. If multiple, must be a space-separated list.")
    parser.add_option("--parent",
                      action="store", type="string",
                      dest="parent", default="",
                      help="Define the parent of the root object")
    parser.add_option("-e", "--ephID",
                      action="store", type="string",
                      dest="ephID", default="",
                      help="Specify an ephemeralID to send with the object")
    parser.add_option("--metadata",
                      action="store",
                      dest="ext_metadata",
                      help="Define metadata to add to the scan or specify a file containing the metadata.")
    parser.add_option("--size-limit",
                      action="store", type="int", default=10,
                      dest="sizeLimit",
                      help="Specify a size limit in MB (default: 10)")
    parser.add_option("--file-limit",
                      action="store", type="int", default=0,
                      dest="fileLimit",
                      help="Specify a limited number of files to scan (default: off)")
    parser.add_option("--progress",
                      action="store_true",
                      dest="progress",
                      default=False,
                      help="enable the progress bar")
    (options, args) = parser.parse_args()
    
    logger = logging.getLogger()

    if options.debug:
        # stdout is added by default, we'll capture this object here
        #lhStdout = logger.handlers[0]
        fileHandler = logging.FileHandler('laika-debug.log', 'w')
        formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
        fileHandler.setFormatter(formatter)
        logger.addHandler(fileHandler)
        # remove stdout from handlers so that debug info is only written to the file
        #logger.removeHandler(lhStdout)
        logging.basicConfig(level=logging.DEBUG)
        logger.setLevel(logging.DEBUG)

    global EXT_METADATA
    if options.ext_metadata:
        if os.path.exists(options.ext_metadata):
            with open(options.ext_metadata) as metafile:
                EXT_METADATA = json.loads(metafile.read())
        else:
            EXT_METADATA = json.loads(options.ext_metadata)
    else:
        EXT_METADATA = getConfig("ext_metadata")
    
    global EPHID
    if options.ephID:
        EPHID = options.ephID
    else:
        EPHID = getConfig("ephID")

    global SCAN_MODULES
    if options.scan_modules:
        SCAN_MODULES = options.scan_modules.split()
    else:
        SCAN_MODULES = None
    logging.debug("SCAN_MODULES: %s"  % (SCAN_MODULES))

    global PROGRESS_BAR
    if options.progress:
        PROGRESS_BAR = 1
    else:
        PROGRESS_BAR = strtobool(getConfig('progress_bar'))
    logging.debug("PROGRESS_BAR: %s"  % (PROGRESS_BAR))

    global LOG_RESULT
    if options.log_result:
        LOG_RESULT = 1
    else:
        LOG_RESULT = strtobool(getConfig('log_result'))
    logging.debug("LOG_RESULT: %s" % (LOG_RESULT))

    global LOG_JSON
    if options.log_json:
        LOG_JSON = options.log_json
    else:
        LOG_JSON = getConfig('log_json')

    global NUM_PROCS
    if options.num_procs:
        NUM_PROCS = options.num_procs
    else:
        NUM_PROCS = int(getConfig('num_procs'))
    logging.debug("NUM_PROCS: %s"  % (NUM_PROCS))

    global MAX_BYTES
    if options.sizeLimit:
        MAX_BYTES = options.sizeLimit * 1024 * 1024
    else:
        MAX_BYTES = int(getConfig('max_bytes'))
    logging.debug("MAX_BYTES: %s"  % (MAX_BYTES))

    global MAX_FILES
    if options.fileLimit:
        MAX_FILES = options.fileLimit
    else:
        MAX_FILES = int(getConfig('max_files'))
    logging.debug("MAX_FILES: %s"  % (MAX_FILES))

    global SOURCE
    if options.source:
        SOURCE = options.source
    else:
        SOURCE = getConfig('source')

    global SAVE_PATH
    if options.save_path:
        SAVE_PATH = options.save_path
    else:
        SAVE_PATH = getConfig('save_path')

    global CONFIG_PATH
    # Highest priority configuration is via argument
    if options.config_path:
        CONFIG_PATH = options.config_path
        logging.debug("using alternative config path: %s" % options.config_path)
        if not os.path.exists(options.config_path):
            error("the provided config path is not valid, exiting")
            return 1
    # Next, check to see if we're in the top level source directory (dev environment)
    elif os.path.exists(default_configs['dev_config_path']):
        CONFIG_PATH = default_configs['dev_config_path']
    # Next, check for an installed copy of the default configuration
    elif os.path.exists(default_configs['sys_config_path']):
        CONFIG_PATH = default_configs['sys_config_path']
    # Exit
    else:
        error('A valid framework configuration was not found in either of the following locations:\
\n%s\n%s' % (default_configs['dev_config_path'],default_configs['sys_config_path']))
        return 1
       

    # Check for stdin in no arguments were provided
    if len(args) == 0:

        DATA_PATH = []

        if not sys.stdin.isatty():
            while True:
                f = sys.stdin.readline().strip()
                if not f:
                    break
                else:
                    if not os.path.isfile(f):
                        error("One of the specified files does not exist: %s" % (f))
                        return 1
                    if os.path.isdir(f):
                        error("One of the files you specified is actually a directory: %s" % (f))
                        return 1
                    DATA_PATH.append(f)

        if not DATA_PATH:
            error("You must provide files via stdin when no arguments are provided")
            return 1
        logging.debug("Loaded %s files from stdin" % (len(DATA_PATH)))
    elif len(args) == 1:
        if os.path.isdir(args[0]):
            DATA_PATH = args[0]
        elif os.path.isfile(args[0]):
            DATA_PATH = [args[0]]
        else:
            error("File or directory does not exist: %s" % (args[0]))
            return 1
    else:
        for f in args:
            if not os.path.isfile(f):
                error("One of the specified files does not exist: %s" % (f))
                return 1
            if os.path.isdir(f):
                error("One of the files you specified is actually a directory: %s" % (f))
                return 1
        
        DATA_PATH = args

   
    tasks = multiprocessing.JoinableQueue()
    results = multiprocessing.Queue()
    
    fileList = []
    if type(DATA_PATH) is str:
        for root, dirs, files in os.walk(DATA_PATH):
            files = [f for f in files if not f[0] == '.']
            dirs[:] = [d for d in dirs if not d[0] == '.']
            for fname in files:
                fullpath = os.path.join(root, fname)
                if not os.path.islink(fullpath) and os.path.isfile(fullpath):
                    fileList.append(fullpath)
    else:
        fileList = DATA_PATH

    if MAX_FILES:
        fileList = fileList[:MAX_FILES]

    num_jobs = len(fileList)
    logging.debug("Loaded %s files for scanning" % (num_jobs))
    
    # Start consumers
    # If there's less files to process than processes, reduce the number of processes
    if num_jobs < NUM_PROCS:
        NUM_PROCS = num_jobs
    logging.debug("Starting %s processes" % (NUM_PROCS))
    consumers = [ Consumer(tasks, results)
                  for i in xrange(NUM_PROCS) ]
    try:
        
        for w in consumers:
            w.start()

        # Enqueue jobs
        for fname in fileList:
            tasks.put(fname)
        
        # Add a poison pill for each consumer
        for i in xrange(NUM_PROCS):
            tasks.put(None)

        if PROGRESS_BAR:
            monitor = QueueMonitor(tasks, num_jobs)
            monitor.start()

        # Wait for all of the tasks to finish
        tasks.join()
        if PROGRESS_BAR:
            monitor.join()

        while num_jobs:
            answer = zlib.decompress(results.get())
            print(answer)
            num_jobs -= 1

    except KeyboardInterrupt:
        error("Cancelled by user.. Shutting down.")
        for w in consumers:
            w.terminate()
            w.join()
        return None
    except:
        raise

Example 27

Project: bumpversion
Source File: __init__.py
View license
def main(original_args=None):

    positionals, args = split_args_in_optional_and_positional(
      sys.argv[1:] if original_args is None else original_args
    )

    if len(positionals[1:]) > 2:
        warnings.warn("Giving multiple files on the command line will be deprecated, please use [bumpversion:file:...] in a config file.", PendingDeprecationWarning)

    parser1 = argparse.ArgumentParser(add_help=False)

    parser1.add_argument(
        '--config-file', metavar='FILE',
        default=argparse.SUPPRESS, required=False,
        help='Config file to read most of the variables from (default: .bumpversion.cfg)')

    parser1.add_argument(
        '--verbose', action='count', default=0,
        help='Print verbose logging to stderr', required=False)

    parser1.add_argument(
        '--list', action='store_true', default=False,
        help='List machine readable information', required=False)

    parser1.add_argument(
        '--allow-dirty', action='store_true', default=False,
        help="Don't abort if working directory is dirty", required=False)

    known_args, remaining_argv = parser1.parse_known_args(args)

    logformatter = logging.Formatter('%(message)s')

    if len(logger.handlers) == 0:
        ch = logging.StreamHandler(sys.stderr)
        ch.setFormatter(logformatter)
        logger.addHandler(ch)

    if len(logger_list.handlers) == 0:
       ch2 = logging.StreamHandler(sys.stdout)
       ch2.setFormatter(logformatter)
       logger_list.addHandler(ch2)

    if known_args.list:
          logger_list.setLevel(1)

    log_level = {
        0: logging.WARNING,
        1: logging.INFO,
        2: logging.DEBUG,
    }.get(known_args.verbose, logging.DEBUG)

    logger.setLevel(log_level)

    logger.debug("Starting {}".format(DESCRIPTION))

    defaults = {}
    vcs_info = {}

    for vcs in VCS:
        if vcs.is_usable():
            vcs_info.update(vcs.latest_tag_info())

    if 'current_version' in vcs_info:
        defaults['current_version'] = vcs_info['current_version']

    config = RawConfigParser('')

    # don't transform keys to lowercase (which would be the default)
    config.optionxform = lambda option: option

    config.add_section('bumpversion')

    explicit_config = hasattr(known_args, 'config_file')

    if explicit_config:
        config_file = known_args.config_file
    elif not os.path.exists('.bumpversion.cfg') and \
            os.path.exists('setup.cfg'):
        config_file = 'setup.cfg'
    else:
        config_file = '.bumpversion.cfg'

    config_file_exists = os.path.exists(config_file)

    part_configs = {}

    files = []

    if config_file_exists:

        logger.info("Reading config file {}:".format(config_file))
        logger.info(io.open(config_file, 'rt', encoding='utf-8').read())

        config.readfp(io.open(config_file, 'rt', encoding='utf-8'))

        log_config = StringIO()
        config.write(log_config)

        if 'files' in dict(config.items("bumpversion")):
            warnings.warn(
                "'files =' configuration is will be deprecated, please use [bumpversion:file:...]",
                PendingDeprecationWarning
            )

        defaults.update(dict(config.items("bumpversion")))

        for listvaluename in ("serialize",):
            try:
                value = config.get("bumpversion", listvaluename)
                defaults[listvaluename] = list(filter(None, (x.strip() for x in value.splitlines())))
            except NoOptionError:
                pass  # no default value then ;)

        for boolvaluename in ("commit", "tag", "dry_run"):
            try:
                defaults[boolvaluename] = config.getboolean(
                    "bumpversion", boolvaluename)
            except NoOptionError:
                pass  # no default value then ;)

        for section_name in config.sections():

            section_name_match = re.compile("^bumpversion:(file|part):(.+)").match(section_name)

            if not section_name_match:
                continue

            section_prefix, section_value = section_name_match.groups()

            section_config = dict(config.items(section_name))

            if section_prefix == "part":

                ThisVersionPartConfiguration = NumericVersionPartConfiguration

                if 'values' in section_config:
                    section_config['values'] = list(filter(None, (x.strip() for x in section_config['values'].splitlines())))
                    ThisVersionPartConfiguration = ConfiguredVersionPartConfiguration

                part_configs[section_value] = ThisVersionPartConfiguration(**section_config)

            elif section_prefix == "file":

                filename = section_value

                if 'serialize' in section_config:
                    section_config['serialize'] = list(filter(None, (x.strip() for x in section_config['serialize'].splitlines())))

                section_config['part_configs'] = part_configs

                if not 'parse' in section_config:
                    section_config['parse'] = defaults.get("parse", '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)')

                if not 'serialize' in section_config:
                    section_config['serialize'] = defaults.get('serialize', [str('{major}.{minor}.{patch}')])

                if not 'search' in section_config:
                    section_config['search'] = defaults.get("search", '{current_version}')

                if not 'replace' in section_config:
                    section_config['replace'] = defaults.get("replace", '{new_version}')

                files.append(ConfiguredFile(filename, VersionConfig(**section_config)))

    else:
        message = "Could not read config file at {}".format(config_file)
        if explicit_config:
            raise argparse.ArgumentTypeError(message)
        else:
            logger.info(message)

    parser2 = argparse.ArgumentParser(prog='bumpversion', add_help=False, parents=[parser1])
    parser2.set_defaults(**defaults)

    parser2.add_argument('--current-version', metavar='VERSION',
                         help='Version that needs to be updated', required=False)
    parser2.add_argument('--parse', metavar='REGEX',
                         help='Regex parsing the version string',
                         default=defaults.get("parse", '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)'))
    parser2.add_argument('--serialize', metavar='FORMAT',
                         action=DiscardDefaultIfSpecifiedAppendAction,
                         help='How to format what is parsed back to a version',
                         default=defaults.get("serialize", [str('{major}.{minor}.{patch}')]))
    parser2.add_argument('--search', metavar='SEARCH',
                         help='Template for complete string to search',
                         default=defaults.get("search", '{current_version}'))
    parser2.add_argument('--replace', metavar='REPLACE',
                         help='Template for complete string to replace',
                         default=defaults.get("replace", '{new_version}'))

    known_args, remaining_argv = parser2.parse_known_args(args)

    defaults.update(vars(known_args))

    assert type(known_args.serialize) == list

    context = dict(list(time_context.items()) + list(prefixed_environ().items()) + list(vcs_info.items()))

    try:
        vc = VersionConfig(
            parse=known_args.parse,
            serialize=known_args.serialize,
            search=known_args.search,
            replace=known_args.replace,
            part_configs=part_configs,
        )
    except sre_constants.error as e:
        sys.exit(1)

    current_version = vc.parse(known_args.current_version) if known_args.current_version else None

    new_version = None

    if not 'new_version' in defaults and known_args.current_version:
        try:
            if current_version and len(positionals) > 0:
                logger.info("Attempting to increment part '{}'".format(positionals[0]))
                new_version = current_version.bump(positionals[0], vc.order())
                logger.info("Values are now: " + keyvaluestring(new_version._values))
                defaults['new_version'] = vc.serialize(new_version, context)
        except MissingValueForSerializationException as e:
            logger.info("Opportunistic finding of new_version failed: " + e.message)
        except IncompleteVersionRepresenationException as e:
            logger.info("Opportunistic finding of new_version failed: " + e.message)
        except KeyError as e:
            logger.info("Opportunistic finding of new_version failed")

    parser3 = argparse.ArgumentParser(
        prog='bumpversion',
        description=DESCRIPTION,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        conflict_handler='resolve',
        parents=[parser2],
    )

    parser3.set_defaults(**defaults)

    parser3.add_argument('--current-version', metavar='VERSION',
                         help='Version that needs to be updated',
                         required=not 'current_version' in defaults)
    parser3.add_argument('--dry-run', '-n', action='store_true',
                         default=False, help="Don't write any files, just pretend.")
    parser3.add_argument('--new-version', metavar='VERSION',
                         help='New version that should be in the files',
                         required=not 'new_version' in defaults)

    commitgroup = parser3.add_mutually_exclusive_group()

    commitgroup.add_argument('--commit', action='store_true', dest="commit",
                             help='Commit to version control', default=defaults.get("commit", False))
    commitgroup.add_argument('--no-commit', action='store_false', dest="commit",
                             help='Do not commit to version control', default=argparse.SUPPRESS)

    taggroup = parser3.add_mutually_exclusive_group()

    taggroup.add_argument('--tag', action='store_true', dest="tag", default=defaults.get("tag", False),
                          help='Create a tag in version control')
    taggroup.add_argument('--no-tag', action='store_false', dest="tag",
                          help='Do not create a tag in version control', default=argparse.SUPPRESS)

    parser3.add_argument('--tag-name', metavar='TAG_NAME',
                         help='Tag name (only works with --tag)',
                         default=defaults.get('tag_name', 'v{new_version}'))

    parser3.add_argument('--message', '-m', metavar='COMMIT_MSG',
                         help='Commit message',
                         default=defaults.get('message', 'Bump version: {current_version} → {new_version}'))


    file_names = []
    if 'files' in defaults:
        assert defaults['files'] != None
        file_names = defaults['files'].split(' ')

    parser3.add_argument('part',
                         help='Part of the version to be bumped.')
    parser3.add_argument('files', metavar='file',
                         nargs='*',
                         help='Files to change', default=file_names)

    args = parser3.parse_args(remaining_argv + positionals)

    if args.dry_run:
        logger.info("Dry run active, won't touch any files.")
    
    if args.new_version:
        new_version = vc.parse(args.new_version)

    logger.info("New version will be '{}'".format(args.new_version))

    file_names = file_names or positionals[1:]

    for file_name in file_names:
        files.append(ConfiguredFile(file_name, vc))

    for vcs in VCS:
        if vcs.is_usable():
            try:
                vcs.assert_nondirty()
            except WorkingDirectoryIsDirtyException as e:
                if not defaults['allow_dirty']:
                    logger.warn(
                        "{}\n\nUse --allow-dirty to override this if you know what you're doing.".format(e.message))
                    raise
            break
        else:
            vcs = None

    # make sure files exist and contain version string

    logger.info("Asserting files {} contain the version string:".format(", ".join([str(f) for f in files])))

    for f in files:
        f.should_contain_version(current_version, context)

    # change version string in files
    for f in files:
        f.replace(current_version, new_version, context, args.dry_run)

    commit_files = [f.path for f in files]

    config.set('bumpversion', 'new_version', args.new_version)

    for key, value in config.items('bumpversion'):
        logger_list.info("{}={}".format(key, value))

    config.remove_option('bumpversion', 'new_version')

    config.set('bumpversion', 'current_version', args.new_version)

    new_config = StringIO()

    try:
        write_to_config_file = (not args.dry_run) and config_file_exists

        logger.info("{} to config file {}:".format(
            "Would write" if not write_to_config_file else "Writing",
            config_file,
        ))

        config.write(new_config)
        logger.info(new_config.getvalue())

        if write_to_config_file:
            with io.open(config_file, 'wb') as f:
                f.write(new_config.getvalue().encode('utf-8'))

    except UnicodeEncodeError:
        warnings.warn(
            "Unable to write UTF-8 to config file, because of an old configparser version. "
            "Update with `pip install --upgrade configparser`."
        )

    if config_file_exists:
        commit_files.append(config_file)

    if not vcs:
        return

    assert vcs.is_usable(), "Did find '{}' unusable, unable to commit.".format(vcs.__name__)

    do_commit = (not args.dry_run) and args.commit
    do_tag = (not args.dry_run) and args.tag

    logger.info("{} {} commit".format(
        "Would prepare" if not do_commit else "Preparing",
        vcs.__name__,
    ))

    for path in commit_files:
        logger.info("{} changes in file '{}' to {}".format(
            "Would add" if not do_commit else "Adding",
            path,
            vcs.__name__,
        ))

        if do_commit:
            vcs.add_path(path)

    vcs_context = {
        "current_version": args.current_version,
        "new_version": args.new_version,
    }
    vcs_context.update(time_context)
    vcs_context.update(prefixed_environ())

    commit_message = args.message.format(**vcs_context)

    logger.info("{} to {} with message '{}'".format(
        "Would commit" if not do_commit else "Committing",
        vcs.__name__,
        commit_message,
    ))

    if do_commit:
        vcs.commit(message=commit_message)

    tag_name = args.tag_name.format(**vcs_context)
    logger.info("{} '{}' in {}".format(
        "Would tag" if not do_tag else "Tagging",
        tag_name,
        vcs.__name__
    ))

    if do_tag:
        vcs.tag(tag_name)

Example 28

Project: laikaboss
Source File: laikad.py
View license
def main():
    '''Main program logic. Becomes the supervisor process.'''
    parser = OptionParser(usage="usage: %prog [options]\n"
        "Default settings in config file: laikad.conf")

    parser.add_option("-d", "--debug",
                      action="store_true", default=False,
                      dest="debug",
                      help="enable debug messages to the console.")
    parser.add_option("-s", "--scan-config",
                      action="store", type="string",
                      dest="laikaboss_config_path",
                      help="specify a path for laikaboss configuration")
    parser.add_option("-c", "--laikad-config",
                      action="store", type="string",
                      dest="laikad_config_path",
                      help="specify a path for laikad configuration")
    parser.add_option("-b", "--broker-backend",
                      action="store", type="string",
                      dest="broker_backend_address",
                      help="specify an address for the workers to connect to. "
                      "ex: tcp://*:5559")
    parser.add_option("-f", "--broker-frontend",
                      action="store", type="string",
                      dest="broker_frontend_address",
                      help="specify an address for clients to connect to. ex: "
                      "tcp://*:5558")
    parser.add_option("-w", "--worker-connect",
                      action="store", type="string",
                      dest="worker_connect_address",
                      help="specify an address for clients to connect to. ex: "
                      "tcp://localhost:5559")
    parser.add_option("-n", "--no-broker",
                      action="store_true", default=False,
                      dest="no_broker",
                      help="specify this option to disable the broker for this "
                      "instance.")
    parser.add_option("-i", "--id",
                      action="store", type="string",
                      dest="runas_uid",
                      help="specify a valid username to switch to after starting "
                      "as root.")
    parser.add_option("-p", "--processes",
                      action="store", type="int",
                      dest="num_procs",
                      help="specify the number of workers to launch with this "
                      "daemon")
    parser.add_option("-r", "--restart-after",
                      action="store", type="int",
                      dest="ttl",
                      help="restart worker after scanning this many items")
    parser.add_option("-t", "--restart-after-min",
                      action="store", type="int",
                      dest="time_ttl",
                      help="restart worker after scanning for this many "
                      "minutes.")
    parser.add_option("-a", "--async",
                      action="store_true", default=False,
                      dest="run_async",
                      help="enable async messages. "
                      "This will disable any responses back to the client.")
    parser.add_option("-g", "--grace-timeout",
                      action="store", type="int",
                      dest="gracetimeout",
                      help="when shutting down, the timeout to allow workers to"
                      " finish ongoing scans before being killed")
    (options, _) = parser.parse_args()

    # Set the configuration file path for laikad
    config_location = '/etc/laikaboss/laikad.conf'
    if options.laikad_config_path:
        config_location = options.laikad_config_path
        if not os.path.exists(options.laikad_config_path):
            print "the provided config path is not valid, exiting"
            return 1
    # Next, check to see if we're in the top level source directory (dev environment)
    elif os.path.exists(DEFAULT_CONFIGS['laikad_dev_config_path']):
        config_location = DEFAULT_CONFIGS['laikad_dev_config_path']
    # Next, check for an installed copy of the default configuration
    elif os.path.exists(DEFAULT_CONFIGS['laikad_sys_config_path']):
        config_location = DEFAULT_CONFIGS['laikad_sys_config_path']
    # Exit
    else:
        print 'A valid laikad configuration was not found in either of the following locations:\
\n%s\n%s' % (DEFAULT_CONFIGS['laikad_dev_config_path'],DEFAULT_CONFIGS['laikad_sys_config_path'])
        return 1
    
    # Read the laikad config file
    config_parser = ConfigParser()
    config_parser.read(config_location)

    # Parse through the config file and append each section to a single dict
    for section in config_parser.sections():
        CONFIGS.update(dict(config_parser.items(section)))

    # We need a default framework config at a minimum
    if options.laikaboss_config_path:
        laikaboss_config_path = options.laikaboss_config_path
        logging.debug("using alternative config path: %s" % options.laikaboss_config_path)
        if not os.path.exists(options.laikaboss_config_path):
            print "the provided config path is not valid, exiting"
            return 1
    #Next, check for a config path in the laikad config
    elif os.path.exists(get_option('configpath')):
        laikaboss_config_path = get_option('configpath')
    # Next, check to see if we're in the top level source directory (dev environment)
    elif os.path.exists(DEFAULT_CONFIGS['dev_config_path']):
        laikaboss_config_path = DEFAULT_CONFIGS['dev_config_path']
    # Next, check for an installed copy of the default configuration
    elif os.path.exists(DEFAULT_CONFIGS['sys_config_path']):
        laikaboss_config_path = DEFAULT_CONFIGS['sys_config_path']
    # Exit
    else:
        print 'A valid framework configuration was not found in either of the following locations:\
\n%s\n%s' % (DEFAULT_CONFIGS['dev_config_path'],DEFAULT_CONFIGS['sys_config_path'])
        return 1

    if options.num_procs:
        num_procs = options.num_procs
    else:
        num_procs = int(get_option('numprocs'))

    if options.ttl:
        ttl = options.ttl
    else:
        ttl = int(get_option('ttl'))

    if options.time_ttl:
        time_ttl = options.time_ttl
    else:
        time_ttl = int(get_option('time_ttl'))

    if options.broker_backend_address:
        broker_backend_address = options.broker_backend_address
    else:
        broker_backend_address = get_option('brokerbackend')

    if options.broker_frontend_address:
        broker_frontend_address = options.broker_frontend_address
    else:
        broker_frontend_address = get_option('brokerfrontend')

    if options.worker_connect_address:
        worker_connect_address = options.worker_connect_address
    else:
        worker_connect_address = get_option('workerconnect')

    if options.gracetimeout:
        gracetimeout = options.gracetimeout
    else:
        gracetimeout = int(get_option('gracetimeout'))

    if options.run_async:
        async = True
    else:
        async = strtobool(get_option('async'))
   
    logresult = strtobool(get_option('log_result'))

    # Get the UserID to run as, if it was not specified on the command line
    # we'll use the current user by default
    runas_uid = None
    runas_gid = None

    if options.runas_uid:
        from pwd import getpwnam
        runas_uid = getpwnam(options.runas_uid).pw_uid
        runas_gid = getpwnam(options.runas_uid).pw_gid

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)

    # Lower privileges if a UID has been set
    try:
        if runas_uid:
            os.setgid(runas_gid)
            os.setuid(runas_uid)
    except OSError:
        print "Unable to set user ID to %i, defaulting to current user" % runas_uid

    # Add intercept for graceful shutdown
    def shutdown(signum, frame):
        '''Signal handler for shutting down supervisor gracefully'''
        logging.debug("Supervisor: shutdown handler triggered")
        global KEEP_RUNNING
        KEEP_RUNNING = False
    signal.signal(signal.SIGTERM, shutdown)
    signal.signal(signal.SIGINT, shutdown)

    # Start the broker
    broker_proc = None
    if not options.no_broker:
        if async:
            broker_proc = AsyncBroker(broker_backend_address, broker_frontend_address)
        else:
            broker_proc = SyncBroker(broker_backend_address, broker_frontend_address, gracetimeout)
        broker_proc.start()

    # Start the workers
    workers = []
    for _ in range(num_procs):
        worker_proc = Worker(laikaboss_config_path, worker_connect_address, ttl,
            time_ttl, logresult, int(get_option('workerpolltimeout')), gracetimeout)
        worker_proc.start()
        workers.append(worker_proc)

    while KEEP_RUNNING:
        # Ensure we have a broker
        if not options.no_broker and not broker_proc.is_alive():
            if async:
                broker_proc = AsyncBroker(broker_backend_address, broker_frontend_address)
            else:
                broker_proc = SyncBroker(broker_backend_address, broker_frontend_address,
                    gracetimeout)
            broker_proc.start()

        # Ensure we have living workers
        dead_workers = []
        for worker_proc in workers:
            if not worker_proc.is_alive():
                dead_workers.append(worker_proc)

        for worker_proc in dead_workers:
            workers.remove(worker_proc)
            new_proc = Worker(laikaboss_config_path, worker_connect_address, ttl, time_ttl,
                logresult, int(get_option('workerpolltimeout')), gracetimeout)
            new_proc.start()
            workers.append(new_proc)
            worker_proc.join()

        # Wait a little bit
        time.sleep(5)

    logging.debug("Supervisor: beginning graceful shutdown sequence")
    logging.info("Supervisor: giving workers %d second grace period", gracetimeout)
    time.sleep(gracetimeout)
    logging.info("Supervisor: terminating workers")
    for worker_proc in workers:
        if worker_proc.is_alive():
            os.kill(worker_proc.pid, signal.SIGKILL)
    for worker_proc in workers:
        worker_proc.join()
    if not options.no_broker:
        if broker_proc.is_alive():
            os.kill(broker_proc.pid, signal.SIGKILL)
        broker_proc.join()
    logging.debug("Supervisor: finished")

Example 29

Project: tracpy
Source File: plotting.py
View license
def hist(lonp, latp, fname, tind='final', which='contour', vmax=None,
         fig=None, ax=None, bins=(40, 40), N=10, grid=None, xlims=None,
         ylims=None, C=None, Title=None, weights=None,
         Label='Final drifter location (%)', isll=True, binscale=None):
    """
    Plot histogram of given track data at time index tind.

    Args:
        lonp,latp: Drifter track positions in lon/lat [time x ndrifters]
        fname: Plot name to save
        tind (Optional): Default is 'final', in which case the final
         position of each drifter in the array is found and plotted.
         Alternatively, a time index can be input and drifters at that time
         will be plotted. Note that once drifters hit the outer numerical
         boundary, they are nan'ed out so this may miss some drifters.
        which (Optional[str]): 'contour', 'pcolor', 'hexbin', 'hist2d' for
         type of plot used. Default 'hexbin'.
        bins (Optional): Number of bins used in histogram. Default (15,25).
        N (Optional[int]): Number of contours to make. Default 10.
        grid (Optional): grid as read in by inout.readgrid()
        xlims (Optional): value limits on the x axis
        ylims (Optional): value limits on the y axis
        isll: Default True. Inputs are in lon/lat. If False, assume they
         are in projected coords.

    Note:
        Currently assuming we are plotting the final location of each drifter
        regardless of tind.
    """

    if grid is None:
        loc = 'http://barataria.tamu.edu:8080/thredds/dodsC/NcML/txla_nesting6.nc'
        grid = inout.readgrid(loc)

    if isll:  # if inputs are in lon/lat, change to projected x/y
        # Change positions from lon/lat to x/y
        xp, yp = grid.proj(lonp, latp)
        # Need to retain nan's since basemap changes them to values
        ind = np.isnan(lonp)
        xp[ind] = np.nan
        yp[ind] = np.nan
    else:
        xp = lonp
        yp = latp

    if fig is None:
        fig = plt.figure(figsize=(11, 10))
    else:
        fig = fig
    background(grid)  # Plot coastline and such

    if tind == 'final':
        # Find final positions of drifters
        xpc, ypc = tools.find_final(xp, yp)
    elif isinstance(tind, int):
        xpc = xp[:, tind]
        ypc = yp[:, tind]
    else:  # just plot what is input if some other string
        xpc = xp.flatten()
        ypc = yp.flatten()

    if which == 'contour':

        # Info for 2d histogram
        H, xedges, yedges = np.histogram2d(xpc, ypc,
                                           range=[[grid.x_rho.min(),
                                                   grid.x_rho.max()],
                                                  [grid.y_rho.min(),
                                                   grid.y_rho.max()]],
                                           bins=bins)

        # Contour Plot
        XE, YE = np.meshgrid(op.resize(xedges, 0), op.resize(yedges, 0))
        d = (H/H.sum())*100
        # # from http://matplotlib.1069221.n5.nabble.com/question-about-contours-and-clim-td21111.html
        # locator = ticker.MaxNLocator(50) # if you want no more than 10 contours
        # locator.create_dummy_axis()
        # locator.set_bounds(0,1)#d.min(),d.max())
        # levs = locator()
        con = fig.contourf(XE, YE, d.T, N)  # ,levels=levs)#(0,15,30,45,60,75,90,105,120))
        con.set_cmap('YlOrRd')

        if Title is not None:
            plt.set_title(Title)

        # Horizontal colorbar below plot
        cax = fig.add_axes([0.3725, 0.25, 0.48, 0.02])  # colorbar axes
        cb = fig.colorbar(con, cax=cax, orientation='horizontal')
        cb.set_label('Final drifter location (percent)')

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'histcon.png', bbox_inches='tight')

    elif which == 'pcolor':

        # Info for 2d histogram
        H, xedges, yedges = np.histogram2d(xpc, ypc,
                                           range=[[grid.x_rho.min(),
                                                   grid.x_rho.max()],
                                                  [grid.y_rho.min(),
                                                   grid.y_rho.max()]],
                                           bins=bins, weights=weights)

        # Pcolor plot
        # C is the z value plotted, and is normalized by the total number of
        # drifters
        if C is None:
            C = (H.T/H.sum())*100
        else:
            # or, provide some other weighting
            C = (H.T/C)*100

        p = plt.pcolor(xedges, yedges, C, cmap='YlOrRd')

        if Title is not None:
            plt.set_title(Title)

        # Set x and y limits
        if xlims is not None:
            plt.xlim(xlims)
        if ylims is not None:
            plt.ylim(ylims)

        # Horizontal colorbar below plot
        cax = fig.add_axes([0.3775, 0.25, 0.48, 0.02])  # colorbar axes
        cb = fig.colorbar(p, cax=cax, orientation='horizontal')
        cb.set_label('Final drifter location (percent)')

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'histpcolor.png', bbox_inches='tight')
        # savefig('figures/' + fname + 'histpcolor.pdf',bbox_inches='tight')

    elif which == 'hexbin':

        if ax is None:
            ax = plt.gca()
        else:
            ax = ax

        if C is None:
            # C with the reduce_C_function as sum is what makes it a percent
            C = np.ones(len(xpc))*(1./len(xpc))*100
        else:
            C = C*np.ones(len(xpc))*100
        hb = plt.hexbin(xpc, ypc, C=C, cmap='YlOrRd', gridsize=bins[0],
                    extent=(grid.x_psi.min(), grid.x_psi.max(),
                            grid.y_psi.min(), grid.y_psi.max()),
                    reduce_C_function=sum, vmax=vmax, axes=ax, bins=binscale)

        # Set x and y limits
        if xlims is not None:
            plt.xlim(xlims)
        if ylims is not None:
            plt.ylim(ylims)

        if Title is not None:
            ax.set_title(Title)

        # Want colorbar at the given location relative to axis so this works
        # regardless of # of subplots, so convert from axis to figure
        # coordinates. To do this, first convert from axis to display coords
        # transformations:
        # http://matplotlib.org/users/transforms_tutorial.html
        # axis: [x_left, y_bottom, width, height]
        ax_coords = [0.35, 0.25, 0.6, 0.02]
        # display: [x_left,y_bottom,x_right,y_top]
        disp_coords = ax.transAxes.transform([(ax_coords[0], ax_coords[1]),
                                              (ax_coords[0]+ax_coords[2],
                                               ax_coords[1]+ax_coords[3])])
        # inverter object to go from display coords to figure coords
        inv = fig.transFigure.inverted()
        # figure: [x_left,y_bottom,x_right,y_top]
        fig_coords = inv.transform(disp_coords)
        # actual desired figure coords. figure:
        # [x_left, y_bottom, width, height]
        fig_coords = [fig_coords[0, 0], fig_coords[0, 1], fig_coords[1, 0] -
                      fig_coords[0, 0], fig_coords[1, 1] - fig_coords[0, 1]]
        # Inlaid colorbar
        cax = fig.add_axes(fig_coords)

        # # Horizontal colorbar below plot
        # cax = fig.add_axes([0.3775, 0.25, 0.48, 0.02]) # colorbar axes
        cb = fig.colorbar(hb, cax=cax, orientation='horizontal')
        cb.set_label(Label)

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'histhexbin.png', bbox_inches='tight')
        # savefig('figures/' + fname + 'histhexbin.pdf',bbox_inches='tight')

    elif which == 'hist2d':

        plt.hist2d(xpc, ypc, bins=40, range=[[grid.x_rho.min(),
                                          grid.x_rho.max()],
                                         [grid.y_rho.min(),
                                          grid.y_rho.max()]], normed=True)
        plt.set_cmap('YlOrRd')
        # Set x and y limits
        if xlims is not None:
            xlim(xlims)
        if ylims is not None:
            ylim(ylims)

        # Horizontal colorbar below plot
        cax = fig.add_axes([0.3775, 0.25, 0.48, 0.02])  # colorbar axes
        cb = fig.colorbar(cax=cax, orientation='horizontal')
        cb.set_label('Final drifter location (percent)')

        # Save figure into a local directory called figures. Make directory
        # if it doesn't exist.
        if not os.path.exists('figures'):
            os.makedirs('figures')

        fig.savefig('figures/' + fname + 'hist2d.png', bbox_inches='tight')

Example 30

Project: laikaboss
Source File: laikad.py
View license
def main():
    '''Main program logic. Becomes the supervisor process.'''
    parser = OptionParser(usage="usage: %prog [options]\n"
        "Default settings in config file: laikad.conf")

    parser.add_option("-d", "--debug",
                      action="store_true", default=False,
                      dest="debug",
                      help="enable debug messages to the console.")
    parser.add_option("-s", "--scan-config",
                      action="store", type="string",
                      dest="laikaboss_config_path",
                      help="specify a path for laikaboss configuration")
    parser.add_option("-c", "--laikad-config",
                      action="store", type="string",
                      dest="laikad_config_path",
                      help="specify a path for laikad configuration")
    parser.add_option("-b", "--broker-backend",
                      action="store", type="string",
                      dest="broker_backend_address",
                      help="specify an address for the workers to connect to. "
                      "ex: tcp://*:5559")
    parser.add_option("-f", "--broker-frontend",
                      action="store", type="string",
                      dest="broker_frontend_address",
                      help="specify an address for clients to connect to. ex: "
                      "tcp://*:5558")
    parser.add_option("-w", "--worker-connect",
                      action="store", type="string",
                      dest="worker_connect_address",
                      help="specify an address for clients to connect to. ex: "
                      "tcp://localhost:5559")
    parser.add_option("-n", "--no-broker",
                      action="store_true", default=False,
                      dest="no_broker",
                      help="specify this option to disable the broker for this "
                      "instance.")
    parser.add_option("-i", "--id",
                      action="store", type="string",
                      dest="runas_uid",
                      help="specify a valid username to switch to after starting "
                      "as root.")
    parser.add_option("-p", "--processes",
                      action="store", type="int",
                      dest="num_procs",
                      help="specify the number of workers to launch with this "
                      "daemon")
    parser.add_option("-r", "--restart-after",
                      action="store", type="int",
                      dest="ttl",
                      help="restart worker after scanning this many items")
    parser.add_option("-t", "--restart-after-min",
                      action="store", type="int",
                      dest="time_ttl",
                      help="restart worker after scanning for this many "
                      "minutes.")
    parser.add_option("-a", "--async",
                      action="store_true", default=False,
                      dest="run_async",
                      help="enable async messages. "
                      "This will disable any responses back to the client.")
    parser.add_option("-g", "--grace-timeout",
                      action="store", type="int",
                      dest="gracetimeout",
                      help="when shutting down, the timeout to allow workers to"
                      " finish ongoing scans before being killed")
    (options, _) = parser.parse_args()

    # Set the configuration file path for laikad
    config_location = '/etc/laikaboss/laikad.conf'
    if options.laikad_config_path:
        config_location = options.laikad_config_path
        if not os.path.exists(options.laikad_config_path):
            print "the provided config path is not valid, exiting"
            return 1
    # Next, check to see if we're in the top level source directory (dev environment)
    elif os.path.exists(DEFAULT_CONFIGS['laikad_dev_config_path']):
        config_location = DEFAULT_CONFIGS['laikad_dev_config_path']
    # Next, check for an installed copy of the default configuration
    elif os.path.exists(DEFAULT_CONFIGS['laikad_sys_config_path']):
        config_location = DEFAULT_CONFIGS['laikad_sys_config_path']
    # Exit
    else:
        print 'A valid laikad configuration was not found in either of the following locations:\
\n%s\n%s' % (DEFAULT_CONFIGS['laikad_dev_config_path'],DEFAULT_CONFIGS['laikad_sys_config_path'])
        return 1
    
    # Read the laikad config file
    config_parser = ConfigParser()
    config_parser.read(config_location)

    # Parse through the config file and append each section to a single dict
    for section in config_parser.sections():
        CONFIGS.update(dict(config_parser.items(section)))

    # We need a default framework config at a minimum
    if options.laikaboss_config_path:
        laikaboss_config_path = options.laikaboss_config_path
        logging.debug("using alternative config path: %s" % options.laikaboss_config_path)
        if not os.path.exists(options.laikaboss_config_path):
            print "the provided config path is not valid, exiting"
            return 1
    #Next, check for a config path in the laikad config
    elif os.path.exists(get_option('configpath')):
        laikaboss_config_path = get_option('configpath')
    # Next, check to see if we're in the top level source directory (dev environment)
    elif os.path.exists(DEFAULT_CONFIGS['dev_config_path']):
        laikaboss_config_path = DEFAULT_CONFIGS['dev_config_path']
    # Next, check for an installed copy of the default configuration
    elif os.path.exists(DEFAULT_CONFIGS['sys_config_path']):
        laikaboss_config_path = DEFAULT_CONFIGS['sys_config_path']
    # Exit
    else:
        print 'A valid framework configuration was not found in either of the following locations:\
\n%s\n%s' % (DEFAULT_CONFIGS['dev_config_path'],DEFAULT_CONFIGS['sys_config_path'])
        return 1

    if options.num_procs:
        num_procs = options.num_procs
    else:
        num_procs = int(get_option('numprocs'))

    if options.ttl:
        ttl = options.ttl
    else:
        ttl = int(get_option('ttl'))

    if options.time_ttl:
        time_ttl = options.time_ttl
    else:
        time_ttl = int(get_option('time_ttl'))

    if options.broker_backend_address:
        broker_backend_address = options.broker_backend_address
    else:
        broker_backend_address = get_option('brokerbackend')

    if options.broker_frontend_address:
        broker_frontend_address = options.broker_frontend_address
    else:
        broker_frontend_address = get_option('brokerfrontend')

    if options.worker_connect_address:
        worker_connect_address = options.worker_connect_address
    else:
        worker_connect_address = get_option('workerconnect')

    if options.gracetimeout:
        gracetimeout = options.gracetimeout
    else:
        gracetimeout = int(get_option('gracetimeout'))

    if options.run_async:
        async = True
    else:
        async = strtobool(get_option('async'))
   
    logresult = strtobool(get_option('log_result'))

    # Get the UserID to run as, if it was not specified on the command line
    # we'll use the current user by default
    runas_uid = None
    runas_gid = None

    if options.runas_uid:
        from pwd import getpwnam
        runas_uid = getpwnam(options.runas_uid).pw_uid
        runas_gid = getpwnam(options.runas_uid).pw_gid

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)

    # Lower privileges if a UID has been set
    try:
        if runas_uid:
            os.setgid(runas_gid)
            os.setuid(runas_uid)
    except OSError:
        print "Unable to set user ID to %i, defaulting to current user" % runas_uid

    # Add intercept for graceful shutdown
    def shutdown(signum, frame):
        '''Signal handler for shutting down supervisor gracefully'''
        logging.debug("Supervisor: shutdown handler triggered")
        global KEEP_RUNNING
        KEEP_RUNNING = False
    signal.signal(signal.SIGTERM, shutdown)
    signal.signal(signal.SIGINT, shutdown)

    # Start the broker
    broker_proc = None
    if not options.no_broker:
        if async:
            broker_proc = AsyncBroker(broker_backend_address, broker_frontend_address)
        else:
            broker_proc = SyncBroker(broker_backend_address, broker_frontend_address, gracetimeout)
        broker_proc.start()

    # Start the workers
    workers = []
    for _ in range(num_procs):
        worker_proc = Worker(laikaboss_config_path, worker_connect_address, ttl,
            time_ttl, logresult, int(get_option('workerpolltimeout')), gracetimeout)
        worker_proc.start()
        workers.append(worker_proc)

    while KEEP_RUNNING:
        # Ensure we have a broker
        if not options.no_broker and not broker_proc.is_alive():
            if async:
                broker_proc = AsyncBroker(broker_backend_address, broker_frontend_address)
            else:
                broker_proc = SyncBroker(broker_backend_address, broker_frontend_address,
                    gracetimeout)
            broker_proc.start()

        # Ensure we have living workers
        dead_workers = []
        for worker_proc in workers:
            if not worker_proc.is_alive():
                dead_workers.append(worker_proc)

        for worker_proc in dead_workers:
            workers.remove(worker_proc)
            new_proc = Worker(laikaboss_config_path, worker_connect_address, ttl, time_ttl,
                logresult, int(get_option('workerpolltimeout')), gracetimeout)
            new_proc.start()
            workers.append(new_proc)
            worker_proc.join()

        # Wait a little bit
        time.sleep(5)

    logging.debug("Supervisor: beginning graceful shutdown sequence")
    logging.info("Supervisor: giving workers %d second grace period", gracetimeout)
    time.sleep(gracetimeout)
    logging.info("Supervisor: terminating workers")
    for worker_proc in workers:
        if worker_proc.is_alive():
            os.kill(worker_proc.pid, signal.SIGKILL)
    for worker_proc in workers:
        worker_proc.join()
    if not options.no_broker:
        if broker_proc.is_alive():
            os.kill(broker_proc.pid, signal.SIGKILL)
        broker_proc.join()
    logging.debug("Supervisor: finished")

Example 31

Project: SublimePythonIDE
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    import cPickle as pickle
    import marshal
    import inspect
    import types
    import threading

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('w')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()


    def _cached(func):
        cache = {}
        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    args.append(self._object_to_persisted_form(frame.f_locals[argname]))
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)

            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                   _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, (str, unicode)):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list', self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    keys = object_.keys()[0]
                    values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set', self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.func_code)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.im_func.func_code)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, (str, unicode, list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, (types.TypeType, types.ClassType)):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()
            sys.settrace(None)

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        '__builtins__': __builtins__,
                        '__file__': file_to_run})
    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    execfile(file_to_run, run_globals)
    if send_info != '-':
        data_sender.close()

Example 32

Project: SublimeRope
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    import cPickle as pickle
    import marshal
    import inspect
    import types
    import threading

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('w')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()


    def _cached(func):
        cache = {}
        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    args.append(self._object_to_persisted_form(frame.f_locals[argname]))
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)

            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                   _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, (str, unicode)):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list', self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    keys = object_.keys()[0]
                    values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set', self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.func_code)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.im_func.func_code)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, (str, unicode, list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, (types.TypeType, types.ClassType)):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()
            sys.settrace(None)

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        '__builtins__': __builtins__,
                        '__file__': file_to_run})
    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    execfile(file_to_run, run_globals)
    if send_info != '-':
        data_sender.close()

Example 33

Project: tp-libvirt
Source File: libvirt_rng.py
View license
def run(test, params, env):
    """
    Test rng device options.

    1.Prepare test environment, destroy or suspend a VM.
    2.Edit xml and start the domain.
    3.Perform test operation.
    4.Recover test environment.
    5.Confirm the test result.
    """
    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)

    def modify_rng_xml(dparams, sync=True):
        """
        Modify interface xml options
        """
        rng_model = dparams.get("rng_model", "virtio")
        rng_rate = dparams.get("rng_rate")
        backend_model = dparams.get("backend_model", "random")
        backend_type = dparams.get("backend_type")
        backend_dev = dparams.get("backend_dev", "")
        backend_source_list = dparams.get("backend_source",
                                          "").split()
        backend_protocol = dparams.get("backend_protocol")
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        rng_xml = rng.Rng()
        rng_xml.rng_model = rng_model
        if rng_rate:
            rng_xml.rate = ast.literal_eval(rng_rate)
        backend = rng.Rng.Backend()
        backend.backend_model = backend_model
        if backend_type:
            backend.backend_type = backend_type
        if backend_dev:
            backend.backend_dev = backend_dev
        if backend_source_list:
            source_list = [ast.literal_eval(source) for source in
                           backend_source_list]
            backend.source = source_list
        if backend_protocol:
            backend.backend_protocol = backend_protocol
        rng_xml.backend = backend

        logging.debug("Rng xml: %s", rng_xml)
        if sync:
            vmxml.add_device(rng_xml)
            vmxml.xmltreefile.write()
            vmxml.sync()
        else:
            status = libvirt.exec_virsh_edit(
                vm_name, [(r":/<devices>/s/$/%s" %
                           re.findall(r"<rng.*<\/rng>",
                                      str(rng_xml), re.M
                                      )[0].replace("/", "\/"))])
            if not status:
                raise error.TestFail("Failed to edit vm xml")

    def check_qemu_cmd(dparams):
        """
        Verify qemu-kvm command line.
        """
        rng_model = dparams.get("rng_model", "virtio")
        rng_rate = dparams.get("rng_rate")
        backend_type = dparams.get("backend_type")
        backend_source_list = dparams.get("backend_source",
                                          "").split()
        cmd = ("ps -ef | grep %s | grep -v grep" % vm_name)
        chardev = src_host = src_port = None
        if backend_type == "tcp":
            chardev = "socket"
        elif backend_type == "udp":
            chardev = "udp"
        for bc_source in backend_source_list:
            source = ast.literal_eval(bc_source)
            if "mode" in source and source['mode'] == "connect":
                src_host = source['host']
                src_port = source['service']

        if chardev and src_host and src_port:
            cmd += (" | grep 'chardev %s,.*host=%s,port=%s'"
                    % (chardev, src_host, src_port))
        if rng_model == "virtio":
            cmd += (" | grep 'device virtio-rng-pci'")
        if rng_rate:
            rate = ast.literal_eval(rng_rate)
            cmd += (" | grep 'max-bytes=%s,period=%s'"
                    % (rate['bytes'], rate['period']))
        if utils.run(cmd, ignore_status=True).exit_status:
            raise error.TestFail("Cann't see rng option"
                                 " in command line")

    def check_host():
        """
        Check random device on host
        """
        backend_dev = params.get("backend_dev")
        if backend_dev:
            cmd = "lsof %s" % backend_dev
            ret = utils.run(cmd, ignore_status=True)
            if ret.exit_status or not ret.stdout.count("qemu"):
                raise error.TestFail("Failed to check random device"
                                     " on host, command output: %s",
                                     ret.stdout)

    def check_snapshot(bgjob=None):
        """
        Do snapshot operation and check the results
        """
        snapshot_name1 = "snap.s1"
        snapshot_name2 = "snap.s2"
        if not snapshot_vm_running:
            vm.destroy(gracefully=False)
        ret = virsh.snapshot_create_as(vm_name, snapshot_name1)
        libvirt.check_exit_status(ret)
        snap_lists = virsh.snapshot_list(vm_name)
        if snapshot_name not in snap_lists:
            raise error.TestFail("Snapshot %s doesn't exist"
                                 % snapshot_name)

        if snapshot_vm_running:
            options = "--force"
        else:
            options = ""
        ret = virsh.snapshot_revert(
            vm_name, ("%s %s" % (snapshot_name, options)))
        libvirt.check_exit_status(ret)
        ret = virsh.dumpxml(vm_name)
        if ret.stdout.count("<rng model="):
            raise error.TestFail("Found rng device in xml")

        if snapshot_with_rng:
            if vm.is_alive():
                vm.destroy(gracefully=False)
            if bgjob:
                bgjob.kill_func()
            modify_rng_xml(params, False)

        # Start the domain before disk-only snapshot
        if vm.is_dead():
            # Add random server
            if params.get("backend_type") == "tcp":
                cmd = "cat /dev/random | nc -4 -l localhost 1024"
                bgjob = utils.AsyncJob(cmd)
            vm.start()
            vm.wait_for_login().close()
        err_msgs = ("live disk snapshot not supported"
                    " with this QEMU binary")
        ret = virsh.snapshot_create_as(vm_name,
                                       "%s --disk-only"
                                       % snapshot_name2)
        if ret.exit_status:
            if ret.stderr.count(err_msgs):
                raise error.TestNAError(err_msgs)
            else:
                raise error.TestFail("Failed to create external snapshot")
        snap_lists = virsh.snapshot_list(vm_name)
        if snapshot_name2 not in snap_lists:
            raise error.TestFail("Failed to check snapshot list")

        ret = virsh.domblklist(vm_name)
        if not ret.stdout.count(snapshot_name2):
            raise error.TestFail("Failed to find snapshot disk")

    def check_guest(session):
        """
        Check random device on guest
        """
        rng_files = (
            "/sys/devices/virtual/misc/hw_random/rng_available",
            "/sys/devices/virtual/misc/hw_random/rng_current")
        rng_avail = session.cmd_output("cat %s" % rng_files[0]).strip()
        rng_currt = session.cmd_output("cat %s" % rng_files[1]).strip()
        logging.debug("rng avail:%s, current:%s", rng_avail, rng_currt)
        if not rng_currt.count("virtio") or rng_currt not in rng_avail:
            raise error.TestFail("Failed to check rng file on guest")

        # Read the random device
        cmd = ("dd if=/dev/hwrng of=rng.test count=100"
               " && rm -f rng.test")
        ret, output = session.cmd_status_output(cmd, timeout=120)
        if ret:
            raise error.TestFail("Failed to read the random device")
        rng_rate = params.get("rng_rate")
        if rng_rate:
            rate_bytes, rate_period = ast.literal_eval(rng_rate).values()
            rate_conf = float(rate_bytes) / (float(rate_period)/1000)
            ret = re.search(r"(\d+) bytes.*copied, (\d+.\d+) s",
                            output, re.M)
            if not ret:
                raise error.TestFail("Can't find rate from output")
            rate_real = float(ret.group(1)) / float(ret.group(2))
            logging.debug("Find rate: %s, config rate: %s",
                          rate_real, rate_conf)
            if rate_real > rate_conf * 1.2:
                raise error.TestFail("The rate of reading exceed"
                                     " the limitation of configuration")
        if device_num > 1:
            rng_dev = rng_avail.split()
            if len(rng_dev) != device_num:
                raise error.TestNAError("Multiple virtio-rng devices are not"
                                        " supported on this guest kernel. "
                                        "Bug: https://bugzilla.redhat.com/"
                                        "show_bug.cgi?id=915335")
            session.cmd("echo -n %s > %s" % (rng_dev[1], rng_files[1]))
            # Read the random device
            if session.cmd_status(cmd, timeout=120):
                raise error.TestFail("Failed to read the random device")

    start_error = "yes" == params.get("start_error", "no")

    test_host = "yes" == params.get("test_host", "no")
    test_guest = "yes" == params.get("test_guest", "no")
    test_qemu_cmd = "yes" == params.get("test_qemu_cmd", "no")
    test_snapshot = "yes" == params.get("test_snapshot", "no")
    snapshot_vm_running = "yes" == params.get("snapshot_vm_running",
                                              "no")
    snapshot_with_rng = "yes" == params.get("snapshot_with_rng", "no")
    snapshot_name = params.get("snapshot_name")
    device_num = int(params.get("device_num", 1))

    if device_num > 1 and not libvirt_version.version_compare(1, 2, 7):
        raise error.TestNAError("Multiple virtio-rng devices not "
                                "supported on this libvirt version")
    # Back up xml file.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    # Try to install rng-tools on host, it can speed up random rate
    # if installation failed, ignore the error and continue the test
    if utils_misc.yum_install(["rng-tools"], timeout=300):
        rngd_conf = "/etc/sysconfig/rngd"
        rngd_srv = "/usr/lib/systemd/system/rngd.service"
        if os.path.exists(rngd_conf):
            # For rhel6 host, add extraoptions
            with open(rngd_conf, 'w') as f_rng:
                f_rng.write('EXTRAOPTIONS="--rng-device /dev/urandom"')
        elif os.path.exists(rngd_srv):
            # For rhel7 host, modify start options
            rngd_srv_conf = "/etc/systemd/system/rngd.service"
            if not os.path.exists(rngd_srv_conf):
                shutil.copy(rngd_srv, rngd_srv_conf)
            utils.run("sed -i -e 's#^ExecStart=.*#ExecStart=/sbin/rngd"
                      " -f -r /dev/urandom -o /dev/random#' %s"
                      % rngd_srv_conf)
            utils.run('systemctl daemon-reload')
        utils.run("service rngd start")

    # Build the xml and run test.
    try:
        bgjob = None
        # Take snapshot if needed
        if snapshot_name:
            if snapshot_vm_running:
                vm.start()
                vm.wait_for_login().close()
            ret = virsh.snapshot_create_as(vm_name, snapshot_name)
            libvirt.check_exit_status(ret)

        # Destroy VM first
        if vm.is_alive():
            vm.destroy(gracefully=False)

        # Build vm xml.
        dparams = {}
        if device_num > 1:
            for i in xrange(device_num):
                dparams[i] = {"rng_model": params.get(
                    "rng_model_%s" % i, "virtio")}
                dparams[i].update({"backend_model": params.get(
                    "backend_model_%s" % i, "random")})
                bk_type = params.get("backend_type_%s" % i)
                if bk_type:
                    dparams[i].update({"backend_type": bk_type})
                bk_dev = params.get("backend_dev_%s" % i)
                if bk_dev:
                    dparams[i].update({"backend_dev": bk_dev})
                bk_src = params.get("backend_source_%s" % i)
                if bk_src:
                    dparams[i].update({"backend_source": bk_src})
                bk_pro = params.get("backend_protocol_%s" % i)
                if bk_pro:
                    dparams[i].update({"backend_protocol": bk_pro})
                modify_rng_xml(dparams[i], False)
        else:
            modify_rng_xml(params, not test_snapshot)

        try:
            # Add random server
            if params.get("backend_type") == "tcp":
                cmd = "cat /dev/random | nc -4 -l localhost 1024"
                bgjob = utils.AsyncJob(cmd)

            # Start the VM.
            vm.start()
            if start_error:
                raise error.TestFail("VM started unexpectedly")

            if test_qemu_cmd:
                if device_num > 1:
                    for i in xrange(device_num):
                        check_qemu_cmd(dparams[i])
                else:
                    check_qemu_cmd(params)
            if test_host:
                check_host()
            session = vm.wait_for_login()
            if test_guest:
                check_guest(session)
            session.close()

            if test_snapshot:
                check_snapshot(bgjob)
        except virt_vm.VMStartError as details:
            logging.info(str(details))
            if not start_error:
                raise error.TestFail('VM failed to start, '
                                     'please refer to https://bugzilla.'
                                     'redhat.com/show_bug.cgi?id=1220252:'
                                     '\n%s' % details)

    finally:
        # Delete snapshots.
        snapshot_lists = virsh.snapshot_list(vm_name)
        if len(snapshot_lists) > 0:
            libvirt.clean_up_snapshots(vm_name, snapshot_lists)
            for snapshot in snapshot_lists:
                virsh.snapshot_delete(vm_name, snapshot, "--metadata")

        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        logging.info("Restoring vm...")
        vmxml_backup.sync()
        if bgjob:
            bgjob.kill_func()

Example 34

Project: Haystack
Source File: haystack_pipeline_CORE.py
View license
def main():

    print '\n[H A Y S T A C K   P I P E L I N E]'
    print('\n-SELECTION OF HOTSPOTS OF VARIABILITY AND ENRICHED MOTIFS- [Luca Pinello - [email protected]]\n')
    print 'Version %s\n' % HAYSTACK_VERSION
    
    #mandatory
    parser = argparse.ArgumentParser(description='HAYSTACK Parameters')
    parser.add_argument('samples_filename_or_bam_folder', type=str,  help='A tab delimeted file with in each row (1) a sample name, (2) the path to the corresponding bam filename, (3 optional) the path to the corresponding gene expression filaneme. Alternatively it is possible to specify a folder containing some .bam files to analyze.')
    parser.add_argument('genome_name', type=str,  help='Genome assembly to use from UCSC (for example hg19, mm9, etc.)')
    
    #optional
    parser.add_argument('--name',  help='Define a custom output filename for the report', default='')
    parser.add_argument('--output_directory',type=str, help='Output directory (default: current directory)',default='')
    parser.add_argument('--bin_size', type=int,help='bin size to use(default: 500bp)',default=500)
    parser.add_argument('--recompute_all',help='Ignore any file previously precalculated fot the command haystack_hotstpot',action='store_true')
    parser.add_argument('--depleted', help='Look for cell type specific regions with depletion of signal instead of enrichment',action='store_true')
    parser.add_argument('--input_is_bigwig', help='Use the bigwig format instead of the bam format for the input. Note: The files must have extension .bw',action='store_true')
    parser.add_argument('--disable_quantile_normalization',help='Disable quantile normalization (default: False)',action='store_true')
    parser.add_argument('--transformation',type=str,help='Variance stabilizing transformation among: none, log2, angle (default: angle)',default='angle',choices=['angle', 'log2', 'none'])
    parser.add_argument('--z_score_high', type=float,help='z-score value to select the specific regions(default: 1.5)',default=1.5)
    parser.add_argument('--z_score_low', type=float,help='z-score value to select the not specific regions(default: 0.25)',default=0.25)
    parser.add_argument('--th_rpm',type=float,help='Percentile on the signal intensity to consider for the hotspots (default: 99)', default=99)
    parser.add_argument('--meme_motifs_filename', type=str, help='Motifs database in MEME format (default JASPAR CORE 2016)')
    parser.add_argument('--motif_mapping_filename', type=str, help='Custom motif to gene mapping file (the default is for JASPAR CORE 2016 database)')
    parser.add_argument('--plot_all',  help='Disable the filter on the TF activity and correlation (default z-score TF>0 and rho>0.3)',action='store_true')
    parser.add_argument('--n_processes',type=int, help='Specify the number of processes to use. The default is #cores available.',default=multiprocessing.cpu_count())
    parser.add_argument('--temp_directory',  help='Directory to store temporary files  (default: /tmp)', default='/tmp')
    parser.add_argument('--version',help='Print version and exit.',action='version', version='Version %s' % HAYSTACK_VERSION)
    
    args = parser.parse_args()
    args_dict=vars(args)
    for key,value in args_dict.items():
            exec('%s=%s' %(key,repr(value)))
            
            
            
    if meme_motifs_filename:
        check_file(meme_motifs_filename)
    
    if motif_mapping_filename:
        check_file(motif_mapping_filename)
        
    if not os.path.exists(temp_directory):
        error('The folder specified with --temp_directory: %s does not exist!' % temp_directory)
        sys.exit(1)
    
    if input_is_bigwig:
            extension_to_check='.bw'
            info('Input is set BigWig (.bw)')
    else:
            extension_to_check='.bam'
            info('Input is set compressed SAM (.bam)')
    
    if name:
            directory_name='HAYSTACK_PIPELINE_RESULTS_on_%s' % name
    
    else:
            directory_name='HAYSTACK_PIPELINE_RESULTS'
    
    if output_directory:
            output_directory=os.path.join(output_directory,directory_name)
    else:
            output_directory=directory_name
    
    #check folder or sample filename
    
    USE_GENE_EXPRESSION=True
    
    if os.path.isfile(samples_filename_or_bam_folder):
            BAM_FOLDER=False
            bam_filenames=[]
            gene_expression_filenames=[]
            sample_names=[]
    
            with open(samples_filename_or_bam_folder) as infile:
                for line in infile:
    
                    if not line.strip():
                            continue
                    
                    if line.startswith('#'): #skip optional header line or empty lines
                            info('Skipping header/comment line:%s' % line)
                            continue
    
                    fields=line.strip().split()
                    n_fields=len(fields)
    
                    if n_fields==2:
    
                        USE_GENE_EXPRESSION=False
                        
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
    
                    elif n_fields==3:
    
                        USE_GENE_EXPRESSION=USE_GENE_EXPRESSION and True
    
                        sample_names.append(fields[0])
                        bam_filenames.append(fields[1])
                        gene_expression_filenames.append(fields[2])
                    else:
                        error('The samples file format is wrong!')
                        sys.exit(1)
            
    else:
            if os.path.exists(samples_filename_or_bam_folder):
                    BAM_FOLDER=True
                    USE_GENE_EXPRESSION=False
                    bam_filenames=glob.glob(os.path.join(samples_filename_or_bam_folder,'*'+extension_to_check))
    
                    if not bam_filenames:
                        error('No bam/bigwig  files to analyze in %s. Exiting.' % samples_filename_or_bam_folder)
                        sys.exit(1)
                    
                    sample_names=[os.path.basename(bam_filename).replace(extension_to_check,'') for bam_filename in bam_filenames]
            else:
                    error("The file or folder %s doesn't exist. Exiting." % samples_filename_or_bam_folder)
                    sys.exit(1)
    
    
    #check all the files before starting
    info('Checking samples files location...')
    for bam_filename in bam_filenames:
            check_file(bam_filename)
    
    if USE_GENE_EXPRESSION:
        for gene_expression_filename in gene_expression_filenames:
                check_file(gene_expression_filename)
    
    if not os.path.exists(output_directory):
            os.makedirs(output_directory)
    
    #copy back the file used
    if not BAM_FOLDER:
            shutil.copy2(samples_filename_or_bam_folder,output_directory)
    
    #write hotspots conf files
    sample_names_hotspots_filename=os.path.join(output_directory,'sample_names_hotspots.txt')
    
    with open(sample_names_hotspots_filename,'w+') as outfile:
        for sample_name,bam_filename in zip(sample_names,bam_filenames):
            outfile.write('%s\t%s\n' % (sample_name, bam_filename))
    
    #write tf activity  conf files
    if USE_GENE_EXPRESSION:
            sample_names_tf_activity_filename=os.path.join(output_directory,'sample_names_tf_activity.txt')
    
            with open(sample_names_tf_activity_filename,'w+') as outfile:
                    for sample_name,gene_expression_filename in zip(sample_names,gene_expression_filenames):
                            outfile.write('%s\t%s\n' % (sample_name, gene_expression_filename))
    
            tf_activity_directory=os.path.join(output_directory,'HAYSTACK_TFs_ACTIVITY_PLANES')
    
    
    #CALL HAYSTACK HOTSPOTS
    cmd_to_run='haystack_hotspots %s %s --output_directory %s --bin_size %d %s %s %s %s %s %s %s %s' % \
                (sample_names_hotspots_filename, genome_name,output_directory,bin_size,
                 ('--recompute_all' if recompute_all else ''),
                 ('--depleted' if depleted else ''),
                 ('--input_is_bigwig' if input_is_bigwig else ''),
                 ('--disable_quantile_normalization' if disable_quantile_normalization else ''),
                 '--transformation %s' % transformation,
                 '--z_score_high %f' % z_score_high,
                 '--z_score_low %f' % z_score_low,
                 '--th_rpm %f' % th_rpm)
    print cmd_to_run
    sb.call(cmd_to_run ,shell=True,env=system_env)        
    
    #CALL HAYSTACK MOTIFS
    motif_directory=os.path.join(output_directory,'HAYSTACK_MOTIFS')
    for sample_name in sample_names:
        specific_regions_filename=os.path.join(output_directory,'HAYSTACK_HOTSPOTS','SPECIFIC_REGIONS','Regions_specific_for_%s*.bed' %sample_name)
        bg_regions_filename=glob.glob(os.path.join(output_directory,'HAYSTACK_HOTSPOTS','SPECIFIC_REGIONS','Background_for_%s*.bed' %sample_name))[0]
        #bg_regions_filename=glob.glob(specific_regions_filename.replace('Regions_specific','Background')[:-11]+'*.bed')[0] #lo zscore e' diverso...
        #print specific_regions_filename,bg_regions_filename
        cmd_to_run='haystack_motifs %s %s --bed_bg_filename %s --output_directory %s --name %s' % (specific_regions_filename,genome_name, bg_regions_filename,motif_directory, sample_name)
        
        if meme_motifs_filename:
             cmd_to_run+=' --meme_motifs_filename %s' % meme_motifs_filename
             
             
        if n_processes:
            cmd_to_run+=' --n_processes %d' % n_processes
            
        if temp_directory:
            cmd_to_run+=' --temp_directory %s' % temp_directory
            
            
        
        print cmd_to_run
        sb.call(cmd_to_run,shell=True,env=system_env)
    
        if USE_GENE_EXPRESSION:
                #CALL HAYSTACK TF ACTIVITY 
                motifs_output_folder=os.path.join(motif_directory,'HAYSTACK_MOTIFS_on_%s' % sample_name) 
                if os.path.exists(motifs_output_folder):
                    cmd_to_run='haystack_tf_activity_plane %s %s %s --output_directory %s'  %(motifs_output_folder,sample_names_tf_activity_filename,sample_name,tf_activity_directory)
                    
                    if motif_mapping_filename:
                        cmd_to_run+=' --motif_mapping_filename %s' %  motif_mapping_filename       
                    
                    if plot_all:
                        cmd_to_run+=' --plot_all'
                        
                    
                    print cmd_to_run
                    sb.call(cmd_to_run,shell=True,env=system_env) 

Example 35

View license
def prep_group_analysis_workflow(model_df, pipeline_config_obj, \
    model_name, group_config_obj, resource_id, preproc_strat, \
    series_or_repeated_label):
    
    #
    # this function runs once per derivative type and preproc strat combo
    # during group analysis
    #

    import os

    import nipype.pipeline.engine as pe
    import nipype.interfaces.utility as util
    import nipype.interfaces.io as nio

    pipeline_ID = pipeline_config_obj.pipeline_name

    # get thresholds
    z_threshold = float(group_config_obj.z_threshold[0])

    p_threshold = float(group_config_obj.p_threshold[0])

    sub_id_label = group_config_obj.subject_id_label

    # determine if f-tests are included or not
    custom_confile = group_config_obj.custom_contrasts

    if ((custom_confile == None) or (custom_confile == '') or \
            ("None" in custom_confile) or ("none" in custom_confile)):

        if (len(group_config_obj.f_tests) == 0) or \
            (group_config_obj.f_tests == None):
            fTest = False
        else:
            fTest = True

    else:

        if not os.path.exists(custom_confile):
            errmsg = "\n[!] CPAC says: You've specified a custom contrasts " \
                     ".CSV file for your group model, but this file cannot " \
                     "be found. Please double-check the filepath you have " \
                     "entered.\n\nFilepath: %s\n\n" % custom_confile
            raise Exception(errmsg)

        with open(custom_confile,"r") as f:
            evs = f.readline()

        evs = evs.rstrip('\r\n').split(',')
        count_ftests = 0

        fTest = False

        for ev in evs:
            if "f_test" in ev:
                count_ftests += 1

        if count_ftests > 0:
            fTest = True


    # create path for output directory
    out_dir = os.path.join(group_config_obj.output_dir, \
        "group_analysis_results_%s" % pipeline_ID, \
        "group_model_%s" % model_name, resource_id, \
        series_or_repeated_label, preproc_strat)

    model_path = os.path.join(out_dir, 'model_files')

    # generate working directory for this output's group analysis run
    work_dir = os.path.join(c.workingDirectory, "group_analysis", model_name,\
        resource_id, series_or_repeated_label, preproc_strat)

    log_dir = os.path.join(out_dir, 'logs', resource_id, \
        'model_%s' % model_name)

    # create the actual directories
    if not os.path.isdir(model_path):
        try:
            os.makedirs(model_path)
        except Exception as e:
            err = "\n\n[!] Could not create the group analysis output " \
                  "directories.\n\nAttempted directory creation: %s\n\n" \
                  "Error details: %s\n\n" % (model_path, e)
            raise Exception(err)

    if not os.path.isdir(work_dir):
        try:
            os.makedirs(work_dir)
        except Exception as e:
            err = "\n\n[!] Could not create the group analysis working " \
                  "directories.\n\nAttempted directory creation: %s\n\n" \
                  "Error details: %s\n\n" % (model_path, e)
            raise Exception(err)

    if not os.path.isdir(log_dir):
        try:
            os.makedirs(log_dir)
        except Exception as e:
            err = "\n\n[!] Could not create the group analysis logfile " \
                  "directories.\n\nAttempted directory creation: %s\n\n" \
                  "Error details: %s\n\n" % (model_path, e)
            raise Exception(err)


    # create new subject list based on which subjects are left after checking
    # for missing outputs
    new_participant_list = []
    for part in list(model_df["Participant"]):
        # do this instead of using "set" just in case, to preserve order
        #   only reason there may be duplicates is because of multiple-series
        #   repeated measures runs
        if part not in new_participant_list:
            new_participant_list.append(part)

    new_sub_file = write_new_sub_file(model_path, \
                                      group_config_obj.participant_list, \
                                      new_participant_list)

    group_conf.update('participant_list',new_sub_file)


    # start processing the dataframe further
    design_formula = group_config_obj.design_formula

    # demean the motion params
    if ("MeanFD" in design_formula) or ("MeanDVARS" in design_formula):
        params = ["MeanFD_Power", "MeanFD_Jenkinson", "MeanDVARS"]
        for param in params:
            model_df[param] = model_df[param].astype(float)
            model_df[param] = model_df[param].sub(model_df[param].mean())


    # create 4D merged copefile, in the correct order, identical to design
    # matrix
    merge_outfile = model_name + "_" + resource_id + "_merged.nii.gz"
    merge_outfile = os.path.join(model_path, merge_outfile)

    merge_file = create_merged_copefile(list(model_df["Filepath"]), \
                                        merge_outfile)

    # create merged group mask
    if group_config_obj.mean_mask[0] == "Group Mask":
        merge_mask_outfile = os.path.basename(merge_file) + "_mask.nii.gz"
        merge_mask = create_merged_mask(merge_file, merge_mask_outfile)

    # calculate measure means, and demean
    if "Measure_Mean" in design_formula:
        model_df = calculate_measure_mean_in_df(model_df, merge_mask)

    # calculate custom ROIs, and demean (in workflow?)
    if "Custom_ROI_Mean" in design_formula:

        custom_roi_mask = group_config_obj.custom_roi_mask

        if (custom_roi_mask == None) or (custom_roi_mask == "None") or \
            (custom_roi_mask == "none") or (custom_roi_mask == ""):
            err = "\n\n[!] You included 'Custom_ROI_Mean' in your design " \
                  "formula, but you didn't supply a custom ROI mask file." \
                  "\n\nDesign formula: %s\n\n" % design_formula
            raise Exception(err)

        # make sure the custom ROI mask file is the same resolution as the
        # output files - if not, resample and warn the user
        roi_mask = check_mask_file_resolution(list(model_df["Raw_Filepath"])[0], \
                                              custom_roi_mask, model_path, \
                                              resource_id)

        # if using group merged mask, trim the custom ROI mask to be within
        # its constraints
        if merge_mask:
            output_mask = os.path.join(model_path, "group_masked_%s" \
                                       % os.path.basename(input_mask))
            roi_mask = trim_mask(roi_mask, merge_mask, output_mask)

        # calculate
        model_df = calculate_custom_roi_mean_in_df(model_df, roi_mask)   

    


    # modeling group variances separately

    # add repeated measures 1's matrices

    # patsify model DF, drop columns not in design formula

    # process contrasts


        
    wf = pe.Workflow(name=resource_id)

    wf.base_dir = work_dir
    crash_dir = os.path.join(pipeline_config_obj.crashLogDirectory, \
                             "group_analysis", model_name)

    wf.config['execution'] = {'hash_method': 'timestamp', \
                              'crashdump_dir': crash_dir}








    if "Measure_Mean" in design_formula:
        measure_mean = pe.Node(util.Function(input_names=['model_df',
                                                          'merge_mask'],
                                       output_names=['model_df'],
                                       function=calculate_measure_mean_in_df),
                                       name='measure_mean')
        measure_mean.inputs.model_df = model_df

        wf.connect(merge_mask, "out_file", measure_mean, "merge_mask")


    if "Custom_ROI_Mean" in design_formula:
        roi_mean = pe.Node(util.Function())


    group_config_obj.custom_roi_mask
    






    #----------------

    import yaml
    import pandas as pd


    # load group analysis model configuration file
    try:
        with open(os.path.realpath(group_config_file),"r") as f:
            group_conf = Configuration(yaml.load(f))
    except Exception as e:
        err_string = "\n\n[!] CPAC says: Could not read group model " \
                     "configuration YML file. Ensure you have read access " \
                     "for the file and that it is formatted properly.\n\n" \
                     "Configuration file: %s\n\nError details: %s" \
                     % (group_config_file, e)
        raise Exception(err_string)


    # gather all of the information
    # - lists of all the participant unique IDs (participant_site_session) and
    # of all of the series IDs present in output_file_list
    # - also returns the pipeline ID
    new_participant_list, all_series_names, pipeline_ID = \
        gather_new_participant_list(output_path_file, output_file_list)

     

      

    # create the path string for the group analysis output
    #    replicate the directory path of one of the participant's output
    #    folder path to the derivative's file, but replace the participant ID
    #    with the group model name
    #        this is to ensure nothing gets overwritten between strategies
    #        or thresholds, etc.
    out_dir = os.path.dirname(output_file_list[0]).split(pipeline_ID + '/')
    out_dir = out_dir[1].split(out_dir[1].split("/")[-1])[0]
    out_dir = os.path.join(group_conf.output_dir, out_dir)
    out_dir = out_dir.replace(new_participant_list[0], \
                  'group_analysis_results_%s/_grp_model_%s' \
                  % (pipeline_ID, group_conf.model_name))

    # !!!!!!!!!!
    if (group_conf.repeated_measures == True) and (series_ids[0] != None):
        out_dir = out_dir.replace(series_ids[0] + "/", "multiple_series")

    # create model file output directories
    model_out_dir = os.path.join(group_conf.output_dir, \
        'group_analysis_results_%s/_grp_model_%s' \
        %(pipeline_ID, group_conf.model_name))

    mod_path = os.path.join(model_out_dir, 'model_files')

    if not os.path.isdir(mod_path):
        os.makedirs(mod_path)

    # current_mod_path = folder under
    #   "/gpa_output/_grp_model_{model name}/model_files/{current derivative}"
    current_mod_path = os.path.join(mod_path, resource)

    if not os.path.isdir(current_mod_path):
        os.makedirs(current_mod_path)

        
    # create new subject list based on which subjects are left after checking
    # for missing outputs
    new_sub_file = write_new_sub_file(current_mod_path, \
                       group_conf.subject_list, new_participant_list)

    group_conf.update('subject_list',new_sub_file)


    # create new design matrix with only the subjects that are left






    # Run 'create_fsl_model' script to extract phenotypic data from
    # the phenotypic file for each of the subjects in the subject list

    # get the motion statistics parameter file, if present
    # get the parameter file so it can be passed to create_fsl_model.py
    # so MeanFD or other measures can be included in the design matrix


    ''' okay, here we go... how are we handling series? because here it needs to take in '''
    ''' the appropriate series to get the appropriate parameter file ! ! ! '''

    ''' MAY HAVE TO GO BACK ON THIS, and just have one series sent in per this function...'''

    power_params_files = {}

    measure_list = ['MeanFD_Power', 'MeanFD_Jenkinson', 'MeanDVARS']

    for measure in measure_list:
    
        if measure in group_conf.design_formula:

            for series_id in all_series_names:

                parameter_file = os.path.join(c.outputDirectory, \
                                              pipeline_ID, \
                                              '%s%s_all_params.csv' % \
                                              (series_id.strip('_'), \
                                              threshold_val))

                if not os.path.exists(parameter_file):
                    err = "\n\n[!] CPAC says: Could not find or open the motion "\
                          "parameter file. This is necessary if you have " \
                          "included any of the MeanFD measures in your group " \
                          "model.\n\nThis file can usually be found in the " \
                          "output directory of your individual-level analysis " \
                          "runs. If it is not there, double-check to see if " \
                          "individual-level analysis had completed successfully."\
                          "\n\nPath not found: %s\n\n" % parameter_file
                    raise Exception(err)


                power_params_files[series_id] = parameter_file
                

            break
            
    else:
    
        power_params_files = None



    # path to the pipeline folder to be passed to create_fsl_model.py
    # so that certain files like output_means.csv can be accessed
    pipeline_path = os.path.join(c.outputDirectory, pipeline_ID)

    # generate working directory for this output's group analysis run
    workDir = '%s/group_analysis/%s/%s' % (c.workingDirectory, \
                                               group_conf.model_name, \
                                               resource)
            
    # this makes strgy_path basically the directory path of the folders after
    # the resource/derivative folder level         
    strgy_path = os.path.dirname(output_file_list[0]).split(resource)[1]

    # get rid of periods in the path
    for ch in ['.']:
        if ch in strgy_path:
            strgy_path = strgy_path.replace(ch, "")
                
    # create nipype-workflow-name-friendly strgy_path
    # (remove special characters)
    strgy_path_name = strgy_path.replace('/', "_")

    workDir = workDir + '/' + strgy_path_name



    # merge the subjects for this current output
    # then, take the group mask, and iterate over the list of subjects
    # to extract the mean of each subject using the group mask
    merge_output, merge_mask_output, merge_output_dir = \
        create_merged_files(workDir, resource, output_file_list)

    
    # CALCULATE THE MEANS of each output using the group mask
    derivative_means_dict, roi_means_dict = \
        calculate_output_means(resource, output_file_list, \
                               group_conf.mean_mask, \
                               group_conf.design_formula, \
                               group_conf.custom_roi_mask, pipeline_path, \
                               merge_output_dir, c.identityMatrix)


    measure_dict = {}

    # extract motion measures from CPAC-generated power params file
    if power_params_files != None:
        for param_file in power_params_files.values():
            new_measure_dict = get_measure_dict(param_file)
            measure_dict.update(new_measure_dict)


    # combine the motion measures dictionary with the measure_mean
    # dictionary (if it exists)
    if derivative_means_dict:
        measure_dict["Measure_Mean"] = derivative_means_dict

    # run create_fsl_model.py to generate the group analysis models
    
    from CPAC.utils import create_fsl_model, kill_me
    create_fsl_model.run(group_conf, resource, parameter_file, \
                             derivative_means_dict, roi_means_dict, \
                                 current_mod_path, True)


    # begin GA workflow setup

    if not os.path.exists(new_sub_file):
        raise Exception("path to input subject list %s is invalid" % new_sub_file)
        
    #if c.mixedScanAnalysis == True:
    #    wf = pe.Workflow(name = 'group_analysis/%s/grp_model_%s'%(resource, os.path.basename(model)))
    #else:

    wf = pe.Workflow(name = resource)

    wf.base_dir = workDir
    wf.config['execution'] = {'hash_method': 'timestamp', 'crashdump_dir': os.path.abspath(c.crashLogDirectory)}
    log_dir = os.path.join(group_conf.output_dir, 'logs', 'group_analysis', resource, 'model_%s' % (group_conf.model_name))
        

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    else:
        pass


    # gp_flow
    # Extracts the model files (.con, .grp, .mat, .fts) from the model
    # directory and sends them to the create_group_analysis workflow gpa_wf

    gp_flow = create_grp_analysis_dataflow("gp_dataflow_%s" % resource)
    gp_flow.inputs.inputspec.grp_model = os.path.join(mod_path, resource)
    gp_flow.inputs.inputspec.model_name = group_conf.model_name
    gp_flow.inputs.inputspec.ftest = fTest
  

    # gpa_wf
    # Creates the actual group analysis workflow

    gpa_wf = create_group_analysis(fTest, "gp_analysis_%s" % resource)

    gpa_wf.inputs.inputspec.merged_file = merge_output
    gpa_wf.inputs.inputspec.merge_mask = merge_mask_output

    gpa_wf.inputs.inputspec.z_threshold = z_threshold
    gpa_wf.inputs.inputspec.p_threshold = p_threshold
    gpa_wf.inputs.inputspec.parameters = (c.FSLDIR, 'MNI152')
    
   
    wf.connect(gp_flow, 'outputspec.mat',
               gpa_wf, 'inputspec.mat_file')
    wf.connect(gp_flow, 'outputspec.con',
               gpa_wf, 'inputspec.con_file')
    wf.connect(gp_flow, 'outputspec.grp',
                gpa_wf, 'inputspec.grp_file')
           
    if fTest:
        wf.connect(gp_flow, 'outputspec.fts',
                   gpa_wf, 'inputspec.fts_file')
        

    # ds
    # Creates the datasink node for group analysis
       
    ds = pe.Node(nio.DataSink(), name='gpa_sink')
     
    if 'sca_roi' in resource:
        out_dir = os.path.join(out_dir, \
            re.search('sca_roi_(\d)+',os.path.splitext(os.path.splitext(os.path.basename(output_file_list[0]))[0])[0]).group(0))
            
            
    if 'dr_tempreg_maps_zstat_files_to_standard_smooth' in resource:
        out_dir = os.path.join(out_dir, \
            re.search('temp_reg_map_z_(\d)+',os.path.splitext(os.path.splitext(os.path.basename(output_file_list[0]))[0])[0]).group(0))
            
            
    if 'centrality' in resource:
        names = ['degree_centrality_binarize', 'degree_centrality_weighted', \
                 'eigenvector_centrality_binarize', 'eigenvector_centrality_weighted', \
                 'lfcd_binarize', 'lfcd_weighted']

        for name in names:
            if name in os.path.basename(output_file_list[0]):
                out_dir = os.path.join(out_dir, name)
                break

    if 'tempreg_maps' in resource:
        out_dir = os.path.join(out_dir, \
            re.search('\w*[#]*\d+', os.path.splitext(os.path.splitext(os.path.basename(output_file_list[0]))[0])[0]).group(0))
        
#     if c.mixedScanAnalysis == True:
#         out_dir = re.sub(r'(\w)*scan_(\w)*(\d)*(\w)*[/]', '', out_dir)
              
    ds.inputs.base_directory = out_dir
    ds.inputs.container = ''
        
    ds.inputs.regexp_substitutions = [(r'(?<=rendered)(.)*[/]','/'),
                                      (r'(?<=model_files)(.)*[/]','/'),
                                      (r'(?<=merged)(.)*[/]','/'),
                                      (r'(?<=stats/clusterMap)(.)*[/]','/'),
                                      (r'(?<=stats/unthreshold)(.)*[/]','/'),
                                      (r'(?<=stats/threshold)(.)*[/]','/'),
                                      (r'_cluster(.)*[/]',''),
                                      (r'_slicer(.)*[/]',''),
                                      (r'_overlay(.)*[/]','')]
   

    ########datasink connections#########
    if fTest:
        wf.connect(gp_flow, 'outputspec.fts',
                   ds, '[email protected]') 
        
    wf.connect(gp_flow, 'outputspec.mat',
               ds, '[email protected]' )
    wf.connect(gp_flow, 'outputspec.con',
               ds, '[email protected]')
    wf.connect(gp_flow, 'outputspec.grp',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.merged',
               ds, 'merged')
    wf.connect(gpa_wf, 'outputspec.zstats',
               ds, 'stats.unthreshold')
    wf.connect(gpa_wf, 'outputspec.zfstats',
               ds,'[email protected]')
    wf.connect(gpa_wf, 'outputspec.fstats',
               ds,'[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_threshold_zf',
               ds, 'stats.threshold')
    wf.connect(gpa_wf, 'outputspec.cluster_index_zf',
               ds,'stats.clusterMap')
    wf.connect(gpa_wf, 'outputspec.cluster_localmax_txt_zf',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.overlay_threshold_zf',
               ds, 'rendered')
    wf.connect(gpa_wf, 'outputspec.rendered_image_zf',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_threshold',
               ds,  '[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_index',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_localmax_txt',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.overlay_threshold',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.rendered_image',
               ds, '[email protected]')
       
    ######################################

    # Run the actual group analysis workflow
    wf.run()

    
    print "\n\nWorkflow finished for model %s and resource %s\n\n" \
          % (os.path.basename(group_conf.output_dir), resource)

Example 36

Project: alignak
Source File: test_launch_daemons.py
View license
    def _run_daemons_and_test_api(self, ssl=False):
        """ Running all the Alignak daemons to check their correct launch and API

        :return:
        """
        req = requests.Session()

        # copy etc config files in test/cfg/run_test_launch_daemons and change folder
        # in the files for pid and log files
        if os.path.exists('./cfg/run_test_launch_daemons'):
            shutil.rmtree('./cfg/run_test_launch_daemons')

        shutil.copytree('../etc', './cfg/run_test_launch_daemons')
        files = ['cfg/run_test_launch_daemons/daemons/arbiterd.ini',
                 'cfg/run_test_launch_daemons/daemons/brokerd.ini',
                 'cfg/run_test_launch_daemons/daemons/pollerd.ini',
                 'cfg/run_test_launch_daemons/daemons/reactionnerd.ini',
                 'cfg/run_test_launch_daemons/daemons/receiverd.ini',
                 'cfg/run_test_launch_daemons/daemons/schedulerd.ini',
                 'cfg/run_test_launch_daemons/alignak.cfg',
                 'cfg/run_test_launch_daemons/arbiter/daemons/arbiter-master.cfg',
                 'cfg/run_test_launch_daemons/arbiter/daemons/broker-master.cfg',
                 'cfg/run_test_launch_daemons/arbiter/daemons/poller-master.cfg',
                 'cfg/run_test_launch_daemons/arbiter/daemons/reactionner-master.cfg',
                 'cfg/run_test_launch_daemons/arbiter/daemons/receiver-master.cfg',
                 'cfg/run_test_launch_daemons/arbiter/daemons/scheduler-master.cfg']
        replacements = {
            '/usr/local/var/run/alignak': '/tmp',
            '/usr/local/var/log/alignak': '/tmp',
            '%(workdir)s': '/tmp',
            '%(logdir)s': '/tmp',
            '%(etcdir)s': '/tmp'
        }
        if ssl:
            shutil.copy('./cfg/ssl/server.csr', '/tmp/')
            shutil.copy('./cfg/ssl/server.key', '/tmp/')
            shutil.copy('./cfg/ssl/server.pem', '/tmp/')
            # Set daemons configuration to use SSL
            print replacements
            replacements.update({
                'use_ssl=0': 'use_ssl=1',
                '#server_cert=': 'server_cert=',
                '#server_key=': 'server_key=',
                '#server_dh=': 'server_dh=',
                '#hard_ssl_name_check=0': 'hard_ssl_name_check=0',
                'certs/': '',
                'use_ssl	                0': 'use_ssl	                1'
            })
        for filename in files:
            lines = []
            with open(filename) as infile:
                for line in infile:
                    for src, target in replacements.iteritems():
                        line = line.replace(src, target)
                    lines.append(line)
            with open(filename, 'w') as outfile:
                for line in lines:
                    outfile.write(line)

        self.procs = {}
        satellite_map = {
            'arbiter': '7770', 'scheduler': '7768', 'broker': '7772',
            'poller': '7771', 'reactionner': '7769', 'receiver': '7773'
        }

        print("Cleaning pid and log files...")
        for daemon in ['arbiter', 'scheduler', 'broker', 'poller', 'reactionner', 'receiver']:
            if os.path.exists('/tmp/%sd.pid' % daemon):
                os.remove('/tmp/%sd.pid' % daemon)
                print("- removed /tmp/%sd.pid" % daemon)
            if os.path.exists('/tmp/%sd.log' % daemon):
                os.remove('/tmp/%sd.log' % daemon)
                print("- removed /tmp/%sd.log" % daemon)

        print("Launching the daemons...")
        for daemon in ['scheduler', 'broker', 'poller', 'reactionner', 'receiver']:
            args = ["../alignak/bin/alignak_%s.py" %daemon,
                    "-c", "./cfg/run_test_launch_daemons/daemons/%sd.ini" % daemon]
            self.procs[daemon] = \
                subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            sleep(1)
            print("- %s launched (pid=%d)" % (daemon, self.procs[daemon].pid))

        sleep(1)

        print("Testing daemons start")
        for name, proc in self.procs.items():
            ret = proc.poll()
            if ret is not None:
                print("*** %s exited on start!" % (name))
                for line in iter(proc.stdout.readline, b''):
                    print(">>> " + line.rstrip())
                for line in iter(proc.stderr.readline, b''):
                    print(">>> " + line.rstrip())
            self.assertIsNone(ret, "Daemon %s not started!" % name)
            print("%s running (pid=%d)" % (name, self.procs[daemon].pid))

        # Let the daemons start ...
        sleep(5)

        print("Testing pid files and log files...")
        for daemon in ['scheduler', 'broker', 'poller', 'reactionner', 'receiver']:
            self.assertTrue(os.path.exists('/tmp/%sd.pid' % daemon), '/tmp/%sd.pid does not exist!' % daemon)
            self.assertTrue(os.path.exists('/tmp/%sd.log' % daemon), '/tmp/%sd.log does not exist!' % daemon)

        sleep(1)

        print("Launching arbiter...")
        args = ["../alignak/bin/alignak_arbiter.py",
                "-c", "cfg/run_test_launch_daemons/daemons/arbiterd.ini",
                "-a", "cfg/run_test_launch_daemons/alignak.cfg"]
        self.procs['arbiter'] = \
            subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("%s launched (pid=%d)" % ('arbiter', self.procs['arbiter'].pid))

        sleep(5)

        name = 'arbiter'
        print("Testing Arbiter start %s" % name)
        ret = self.procs[name].poll()
        if ret is not None:
            print("*** %s exited on start!" % (name))
            for line in iter(self.procs[name].stdout.readline, b''):
                print(">>> " + line.rstrip())
            for line in iter(self.procs[name].stderr.readline, b''):
                print(">>> " + line.rstrip())
        self.assertIsNone(ret, "Daemon %s not started!" % name)
        print("%s running (pid=%d)" % (name, self.procs[name].pid))

        sleep(1)

        print("Testing pid files and log files...")
        for daemon in ['arbiter']:
            self.assertTrue(os.path.exists('/tmp/%sd.pid' % daemon), '/tmp/%sd.pid does not exist!' % daemon)
            self.assertTrue(os.path.exists('/tmp/%sd.log' % daemon), '/tmp/%sd.log does not exist!' % daemon)

        # Let the arbiter build and dispatch its configuration
        sleep(5)

        http = 'http'
        if ssl:
            http = 'https'

        print("Testing ping")
        for name, port in satellite_map.items():
            raw_data = req.get("%s://localhost:%s/ping" % (http, port), verify=False)
            data = raw_data.json()
            self.assertEqual(data, 'pong', "Daemon %s  did not ping back!" % name)

        print("Testing ping with satellite SSL and client not SSL")
        if ssl:
            for name, port in satellite_map.items():
                raw_data = req.get("http://localhost:%s/ping" % port)
                self.assertEqual('The client sent a plain HTTP request, but this server only speaks HTTPS on this port.', raw_data.text)

        print("Testing get_satellite_list")
        raw_data = req.get("%s://localhost:%s/get_satellite_list" % (http,
                                                                     satellite_map['arbiter']), verify=False)
        expected_data ={"reactionner": ["reactionner-master"],
                        "broker": ["broker-master"],
                        "arbiter": ["arbiter-master"],
                        "scheduler": ["scheduler-master"],
                        "receiver": ["receiver-master"],
                        "poller": ["poller-master"]}
        data = raw_data.json()
        self.assertIsInstance(data, dict, "Data is not a dict!")
        for k, v in expected_data.iteritems():
            self.assertEqual(set(data[k]), set(v))

        print("Testing have_conf")
        for daemon in ['scheduler', 'broker', 'poller', 'reactionner', 'receiver']:
            raw_data = req.get("%s://localhost:%s/have_conf" % (http, satellite_map[daemon]), verify=False)
            data = raw_data.json()
            self.assertTrue(data, "Daemon %s has no conf!" % daemon)
            # TODO: test with magic_hash

        print("Testing api")
        name_to_interface = {'arbiter': ArbiterInterface,
                             'scheduler': SchedulerInterface,
                             'broker': BrokerInterface,
                             'poller': GenericInterface,
                             'reactionner': GenericInterface,
                             'receiver': ReceiverInterface}
        for name, port in satellite_map.items():
            raw_data = req.get("%s://localhost:%s/api" % (http, port), verify=False)
            data = raw_data.json()
            expected_data = set(name_to_interface[name](None).api())
            self.assertIsInstance(data, list, "Data is not a list!")
            self.assertEqual(set(data), expected_data, "Daemon %s has a bad API!" % name)

        print("Testing get_checks on scheduler")
        # TODO: if have poller running, the poller will get the checks before us
        #
        # We need to sleep 10s to be sure the first check can be launched now (check_interval = 5)
        # sleep(4)
        # raw_data = req.get("http://localhost:%s/get_checks" % satellite_map['scheduler'], params={'do_checks': True})
        # data = unserialize(raw_data.json(), True)
        # self.assertIsInstance(data, list, "Data is not a list!")
        # self.assertNotEqual(len(data), 0, "List is empty!")
        # for elem in data:
        #     self.assertIsInstance(elem, Check, "One elem of the list is not a Check!")

        print("Testing get_raw_stats")
        for name, port in satellite_map.items():
            raw_data = req.get("%s://localhost:%s/get_raw_stats" % (http, port), verify=False)
            data = raw_data.json()
            if name == 'broker':
                self.assertIsInstance(data, list, "Data is not a list!")
            else:
                self.assertIsInstance(data, dict, "Data is not a dict!")

        print("Testing what_i_managed")
        for name, port in satellite_map.items():
            raw_data = req.get("%s://localhost:%s/what_i_managed" % (http, port), verify=False)
            data = raw_data.json()
            self.assertIsInstance(data, dict, "Data is not a dict!")
            if name != 'arbiter':
                self.assertEqual(1, len(data), "The dict must have 1 key/value!")

        print("Testing get_external_commands")
        for name, port in satellite_map.items():
            raw_data = req.get("%s://localhost:%s/get_external_commands" % (http, port), verify=False)
            data = raw_data.json()
            self.assertIsInstance(data, list, "Data is not a list!")

        print("Testing get_log_level")
        for name, port in satellite_map.items():
            raw_data = req.get("%s://localhost:%s/get_log_level" % (http, port), verify=False)
            data = raw_data.json()
            self.assertIsInstance(data, unicode, "Data is not an unicode!")
            # TODO: seems level get not same tham defined in *d.ini files

        print("Testing get_all_states")
        raw_data = req.get("%s://localhost:%s/get_all_states" % (http, satellite_map['arbiter']), verify=False)
        data = raw_data.json()
        self.assertIsInstance(data, dict, "Data is not a dict!")
        for daemon_type in data:
            daemons = data[daemon_type]
            print("Got Alignak state for: %ss / %d instances" % (daemon_type, len(daemons)))
            for daemon in daemons:
                print(" - %s: %s", daemon['%s_name' % daemon_type], daemon['alive'])
                self.assertTrue(daemon['alive'])
                self.assertFalse('realm' in daemon)
                self.assertTrue('realm_name' in daemon)

        print("Testing get_running_id")
        for name, port in satellite_map.items():
            raw_data = req.get("%s://localhost:%s/get_running_id" % (http, port), verify=False)
            data = raw_data.json()
            self.assertIsInstance(data, unicode, "Data is not an unicode!")

        print("Testing fill_initial_broks")
        raw_data = req.get("%s://localhost:%s/fill_initial_broks" % (http, satellite_map['scheduler']), params={'bname': 'broker-master'}, verify=False)
        data = raw_data.json()
        self.assertIsNone(data, "Data must be None!")

        print("Testing get_broks")
        for name in ['scheduler', 'poller']:
            raw_data = req.get("%s://localhost:%s/get_broks" % (http, satellite_map[name]),
                               params={'bname': 'broker-master'}, verify=False)
            data = raw_data.json()
            self.assertIsInstance(data, dict, "Data is not a dict!")

        print("Testing get_returns")
        # get_return requested by scheduler to poller daemons
        for name in ['reactionner', 'receiver', 'poller']:
            raw_data = req.get("%s://localhost:%s/get_returns" % (http, satellite_map[name]), params={'sched_id': 0}, verify=False)
            data = raw_data.json()
            self.assertIsInstance(data, list, "Data is not a list!")

        print("Testing signals")
        for name, proc in self.procs.items():
            # SIGUSR1: memory dump
            self.procs[name].send_signal(signal.SIGUSR1)
            time.sleep(0.5)
            # SIGUSR2: objects dump
            self.procs[name].send_signal(signal.SIGUSR2)
            # SIGHUP: reload configuration
            self.procs[name].send_signal(signal.SIGUSR2)

            # Other signals is considered as a request to stop...

        for name, proc in self.procs.items():
            print("Asking %s to end..." % name)
            os.kill(self.procs[name].pid, signal.SIGTERM)

        time.sleep(1)

        for name, proc in self.procs.items():
            data = self._get_subproc_data(name)
            print("%s stdout:" % (name))
            for line in iter(proc.stdout.readline, b''):
                print(">>> " + line.rstrip())
            print("%s stderr:" % (name))
            for line in iter(proc.stderr.readline, b''):
                print(">>> " + line.rstrip())

        print("Done testing")

Example 37

View license
def main():


	if not os.path.exists('data/loc_single'):
   		os.makedirs('data/loc_single')

 
	dbData = open('data/jazzPeople.nt', 'r')
	personNames = {}
	personBirthDates = {}
	personDeathDates = {}
	nameCollisons = {}
	matchesBothDate = []
	matchesBothDateURIs = []	
	matchesSingleDate = []
	
	foundCheckList = []
	possibleLOC={}
	allLOC = {}
	
	for line in dbData:
		
		quad = line.split()
		if quad[1] == '<http://xmlns.com/foaf/0.1/name>':
			name = ''
			name = " ".join(quad[2:])
			name = name[1:name[1:].find('@en')]	
			
			if len(name) < 5:
				print name, line
			
			name = name.replace('\\','')
			
			if personNames.has_key(name) == False:
			
				personNames[name] = quad[0]
				
			else:
			
				if personNames[name] != quad[0]:
					if nameCollisons.has_key(name):
					
						if quad[0] not in nameCollisons[name]:
							nameCollisons[name].append(quad[0])
					else:
						nameCollisons[name]=[quad[0]]			
						
					print "1Name collision", name, line		
					print personNames[name], nameCollisons[name]
				
			addNames = []
			
			if name.find('"') != -1:
				 
				print name
				name = name.split('"')[0].strip() + ' ' + name.split('"')[2].strip()
				addNames.append(name)
			
			#we also want to pull their name from the URL, because that is often the most common variant of their name
			uri = quad[0]
			
			name = formatName(quad[0].split('/resource/')[len(quad[0].split('/resource/'))-1])			
			
			name = name.replace('\\','')
			
			addNames.append(name)
			#print name
			
			
			

			
			#remove any nick name and add that as well
			if name.find('"') != -1:

				
							
				print name
				 
				name = name.split('"')[0].strip() + ' ' + name.split('"')[2].strip()
				addNames.append(name)
			
	
			
			
			for aName in addNames:
				#is this name already in the lookup:
				print aName
				if personNames.has_key(aName):
					
					
					print "\t Name already in personNames"
					
					#yes, is it the same uir as this one?
					if personNames[aName] != quad[0]:
						
						print "\t Name Has Diffrent URI Attached"
						
						#no, it is a new UIR, is it aleady in the collision lookup?
						if nameCollisons.has_key(aName):
							
							print "\t Name already in collission"
							
							#yes, is this URI already in it?
							if quad[0] not in nameCollisons[aName]:
								
								print "\t Diffrent Name, adding to it"
								#no, add it
								nameCollisons[aName].appen(quad[0])
						
						else:
						
							#no, add a new array to the collison with it
							nameCollisons[aName] = [quad[0]]
							print "\t Creating new collission record"
				
				else:
					print "\t not yet in personNames, adding it"
					personNames[aName] = quad[0]
			

		
		if quad[1] == '<http://dbpedia.org/ontology/deathDate>':
			deathDate = ''
			deathDate = " ".join(quad[2:])
			deathDate = deathDate[1:deathDate[1:].find('-')+1]				
			
			if len(deathDate) != 4:
				print "Error death date: ", line
			else:
				personDeathDates[quad[0]] = deathDate
			
			#print deathDate
			
		if quad[1] == '<http://dbpedia.org/ontology/birthDate>':
			birthDate = ''
			birthDate = " ".join(quad[2:])
			birthDate = birthDate[1:birthDate[1:].find('-')+1]				
			if len(birthDate) != 4:
				print "Error birth date: ", line
			else:
				personBirthDates[quad[0]] = birthDate
			
 


 
	print len(personNames), len(personBirthDates), len(personDeathDates)
 
	
 
	temp = open("db_tmp.txt","w")
 	for key, value in personNames.iteritems():

			 
			
		line = key + ' ' + value
		
		if personBirthDates.has_key(value):
			line = line + ' ' + personBirthDates[value]
		if personDeathDates.has_key(value):
			line = line + ' ' + personDeathDates[value]				
		
		temp.writelines(line + "\n")
 
 
	
	for key, value in nameCollisons.iteritems():
		
 		
		for x in value:
		
			line = key + ' ' + x
		
			if personBirthDates.has_key(x):
				line = line + ' ' + personBirthDates[x]
			if personDeathDates.has_key(x):
				line = line + ' ' + personDeathDates[x]				
		
			temp.writelines(line + "\n")			
			print line
			

  
	locFile = open('data/personauthoritiesnames.nt.skos', 'r')
	

	
	counter = 0
	counterMatched = 0
	
 	print "building name list"
	locDebug = open("loc_tmp.txt","w")
	for line in locFile:
		
		counter = counter+1
		
		
		#if counter % 100000 == 0:
		#	print "procssed " +  str(counter / 100000)  + "00k names"
		
		if counter % 1000000 == 0:
			print "procssed ", counter / 1000000,"Million names!"
 
			
		quad = line.split();
		name = " ".join(quad[2:])
		name = name[1:name[1:].find('@EN')]			
		
		name = name.replace('?','')
		
		year = re.findall(r'\d{4}', name)
	
		born = 0
		died = 0
		possibleNames = []
		
		if len(year) != 0:
			
			if len(year) == 1 and name[len(name)-1:] != '-':
				

				
				
				if name.find(' b.') != -1:
					born = year[0]
					#print "Born : ",year[0]
				elif name.find(' d.') != -1:
					died = year[0]
					#print "died : ",year[0]
				elif name.find(' fl.') != -1:
					born = year[0]
					#print "born(flourished) : ",year[0]		
				elif name.find('jin shi') != -1:
					born = year[0]
					#print "born(third stage) : ",year[0]							
				elif name.find('ju ren') != -1:
					born = year[0]
					#print "born(second stage) : ",year[0]	
				elif len(re.findall(r'\d{3}\-', name)) != 0:
					
					
					year = re.findall(r'\d{3}\-', name)					
					born = year[0][0:3]
					#print "born : ", year[0][0:3]	
					#now get the death year
					died = re.findall(r'\d{4}', name)[0]
					
				elif len(re.findall(r'\-\d{4}', name)) != 0:
					died = re.findall(r'\-\d{4}', name)[0][1:]

				elif name.find(' ca. ') != -1 or name.find(' ca ') != -1:
					born = year[0]
					#print "born(ca) : ",year[0]	
				elif name.find(' b ') != -1:
					born = year[0]
					#print "Born : ",year[0]
				elif name.find(' d ') != -1:
					died = year[0]
					#print "died : ",year[0]
				elif name.find(' born ') != -1:
					born = year[0]
					#print "Born : ",year[0]
				elif name.find(' died ') != -1:
					died = year[0]
					#print "died : ",year[0]					
				else:
					#print name, "\n"
					#print "error: cannot figure out this date, update the regex"
					#we have hit like 90% of the cases here, now just stright up weird sutff, so just grab the date
					born = year[0]
					
				#print len(year)
				
			elif len(year) == 1 and name[len(name)-1:] == '-':	
				born = year[0]
				
			elif len(year) == 2:
				born = year[0]
				died = year[1]
			elif len(year) == 3:
				#they are doing "1999 or 2000 - blah blah blah"  take first and last
				born = year[0]
				died = year[2]				
			elif len(year) == 4:
				#they are doing "1999 or 2000 - blah blah blah"  take first and last
				born = year[0]
				died = year[3]				
				
			else:
				print name, "Coluld not process date \n"
				sys.exit()
		
		
			#print name, born, died
		
		
		#else:
		
			#these people would have lived < 0 bce - 999 AD, we currently do not care about them.
			#if len(re.findall(r'\d{3}', name)) != 0:
				#print name
		
		#personDates[quad[0]] = [born,died]
			
		
		#now process the name part 
		
		#chop off the rest where a number is detected to get rid of any date
		if re.search(r'\d{1}',name) != None:			
			name = name[0:name.find(re.search(r'\d{1}',name).group())]			
			name=name.strip()
		
		#now chop off anything past the second comma, it is not name stuff afterwards, also with 3 commas are a lot of "sir" and "duke of earl" etc, dont care about that stuff
		if len(re.findall(',', name)) == 2 or len(re.findall(',', name)) == 3:			
			name = name.split(',')[0] + ', ' + name.split(',')[1]
			#print name, '|', newname
			
		
		if name.find('\"') != -1:
			name = name.replace("\\",'')
		
		if len(re.findall(',', name)) == 1:
		
			if name.find('(') == -1:
				#there is no pranthetical name
				
				newname = name.split(',')
				newname = newname[1] + ' ' + newname[0]
				#print name, '|', newname
				possibleNames.append(newname.strip())
				
				
				#we want to add that name, but also add a version with out a middle intial, if that it is present
				if len(newname.split()) == 3 and (newname.split()[1][len(newname.split()[1])-1] == '.' or len(newname.split()[1]) == 1):								
					newname = newname.split()[0] + ' ' + newname.split()[2]			
					#print "\t" + newname
					possibleNames.append(newname.strip())

					
						
				
				
				#we also want to add a name, that if they only have an inital for the first part and a full middle name drop the first intital
				if len(newname.split()) == 3 and len(newname.split()[1]) > 2 and (newname.split()[0][len(newname.split()[0])-1] == '.' or len(newname.split()[1]) == 1):
					newname = newname.split()[1] + ' ' + newname.split()[2]			
					#print "\t" + newname
					possibleNames.append(newname.strip())
						
		
			
			else:
				
				#they have prenthasis in their name meaning that their long form of the name is contained in the pranthesis
				newname = name.split(',')
				newname = newname[1] + ' ' + newname[0]
				
				#cut out the stuff before the pran
				newname = newname[newname.find('(')+1:]
				newname = newname.replace(')','')
				#print name, '|', newname
				possibleNames.append(newname.strip())
				
	
				
				
				
				#now also cut out the middle inital if it is there and add that version
				if len(newname.split()) == 3 and (newname.split()[1][len(newname.split()[1])-1] == '.' or len(newname.split()[1]) == 1):								
					newname = newname.split()[0] + ' ' + newname.split()[2]			
					#print "\t" + newname
					possibleNames.append(newname.strip())

						
				
						

		else:
		
			#so here we are... the depths of the quirks
			if name.find('(') != -1:
			
				#if the very first thing is a inital, it is likely a abrrivated name and the full name is in the prans
				if len(name.split()[0])==2:
					if name.split()[0][1] == '.':
						newname = name.split('(')[1]
						newname = newname.replace(')','')
						possibleNames.append(newname.strip())
							
						#print name, '|', newname
				
				#if len(name.split()[len(name.split())-1])==2:
				#	if name.split()[len(name.split())-1][1] == '.':				
				#		print name, '|'
		
				else:
					#this will be stuff like P-King (Musician), or Shyne (Rapper), stuff we are intrested in, nicknames, so cut out the descriptor
					newname = name.split('(')[0].strip()
					
					#TODO: if we really care to take this further here is a spot where we will lose some names
					#the quirks get very specific and would need a lot more rules
					
					#print name, '|', newname
					possibleNames.append(newname.strip())
					
			else:
				#print name, '|'
				newname = name.strip()
				
				#single names here, add them in
				possibleNames.append(newname.strip())
					
	
		#print possibleNames
		
		
		
		#skip logic:
		if int(born) != 0 and int(born) < 1875:
			continue
		
		
		 
		
		for aPossible in possibleNames:
		
			 
		
			if personNames.has_key(aPossible):
			
				#we have a match (!)
				
				#add all the Ids we are going to check into a list
				useURIs = []			
				
				#the main one
				useURIs.append(personNames[aPossible])
				
				#check for collision names, names that are the same but reflect diffrent URIs
				if nameCollisons.has_key(aPossible):
					for collison in nameCollisons[aPossible]:
						useURIs.append(collison)
				
				
				
				for useURI in useURIs:
			
					
				
					locDebug.writelines(aPossible + ' ' + str(born) + ' ' + str(died) + "\n")
				
					if allLOC.has_key(aPossible):				
						#it is in here already, see if it has this URI
						if quad[0] not in allLOC[aPossible]:
							allLOC[aPossible].append(quad[0])
					
					else:
						allLOC[aPossible] = [quad[0]]
				
				
					
					
					didMatched = False
				
					
					if personBirthDates.has_key(useURI) and personDeathDates.has_key(useURI):
						
						if int(born) != 0 and int(died) != 0 and int(personBirthDates[useURI]) != 0 and int(personDeathDates[useURI]) != 0:
							
							if (int(personBirthDates[useURI]) == int(born)) and (int(died) == int(personDeathDates[useURI])):
								
								if [useURI, quad[0]] not in matchesBothDate:
								
									didMatched=True
									counterMatched = counterMatched + 1
									matchesBothDate.append([useURI, quad[0]])
									foundCheckList.append(useURI)
									
									matchesBothDateURIs.append(useURI)
							
									#print aPossible, quad[0], born, died
									#print aPossible, useURI, personBirthDates[useURI],  personDeathDates[useURI]												
							
									continue
					
					
					


					#see if birth years match
					if personBirthDates.has_key(useURI):				
						if int(personBirthDates[useURI]) == int(born) and int(personBirthDates[useURI]) != 0 and int(born) != 0:					
						
							if [useURI, quad[0]] not in matchesSingleDate:
								#print personNames[aPossible], '=', quad[0]
								didMatched=True
								counterMatched = counterMatched + 1
								matchesSingleDate.append([useURI, quad[0]])
								foundCheckList.append(useURI)

								#print aPossible, quad[0], born, "born match"
								#print aPossible, useURI, personBirthDates[useURI]
								
								continue


								
					#does it have a death date match?					
					if personDeathDates.has_key(useURI):				
						if int(personDeathDates[useURI]) == int(died) and int(personDeathDates[useURI]) != 0 and int(died) != 0:					

						
							if [useURI, quad[0]] not in matchesSingleDate:
								#print personNames[aPossible], '=', quad[0]
								matchesSingleDate.append([useURI, quad[0]])
								didMatched=True
								counterMatched = counterMatched + 1		
								foundCheckList.append(useURI)
								
								#print aPossible, quad[0], died, "death match"
								#print aPossible, useURI, personDeathDates[useURI]	
								
								continue								
						
								 
 
 
	#we are now going to remove any matches from matchesSingleDate where there is a perfect date match already
	temp  = []
	
	for aSingleDateMatch in matchesSingleDate:
	
		if aSingleDateMatch[0] not in matchesBothDateURIs:
			temp.append(aSingleDateMatch)
		else:
			
			for x in matchesBothDate:
				if x[0] == aSingleDateMatch[0]: 
					print "Attempted Dupe", aSingleDateMatch
					print "With", x
	
	
	
	
	matchesSingleDate = list(temp)
	
	
	matchedSingle = []
	matchedMany = []
	matchedNone = []
	
	for key, value in personNames.iteritems():
		
		if value not in foundCheckList:
			#print "Not matched " + value  + ' ' +  key
			
			if allLOC.has_key(key):
				
				if len(allLOC[key]) == 1:
					#print "\tOnly one possible LOC match:" + allLOC[key][0]
					matchedSingle.append([value,allLOC[key][0]])
				else:
					#print "\t 1+ possible LOC match:", allLOC[key]
					matchedMany.append([value,allLOC[key]])
					
			else:
					matchedNone.append(value)
					
					

	print "	\n****Collision***\n"
	
	for key, value in nameCollisons.iteritems():
		
		
		for x in value:
		
			if x not in foundCheckList:
				#print "Not matched " + x  + ' ' +  key
				
				if allLOC.has_key(key):
					
					if len(allLOC[key]) == 1:
						#print "\tOnly one possible LOC match:" + allLOC[key][0]
						matchedSingle.append([x,allLOC[key][0]])
					else:
						#print "\t 1+ possible LOC match:", allLOC[key]
						matchedMany.append([x,allLOC[key]])						
				
				else:
					matchedNone.append(x)
	
	#for key, value in possibleLOC.iteritems():

		#if len(value) == 1:
		
			#if value not in matches:
			#	matches.append(value)
		
			#print key, '=', value
		 
	
	
	#make sure there are no duplicates, as in same DB to LOC records in the singles
	tempCopy = []
	
	for aSingle in matchedSingle:
	
		add = True
	
		for anotherSingle in tempCopy:
		
			if aSingle[0] == anotherSingle[0] and aSingle[1] == anotherSingle[1]:
				add = False
	
		if add:
			tempCopy.append(aSingle)
			
	
	matchedSingle = list(tempCopy)
	
	#now we are going to go through the singles and pull out anyone that has been added twice
	#this can happen for common names born in the same year, move them to the 1->many list
	matchedSingleCheck = []
	matchedSingleDupes = []
	for aSingle in matchedSingle:
		
		if aSingle[0] not in matchedSingleCheck:
			matchedSingleCheck.append(aSingle[0])
		else:		
		
			print "Dupe in singles found:", aSingle
			matchedSingleDupes.append(aSingle[0])
	
	
	singleDupes = {}
	tempCopy = []
	print len(matchedSingle)
	for aSingle in matchedSingle:
	
		if aSingle[0] in matchedSingleDupes:
			
			if singleDupes.has_key(aSingle[0]):
				singleDupes[aSingle[0]].append(aSingle[1])
			else:
				singleDupes[aSingle[0]] = [aSingle[1]]
	
		else:
			tempCopy.append(aSingle)
	
	matchedSingle = list(tempCopy)
	
	print len(matchedSingle)
	print singleDupes
	
	#add them to the matchedmany list
	for key, value in singleDupes.iteritems():
		matchedMany.append([key,value])
	


	#we now need to do the same for matchesSingleDate, they could have matched a single date true, but it could  matched to other people
	matchesSingleDateCheck = []
	matchesSingleDateDupes = []
	for aSingle in matchesSingleDate:
		
		if aSingle[0] not in matchesSingleDateCheck:
			matchesSingleDateCheck.append(aSingle[0])
		else:		
		
			print "Dupe in single date found:", aSingle
			matchesSingleDateDupes.append(aSingle[0])		

			
	singleDateDupes = {}
	tempCopy = []
	print len(matchesSingleDate)
	for aSingle in matchesSingleDate:
	
		if aSingle[0] in matchesSingleDateDupes:
			
			if singleDateDupes.has_key(aSingle[0]):
				singleDateDupes[aSingle[0]].append(aSingle[1])
			else:
				singleDateDupes[aSingle[0]] = [aSingle[1]]
	
		else:
			tempCopy.append(aSingle)
	
	matchesSingleDate = list(tempCopy)
	print len(matchesSingleDate)
	
	#add them to the matchedmany list
	for key, value in singleDateDupes.iteritems():
		matchedMany.append([key,value])	
	
	
	print singleDateDupes
 
	#TODO: This part needs to be fixed so the call to the LOC site is syncrounous, and wait for the file to be ready...
	
	machtedSingleJazz = []
	machtedSingleNoJazz = []
	
	machtedSingleNoJazzLOC = []
	
	for x in matchedSingle:
	
		url = x[1]
		id = formatName(url.split('/names/')[len(url.split('/names/'))-1])
		foundJazz = False
		
		if os.path.exists('data/loc_single/' + id + '.nt') == False:
			os.system('wget --output-document="data/loc_single/' + id + '.nt" "http://id.loc.gov/authorities/names/' + id + '.nt"')

			#sleep as a TODO fix, 
			time.sleep( 1.5 )

		
		if os.path.exists('data/loc_single/' + id + '.nt'):
			
			f = open('data/loc_single/' + id + '.nt', 'r')
			
			for line in f:
				line = line.lower()
				if line.find('jazz') != -1 or line.find('music') != -1 or line.find('blues') != -1 or line.find('jazz') != -1 or line.find('vocal') != -1:
					print line
					foundJazz = True
					
			
			
			f.close()
			 

		else:
			print 'data/loc_single/' + id + '.nt does not exist'
 
		if id in machtedSingleNoJazzLOC:
			foundJazz = False
			print "Dupe detected trying to assign" ,x
	
	
		if foundJazz:
			machtedSingleJazz.append(x)
			machtedSingleNoJazzLOC.append(id)
			
			
		else:
			machtedSingleNoJazz.append(x)
			
	

	
	
	
	
	print len(matchesBothDate), " BothDate Matches", len(matchesSingleDate), " Single Date Matches", len(matchedSingle), "Single LOC", len(matchedMany), "Multiple LOC matches", len(matchedNone), "No Matches"
	#print len(matches)+ len(matchedSingle)+len(matchedMany) , " matched out of Total of about ", len(personNames)
 
	print len(matchedSingle) , " = " , len(machtedSingleJazz) , " keyword found and ", len(machtedSingleNoJazz), " no keyword found"

	
	
	#make the sameas files
	
	allLines=[]
	
	temp = open("data/sameAs_perfect.nt","w")
 	for value in matchesBothDate:		
	
		line = value[0] + ' <http://www.w3.org/2002/07/owl#sameAs> ' + value[1] + " . \n";
		if line not in allLines:
			temp.writelines(line)
			allLines.append(line)
	
	temp = open("data/sameAs_high.nt","w")
 	for value in matchesSingleDate:		
	
		line = value[0] + ' <http://www.w3.org/2002/07/owl#sameAs> ' + value[1] + " . \n";
		if line not in allLines:
			temp.writelines(line)
			allLines.append(line)

	
	temp = open("data/sameAs_medium.nt","w")
 	for value in machtedSingleJazz:		

		line = value[0] + ' <http://www.w3.org/2002/07/owl#sameAs> ' + value[1] + " . \n";
		if line not in allLines:
			temp.writelines(line)
			allLines.append(line)
	
	temp = open("data/sameAs_low.nt","w")
	for value in machtedSingleNoJazz:	
	
		line = value[0] + ' <http://www.w3.org/2002/07/owl#sameAs> ' + value[1] + " . \n";
		if line not in allLines:
			temp.writelines(line)
			allLines.append(line)	
		
	temp = open("data/sameAs_many.nt","w")
 	for value in matchedMany:		
	
		for x in value[1]:
			temp.writelines(value[0] + ' <http://www.w3.org/2004/02/skos/core#closeMatch> ' + x + " . \n") 	
		
	temp = open("data/sameAs_none.nt","w")
 	for value in matchedNone:		
		temp.writelines(value + ' <http://www.w3.org/2002/07/owl#sameAs> ' + '<none>' + " . \n") 		

Example 38

Project: kay
Source File: media_compiler.py
View license
def compile_js_(tag_name, js_config, force):
  if IS_APPSERVER:
    return

  def needs_update(media_info):
    if js_config['tool'] != 'goog_calcdeps':
      # update if target file does not exist
      target_path = make_output_path_(js_config, js_config['subdir'],
                                      js_config['output_filename'])
      if not os.path.exists(target_path):
        return True

    # update if it lacks required info in _media.yaml
    last_info = media_info.get(js_config['subdir'], tag_name)
    if not last_info:
      return True
    last_config = last_info.get('config')
    if not last_config:
      return True

    # update if any configuration setting is changed
    if not equal_object_(last_config, js_config):
      return True

    if 'related_files' not in last_info:
      return True
    for path, mtime in last_info['related_files']:
      if mtime != os.path.getmtime(path):
        return True
      
  def jsminify(js_path):
    from StringIO import StringIO
    from kay.ext.media_compressor.jsmin import JavascriptMinify
    ifile = open(js_path)
    outs = StringIO()
    JavascriptMinify().minify(ifile, outs)
    ret = outs.getvalue()
    if len(ret) > 0 and ret[0] == '\n':
      ret = ret[1:]
    return ret

  def concat(js_path):
    print_status(" concat %s" % js_path)
    ifile = open(js_path)
    js = ifile.read()
    ifile.close()
    return js

  def goog_calcdeps():
    deps_config = copy.deepcopy(js_config['goog_common'])
    deps_config.update(js_config['goog_calcdeps'])

    if deps_config.get('method') not in \
          ['separate', 'concat', 'concat_refs', 'compile']:
      print_status("COMPILE_MEDIA_JS['goog_calcdeps']['method'] setting is"
                   " invalid; unknown method `%s'" % deps_config.get('method'))
      sys.exit(1)

    output_urls = []
    if deps_config['method'] == 'separate':
      source_files, output_urls = goog_calcdeps_separate(deps_config)
    elif deps_config['method'] == 'concat':
      source_files, output_urls = goog_calcdeps_concat(deps_config)
    elif deps_config['method'] == 'concat_refs':
      source_files, output_urls = goog_calcdeps_concat_refs(deps_config)
    elif deps_config['method'] == 'compile':
      source_files, output_urls = goog_calcdeps_compile(deps_config)
      source_files = [file[0] for file in source_files]

    related_files = union_list(source_files, 
                               [make_input_path_(path)
                                  for path in js_config['source_files']])
    related_file_info = [(path, os.path.getmtime(path))
                           for path in related_files]
    
    # create yaml info
    last_info = {'config': copy.deepcopy(js_config),
                 'related_files': related_file_info,
                 'result_urls': output_urls}
    media_info.set(js_config['subdir'], tag_name, last_info)
    media_info.save()

  def goog_calcdeps_separate(deps_config):
    source_files = goog_calcdeps_list(deps_config)
    (output_urls, extern_urls) = goog_calcdeps_copy_files(deps_config,
                                                          source_files)
    return (source_files, extern_urls + output_urls)

  def goog_calcdeps_concat(deps_config):
    source_files = goog_calcdeps_list(deps_config)
    (output_urls, extern_urls) = goog_calcdeps_concat_files(deps_config,
                                                            source_files)
    return (source_files, extern_urls + output_urls)

  def goog_calcdeps_concat_refs(deps_config):
    source_files = goog_calcdeps_list(deps_config)
    original_files = [make_input_path_(path)
                      for path in js_config['source_files']]
    ref_files = [path for path in source_files if path not in original_files]
    (output_urls, extern_urls) = goog_calcdeps_concat_files(deps_config,
                                                            ref_files)
    original_urls = [path[len(kay.PROJECT_DIR):] for path in original_files]
    return (source_files, extern_urls + output_urls + original_urls)

  def goog_calcdeps_compile(deps_config):
    comp_config = copy.deepcopy(js_config['goog_common'])
    comp_config.update(js_config['goog_compiler'])

    source_files = []
    extern_urls = []

    command = '%s -o compiled -c "%s" ' % (deps_config['path'],
                                                 comp_config['path'])
    for path in deps_config.get('search_paths', []):
      command += '-p %s ' % make_input_path_(path)
    for path in js_config['source_files']:
      path = make_input_path_(path)
      command += '-i %s ' % path
      source_files.append((path, os.path.getmtime(path)))

    if comp_config['level'] == 'minify':
      level = 'WHITESPACE_ONLY'
    elif comp_config['level'] == 'advanced':
      level = 'ADVANCED_OPTIMIZATIONS'
    else:
      level = 'SIMPLE_OPTIMIZATIONS'
    flags = '--compilation_level=%s' % level
#    for path in comp_config.get('externs', []):
#      flags += '--externs=%s ' % make_input_path_(path)
#    if comp_config.get('externs'):
#      flags += ' --externs=%s ' % " ".join(comp_config['externs'])
    command += '-f "%s" ' % flags
    print_status(command)
    command_output = os.popen(command).read()

    output_path = make_output_path_(js_config, js_config['subdir'],
                                    js_config['output_filename'])
    ofile = create_file_(output_path)
    try:
      for path in comp_config.get('externs', []):
        if re.match(r'^https?://', path):
          extern_urls.append(path)
          continue
        path = make_input_path_(path)
        ifile = open(path)
        try:
          ofile.write(ifile.read())
        finally:
          ifile.close()
        source_files.append((path, os.path.getmtime(path)))
      ofile.write(command_output)
    finally:
      ofile.close()
    return (source_files, extern_urls + [output_path[len(kay.PROJECT_DIR):]])

  def goog_calcdeps_list(deps_config):
    source_files = []

    command = '%s -o list ' % deps_config['path']
    for path in deps_config['search_paths']:
      command += '-p %s ' % make_input_path_(path)
    for path in js_config['source_files']:
      command += '-i %s ' % make_input_path_(path)
    print_status(command)
    command_output = os.popen(command).read()
    for path in command_output.split("\n"):
      if path == '': continue
      source_files.append(path)
    return source_files

  def goog_calcdeps_copy_files(deps_config, source_files):
    extern_urls = []
    output_urls = []

    output_dir_base = make_output_path_(js_config, 'separated_js')

    if not os.path.exists(output_dir_base):
      os.makedirs(output_dir_base)
    if not deps_config.get('use_dependency_file', True):
      output_path = os.path.join(output_dir_base, '__goog_nodeps.js')
      ofile = open(output_path, "w")
      output_urls.append(output_path[len(kay.PROJECT_DIR):])
      try:
        ofile.write('CLOSURE_NO_DEPS = true;')
      finally:
        ofile.close()

    output_dirs = {}
    search_paths = [make_input_path_(path)
                    for path in deps_config['search_paths']]
    for path in search_paths:
      output_dirs[path] = os.path.join(output_dir_base,
                                       md5.new(path).hexdigest())

    all_paths = [make_input_path_(path)
                 for path in deps_config.get('externs', [])]
    all_paths.extend(source_files)
    for path in all_paths:
      if re.match(r'^https?://', path):
        extern_urls.append(path)
        continue

      path = make_input_path_(path)
      output_path = os.path.join(output_dir_base, re.sub('^/', '', path))
      for dir in search_paths:
        if path[0:len(dir)] == dir:
          output_path = os.path.join(output_dirs[dir],
                                     re.sub('^/', '', path[len(dir):]))
          break
      output_dir = os.path.dirname(output_path)

      if not os.path.exists(output_dir):
        os.makedirs(output_dir)
      shutil.copy2(path, output_path)
      output_urls.append(output_path[len(kay.PROJECT_DIR):])
    return (output_urls, extern_urls)
    
  def goog_calcdeps_concat_files(deps_config, source_files):
    extern_urls = []

    output_path = make_output_path_(js_config, js_config['subdir'],
                                    js_config['output_filename'])
    ofile = create_file_(output_path)
    try:
      if not deps_config.get('use_dependency_file', True):
        ofile.write('CLOSURE_NO_DEPS = true;')
      all_paths = [make_input_path_(path)
                   for path in deps_config.get('externs', [])]
      all_paths.extend(source_files)
      for path in all_paths:
        if re.match(r'^https?://', path):
          extern_urls.append(path)
          continue
        ifile = open(make_input_path_(path))
        ofile.write(ifile.read())
        ifile.close()
    finally:
      ofile.close()

    return ([output_path[len(kay.PROJECT_DIR):]], extern_urls)

  selected_tool = js_config['tool']

  if selected_tool not in \
        (None, 'jsminify', 'concat', 'goog_calcdeps', 'goog_compiler'):
    print_status("COMPILE_MEDIA_JS['tool'] setting is invalid;"
                 " unknown tool `%s'" % selected_tool)
    sys.exit(1)

  global media_info
  if media_info is None:
    media_info = MediaInfo.load()

  if not force and not needs_update(media_info):
    print_status(' up to date.')
    return

  if selected_tool == 'goog_calcdeps':
    return goog_calcdeps()

  if selected_tool is None:
    last_info = {'config': copy.deepcopy(js_config),
                 'result_urls': ['/'+f for f in js_config['source_files']]}
    media_info.set(js_config['subdir'], tag_name, last_info)
    media_info.save()
    return

  dest_path = make_output_path_(js_config, js_config['subdir'],
                                js_config['output_filename'])
  ofile = create_file_(dest_path)
  try:
    if selected_tool == 'jsminify':
      for path in js_config['source_files']:
        src_path = make_input_path_(path)
        ofile.write(jsminify(src_path))
    elif selected_tool == 'concat':
      for path in js_config['source_files']:
        src_path = make_input_path_(path)
        ofile.write(concat(src_path))
  finally:
    ofile.close()
  
  if selected_tool == 'goog_compiler':
    comp_config = copy.deepcopy(js_config['goog_common'])
    comp_config.update(js_config['goog_compiler'])
    if comp_config['level'] == 'minify':
      level = 'WHITESPACE_ONLY'
    elif comp_config['level'] == 'advanced':
      level = 'ADVANCED_OPTIMIZATIONS'
    else:
      level = 'SIMPLE_OPTIMIZATIONS'
    command_args = '--compilation_level=%s' % level
    for path in js_config['source_files']:
      command_args += ' --js %s' % make_input_path_(path)
    command_args += ' --js_output_file %s' % dest_path
    command = 'java -jar %s %s' % (comp_config['path'], command_args)
    command_output = os.popen(command).read()

  info = copy.deepcopy(js_config)
  info['output_filename'] = make_output_path_(js_config, js_config['subdir'],
                                              js_config['output_filename'],
                                              relative=True)
  info['result_urls'] = ['/'+info['output_filename']]
  media_info.set(js_config['subdir'], tag_name, info)
  media_info.save()

Example 39

Project: buildtools-BaseTools
Source File: GenFds.py
View license
def main():
    global Options
    Options = myOptionParser()

    global Workspace
    Workspace = ""
    ArchList = None
    ReturnCode = 0

    EdkLogger.Initialize()
    try:
        if Options.verbose != None:
            EdkLogger.SetLevel(EdkLogger.VERBOSE)
            GenFdsGlobalVariable.VerboseMode = True
            
        if Options.FixedAddress != None:
            GenFdsGlobalVariable.FixedLoadAddress = True
            
        if Options.quiet != None:
            EdkLogger.SetLevel(EdkLogger.QUIET)
        if Options.debug != None:
            EdkLogger.SetLevel(Options.debug + 1)
            GenFdsGlobalVariable.DebugLevel = Options.debug
        else:
            EdkLogger.SetLevel(EdkLogger.INFO)

        if (Options.Workspace == None):
            EdkLogger.error("GenFds", OPTION_MISSING, "WORKSPACE not defined",
                            ExtraData="Please use '-w' switch to pass it or set the WORKSPACE environment variable.")
        elif not os.path.exists(Options.Workspace):
            EdkLogger.error("GenFds", PARAMETER_INVALID, "WORKSPACE is invalid",
                            ExtraData="Please use '-w' switch to pass it or set the WORKSPACE environment variable.")
        else:
            Workspace = os.path.normcase(Options.Workspace)
            GenFdsGlobalVariable.WorkSpaceDir = Workspace
            if 'EDK_SOURCE' in os.environ.keys():
                GenFdsGlobalVariable.EdkSourceDir = os.path.normcase(os.environ['EDK_SOURCE'])
            if (Options.debug):
                GenFdsGlobalVariable.VerboseLogger( "Using Workspace:" + Workspace)
        os.chdir(GenFdsGlobalVariable.WorkSpaceDir)

        if (Options.filename):
            FdfFilename = Options.filename
            FdfFilename = GenFdsGlobalVariable.ReplaceWorkspaceMacro(FdfFilename)

            if FdfFilename[0:2] == '..':
                FdfFilename = os.path.realpath(FdfFilename)
            if not os.path.isabs (FdfFilename):
                FdfFilename = os.path.join(GenFdsGlobalVariable.WorkSpaceDir, FdfFilename)
            if not os.path.exists(FdfFilename):
                EdkLogger.error("GenFds", FILE_NOT_FOUND, ExtraData=FdfFilename)
            if os.path.normcase (FdfFilename).find(Workspace) != 0:
                EdkLogger.error("GenFds", FILE_NOT_FOUND, "FdfFile doesn't exist in Workspace!")

            GenFdsGlobalVariable.FdfFile = FdfFilename
            GenFdsGlobalVariable.FdfFileTimeStamp = os.path.getmtime(FdfFilename)
        else:
            EdkLogger.error("GenFds", OPTION_MISSING, "Missing FDF filename")

        if (Options.BuildTarget):
            GenFdsGlobalVariable.TargetName = Options.BuildTarget
        else:
            EdkLogger.error("GenFds", OPTION_MISSING, "Missing build target")

        if (Options.ToolChain):
            GenFdsGlobalVariable.ToolChainTag = Options.ToolChain
        else:
            EdkLogger.error("GenFds", OPTION_MISSING, "Missing tool chain tag")

        if (Options.activePlatform):
            ActivePlatform = Options.activePlatform
            ActivePlatform = GenFdsGlobalVariable.ReplaceWorkspaceMacro(ActivePlatform)

            if ActivePlatform[0:2] == '..':
                ActivePlatform = os.path.realpath(ActivePlatform)

            if not os.path.isabs (ActivePlatform):
                ActivePlatform = os.path.join(GenFdsGlobalVariable.WorkSpaceDir, ActivePlatform)

            if not os.path.exists(ActivePlatform)  :
                EdkLogger.error("GenFds", FILE_NOT_FOUND, "ActivePlatform doesn't exist!")

            if os.path.normcase (ActivePlatform).find(Workspace) != 0:
                EdkLogger.error("GenFds", FILE_NOT_FOUND, "ActivePlatform doesn't exist in Workspace!")

            ActivePlatform = ActivePlatform[len(Workspace):]
            if len(ActivePlatform) > 0 :
                if ActivePlatform[0] == '\\' or ActivePlatform[0] == '/':
                    ActivePlatform = ActivePlatform[1:]
            else:
                EdkLogger.error("GenFds", FILE_NOT_FOUND, "ActivePlatform doesn't exist!")
        else:
            EdkLogger.error("GenFds", OPTION_MISSING, "Missing active platform")

        GenFdsGlobalVariable.ActivePlatform = PathClass(NormPath(ActivePlatform), Workspace)

        BuildConfigurationFile = os.path.normpath(os.path.join(GenFdsGlobalVariable.WorkSpaceDir, "Conf/target.txt"))
        if os.path.isfile(BuildConfigurationFile) == True:
            TargetTxtClassObject.TargetTxtClassObject(BuildConfigurationFile)
        else:
            EdkLogger.error("GenFds", FILE_NOT_FOUND, ExtraData=BuildConfigurationFile)

        if Options.Macros:
            for Pair in Options.Macros:
                Pair.strip('"')
                List = Pair.split('=')
                if len(List) == 2:
                    if List[0].strip() == "EFI_SOURCE":
                        GlobalData.gEfiSource = List[1].strip()
                        GlobalData.gGlobalDefines["EFI_SOURCE"] = GlobalData.gEfiSource
                        continue
                    elif List[0].strip() == "EDK_SOURCE":
                        GlobalData.gEdkSource = List[1].strip()
                        GlobalData.gGlobalDefines["EDK_SOURCE"] = GlobalData.gEdkSource
                        continue
                    elif List[0].strip() in ["WORKSPACE", "TARGET", "TOOLCHAIN"]:
                        GlobalData.gGlobalDefines[List[0].strip()] = List[1].strip()
                    else:
                        GlobalData.gCommandLineDefines[List[0].strip()] = List[1].strip()
                else:
                    GlobalData.gCommandLineDefines[List[0].strip()] = "TRUE"
        os.environ["WORKSPACE"] = Workspace

        """call Workspace build create database"""
        BuildWorkSpace = WorkspaceDatabase(None)
        BuildWorkSpace.InitDatabase()
        
        #
        # Get files real name in workspace dir
        #
        GlobalData.gAllFiles = DirCache(Workspace)
        GlobalData.gWorkspace = Workspace

        if (Options.archList) :
            ArchList = Options.archList.split(',')
        else:
#            EdkLogger.error("GenFds", OPTION_MISSING, "Missing build ARCH")
            ArchList = BuildWorkSpace.BuildObject[GenFdsGlobalVariable.ActivePlatform, 'COMMON', Options.BuildTarget, Options.ToolChain].SupArchList

        TargetArchList = set(BuildWorkSpace.BuildObject[GenFdsGlobalVariable.ActivePlatform, 'COMMON', Options.BuildTarget, Options.ToolChain].SupArchList) & set(ArchList)
        if len(TargetArchList) == 0:
            EdkLogger.error("GenFds", GENFDS_ERROR, "Target ARCH %s not in platform supported ARCH %s" % (str(ArchList), str(BuildWorkSpace.BuildObject[GenFdsGlobalVariable.ActivePlatform, 'COMMON'].SupArchList)))
        
        for Arch in ArchList:
            GenFdsGlobalVariable.OutputDirFromDscDict[Arch] = NormPath(BuildWorkSpace.BuildObject[GenFdsGlobalVariable.ActivePlatform, Arch, Options.BuildTarget, Options.ToolChain].OutputDirectory)
            GenFdsGlobalVariable.PlatformName = BuildWorkSpace.BuildObject[GenFdsGlobalVariable.ActivePlatform, Arch, Options.BuildTarget, Options.ToolChain].PlatformName

        if (Options.outputDir):
            OutputDirFromCommandLine = GenFdsGlobalVariable.ReplaceWorkspaceMacro(Options.outputDir)
            if not os.path.isabs (OutputDirFromCommandLine):
                OutputDirFromCommandLine = os.path.join(GenFdsGlobalVariable.WorkSpaceDir, OutputDirFromCommandLine)
            for Arch in ArchList:
                GenFdsGlobalVariable.OutputDirDict[Arch] = OutputDirFromCommandLine
        else:
            for Arch in ArchList:
                GenFdsGlobalVariable.OutputDirDict[Arch] = os.path.join(GenFdsGlobalVariable.OutputDirFromDscDict[Arch], GenFdsGlobalVariable.TargetName + '_' + GenFdsGlobalVariable.ToolChainTag)

        for Key in GenFdsGlobalVariable.OutputDirDict:
            OutputDir = GenFdsGlobalVariable.OutputDirDict[Key]
            if OutputDir[0:2] == '..':
                OutputDir = os.path.realpath(OutputDir)

            if OutputDir[1] != ':':
                OutputDir = os.path.join (GenFdsGlobalVariable.WorkSpaceDir, OutputDir)

            if not os.path.exists(OutputDir):
                EdkLogger.error("GenFds", FILE_NOT_FOUND, ExtraData=OutputDir)
            GenFdsGlobalVariable.OutputDirDict[Key] = OutputDir

        """ Parse Fdf file, has to place after build Workspace as FDF may contain macros from DSC file """
        FdfParserObj = FdfParser.FdfParser(FdfFilename)
        FdfParserObj.ParseFile()

        if FdfParserObj.CycleReferenceCheck():
            EdkLogger.error("GenFds", FORMAT_NOT_SUPPORTED, "Cycle Reference Detected in FDF file")

        if (Options.uiFdName) :
            if Options.uiFdName.upper() in FdfParserObj.Profile.FdDict.keys():
                GenFds.OnlyGenerateThisFd = Options.uiFdName
            else:
                EdkLogger.error("GenFds", OPTION_VALUE_INVALID,
                                "No such an FD in FDF file: %s" % Options.uiFdName)

        if (Options.uiFvName) :
            if Options.uiFvName.upper() in FdfParserObj.Profile.FvDict.keys():
                GenFds.OnlyGenerateThisFv = Options.uiFvName
            else:
                EdkLogger.error("GenFds", OPTION_VALUE_INVALID,
                                "No such an FV in FDF file: %s" % Options.uiFvName)

        if (Options.uiCapName) :
            if Options.uiCapName.upper() in FdfParserObj.Profile.CapsuleDict.keys():
                GenFds.OnlyGenerateThisCap = Options.uiCapName
            else:
                EdkLogger.error("GenFds", OPTION_VALUE_INVALID,
                                "No such a Capsule in FDF file: %s" % Options.uiCapName)

        """Modify images from build output if the feature of loading driver at fixed address is on."""
        if GenFdsGlobalVariable.FixedLoadAddress:
            GenFds.PreprocessImage(BuildWorkSpace, GenFdsGlobalVariable.ActivePlatform)
        """Call GenFds"""
        GenFds.GenFd('', FdfParserObj, BuildWorkSpace, ArchList)

        """Generate GUID cross reference file"""
        GenFds.GenerateGuidXRefFile(BuildWorkSpace, ArchList)

        """Display FV space info."""
        GenFds.DisplayFvSpaceInfo(FdfParserObj)

    except FdfParser.Warning, X:
        EdkLogger.error(X.ToolName, FORMAT_INVALID, File=X.FileName, Line=X.LineNumber, ExtraData=X.Message, RaiseError = False)
        ReturnCode = FORMAT_INVALID
    except FatalError, X:
        if Options.debug != None:
            import traceback
            EdkLogger.quiet(traceback.format_exc())
        ReturnCode = X.args[0]
    except:
        import traceback
        EdkLogger.error(
                    "\nPython",
                    CODE_ERROR,
                    "Tools code failure",
                    ExtraData="Please send email to [email protected] for help, attaching following call stack trace!\n",
                    RaiseError=False
                    )
        EdkLogger.quiet(traceback.format_exc())
        ReturnCode = CODE_ERROR
    return ReturnCode

Example 40

Project: IDSDeathBlossom
Source File: IDSToolEnv.py
View license
    def run(self):
        pcaplist = []
        pcaplisttmp = []
        ignorelist = ""
        #loopcnt = 0
        self.currentts = time.strftime("%Y-%m-%d-T-%H-%M-%S", time.localtime())

        #Generate a specific run-id for this run in the format of runmode-timestamp
        if self.Runmode.conf.has_key("custom_runid") and self.Runmode.conf["custom_runid"]:
            self.runid = self.Runmode.conf["custom_runid"]
        else:
            self.runid = "%s-%s" %(str(self.Runmode.runmode),time.strftime("%Y-%m-%d-T-%H-%M-%S", time.localtime()))

        for engine in self.targets:
            e = self.EngineMgr.engines[engine]
            e.runid = self.runid
            e.db = self.db
            e.host = self.host

            #Global Override
            if self.Runmode.conf.has_key("glogoverride") and self.Runmode.conf.has_key("globallogdir"):
                e.conf["logdir"] = self.Runmode.conf["globallogdir"]

            #RunID Dir appended
            if self.Runmode.conf.has_key("appendrunid") and self.Runmode.conf["appendrunid"]:
                if e.conf["logdir"]:
                    e.conf["logdir"] = "%s/%s" % (e.conf["logdir"],self.runid)
                    if not os.path.exists(e.conf["logdir"]):
                        try:
                           os.mkdir(e.conf["logdir"])
                        except:
                           p_error("%s: failed to make directory %s\n" % (str(whoami()),e.conf["logdir"]))
                           sys.exit(1)

                self.Runmode.conf["globallogdir"] = "%s/%s" % (self.Runmode.conf["globallogdir"],self.runid)

                #No Reason to try and create again if we merged them
                if e.conf["logdir"] != self.Runmode.conf["globallogdir"] and not os.path.exists(self.Runmode.conf["globallogdir"]):
                    try:
                        os.mkdir(self.Runmode.conf["globallogdir"])
                    except:
                        p_error("%s: failed to make directory %s\n" % (str(whoami()),self.Runmode.conf["globallogdir"]))
                        sys.exit(1)

            #EngineID Dir Appended probably only makes sense for non-global log dir!?!?
            if self.Runmode.conf.has_key("appendengineid") and self.Runmode.conf["appendengineid"]:
                if e.conf["logdir"]:
                    e.conf["logdir"] = "%s/%s" % (e.conf["logdir"],e.conf["engine"])
                    if not os.path.exists(e.conf["logdir"]):
                        try:
                            os.mkdir(e.conf["logdir"])
                        except:
                            p_error("%s: failed to make directory %s\n" % (str(whoami()),e.conf["logdir"]))
                            sys.exit(1)

            # We setup defaults elsewhere TODO: cleaner version of this we end up setting twice.
            e.logfile = "%s/%s" % (e.conf['logdir'], e.conf['fastlog'])
            e.perflogfile = "%s/%s" % (e.conf['logdir'], e.perflog)

        # All the runmodes that doesn't compare the output of different engines should be in the following list
        # (All that can be independently executed)
        if self.Runmode.runmode in ["sanitize", "verify", "sidperfq"]:
            for engine in self.targets:
                e = self.EngineMgr.engines[engine]

                # And now execute the engine through the runmode
                if self.Runmode.runmode in ["sanitize", "verify"]:
                    e.run(self.Runmode.runmode)
                elif self.Runmode.runmode == "sidperfq":
                    # First check sperfsid
                    if self.Runmode.conf.has_key("sperfsid") and self.Runmode.conf["sperfsid"].isdigit():
                        e.run(self.Runmode.runmode)
                    else:
                        p_error("<%s><%s><%s>: sid provided via --sperfsid %s is invalid or None and/or --perfdb %s option was invalid or not provided" % (str(whoami()),str(lineno()),str(__file__),str(self.Runmode.conf["sperfsid"]),str(self.Runmode.conf["perfdb"])))
                        sys.exit(-19)
            if self.Runmode.runmode == "sanitize":
                self.SummaryHTMLSanitize(self.Runmode.conf['reportgroup'],self.Runmode.conf["globallogdir"],self.Runmode.conf["custom_runid"])

        # Comparison modes here
        elif self.Runmode.runmode == "comparefast":
            if self.Runmode.conf.has_key("cmpropts"):
                self.comparefast(self.Runmode.conf["cmpropts"])
            else:
                p_error("%s: cmpropts is a required argument for the comparefast runmode the options should be passed like --cmpropts=\"file1:mode1,file2:mode2\"")
                sys.exit(1)


        # The looping runmodes should go here
        elif self.Runmode.runmode in ["run","dumbfuzz","xtract","xtractall","rcomparefast"]:
            if self.Signature.conf.has_key("xtractignore") and self.Signature.conf["xtractignore"]:
                self.xignore = self.parse_xtract_ignore()
            else:
                self.xignore = []

            globlist = []

            # hack to get around those of us used to perl globbing.  Globs can be specified as a list
            if self.Pcap.conf.has_key("pcappath") and self.Pcap.conf["pcappath"]:
                globlist = get_glob_list(self.Pcap.conf["pcappath"])
            else:
                p_error("<%s><%s><%s> You must specify suppy a pcap file or a list of pcap files with --pcapppath wildcards are supported\n" % (str(whoami()),str(lineno()),str(__file__)))
                sys.exit(1)

            if self.Pcap.conf.has_key("pcapignore") and self.Pcap.conf["pcapignore"]:
                ignorelist = get_glob_list(self.Pcap.conf["pcapignore"])
            else:
                ignorelist = []

            for pcap in ignorelist:
                if pcap in globlist: globlist.remove(pcap)

            if not globlist:
                p_error("Pcap list empty...bailing")
                sys.exit(1)

            if self.Pcap.conf.has_key("sortpcaps") and self.Pcap.conf["sortpcaps"]:
                if self.Pcap.conf["sortpcaps"] == "size":
                    for pcapfile in globlist:
                        stats = os.stat(pcapfile)
                        pcap_tuple = stats.st_size, pcapfile
                        pcaplisttmp.append(pcap_tuple)
                        pcaplisttmp.sort()
                    for pcap_t in pcaplisttmp:
                        pcaplist.append(pcap_t[1])

                elif self.Pcap.conf["sortpcaps"] == "random":
                    random.shuffle(globlist, random.random)
                    pcaplist = globlist
                    p_debug(str(pcaplist))
                else:
                    pcaplist = globlist

            #The number of times we are going to loop throug the tests if it is a digit
            #we convert to the digit if it is the string forever we leave it as a string
            #in this case loopcnt will always be less than a string
            if self.Runmode.conf.has_key("loopnum"):
                if self.Runmode.conf["loopnum"].isdigit():
                    self.convloop = int(self.Runmode.conf["loopnum"])
                elif self.Runmode.conf["loopnum"] != None and self.Runmode.conf["loopnum"] == "forever":
                    self.convloop = self.Runmode.conf["loopnum"]
            else:
                p_debug("invalid loopnum... defaulting to 1")
                self.Runmode.conf["loopnum"] = 1
                self.convloop = self.Runmode.conf["loopnum"] = 1

            p_info("looping %s times in runmode %s" % (str(self.convloop), self.Runmode.runmode))

            if self.Runmode.runmode in ["xtract","xtractall","run","dumbfuzz"]:
                for engine in self.targets:
                    loopcnt = 0
                    e = self.EngineMgr.engines[engine]
                    # Let each engine know the xignore list
                    e.xignore = self.xignore
                    #for loopcnt in range(0, int(self.convloop)):
                    while loopcnt < self.convloop :
                        p_info("run with success %i out of %s" % (loopcnt, str(self.convloop)))
                        for pcap in pcaplist:
                            self.sidd = {}
                            if self.Runmode.runmode == "run":
                                e.run_ids(pcap, "yes")
                            elif self.Runmode.runmode == "xtract":
                                e.run(self.Runmode.runmode, pcap)
                            elif self.Runmode.runmode == "xtractall":
                                e.run(self.Runmode.runmode, pcap)
                            elif self.Runmode.runmode == "dumbfuzz":
                                e.run(self.Runmode.runmode, pcap)
                        loopcnt += 1
            elif self.Runmode.runmode == "rcomparefast":
                if len(self.targets) != 2:
                    p_error("Error, \"%s\" requires 2 (and only 2) target engines. Got %d engines. Use -L to list the engines available. Exiting..." % (self.Runmode.runmode, len(self.targets)))
                    sys.exit(-21)
                # Recursive compare
                for loopcnt in range(0, int(self.convloop)):
                    p_info("run with success %i out of %s" % (loopcnt, str(self.convloop)))
                    for pcap in pcaplist:
                        self.sidd = {}
                        # TODO: Check that it passes only 2 target engines
                        self.rcomparefast(pcap)
            else:
                p_warn("No runmode selected" % self.Runmode.runmode)
        elif self.Runmode.runmode == "reportonly":
            if self.Runmode.conf.has_key("custom_runid") != True:
                p_error("You must specify a custom runid via --custom-runid for reportonly runmode")
                sys.exit(-20)
            elif self.Runmode.conf.has_key("reportonarr") != True:
                p_error("You must specify something to report on for reportonly runmode")
                sys.exit(-20)
            elif not self.Runmode.conf["reportonarr"]:
                p_error("You must specify something to report on for reportonly runmode")
                sys.exit(-20)
        else:
            p_error("Unknown runmode?? %s??" % self.Runmode.runmode)

        #once we are done looping gen perf report if option specified
        if self.Runmode.conf.has_key("reportonarr"):
            if "TopNWorstAll" in self.Runmode.conf["reportonarr"]:
                self.TopNWorstAll()
            if "TopNWorstCurrent" in self.Runmode.conf["reportonarr"]:
                self.TopNWorstCurrent()
            if "TopNWorstCurrentHTML" in self.Runmode.conf["reportonarr"]:
                self.TopNWorstCurrentHTML()
            if "LoadReportCurrent" in self.Runmode.conf["reportonarr"]:
                self.LoadReportCurrent()
            if "LoadReportCurrentHTMLMoloch" in self.Runmode.conf["reportonarr"]:
                self.LoadReportCurrentHTMLMoloch()

        if self.Runmode.conf.has_key("sqlquery") and self.Runmode.conf["sqlquery"] != "":
            self.queryDB(self.Runmode.conf["sqlquery"])

Example 41

Project: tp-libvirt
Source File: virsh_managedsave.py
View license
def run(test, params, env):
    """
    Test command: virsh managedsave.

    This command can save and destroy a
    running domain, so it can be restarted
    from the same state at a later time.
    """

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    managed_save_file = "/var/lib/libvirt/qemu/save/%s.save" % vm_name

    # define function
    def vm_recover_check(option, libvirtd, check_shutdown=False):
        """
        Check if the vm can be recovered correctly.

        :param guest_name : Checked vm's name.
        :param option : managedsave command option.
        """
        # This time vm not be shut down
        if vm.is_alive():
            raise error.TestFail("Guest should be inactive")
        # Check vm managed save state.
        ret = virsh.dom_list("--managed-save --inactive")
        vm_state1 = re.findall(r".*%s.*" % vm_name,
                               ret.stdout.strip())[0].split()[2]
        ret = virsh.dom_list("--managed-save --all")
        vm_state2 = re.findall(r".*%s.*" % vm_name,
                               ret.stdout.strip())[0].split()[2]
        if vm_state1 != "saved" or vm_state2 != "saved":
            raise error.TestFail("Guest state should be saved")

        virsh.start(vm_name)
        # This time vm should be in the list
        if vm.is_dead():
            raise error.TestFail("Guest should be active")
        # Restart libvirtd and check vm status again.
        libvirtd.restart()
        if vm.is_dead():
            raise error.TestFail("Guest should be active after"
                                 " restarting libvirtd")
        # Check managed save file:
        if os.path.exists(managed_save_file):
            raise error.TestFail("Managed save image exist "
                                 "after starting the domain")
        if option:
            if option.count("running"):
                if vm.is_dead() or vm.is_paused():
                    raise error.TestFail("Guest state should be"
                                         " running after started"
                                         " because of '--running' option")
            elif option.count("paused"):
                if not vm.is_paused():
                    raise error.TestFail("Guest state should be"
                                         " paused after started"
                                         " because of '--paused' option")
        else:
            if params.get("paused_after_start_vm") == "yes":
                if not vm.is_paused():
                    raise error.TestFail("Guest state should be"
                                         " paused after started"
                                         " because of initia guest state")
        if check_shutdown:
            # Resume the domain.
            if vm.is_paused():
                vm.resume()
            vm.wait_for_login()
            # Shutdown and start the domain,
            # it should be in runing state and can be login.
            vm.shutdown()
            vm.wait_for_shutdown()
            vm.start()
            vm.wait_for_login()

    def vm_undefine_check(vm_name):
        """
        Check if vm can be undefined with manage-save option
        """
        #backup xml file
        xml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
        if not os.path.exists(managed_save_file):
            raise error.TestFail("Can't find managed save image")
        #undefine domain with no options.
        if not virsh.undefine(vm_name, options=None,
                              ignore_status=True).exit_status:
            raise error.TestFail("Guest shouldn't be undefined"
                                 "while domain managed save image exists")
        #undefine domain with managed-save option.
        if virsh.undefine(vm_name, options="--managed-save",
                          ignore_status=True).exit_status:
            raise error.TestFail("Guest can't be undefine with "
                                 "managed-save option")

        if os.path.exists(managed_save_file):
            raise error.TestFail("Managed save image exists"
                                 " after undefining vm")
        #restore and start the vm.
        xml_backup.define()
        vm.start()

    def check_flags_parallel(virsh_cmd, bash_cmd, flags):
        """
        Run the commands parallel and check the output.
        """
        cmd = ("%s & %s" % (virsh_cmd, bash_cmd))
        ret = utils.run(cmd, ignore_status=True)
        output = ret.stdout.strip()
        logging.debug("check flags output: %s" % output)
        lines = re.findall(r"flags:.+%s" % flags, output, re.M)
        logging.debug("Find lines: %s" % lines)
        if not lines:
            raise error.TestFail("Checking flags %s failed" % flags)

        return ret

    def check_multi_guests(guests, start_delay, libvirt_guests):
        """
        Check start_delay option for multiple guests.
        """
        # Destroy vm first
        if vm.is_alive():
            vm.destroy(gracefully=False)
        # Clone given number of guests
        timeout = params.get("clone_timeout", 360)
        for i in range(int(guests)):
            dst_vm = "%s_%s" % (vm_name, i)
            utils_libguestfs.virt_clone_cmd(vm_name, dst_vm,
                                            True, timeout=timeout)
            virsh.start(dst_vm)

        # Wait 10 seconds for vm to start
        time.sleep(10)
        is_systemd = utils.run("cat /proc/1/comm").stdout.count("systemd")
        if is_systemd:
            libvirt_guests.restart()
            pattern = r'(.+ \d\d:\d\d:\d\d).+: Resuming guest.+done'
        else:
            ret = utils.run("service libvirt-guests restart | \
            awk '{ print strftime(\"%b %y %H:%M:%S\"), $0; fflush(); }'")
            pattern = r'(.+ \d\d:\d\d:\d\d)+ Resuming guest.+done'

        # libvirt-guests status command read messages from systemd
        # journal, in cases of messages are not ready in time,
        # add a time wait here.
        def wait_func():
            return libvirt_guests.raw_status().stdout.count("Resuming guest")

        utils_misc.wait_for(wait_func, 5)
        if is_systemd:
            ret = libvirt_guests.raw_status()
        logging.info("status output: %s", ret.stdout)
        resume_time = re.findall(pattern, ret.stdout, re.M)
        if not resume_time:
            raise error.TestFail("Can't see messages of resuming guest")

        # Convert time string to int
        resume_seconds = [time.mktime(time.strptime(
            tm, "%b %y %H:%M:%S")) for tm in resume_time]
        logging.info("Resume time in seconds: %s", resume_seconds)
        # Check if start_delay take effect
        for i in range(len(resume_seconds)-1):
            if resume_seconds[i+1] - resume_seconds[i] < int(start_delay):
                raise error.TestFail("Checking start_delay failed")

    def wait_for_state(vm_state):
        """
        Wait for vm state is ready.
        """
        utils_misc.wait_for(lambda: vm.state() == vm_state, 10)

    def check_guest_flags(bash_cmd, flags):
        """
        Check bypass_cache option for single guest.
        """
        # Drop caches.
        drop_caches()
        # form proper parallel command based on if systemd is used or not
        is_systemd = utils.run("cat /proc/1/comm").stdout.count("systemd")
        if is_systemd:
            virsh_cmd_stop = "systemctl stop libvirt-guests"
            virsh_cmd_start = "systemctl start libvirt-guests"
        else:
            virsh_cmd_stop = "service libvirt-guests stop"
            virsh_cmd_start = "service libvirt-guests start"

        ret = check_flags_parallel(virsh_cmd_stop, bash_cmd %
                                   (managed_save_file, managed_save_file,
                                    "1", flags), flags)
        if is_systemd:
            ret = libvirt_guests.raw_status()
        logging.info("status output: %s", ret.stdout)
        if all(["Suspending %s" % vm_name not in ret.stdout,
                "stopped, with saved guests" not in ret.stdout]):
            raise error.TestFail("Can't see messages of suspending vm")
        # status command should return 3.
        if not is_systemd:
            ret = libvirt_guests.raw_status()
        if ret.exit_status != 3:
            raise error.TestFail("The exit code %s for libvirt-guests"
                                 " status is not correct" % ret)

        # Wait for VM in shut off state
        wait_for_state("shut off")
        check_flags_parallel(virsh_cmd_start, bash_cmd %
                             (managed_save_file, managed_save_file,
                              "0", flags), flags)
        # Wait for VM in running state
        wait_for_state("running")

    def vm_msave_remove_check(vm_name):
        """
        Check managed save remove command.
        """
        if not os.path.exists(managed_save_file):
            raise error.TestFail("Can't find managed save image")
        virsh.managedsave_remove(vm_name)
        if os.path.exists(managed_save_file):
            raise error.TestFail("Managed save image still exists")
        virsh.start(vm_name)
        # The domain state should be running
        if vm.state() != "running":
            raise error.TestFail("Guest state should be"
                                 " running after started")

    def vm_managedsave_loop(vm_name, loop_range, libvirtd):
        """
        Run a loop of managedsave command and check its result.
        """
        if vm.is_dead():
            virsh.start(vm_name)
        for i in range(int(loop_range)):
            logging.debug("Test loop: %s" % i)
            virsh.managedsave(vm_name)
            virsh.start(vm_name)
        # Check libvirtd status.
        if not libvirtd.is_running():
            raise error.TestFail("libvirtd is stopped after cmd")
        # Check vm status.
        if vm.state() != "running":
            raise error.TestFail("Guest isn't in running state")

    def build_vm_xml(vm_name, **dargs):
        """
        Build the new domain xml and define it.
        """
        try:
            # stop vm before doing any change to xml
            if vm.is_alive():
                vm.destroy(gracefully=False)
            vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
            if dargs.get("cpu_mode"):
                if "cpu" in vmxml:
                    del vmxml.cpu
                cpuxml = vm_xml.VMCPUXML()
                cpuxml.mode = params.get("cpu_mode", "host-model")
                cpuxml.match = params.get("cpu_match", "exact")
                cpuxml.fallback = params.get("cpu_fallback", "forbid")
                cpu_topology = {}
                cpu_topology_sockets = params.get("cpu_topology_sockets")
                if cpu_topology_sockets:
                    cpu_topology["sockets"] = cpu_topology_sockets
                cpu_topology_cores = params.get("cpu_topology_cores")
                if cpu_topology_cores:
                    cpu_topology["cores"] = cpu_topology_cores
                cpu_topology_threads = params.get("cpu_topology_threads")
                if cpu_topology_threads:
                    cpu_topology["threads"] = cpu_topology_threads
                if cpu_topology:
                    cpuxml.topology = cpu_topology
                vmxml.cpu = cpuxml
                vmxml.vcpu = int(params.get("vcpu_nums"))
            if dargs.get("sec_driver"):
                seclabel_dict = {"type": "dynamic", "model": "selinux",
                                 "relabel": "yes"}
                vmxml.set_seclabel([seclabel_dict])

            vmxml.sync()
            vm.start()
        except Exception, e:
            logging.error(str(e))
            raise error.TestNAError("Build domain xml failed")

    status_error = ("yes" == params.get("status_error", "no"))
    vm_ref = params.get("managedsave_vm_ref", "name")
    libvirtd_state = params.get("libvirtd", "on")
    extra_param = params.get("managedsave_extra_param", "")
    progress = ("yes" == params.get("managedsave_progress", "no"))
    cpu_mode = "yes" == params.get("managedsave_cpumode", "no")
    test_undefine = "yes" == params.get("managedsave_undefine", "no")
    test_bypass_cache = "yes" == params.get("test_bypass_cache", "no")
    autostart_bypass_cache = params.get("autostart_bypass_cache", "")
    multi_guests = params.get("multi_guests", "")
    test_libvirt_guests = params.get("test_libvirt_guests", "")
    check_flags = "yes" == params.get("check_flags", "no")
    security_driver = params.get("security_driver", "")
    remove_after_cmd = "yes" == params.get("remove_after_cmd", "no")
    option = params.get("managedsave_option", "")
    check_shutdown = "yes" == params.get("shutdown_after_cmd", "no")
    pre_vm_state = params.get("pre_vm_state", "")
    move_saved_file = "yes" == params.get("move_saved_file", "no")
    test_loop_cmd = "yes" == params.get("test_loop_cmd", "no")
    if option:
        if not virsh.has_command_help_match('managedsave', option):
            # Older libvirt does not have this option
            raise error.TestNAError("Older libvirt does not"
                                    " handle arguments consistently")

    # Backup xml file.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
    # Get the libvirtd service
    libvirtd = utils_libvirtd.Libvirtd()
    # Get config files.
    qemu_config = utils_config.LibvirtQemuConfig()
    libvirt_guests_config = utils_config.LibvirtGuestsConfig()
    # Get libvirt-guests service
    libvirt_guests = Factory.create_service("libvirt-guests")

    try:
        # Destroy vm first for setting configuration file
        if vm.state() == "running":
            vm.destroy(gracefully=False)
        # Prepare test environment.
        if libvirtd_state == "off":
            libvirtd.stop()
        if autostart_bypass_cache:
            ret = virsh.autostart(vm_name, "", ignore_status=True)
            libvirt.check_exit_status(ret)
            qemu_config.auto_start_bypass_cache = autostart_bypass_cache
            libvirtd.restart()
        if security_driver:
            qemu_config.security_driver = [security_driver]
        if test_libvirt_guests:
            if multi_guests:
                start_delay = params.get("start_delay", "20")
                libvirt_guests_config.START_DELAY = start_delay
            if check_flags:
                libvirt_guests_config.BYPASS_CACHE = "1"
            # The config file format should be "x=y" instead of "x = y"
            utils.run("sed -i -e 's/ = /=/g' "
                      "/etc/sysconfig/libvirt-guests")
            libvirt_guests.restart()

        # Change domain xml.
        if cpu_mode:
            build_vm_xml(vm_name, cpu_mode=True)
        if security_driver:
            build_vm_xml(vm_name, sec_driver=True)

        # Turn VM into certain state.
        if pre_vm_state == "transient":
            logging.info("Creating %s..." % vm_name)
            vmxml_for_test = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
            if vm.is_alive():
                vm.destroy(gracefully=False)
            # Wait for VM to be in shut off state
            utils_misc.wait_for(lambda: vm.state() == "shut off", 10)
            vm.undefine()
            if virsh.create(vmxml_for_test.xml, ignore_status=True).exit_status:
                vmxml_backup.define()
                raise error.TestNAError("Cann't create the domain")

        # Wait for vm in stable state
        if params.get("start_vm") == "yes":
            if vm.state() == "shut off":
                vm.start()
                vm.wait_for_login()

        # run test case
        domid = vm.get_id()
        domuuid = vm.get_uuid()
        if vm_ref == "id":
            vm_ref = domid
        elif vm_ref == "uuid":
            vm_ref = domuuid
        elif vm_ref == "hex_id":
            vm_ref = hex(int(domid))
        elif vm_ref.count("invalid"):
            vm_ref = params.get(vm_ref)
        elif vm_ref == "name":
            vm_ref = vm_name

        # Ignore exception with "ignore_status=True"
        if progress:
            option += " --verbose"
        option += extra_param

        # For bypass_cache test. Run a shell command to check fd flags while
        # excuting managedsave command
        bash_cmd = ("let i=1; while((i++<400)); do if [ -e %s ]; then (cat /proc"
                    "/$(lsof -w %s|awk '/libvirt_i/{print $2}')/fdinfo/*%s* |"
                    "grep 'flags:.*%s') && break; else sleep 0.05; fi; done;")
        # Flags to check bypass cache take effect
        flags = "014"
        if test_bypass_cache:
            # Drop caches.
            drop_caches()
            virsh_cmd = "virsh managedsave %s %s" % (option, vm_name)
            check_flags_parallel(virsh_cmd, bash_cmd %
                                 (managed_save_file, managed_save_file,
                                  "1", flags), flags)
            # Wait for VM in shut off state
            wait_for_state("shut off")
            virsh_cmd = "virsh start %s %s" % (option, vm_name)
            check_flags_parallel(virsh_cmd, bash_cmd %
                                 (managed_save_file, managed_save_file,
                                  "0", flags), flags)
            # Wait for VM in running state
            wait_for_state("running")
        elif test_libvirt_guests:
            logging.debug("libvirt-guests status: %s", libvirt_guests.status())
            if multi_guests:
                check_multi_guests(multi_guests,
                                   start_delay, libvirt_guests)

            if check_flags:
                check_guest_flags(bash_cmd, flags)

        else:
            # Ensure VM is running
            utils_misc.wait_for(lambda: vm.state() == "running", 10)
            ret = virsh.managedsave(vm_ref, options=option, ignore_status=True)
            status = ret.exit_status
            # The progress information outputed in error message
            error_msg = ret.stderr.strip()
            if move_saved_file:
                cmd = "echo > %s" % managed_save_file
                utils.run(cmd)

            # recover libvirtd service start
            if libvirtd_state == "off":
                libvirtd.start()

            if status_error:
                if not status:
                    raise error.TestFail("Run successfully with wrong command!")
            else:
                if status:
                    raise error.TestFail("Run failed with right command")
                if progress:
                    if not error_msg.count("Managedsave:"):
                        raise error.TestFail("Got invalid progress output")
                if remove_after_cmd:
                    vm_msave_remove_check(vm_name)
                elif test_undefine:
                    vm_undefine_check(vm_name)
                elif autostart_bypass_cache:
                    libvirtd.stop()
                    virsh_cmd = ("(service libvirtd start)")
                    check_flags_parallel(virsh_cmd, bash_cmd %
                                         (managed_save_file, managed_save_file,
                                          "0", flags), flags)
                elif test_loop_cmd:
                    loop_range = params.get("loop_range", "20")
                    vm_managedsave_loop(vm_name, loop_range, libvirtd)
                else:
                    vm_recover_check(option, libvirtd, check_shutdown)
    finally:
        # Restore test environment.

        # Ensure libvirtd is started
        if not libvirtd.is_running():
            libvirtd.start()
        if vm.is_paused():
            virsh.resume(vm_name)
        elif vm.is_dead():
            vm.start()
        # Wait for VM in running state
        wait_for_state("running")
        if autostart_bypass_cache:
            virsh.autostart(vm_name, "--disable",
                            ignore_status=True)
        if vm.is_alive():
            vm.destroy(gracefully=False)
        # Wait for VM to be in shut off state
        utils_misc.wait_for(lambda: vm.state() == "shut off", 10)
        virsh.managedsave_remove(vm_name)
        vmxml_backup.sync()
        if multi_guests:
            for i in range(int(multi_guests)):
                virsh.remove_domain("%s_%s" % (vm_name, i),
                                    "--remove-all-storage")
        qemu_config.restore()
        libvirt_guests_config.restore()
        libvirtd.restart()

Example 42

Project: mantaray
Source File: create_kml_from_exif_mr.py
View license
def create_kml_from_exif_mr(item_to_process, case_number, root_folder_path, evidence):
	print("The item to process is: " + item_to_process)
	print("The case_name is: " + case_number)
	print("The output folder is: " + root_folder_path)
	print("The evidence to process is: " + evidence)

	evidence_no_quotes = evidence
	evidence = '"' + evidence + '"'

	#create output folder path
	folder_path = root_folder_path + "/" + "KML_From_EXIF"
	check_for_folder(folder_path, "NONE")
	

	#open a log file for output
	log_file = folder_path + "/KML_From_EXIF_logfile.txt"
	outfile = open(log_file, 'wt+')

	#initialize variables
	files_of_interest = {}
	files_of_interest_list = []
	mount_point = "NONE"

	log_file3 = folder_path + "/" + case_number + "_files_to_exploit.xls"
	outfile3 = open(log_file3, 'wt+')

	#write out column headers to xls file
	outfile3.write("Name\tMD5\tFile Size (kb)\n")



	if(item_to_process == "Directory"):
		#select folder to process
		folder_process = evidence_no_quotes
	
		#set folder variable to "folder" since this is a folder and not a disk partition
		folder = "Directory"

		#call process subroutine
		process(folder_process, outfile, folder_path, folder, outfile3)

	elif(item_to_process == 'EnCase Logical Evidence File'):
		folder = "LEF"
		file_to_process = evidence
		mount_point = mount_encase_v6_l01(case_number, file_to_process, outfile)
		process(mount_point, outfile, folder_path, folder, outfile3)

		#umount
		if(os.path.exists(mount_point)):
			subprocess.call(['sudo umount -f ' + mount_point], shell=True)
			os.rmdir(mount_point)

	elif(item_to_process == 'Single File'):
		process_single_file(evidence_no_quotes, outfile, folder_path, "Single-File", outfile3)

	elif(item_to_process == 'Bit-Stream Image'):

		#select image to process
		Image_Path = evidence

		#get datetime
		now = datetime.datetime.now()

		#set Mount Point
		mount_point = "/mnt/" + now.strftime("%Y-%m-%d_%H_%M_%S")	

		#check to see if Image file is in Encase format
		if re.search(".E01", Image_Path):
			#strip out single quotes from the quoted path
			no_quotes_path = Image_Path.replace("'","")
			print("The no quotes path is: " + no_quotes_path)
			#call mount_ewf function
			cmd_false = "sudo gsettings set org.gnome.desktop.media-handling automount false && sudo gsettings set org.gnome.desktop.media-handling automount-open false"
			try:
				subprocess.call([cmd_false], shell=True)
			except:
				print("Autmount false failed")
			Image_Path = mount_ewf(Image_Path, outfile, mount_point)

		#call mmls function
		partition_info_dict, temp_time = mmls(outfile, Image_Path)
		#partition_info_dict_temp, temp_time = partition_info_dict

		#get filesize of mmls_output.txt
		file_size = os.path.getsize("/tmp/mmls_output_" + temp_time +".txt") 
		print("The filesize is: " + str(file_size))

		#if filesize of mmls output is 0 then run parted
		if(file_size == 0):
			print("mmls output was empty, running parted")
			outfile.write("mmls output was empty, running parted")
			#call parted function
			partition_info_dict, temp_time = parted(outfile, Image_Path)	

		else:
	
			#read through the mmls output and look for GUID Partition Tables (used on MACS)
			mmls_output_file = open("/tmp/mmls_output_" + temp_time + ".txt", 'r')
			for line in mmls_output_file:
				if re.search("GUID Partition Table", line):
					print("We found a GUID partition table, need to use parted")
					outfile.write("We found a GUID partition table, need to use parted\n")
					#call parted function
					partition_info_dict, temp_time = parted(outfile, Image_Path)

			#close file
			mmls_output_file.close()

		#loop through the dictionary containing the partition info (filesystem is VALUE, offset is KEY)
		#for key,value in partition_info_dict.items():
		for key,value in sorted(partition_info_dict.items()):

			#create output folder for processed files
			if not os.path.exists(folder_path + "/Processed_files_" + str(key)):
				os.mkdir(folder_path + "/Processed_files_" + str(key))

			#disable auto-mount in nautilis - this stops a nautilis window from popping up everytime the mount command is executed
			cmd_false = "sudo gsettings set org.gnome.desktop.media-handling automount false && sudo gsettings set org.gnome.desktop.media-handling automount-open false"
			try:
				subprocess.call([cmd_false], shell=True)
			except:
				print("Autmount false failed")

			#call mount sub-routine
			success_code, loopback_device_mount = mount(value,key,Image_Path, outfile, mount_point)

			if(success_code):
				print("Could not mount partition with filesystem: " + value + " at offset:" + str(key))
				outfile.write("Could not mount partition with filesystem: " + value + " at offset:" + str(key))
			else:
		
				print("We just mounted filesystem: " + value + " at offset:" + str(key) + ". Scanning for files of interest.....\n")
				outfile.write("We just mounted filesystem: " + value + " at offset:" + str(key) + "\n")

				#call process subroutine
				process(mount_point, outfile, folder_path, key, outfile3)
			

				#unmount and remove mount points
				if(os.path.exists(mount_point)): 
					subprocess.call(['sudo umount -f ' + mount_point], shell=True)
					os.rmdir(mount_point)
				#unmount loopback device if this image was HFS+ - need to run losetup -d <loop_device> before unmounting
				if not (loopback_device_mount == "NONE"):
					losetup_d_command = "losetup -d " + loopback_device_mount
					subprocess.call([losetup_d_command], shell=True)

			#delete /tmp files created for processing bit-stream images
			if (os.path.exists("/tmp/mmls_output_" + temp_time + ".txt")):
				os.remove("/tmp/mmls_output_" + temp_time + ".txt")

	#write out list of filenames to end of output file so that user can create a filter for those filenames in Encase
	outfile3.write("\n\n******** LIST of FILENAMES of INTEREST ******************\n")
	#sort list so that all values are unique
	unique(files_of_interest_list) 
	for files in files_of_interest_list:
		outfile3.write(files + "\n")
	

	#program cleanup
	outfile.close()
	outfile3.close()

	#remove mount points created for this program
	if(os.path.exists(mount_point)):
		subprocess.call(['sudo umount -f ' + mount_point], shell=True)
		os.rmdir(mount_point)
	if(os.path.exists(mount_point+"_ewf")):
		subprocess.call(['sudo umount -f ' + mount_point + "_ewf"], shell=True)
		os.rmdir(mount_point+"_ewf")
	
	#convert outfile using unix2dos	
	#chdir to output foler
	os.chdir(folder_path)

	#run text files through unix2dos
	for root, dirs, files in os.walk(folder_path):
		for filenames in files:
			#get file extension
			fileName, fileExtension = os.path.splitext(filenames)
			if(fileExtension.lower() == ".txt"):
				full_path = os.path.join(root,filenames)
				quoted_full_path = "'" +full_path+"'"
				print("Running Unix2dos against file: " + filenames)
				unix2dos_command = "sudo unix2dos " + filenames
				subprocess.call([unix2dos_command], shell=True)

	#delete empty directories in output folder
	for root, dirs, files in os.walk(folder_path, topdown=False):	
		for directories in dirs:
			files = []
			dir_path = os.path.join(root,directories)
			files = os.listdir(dir_path)	
			if(len(files) == 0):
				os.rmdir(dir_path)

	#unmount and remove mount points
	if(mount_point != "NONE"):
		if(os.path.exists(mount_point+"_ewf")):
			subprocess.call(['sudo umount -f ' + mount_point + "_ewf"], shell=True)
			os.rmdir(mount_point+"_ewf")

Example 43

Project: stonix
Source File: ReqAuthSingleUserMode.py
View license
    def fix(self):
        '''
        The fix method will apply the required settings to the system. 
        self.rulesuccess will be updated if the rule does not succeed.
        Enter the correct config entry in EITHER /etc/inittab OR 
        /etc/default/sulogin OR /etc/ttys OR /etc/sysconfig/init to require 
        authentication with single-user mode.

        @author bemalmbe
        '''
        try:
            if not self.ci.getcurrvalue():
                return
            
            success = True
            self.detailedresults = ""
            
            #clear out event history so only the latest fix is recorded
            self.iditerator = 0
            eventlist = self.statechglogger.findrulechanges(self.rulenumber)
            for event in eventlist:
                self.statechglogger.deleteentry(event)
                    
            #there is no way to disable the requirement of a password for 
            #apt-get systems so no need to do anything 
            if not self.ph.manager == "apt-get" and not self.ph.manager == \
                                                                      "zypper":
                
                #solution for bsd
                if self.ph.manager == "freebsd":
                    fp = "/etc/ttys"
                    tfp = fp + ".tmp"
                    created = False
                    badfile = False
                    if not os.path.exists(fp):
                        createFile(fp)
                        created = True
                    if os.path.exists(fp):
                        if not created:
                            #we check if file was previously created above
                            #if so, we don't want to record a permission
                            #change event, since the undo will be file deletion
                            if not checkPerms(fp, [0, 0, 420], self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, self.rulenumber)
                                if setPerms(fp, [0, 0, 420], self.logger, 
                                                    self.statechglogger, myid):
                                    self.detailedresults += "Successfully \
corrected permissions on file: " + fp + "\n"
                                else:
                                    self.detailedresults += "Was not able to \
successfully set permissions on file: " + fp + "\n"
                                    success = False
                                    
                        #read in file
                        contents = readFile(fp, self.logger)
                        tempstring = ""
                        for line in contents:
                            #search for any line beginning with tty
                            if re.search("^tty", line):
                                linesplit = line.split()
                                try:
                                    #replace any line beginning with tty's
                                    #value with insecure if secure
                                    if linesplit[4] == "secure":
                                        linesplit[4] == "insecure"
                                        badfile = True
                                        tempstring += " ".join(linesplit) + "\n"
                                    else:
                                        tempstring += line
                                except IndexError:
                                    debug = traceback.format_exc() + "\n"
                                    debug += "Index out of range on line: " + line + "\n"
                                    self.logger.log(LogPriority.DEBUG, debug)
                            else:
                                tempstring += line
                        
                        #check to see if badfile is true which is set when
                        #checking contents of the file, if badfile is false
                        #we found everything in the file we needed, so no need
                        #to change
                        if badfile:
                            if writeFile(tfp, tempstring, self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, 
                                                               self.rulenumber)
                                #if the file wasn't created earlier, then we
                                #will record the change event as a file change
                                if not created:
                                    event = {"eventtype":"conf",
                                             "filepath":fp}
                                    self.statechglogger.recordchgevent(myid, 
                                                                         event)
                                    self.statechglogger.recordfilechange(fp, 
                                                                     tfp, myid)
                                #if file was created earlier, then we will 
                                #record the change event as a file creation
                                #so that undo event will be a file deletion
                                else:
                                    event = {"eventtype":"conf",
                                             "filepath":fp}
                                    self.statechglogger.recordchgevent(myid,
                                                                         event)
                                self.detailedresults += "corrected contents \
and wrote to file: " + fp + "\n"
                                os.rename(tfp, fp)
                                os.chown(fp, 0, 0)
                                os.chmod(fp, 420)
                                resetsecon(fp)
                            else:
                                self.detailedresults += "Unable to \
successfully write the file: " + fp + "\n"
                                success = False
                                
                if self.ph.manager == "yum":
                    tempstring = ""
                    fp = "/etc/sysconfig/init"
                    tfp = fp + ".tmp"
                    created = False
                    badfile = False
                    if not os.path.exists(fp):
                        createFile(fp)
                        created = True
                    if os.path.exists(fp):
                        if not created:
                            #we check if file was previously created above
                            #if so, we don't want to record a permission
                            #change event, since the undo will be file deletion
                            if not checkPerms(fp, [0, 0, 420], self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, self.rulenumber)
                                if setPerms(fp, [0, 0, 420], self.logger,
                                                    self.statechglogger, myid):
                                    self.detailedresults += "Successfully \
corrected permissions on file: " + fp + "\n"
                                else:
                                    self.detailedresults += "Was not able to \
successfully set permissions on file: " + fp + "\n"
                                    success = False
                        contents = readFile(fp, self.logger)
                        if contents:
                            linefound = False
                            for line in contents:
                                if re.search("^SINGLE", line.strip()):
                                    if re.search("=", line):
                                        temp = line.split("=")
                                        try:
                                            if temp[1].strip() == "/sbin/sulogin":
                                                tempstring += line
                                                linefound = True
                                        except IndexError:
                                            self.compliant = False
                                            debug = traceback.format_exc() + "\n"
                                            debug += "Index out of range on line: " + line + "\n"
                                            self.logger.log(LogPriority.DEBUG, debug)
                                else:
                                    tempstring += line
                            if not linefound:
                                badfile = True
                                tempstring += "SINGLE=/sbin/sulogin\n"
                                
                        #check to see if badfile is true which is set when
                        #checking contents of the file, if badfile is false
                        #we found everything in the file we needed, so no need
                        #to change
                        if badfile:
                            if writeFile(tfp, tempstring, self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, 
                                                               self.rulenumber)
                                #if the file wasn't created earlier, then we
                                #will record the change event as a file change
                                if not created:
                                    event = {"eventtype":"conf",
                                             "filepath":fp}
                                    self.statechglogger.recordchgevent(myid, 
                                                                         event)
                                    self.statechglogger.recordfilechange(fp, 
                                                                     tfp, myid)
                                #if file was created earlier, then we will 
                                #record the change event as a file creation
                                #so that undo event will be a file deletion
                                else:
                                    event = {"eventtype":"creation",
                                             "filepath":fp}
                                    self.statechglogger.recordchgevent(myid,
                                                                         event)
                                self.detailedresults += "corrected contents \
and wrote to file: " + fp + "\n"
                                os.rename(tfp, fp)
                                os.chown(fp, 0, 0)
                                os.chmod(fp, 420)
                                resetsecon(fp)
                            else:
                                self.detailedresults += "Unable to \
successfully write the file: " + fp + "\n"
                                success = False
                                
                                
#                 if self.ph.manager == "zypper":
#                     tempstring = ""
#                     fp = "/etc/inittab"
#                     tfp = fp + ".tmp"
#                     created = False
#                     badfile = False
#                     if not os.path.exists(fp):
#                         self.createFile(fp)
#                         created = True
#                     if os.path.exists(fp):
#                         if not created:
#                             #we check if file was previously created above
#                             #if so, we don't want to record a permission
#                             #change event, since the undo will be file deletion
#                             if not checkPerms(fp, [0, 0, 420], self.logger):
#                                 self.iditerator += 1
#                                 myid = iterate(self.iditerator, self.rulenumber)
#                                 if setPerms(fp, [0, 0, 420], self.logger, "", 
#                                                     self.statechglogger, myid):
#                                     self.detailedresults += "Successfully \
# corrected permissions on file: " + fp + "\n"
#                                 else:
#                                     self.detailedresults += "Was not able to \
# successfully set permissions on file: " + fp + "\n"
#                                     success = False
#                         contents = readFile(fp, self.logger)
#                         if contents:
#                             linefound = False
#                             for line in contents:
#                                 if re.search("^~~:S:wait:/sbin/sulogin", line.strip()):
#                                     tempstring += line
#                                     linefound = True
#                                 else:
#                                     tempstring += line
#                             if not linefound:
#                                 badfile = True
#                                 tempstring += "~~:S:wait:/sbin/sulogin\n"
#                         #check to see if badfile is true which is set when
#                         #checking contents of the file, if badfile is false
#                         #we found everything in the file we needed, so no need
#                         #to change
#                         if badfile:
#                             if writeFile(tfp, self.logger, tempstring):
#                                 self.iditerator += 1
#                                 myid = iterate(self.iditerator, 
#                                                                self.rulenumber)
#                                 #if the file wasn't created earlier, then we
#                                 #will record the change event as a file change
#                                 if not created:
#                                     event = {"eventtype":"conf",
#                                              "filepath":fp}
#                                     self.statechglogger.recordchgevent(myid, 
#                                                                          event)
#                                     self.statechglogger.recordfilechange(fp, 
#                                                                      tfp, myid)
#                                 #if file was created earlier, then we will 
#                                 #record the change event as a file creation
#                                 #so that undo event will be a file deletion
#                                 else:
#                                     event = {"eventtype":"creation",
#                                              "filepath":fp}
#                                     self.statechglogger.recordchgevent(myid,
#                                                                          event)
#                                 self.detailedresults += "corrected contents \
# and wrote to file: " + fp + "\n"
#                                 os.rename(tfp, fp)
#                                 os.chown(fp, 0, 0)
#                                 os.chmod(fp, 420)
#                                 resetsecon(fp)
#                             else:
#                                 self.detailedresults += "Unable to \
# successfully write the file: " + fp + "\n"
#                                 success = False
                                
                #solution for solaris systems
                if self.ph.manager == "solaris":
                    tempstring = ""
                    fp = "/etc/default/sulogin"
                    tfp = fp + ".tmp"
                    created = False
                    badfile = False
                    
                    if not os.path.exists(fp):
                        createFile(fp)
                        created = True
                    if os.path.exists(fp):
                        if not created:
                            #we check if file was previously created above
                            #if so, we don't want to record a permission
                            #change event, since the undo will be file deletion
                            if not checkPerms(fp, [0, 0, 420], self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, self.rulenumber)
                                if setPerms(fp, [0, 0, 420], self.logger,
                                                    self.statechglogger, myid):
                                    self.detailedresults += "Successfully \
corrected permissions on file: " + fp + "\n"
                                else:
                                    self.detailedresults += "Was not able to \
successfully set permissions on file: " + fp + "\n"
                                    success = False
                        contents = readFile(fp, self.logger)
                        if contents:
                            linefound = False
                            for line in contents:
                                if re.search("^PASSREQ", line.strip()):
                                    if re.search("=", line):
                                        temp = line.split("=")
                                        try:
                                            if temp[1].strip() == "YES":
                                                tempstring += line
                                                linefound = True
                                        except IndexError:
                                            debug = traceback.format_exc() + "\n"
                                            debug += "Index out of range on line: " + line + "\n"
                                            self.logger.log(LogPriority.DEBUG, debug)
                                else:
                                    tempstring += line
                            if not linefound:
                                badfile = True
                                tempstring += "PASSREQ=YES\n"
                        #check to see if badfile is true which is set when
                        #checking contents of the file, if badfile is false
                        #we found everything in the file we needed, so no need
                        #to change
                        if badfile:
                            if writeFile(tfp, tempstring, self.logger):
                                self.iditerator += 1
                                myid = iterate(self.iditerator, 
                                                               self.rulenumber)
                                #if the file wasn't created earlier, then we
                                #will record the change event as a file change
                                if not created:
                                    event = {"eventtype":"conf",
                                             "filepath":fp}
                                    self.statechglogger.recordchgevent(myid, 
                                                                         event)
                                    self.statechglogger.recordfilechange(fp, 
                                                                     tfp, myid)
                                #if file was created earlier, then we will 
                                #record the change event as a file creation
                                #so that undo event will be a file deletion
                                else:
                                    event = {"eventtype":"creation",
                                             "filepath":fp}
                                    self.statechglogger.recordchgevent(myid,
                                                                         event)
                                self.detailedresults += "corrected contents \
and wrote to file: " + fp + "\n"
                                os.rename(tfp, fp)
                                os.chown(fp, 0, 0)
                                os.chmod(fp, 420)
                                resetsecon(fp)
                            else:
                                self.detailedresults += "Unable to \
successfully write the file: " + fp + "\n"
                                success = False
            self.rulesuccess = success
        except (KeyboardInterrupt, SystemExit):
            # User initiated exit
            raise
        except Exception:
            self.rulesuccess = False
            self.detailedresults += "\n" + traceback.format_exc()
            self.logdispatch.log(LogPriority.ERROR, self.detailedresults)
        self.formatDetailedResults("fix", self.rulesuccess,
                                   self.detailedresults)
        self.logdispatch.log(LogPriority.INFO, self.detailedresults)
        return self.rulesuccess

Example 44

Project: dcos
Source File: __init__.py
View license
def build(package_store, name, variant, clean_after_build, recursive=False):
    assert isinstance(package_store, PackageStore)
    print("Building package {} variant {}".format(name, pkgpanda.util.variant_str(variant)))
    tmpdir = tempfile.TemporaryDirectory(prefix="pkgpanda_repo")
    repository = Repository(tmpdir.name)

    package_dir = package_store.get_package_folder(name)

    def src_abs(name):
        return package_dir + '/' + name

    def cache_abs(filename):
        return package_store.get_package_cache_folder(name) + '/' + filename

    # Build pkginfo over time, translating fields from buildinfo.
    pkginfo = {}

    # Build up the docker command arguments over time, translating fields as needed.
    cmd = DockerCmd()

    assert (name, variant) in package_store.packages, \
        "Programming error: name, variant should have been validated to be valid before calling build()."

    builder = IdBuilder(package_store.get_buildinfo(name, variant))
    final_buildinfo = dict()

    builder.add('name', name)
    builder.add('variant', pkgpanda.util.variant_str(variant))

    # Convert single_source -> sources
    if builder.has('sources'):
        if builder.has('single_source'):
            raise BuildError('Both sources and single_source cannot be specified at the same time')
        sources = builder.take('sources')
    elif builder.has('single_source'):
        sources = {name: builder.take('single_source')}
        builder.replace('single_source', 'sources', sources)
    else:
        builder.add('sources', {})
        sources = dict()
        print("NOTICE: No sources specified")

    final_buildinfo['sources'] = sources

    # Construct the source fetchers, gather the checkout ids from them
    checkout_ids = dict()
    fetchers = dict()
    try:
        for src_name, src_info in sorted(sources.items()):
            # TODO(cmaloney): Switch to a unified top level cache directory shared by all packages
            cache_dir = package_store.get_package_cache_folder(name) + '/' + src_name
            check_call(['mkdir', '-p', cache_dir])
            fetcher = get_src_fetcher(src_info, cache_dir, package_dir)
            fetchers[src_name] = fetcher
            checkout_ids[src_name] = fetcher.get_id()
    except ValidationError as ex:
        raise BuildError("Validation error when fetching sources for package: {}".format(ex))

    for src_name, checkout_id in checkout_ids.items():
        # NOTE: single_source buildinfo was expanded above so the src_name is
        # always correct here.
        # Make sure we never accidentally overwrite something which might be
        # important. Fields should match if specified (And that should be
        # tested at some point). For now disallowing identical saves hassle.
        assert_no_duplicate_keys(checkout_id, final_buildinfo['sources'][src_name])
        final_buildinfo['sources'][src_name].update(checkout_id)

    # Add the sha1 of the buildinfo.json + build file to the build ids
    builder.update('sources', checkout_ids)
    build_script = src_abs(builder.take('build_script'))
    # TODO(cmaloney): Change dest name to build_script_sha1
    builder.replace('build_script', 'build', pkgpanda.util.sha1(build_script))
    builder.add('pkgpanda_version', pkgpanda.build.constants.version)

    extra_dir = src_abs("extra")
    # Add the "extra" folder inside the package as an additional source if it
    # exists
    if os.path.exists(extra_dir):
        extra_id = hash_folder(extra_dir)
        builder.add('extra_source', extra_id)
        final_buildinfo['extra_source'] = extra_id

    # Figure out the docker name.
    docker_name = builder.take('docker')
    cmd.container = docker_name

    # Add the id of the docker build environment to the build_ids.
    try:
        docker_id = get_docker_id(docker_name)
    except CalledProcessError:
        # docker pull the container and try again
        check_call(['docker', 'pull', docker_name])
        docker_id = get_docker_id(docker_name)

    builder.update('docker', docker_id)

    # TODO(cmaloney): The environment variables should be generated during build
    # not live in buildinfo.json.
    pkginfo['environment'] = builder.take('environment')

    # Whether pkgpanda should on the host make sure a `/var/lib` state directory is available
    pkginfo['state_directory'] = builder.take('state_directory')
    if pkginfo['state_directory'] not in [True, False]:
        raise BuildError("state_directory in buildinfo.json must be a boolean `true` or `false`")

    username = None
    if builder.has('username'):
        username = builder.take('username')
        if not isinstance(username, str):
            raise BuildError("username in buildinfo.json must be either not set (no user for this"
                             " package), or a user name string")
        try:
            pkgpanda.UserManagement.validate_username(username)
        except ValidationError as ex:
            raise BuildError("username in buildinfo.json didn't meet the validation rules. {}".format(ex))
        pkginfo['username'] = username

    group = None
    if builder.has('group'):
        group = builder.take('group')
        if not isinstance(group, str):
            raise BuildError("group in buildinfo.json must be either not set (use default group for this user)"
                             ", or group must be a string")
        try:
            pkgpanda.UserManagement.validate_group_name(group)
        except ValidationError as ex:
            raise BuildError("group in buildinfo.json didn't meet the validation rules. {}".format(ex))
        pkginfo['group'] = group

    # Packages need directories inside the fake install root (otherwise docker
    # will try making the directories on a readonly filesystem), so build the
    # install root now, and make the package directories in it as we go.
    install_dir = tempfile.mkdtemp(prefix="pkgpanda-")

    active_packages = list()
    active_package_ids = set()
    active_package_variants = dict()
    auto_deps = set()

    # Final package has the same requires as the build.
    requires = builder.take('requires')
    pkginfo['requires'] = requires

    if builder.has("sysctl"):
        pkginfo["sysctl"] = builder.take("sysctl")

    # TODO(cmaloney): Pull generating the full set of requires a function.
    to_check = copy.deepcopy(requires)
    if type(to_check) != list:
        raise BuildError("`requires` in buildinfo.json must be an array of dependencies.")
    while to_check:
        requires_info = to_check.pop(0)
        requires_name, requires_variant = expand_require(requires_info)

        if requires_name in active_package_variants:
            # TODO(cmaloney): If one package depends on the <default>
            # variant of a package and 1+ others depends on a non-<default>
            # variant then update the dependency to the non-default variant
            # rather than erroring.
            if requires_variant != active_package_variants[requires_name]:
                # TODO(cmaloney): Make this contain the chains of
                # dependencies which contain the conflicting packages.
                # a -> b -> c -> d {foo}
                # e {bar} -> d {baz}
                raise BuildError(
                    "Dependncy on multiple variants of the same package {}. variants: {} {}".format(
                        requires_name,
                        requires_variant,
                        active_package_variants[requires_name]))

            # The variant has package {requires_name, variant} already is a
            # dependency, don't process it again / move on to the next.
            continue

        active_package_variants[requires_name] = requires_variant

        # Figure out the last build of the dependency, add that as the
        # fully expanded dependency.
        requires_last_build = package_store.get_last_build_filename(requires_name, requires_variant)
        if not os.path.exists(requires_last_build):
            if recursive:
                # Build the dependency
                build(package_store, requires_name, requires_variant, clean_after_build, recursive)
            else:
                raise BuildError("No last build file found for dependency {} variant {}. Rebuild "
                                 "the dependency".format(requires_name, requires_variant))

        try:
            pkg_id_str = load_string(requires_last_build)
            auto_deps.add(pkg_id_str)
            pkg_buildinfo = package_store.get_buildinfo(requires_name, requires_variant)
            pkg_requires = pkg_buildinfo['requires']
            pkg_path = repository.package_path(pkg_id_str)
            pkg_tar = pkg_id_str + '.tar.xz'
            if not os.path.exists(package_store.get_package_cache_folder(requires_name) + '/' + pkg_tar):
                raise BuildError(
                    "The build tarball {} refered to by the last_build file of the dependency {} "
                    "variant {} doesn't exist. Rebuild the dependency.".format(
                        pkg_tar,
                        requires_name,
                        requires_variant))

            active_package_ids.add(pkg_id_str)

            # Mount the package into the docker container.
            cmd.volumes[pkg_path] = "/opt/mesosphere/packages/{}:ro".format(pkg_id_str)
            os.makedirs(os.path.join(install_dir, "packages/{}".format(pkg_id_str)))

            # Add the dependencies of the package to the set which will be
            # activated.
            # TODO(cmaloney): All these 'transitive' dependencies shouldn't
            # be available to the package being built, only what depends on
            # them directly.
            to_check += pkg_requires
        except ValidationError as ex:
            raise BuildError("validating package needed as dependency {0}: {1}".format(requires_name, ex)) from ex
        except PackageError as ex:
            raise BuildError("loading package needed as dependency {0}: {1}".format(requires_name, ex)) from ex

    # Add requires to the package id, calculate the final package id.
    # NOTE: active_packages isn't fully constructed here since we lazily load
    # packages not already in the repository.
    builder.update('requires', list(active_package_ids))
    version_extra = None
    if builder.has('version_extra'):
        version_extra = builder.take('version_extra')

    build_ids = builder.get_build_ids()
    version_base = hash_checkout(build_ids)
    version = None
    if builder.has('version_extra'):
        version = "{0}-{1}".format(version_extra, version_base)
    else:
        version = version_base
    pkg_id = PackageId.from_parts(name, version)

    # Everything must have been extracted by now. If it wasn't, then we just
    # had a hard error that it was set but not used, as well as didn't include
    # it in the caluclation of the PackageId.
    builder = None

    # Save the build_ids. Useful for verify exactly what went into the
    # package build hash.
    final_buildinfo['build_ids'] = build_ids
    final_buildinfo['package_version'] = version

    # Save the package name and variant. The variant is used when installing
    # packages to validate dependencies.
    final_buildinfo['name'] = name
    final_buildinfo['variant'] = variant

    # If the package is already built, don't do anything.
    pkg_path = package_store.get_package_cache_folder(name) + '/{}.tar.xz'.format(pkg_id)

    # Done if it exists locally
    if exists(pkg_path):
        print("Package up to date. Not re-building.")

        # TODO(cmaloney): Updating / filling last_build should be moved out of
        # the build function.
        write_string(package_store.get_last_build_filename(name, variant), str(pkg_id))

        return pkg_path

    # Try downloading.
    dl_path = package_store.try_fetch_by_id(pkg_id)
    if dl_path:
        print("Package up to date. Not re-building. Downloaded from repository-url.")
        # TODO(cmaloney): Updating / filling last_build should be moved out of
        # the build function.
        write_string(package_store.get_last_build_filename(name, variant), str(pkg_id))
        print(dl_path, pkg_path)
        assert dl_path == pkg_path
        return pkg_path

    # Fall out and do the build since it couldn't be downloaded
    print("Unable to download from cache. Proceeding to build")

    print("Building package {} with buildinfo: {}".format(
        pkg_id,
        json.dumps(final_buildinfo, indent=2, sort_keys=True)))

    # Clean out src, result so later steps can use them freely for building.
    def clean():
        # Run a docker container to remove src/ and result/
        cmd = DockerCmd()
        cmd.volumes = {
            package_store.get_package_cache_folder(name): "/pkg/:rw",
        }
        cmd.container = "ubuntu:14.04.4"
        cmd.run("package-cleaner", ["rm", "-rf", "/pkg/src", "/pkg/result"])

    clean()

    # Only fresh builds are allowed which don't overlap existing artifacts.
    result_dir = cache_abs("result")
    if exists(result_dir):
        raise BuildError("result folder must not exist. It will be made when the package is "
                         "built. {}".format(result_dir))

    # 'mkpanda add' all implicit dependencies since we actually need to build.
    for dep in auto_deps:
        print("Auto-adding dependency: {}".format(dep))
        # NOTE: Not using the name pkg_id because that overrides the outer one.
        id_obj = PackageId(dep)
        add_package_file(repository, package_store.get_package_path(id_obj))
        package = repository.load(dep)
        active_packages.append(package)

    # Checkout all the sources int their respective 'src/' folders.
    try:
        src_dir = cache_abs('src')
        if os.path.exists(src_dir):
            raise ValidationError(
                "'src' directory already exists, did you have a previous build? " +
                "Currently all builds must be from scratch. Support should be " +
                "added for re-using a src directory when possible. src={}".format(src_dir))
        os.mkdir(src_dir)
        for src_name, fetcher in sorted(fetchers.items()):
            root = cache_abs('src/' + src_name)
            os.mkdir(root)

            fetcher.checkout_to(root)
    except ValidationError as ex:
        raise BuildError("Validation error when fetching sources for package: {}".format(ex))

    # Activate the packages so that we have a proper path, environment
    # variables.
    # TODO(cmaloney): RAII type thing for temproary directory so if we
    # don't get all the way through things will be cleaned up?
    install = Install(
        root=install_dir,
        config_dir=None,
        rooted_systemd=True,
        manage_systemd=False,
        block_systemd=True,
        fake_path=True,
        manage_users=False,
        manage_state_dir=False)
    install.activate(active_packages)
    # Rewrite all the symlinks inside the active path because we will
    # be mounting the folder into a docker container, and the absolute
    # paths to the packages will change.
    # TODO(cmaloney): This isn't very clean, it would be much nicer to
    # just run pkgpanda inside the package.
    rewrite_symlinks(install_dir, repository.path, "/opt/mesosphere/packages/")

    print("Building package in docker")

    # TODO(cmaloney): Run as a specific non-root user, make it possible
    # for non-root to cleanup afterwards.
    # Run the build, prepping the environment as necessary.
    mkdir(cache_abs("result"))

    # Copy the build info to the resulting tarball
    write_json(cache_abs("src/buildinfo.full.json"), final_buildinfo)
    write_json(cache_abs("result/buildinfo.full.json"), final_buildinfo)

    write_json(cache_abs("result/pkginfo.json"), pkginfo)

    # Make the folder for the package we are building. If docker does it, it
    # gets auto-created with root permissions and we can't actually delete it.
    os.makedirs(os.path.join(install_dir, "packages", str(pkg_id)))

    # TOOD(cmaloney): Disallow writing to well known files and directories?
    # Source we checked out
    cmd.volumes.update({
        # TODO(cmaloney): src should be read only...
        cache_abs("src"): "/pkg/src:rw",
        # The build script
        build_script: "/pkg/build:ro",
        # Getting the result out
        cache_abs("result"): "/opt/mesosphere/packages/{}:rw".format(pkg_id),
        install_dir: "/opt/mesosphere:ro"
    })

    if os.path.exists(extra_dir):
        cmd.volumes[extra_dir] = "/pkg/extra:ro"

    cmd.environment = {
        "PKG_VERSION": version,
        "PKG_NAME": name,
        "PKG_ID": pkg_id,
        "PKG_PATH": "/opt/mesosphere/packages/{}".format(pkg_id),
        "PKG_VARIANT": variant if variant is not None else "<default>",
        "NUM_CORES": multiprocessing.cpu_count()
    }

    try:
        # TODO(cmaloney): Run a wrapper which sources
        # /opt/mesosphere/environment then runs a build. Also should fix
        # ownership of /opt/mesosphere/packages/{pkg_id} post build.
        cmd.run("package-builder", [
            "/bin/bash",
            "-o", "nounset",
            "-o", "pipefail",
            "-o", "errexit",
            "/pkg/build"])
    except CalledProcessError as ex:
        raise BuildError("docker exited non-zero: {}\nCommand: {}".format(ex.returncode, ' '.join(ex.cmd)))

    # Clean up the temporary install dir used for dependencies.
    # TODO(cmaloney): Move to an RAII wrapper.
    check_call(['rm', '-rf', install_dir])

    print("Building package tarball")

    # Check for forbidden services before packaging the tarball:
    try:
        check_forbidden_services(cache_abs("result"), RESERVED_UNIT_NAMES)
    except ValidationError as ex:
        raise BuildError("Package validation failed: {}".format(ex))

    # TODO(cmaloney): Updating / filling last_build should be moved out of
    # the build function.
    write_string(package_store.get_last_build_filename(name, variant), str(pkg_id))

    # Bundle the artifacts into the pkgpanda package
    tmp_name = pkg_path + "-tmp.tar.xz"
    make_tar(tmp_name, cache_abs("result"))
    os.rename(tmp_name, pkg_path)
    print("Package built.")
    if clean_after_build:
        clean()
    return pkg_path

Example 45

Project: dl4mt-c2c
Source File: nmt.py
View license
def train(
      dim_word=100,
      dim_word_src=200,
      enc_dim=1000,
      dec_dim=1000,  # the number of LSTM units
      patience=-1,  # early stopping patience
      max_epochs=5000,
      finish_after=-1,  # finish after this many updates
      decay_c=0.,  # L2 regularization penalty
      alpha_c=0.,  # alignment regularization
      clip_c=-1.,  # gradient clipping threshold
      lrate=0.01,  # learning rate
      n_words_src=100000,  # source vocabulary size
      n_words=100000,  # target vocabulary size
      maxlen=100,  # maximum length of the description
      maxlen_trg=None,  # maximum length of the description
      maxlen_sample=1000,
      optimizer='rmsprop',
      batch_size=16,
      valid_batch_size=16,
      sort_size=20,
      save_path=None,
      save_file_name='model',
      save_best_models=0,
      dispFreq=100,
      validFreq=100,
      saveFreq=1000,   # save the parameters after every saveFreq updates
      sampleFreq=-1,
      verboseFreq=10000,
      datasets=[
          'data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok'],
      valid_datasets=['../data/dev/newstest2011.en.tok',
                      '../data/dev/newstest2011.fr.tok'],
      dictionaries=[
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok.pkl',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok.pkl'],
      source_word_level=0,
      target_word_level=0,
      use_dropout=False,
      re_load=False,
      re_load_old_setting=False,
      uidx=None,
      eidx=None,
      cidx=None,
      layers=None,
      save_every_saveFreq=0,
      save_burn_in=20000,
      use_bpe=0,
      gru='gru',
      init_params=None,
      build_model=None,
      build_sampler=None,
      gen_sample=None,
      **kwargs
    ):

    if gru not in "gru lngru".split():
        raise

    print "GRU:", gru

    if maxlen_trg is None:
        maxlen_trg = maxlen * 10
    # Model options
    model_options = locals().copy()
    del model_options['init_params']
    del model_options['build_model']
    del model_options['build_sampler']
    del model_options['gen_sample']

    # load dictionaries and invert them
    worddicts = [None] * len(dictionaries)
    worddicts_r = [None] * len(dictionaries)
    for ii, dd in enumerate(dictionaries):
        with open(dd, 'rb') as f:
            worddicts[ii] = cPickle.load(f)
        worddicts_r[ii] = dict()
        for kk, vv in worddicts[ii].iteritems():
            worddicts_r[ii][vv] = kk

    print 'Building model'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    file_name = '%s%s.npz' % (save_path, save_file_name)
    best_file_name = '%s%s.best.npz' % (save_path, save_file_name)
    opt_file_name = '%s%s%s.npz' % (save_path, save_file_name, '.grads')
    best_opt_file_name = '%s%s%s.best.npz' % (save_path, save_file_name, '.grads')
    model_name = '%s%s.pkl' % (save_path, save_file_name)
    params = init_params(model_options)
    cPickle.dump(model_options, open(model_name, 'wb'))
    history_errs = []

    # reload options
    if re_load and os.path.exists(file_name):
        print 'You are reloading your experiment.. do not panic dude..'
        if re_load_old_setting:
            with open(model_name, 'rb') as f:
                models_options = cPickle.load(f)
        params = load_params(file_name, params)
        # reload history
        model = numpy.load(file_name)
        history_errs = list(model['history_errs'])
        if uidx is None:
            uidx = model['uidx']
        if eidx is None:
            eidx = model['eidx']
        if cidx is None:
            cidx = model['cidx']
    else:
        if uidx is None:
            uidx = 0
        if eidx is None:
            eidx = 0
        if cidx is None:
            cidx = 0

    print 'Loading data'
    train = TextIterator(source=datasets[0],
                         target=datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=batch_size,
                         sort_size=sort_size)
    valid = TextIterator(source=valid_datasets[0],
                         target=valid_datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=valid_batch_size,
                         sort_size=sort_size)

    # create shared variables for parameters
    tparams = init_tparams(params)

    trng, use_noise, \
        x, x_mask, y, y_mask, \
        opt_ret, \
        cost = \
        build_model(tparams, model_options)
    inps = [x, x_mask, y, y_mask]

    print 'Building sampler...\n',
    f_init, f_next = build_sampler(tparams, model_options, trng, use_noise)
    #print 'Done'

    # before any regularizer
    print 'Building f_log_probs...',
    f_log_probs = theano.function(inps, cost, profile=profile)
    print 'Done'
    if re_load:
        use_noise.set_value(0.)
        valid_errs = pred_probs(f_log_probs, prepare_data,
                                model_options, valid, verboseFreq=verboseFreq)
        valid_err = valid_errs.mean()

        if numpy.isnan(valid_err):
            import ipdb
            ipdb.set_trace()

        print 'Reload sanity check: Valid ', valid_err

    cost = cost.mean()

    # apply L2 regularization on weights
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.
        for kk, vv in tparams.iteritems():
            weight_decay += (vv ** 2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    # regularize the alpha weights
    if alpha_c > 0. and not model_options['decoder'].endswith('simple'):
        alpha_c = theano.shared(numpy.float32(alpha_c), name='alpha_c')
        alpha_reg = alpha_c * (
            (tensor.cast(y_mask.sum(0) // x_mask.sum(0), 'float32')[:, None] -
             opt_ret['dec_alphas'].sum(0))**2).sum(1).mean()
        cost += alpha_reg

    # after all regularizers - compile the computational graph for cost
    print 'Building f_cost...',
    f_cost = theano.function(inps, cost, profile=profile)
    print 'Done'

    print 'Computing gradient...',
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    print 'Done'

    if clip_c > 0:
        grads, not_finite, clipped = gradient_clipping(grads, tparams, clip_c)
    else:
        not_finite = 0
        clipped = 0

    # compile the optimizer, the actual computational graph is compiled here
    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',
    if re_load and os.path.exists(file_name):
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  not_finite=not_finite, clipped=clipped,
                                                                  file_name=opt_file_name)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  file_name=opt_file_name)
    else:
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  not_finite=not_finite, clipped=clipped)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost)
    print 'Done'

    print 'Optimization'
    best_p = None
    bad_counter = 0

    if validFreq == -1:
        validFreq = len(train[0]) / batch_size
    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size

    # Training loop
    ud_start = time.time()
    estop = False

    if re_load:
        print "Checkpointed minibatch number: %d" % cidx
        for cc in xrange(cidx):
            if numpy.mod(cc, 1000)==0:
                print "Jumping [%d / %d] examples" % (cc, cidx)
            train.next()

    for epoch in xrange(max_epochs):
        time0 = time.time()
        n_samples = 0
        NaN_grad_cnt = 0
        NaN_cost_cnt = 0
        clipped_cnt = 0
        if re_load:
            re_load = 0
        else:
            cidx = 0

        for x, y in train:
            cidx += 1
            uidx += 1
            use_noise.set_value(1.)

            x, x_mask, y, y_mask, n_x = prepare_data(x, y, maxlen=maxlen,
                                                     maxlen_trg=maxlen_trg,
                                                     n_words_src=n_words_src,
                                                     n_words=n_words)
            if x is None:
                print 'Minibatch with zero sample under length ', maxlen
                uidx -= 1
                uidx = max(uidx, 0)
                continue

            n_samples += n_x

            # compute cost, grads and copy grads to shared variables
            if clip_c > 0:
                cost, not_finite, clipped = f_grad_shared(x, x_mask, y, y_mask)
            else:
                cost = f_grad_shared(x, x_mask, y, y_mask)

            if clipped:
                clipped_cnt += 1

            # check for bad numbers, usually we remove non-finite elements
            # and continue training - but not done here
            if numpy.isnan(cost) or numpy.isinf(cost):
                NaN_cost_cnt += 1

            if not_finite:
                NaN_grad_cnt += 1
                continue

            # do the update on parameters
            f_update(lrate)

            if numpy.isnan(cost) or numpy.isinf(cost):
                continue

            if float(NaN_grad_cnt) > max_epochs * 0.5 or float(NaN_cost_cnt) > max_epochs * 0.5:
                print 'Too many NaNs, abort training'
                return 1., 1., 1.

            # verbose
            if numpy.mod(uidx, dispFreq) == 0:
                ud = time.time() - ud_start
                wps = n_samples / float(time.time() - time0)
                print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'NaN_in_grad', NaN_grad_cnt,\
                      'NaN_in_cost', NaN_cost_cnt, 'Gradient_clipped', clipped_cnt, 'UD ', ud, "%.2f sentence/s" % wps
                ud_start = time.time()

            # generate some samples with the model and display them
            if numpy.mod(uidx, sampleFreq) == 0 and sampleFreq != -1:
                # FIXME: random selection?
                for jj in xrange(numpy.minimum(5, x.shape[1])):
                    stochastic = True
                    use_noise.set_value(0.)
                    sample, score = gen_sample(tparams, f_init, f_next,
                                               x[:, jj][:, None],
                                               model_options, trng=trng, k=1,
                                               maxlen=maxlen_sample,
                                               stochastic=stochastic,
                                               argmax=False)
                    print
                    print 'Source ', jj, ': ',
                    if source_word_level:
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                if use_bpe:
                                    print (worddicts_r[0][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[0][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        source_ = []
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                source_.append(worddicts_r[0][vv])
                            else:
                                source_.append('UNK')
                        print "".join(source_)
                    print 'Truth ', jj, ' : ',
                    if target_word_level:
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        truth_ = []
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                truth_.append(worddicts_r[1][vv])
                            else:
                                truth_.append('UNK')
                        print "".join(truth_)
                    print 'Sample ', jj, ': ',
                    if stochastic:
                        ss = sample
                    else:
                        score = score / numpy.array([len(s) for s in sample])
                        ss = sample[score.argmin()]
                    if target_word_level:
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        sample_ = []
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                sample_.append(worddicts_r[1][vv])
                            else:
                                sample_.append('UNK')
                        print "".join(sample_)
                    print

            # validate model on validation set and early stop if necessary
            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)
                valid_errs = pred_probs(f_log_probs, prepare_data,
                                        model_options, valid, verboseFreq=verboseFreq)
                valid_err = valid_errs.mean()
                history_errs.append(valid_err)

                if uidx == 0 or valid_err <= numpy.array(history_errs).min():
                    best_p = unzip(tparams)
                    best_optp = unzip(toptparams)
                    bad_counter = 0

                if saveFreq != validFreq and save_best_models:
                    numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **best_p)
                    numpy.savez(best_opt_file_name, **best_optp)

                if len(history_errs) > patience and valid_err >= \
                        numpy.array(history_errs)[:-patience].min() and patience != -1:
                    bad_counter += 1
                    if bad_counter > patience:
                        print 'Early Stop!'
                        estop = True
                        break

                if numpy.isnan(valid_err):
                    import ipdb
                    ipdb.set_trace()

                print 'Valid ', valid_err

            # save the best model so far
            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',

                if not os.path.exists(save_path):
                    os.mkdir(save_path)

                params = unzip(tparams)
                optparams = unzip(toptparams)
                numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                            cidx=cidx, **params)
                numpy.savez(opt_file_name, **optparams)

                if save_every_saveFreq and (uidx >= save_burn_in):
                    this_file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
                    this_opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
                    numpy.savez(this_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **params)
                    numpy.savez(this_opt_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **params)
                    if best_p is not None and saveFreq != validFreq:
                        this_best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
                        numpy.savez(this_best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                    cidx=cidx, **best_p)
                print 'Done...',
                print 'Saved to %s' % file_name

            # finish after this many updates
            if uidx >= finish_after and finish_after != -1:
                print 'Finishing after %d iterations!' % uidx
                estop = True
                break

        print 'Seen %d samples' % n_samples
        eidx += 1

        if estop:
            break

    use_noise.set_value(0.)
    valid_err = pred_probs(f_log_probs, prepare_data,
                           model_options, valid).mean()

    print 'Valid ', valid_err

    params = unzip(tparams)
    optparams = unzip(toptparams)
    file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
    opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
    numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **params)
    numpy.savez(opt_file_name, **optparams)
    if best_p is not None and saveFreq != validFreq:
        best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
        best_opt_file_name = '%s%s%s.%d.best.npz' % (save_path, save_file_name, '.grads',uidx)
        numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **best_p)
        numpy.savez(best_opt_file_name, **best_optp)

    return valid_err

Example 46

Project: mysql-utilities
Source File: failover_daemon.py
View license
    def test_failover_daemon_nodetach(self, test_case):
        """Tests failover daemon with --nodetach option.

        test_case[in]     Test case.
        """
        server = test_case[0]
        cmd = test_case[1]
        kill_daemon = test_case[2]
        logfile = test_case[3]
        comment = test_case[4]
        key_phrase = test_case[5]
        unregister = test_case[6]
        server_version = server.get_version()

        if unregister:
            # Unregister any failover instance from server
            try:
                server.exec_query("DROP TABLE IF EXISTS "
                                  "mysql.failover_console")
            except UtilError:
                pass

        # Since this test case expects the daemon to stop, we can launch it
        # via a subprocess and wait for it to finish.
        if self.debug:
            print(comment)
            print("# COMMAND: {0}".format(cmd))

        # Cleanup in case previous test case failed
        if os.path.exists(self.failover_dir):
            try:
                os.system("rmdir {0}".format(self.failover_dir))
            except OSError:
                pass

        # Launch the console in stealth mode
        proc, f_out = self.start_process(cmd)

        # Wait for console to load
        if self.debug:
            print("# Waiting for console to start.")
        i = 1
        time.sleep(1)
        while proc.poll() is not None:
            time.sleep(1)
            i += 1
            if i > _TIMEOUT:
                if self.debug:
                    print("# Timeout console to start.")
                raise MUTLibError("{0}: failed - timeout waiting for "
                                  "console to start.".format(comment))

        # Wait for the failover daemon to register on master and start
        # its monitoring process
        phrase = "Failover daemon started"
        if self.debug:
            print("Waiting for failover daemon to register master and start "
                  "its monitoring process")

        # Wait for logfile file to be created
        if self.debug:
            print("# Waiting for logfile to be created.")
        for i in range(_TIMEOUT):
            if os.path.exists(logfile):
                break
            else:
                time.sleep(1)
        else:
            raise MUTLibError("{0}: failed - timeout waiting for "
                              "logfile '{1}' to be "
                              "created.".format(comment, logfile))

        i = 0
        with open(logfile, "r") as f:
            while i < _TIMEOUT:
                line = f.readline()
                if not line:
                    i += 1
                    time.sleep(1)
                elif phrase in line:
                    break
            else:
                if self.debug:
                    print("# Timeout waiting for failover daemon to register "
                          "master and start its monitoring process")
                raise MUTLibError("{0}: failed - timeout waiting for daemon "
                                  "to register master and start its "
                                  "monitoring process".format(comment))

        # Now, kill the master
        res = server.show_server_variable("pid_file")
        pid_file = open(res[0][1])
        pid = int(pid_file.readline().strip("\n"))
        if self.debug:
            print("# Terminating server {0} via pid = {1}".format(server.port,
                                                                  pid))
        pid_file.close()

        # Get server datadir to clean directory after kill.
        res = server.show_server_variable("datadir")
        datadir = res[0][1]

        # Stop the server
        server.disconnect()
        self.kill(pid)

        # Need to wait until the process is really dead.
        if self.debug:
            print("# Waiting for master to stop.")
        i = 0
        while self.is_process_alive(pid, int(server.port) - 1,
                                    int(server.port) + 1):
            time.sleep(1)
            i += 1
            if i > _TIMEOUT:
                if self.debug:
                    print("# Timeout master to fail.")
                raise MUTLibError("{0}: failed - timeout waiting for "
                                  "master to end.".format(comment))

        # Remove server from the list (and clean data directory).
        if self.debug:
            print("# Removing server name '{0}'.".format(server.role))
        delete_directory(datadir)
        self.servers.remove_server(server.role)

        # Now wait for interval to occur.
        if self.debug:
            print("# Waiting for failover to complete.")
        i = 0
        while not os.path.isdir(self.failover_dir):
            time.sleep(5)
            i += 1
            if i > _TIMEOUT:
                if self.debug:
                    print("# Timeout daemon failover.")
                raise MUTLibError("{0}: failed - timeout waiting for "
                                  "exec_post_fail.".format(comment))

        # Need to poll here and wait for daemon to really end.
        ret_val = self.stop_process(proc, f_out, kill_daemon)
        # Wait for daemon to end
        if self.debug:
            print("# Waiting for daemon to end.")
        i = 0
        while proc.poll() is None:
            time.sleep(1)
            i += 1
            if i > _TIMEOUT:
                if self.debug:
                    print("# Timeout daemon to end.")
                raise MUTLibError("{0}: failed - timeout waiting for "
                                  "daemon to end.".format(comment))

        if self.debug:
            print("# Return code from daemon termination = "
                  "{0}".format(ret_val))

        # Check result code from stop_process then read the log to find the
        # key phrase.
        found_row = True
        if key_phrase is not None:
            found_row = False
            with open(logfile, "r") as f:
                rows = f.readlines()
                if self.debug:
                    print("# Looking in log for: {0}".format(key_phrase))
                for row in rows:
                    if key_phrase in row:
                        found_row = True
                        if self.debug:
                            print("# Found in row = '{0}'."
                                  "".format(row[:len(row) - 1]))
            if not found_row:
                print("# ERROR: Cannot find entry in log:")
                for row in rows:
                    print(row)

        # Find MySQL and Utilities versions int the log
        found_row = False
        with open(logfile, "r") as f:
            rows = f.readlines()

            # Find MySQL Utilities version in the log
            if self.debug:
                print("# Looking in log for: {0}"
                      "".format(_UTILITIES_VERSION_PHRASE))
            for row in rows:
                if _UTILITIES_VERSION_PHRASE in row:
                    found_row = True
                    if self.debug:
                        print("# Found in row = '{0}'."
                              "".format(row[:-1]))
                    break

            # Find MySQL server version in the log
            host_port = "{host}:{port}".format(
                **get_connection_dictionary(server))
            key_phrase = MSG_MYSQL_VERSION.format(server=host_port,
                                                  version=server_version)

            if self.debug:
                print("# Looking in log for: {0}".format(key_phrase))
            for row in rows:
                if key_phrase in row:
                    found_row = True
                    if self.debug:
                        print("# Found in row = '{0}'."
                              "".format(row[:-1]))
                    break

            if not found_row:
                print("# ERROR: Cannot find entry in log:")
                for row in rows:
                    print(row)

        # Cleanup after test case
        try:
            os.unlink(logfile)
        except OSError:
            pass

        if os.path.exists(self.failover_dir):
            try:
                os.system("rmdir {0}".format(self.failover_dir))
            except OSError:
                pass

        return comment, found_row

Example 47

Project: tractor
Source File: scuss.py
View license
def main():
    import optparse

    parser = optparse.OptionParser('%prog [options]')
    parser.add_option('-o', dest='outfn', help='Output filename (FITS table)')
    parser.add_option('-i', dest='imgfn', help='Image input filename')
    parser.add_option('-f', dest='flagfn', help='Flags input filename')
    parser.add_option('-p', dest='psffn', help='PsfEx input filename')
    parser.add_option('-s', dest='postxt', help='Source positions input text file')
    parser.add_option('-S', dest='statsfn', help='Output image statistis filename (FITS table); optional')

    parser.add_option('-g', dest='gaussianpsf', action='store_true',
                      default=False,
                      help='Use multi-Gaussian approximation to PSF?')
    
    parser.add_option('-P', dest='plotbase', default='scuss',
                      help='Plot base filename (default: %default)')
    opt,args = parser.parse_args()

    # Check command-line arguments
    if len(args):
        print('Extra arguments:', args)
        parser.print_help()
        sys.exit(-1)
    for fn,name,exists in [(opt.outfn, 'output filename (-o)', False),
                           (opt.imgfn, 'image filename (-i)', True),
                           (opt.flagfn, 'flag filename (-f)', True),
                           (opt.psffn, 'PSF filename (-p)', True),
                           (opt.postxt, 'Source positions filename (-s)', True),
                           ]:
        if fn is None:
            print('Must specify', name)
            sys.exit(-1)
        if exists and not os.path.exists(fn):
            print('Input file', fn, 'does not exist')
            sys.exit(-1)

    # outfn = 'tractor-scuss.fits'
    # imstatsfn = 'tractor-scuss-imstats.fits'
    

    # Read inputs
    print('Reading input image', opt.imgfn)
    img = fitsio.read(opt.imgfn)
    print('Read img', img.shape, img.dtype)
    H,W = img.shape
    
    posfn = opt.postxt + '.fits'
    if not os.path.exists(posfn):
        from astrometry.util.fits import streaming_text_table
        hdr = 'ra dec x y objid sdss_psfmag_u sdss_psfmagerr_u'
        d = np.float64
        f = np.float32
        types = [str,str,d,d,str,f,f]
        print('Reading positions', opt.postxt)
        T = streaming_text_table(opt.postxt, headerline=hdr, coltypes=types)
        T.writeto(posfn)
        print('Wrote', posfn)

    print('Reading positions', posfn)
    T = fits_table(posfn)
    print('Read', len(T), 'source positions')
    
    print('Reading flags', opt.flagfn)
    flag = fitsio.read(opt.flagfn)
    print('Read flag', flag.shape, flag.dtype)

    print('Reading PSF', opt.psffn)
    psf = PsfEx(opt.psffn, W, H)

    if opt.gaussianpsf:
        picpsffn = opt.psffn + '.pickle'
        if not os.path.exists(picpsffn):
            psf.savesplinedata = True
            print('Fitting PSF model...')
            psf.ensureFit()
            pickle_to_file(psf.splinedata, picpsffn)
            print('Wrote', picpsffn)
        else:
            print('Reading PSF model parameters from', picpsffn)
            data = unpickle_from_file(picpsffn)
            print('Fitting PSF...')
            psf.fitSavedData(*data)
            
    print('Computing image sigma...')
    plo,phi = [np.percentile(img[flag == 0], p) for p in [25,75]]
    # Wikipedia says:  IRQ -> sigma:
    sigma = (phi - plo) / (0.6745 * 2)
    print('Sigma:', sigma)
    invvar = np.zeros_like(img) + (1./sigma**2)
    invvar[flag != 0] = 0.
    
    # Estimate sky level from median -- not actually necessary, since
    # we will fit it below...
    med = np.median(img[flag == 0])
    
    band = 'u'
    
    # We will break the image into cells for speed -- save the
    # original full-size inputs here.
    fullinvvar = invvar
    fullimg  = img
    fullflag = flag
    fullpsf  = psf
    fullT = T

    # We add a margin around each cell -- we want sources within the
    # cell, we need to include a margin of image pixels touched by
    # those sources, and also an additional margin of sources that
    # touch those pixels.
    margin = 10 # pixels
    # Number of cells to split the image into
    nx = 10
    ny = 10
    # cell positions
    XX = np.round(np.linspace(0, W, nx+1)).astype(int)
    YY = np.round(np.linspace(0, H, ny+1)).astype(int)
    
    results = []

    # Image statistics
    imstats = fits_table()
    imstats.xlo = np.zeros(((len(YY)-1)*(len(XX)-1)), int)
    imstats.xhi = np.zeros_like(imstats.xlo)
    imstats.ylo = np.zeros_like(imstats.xlo)
    imstats.yhi = np.zeros_like(imstats.xlo)
    imstats.ninbox = np.zeros_like(imstats.xlo)
    imstats.ntotal = np.zeros_like(imstats.xlo)
    imstatkeys = ['imchisq', 'imnpix', 'sky']
    for k in imstatkeys:
        imstats.set(k, np.zeros(len(imstats)))
    
    # Plots:
    ps = PlotSequence(opt.plotbase)
    
    # Loop over cells...
    celli = -1
    for yi,(ylo,yhi) in enumerate(zip(YY, YY[1:])):
        for xi,(xlo,xhi) in enumerate(zip(XX, XX[1:])):
            celli += 1
            imstats.xlo[celli] = xlo
            imstats.xhi[celli] = xhi
            imstats.ylo[celli] = ylo
            imstats.yhi[celli] = yhi
            print()
            print('Doing image cell %i: x=[%i,%i), y=[%i,%i)' % (celli, xlo,xhi,ylo,yhi))
            # We will fit for sources in the [xlo,xhi), [ylo,yhi) box.
            # We add a margin in the image around that ROI
            # Beyond that, we add a margin of extra sources
    
            # image region: [ix0,ix1)
            ix0 = max(0, xlo - margin)
            ix1 = min(W, xhi + margin)
            iy0 = max(0, ylo - margin)
            iy1 = min(H, yhi + margin)
            S = (slice(iy0, iy1), slice(ix0, ix1))

            img = fullimg[S]
            invvar = fullinvvar[S]

            if not opt.gaussianpsf:
                # Instantiate pixelized PSF at this cell center.
                pixpsf = fullpsf.instantiateAt((xlo+xhi)/2., (ylo+yhi)/2.)
                print('Pixpsf:', pixpsf.shape)
                psf = PixelizedPSF(pixpsf)
            else:
                psf = fullpsf
            psf = ShiftedPsf(fullpsf, ix0, iy0)
            
            # sources nearby
            x0 = max(0, xlo - margin*2)
            x1 = min(W, xhi + margin*2)
            y0 = max(0, ylo - margin*2)
            y1 = min(H, yhi + margin*2)
            
            # (SCUSS uses FITS pixel indexing, so -1)
            J = np.flatnonzero((fullT.x-1 >= x0) * (fullT.x-1 < x1) *
                               (fullT.y-1 >= y0) * (fullT.y-1 < y1))
            T = fullT[J].copy()
            T.row = J
    
            # Remember which sources are within the cell (not the margin)
            T.inbounds = ((T.x-1 >= xlo) * (T.x-1 < xhi) *
                          (T.y-1 >= ylo) * (T.y-1 < yhi))
            # Shift source positions so they are correct for this subimage (cell)
            T.x -= ix0
            T.y -= iy0
    
            imstats.ninbox[celli] = sum(T.inbounds)
            imstats.ntotal[celli] = len(T)
    
            # print 'Image subregion:', img.shape
            print('Number of sources in ROI:', sum(T.inbounds))
            print('Number of sources in ROI + margin:', len(T))
            #print 'Source positions: x', T.x.min(), T.x.max(), 'y', T.y.min(), T.y.max()

            # Create tractor.Image object
            tim = Image(data=img, invvar=invvar, psf=psf, wcs=NullWCS(),
                        sky=ConstantSky(med), photocal=LinearPhotoCal(1., band=band),
                        name=opt.imgfn, domask=False)
    
            # Create tractor catalog objects
            cat = []
            for i in range(len(T)):
                # -1: SCUSS, apparently, uses FITS pixel conventions.
                src = PointSource(PixPos(T.x[i] - 1, T.y[i] - 1),
                                  Fluxes(**{band:100.}))
                cat.append(src)

            # Create Tractor object.
            tractor = Tractor([tim], cat)

            # print 'All params:'
            # tractor.printThawedParams()
            t0 = Time()
            tractor.freezeParamsRecursive('*')
            tractor.thawPathsTo('sky')
            tractor.thawPathsTo(band)
            # print 'Fitting params:'
            # tractor.printThawedParams()

            # Forced photometry
            ims0,ims1,IV,fs = tractor.optimize_forced_photometry(
                minsb=1e-3*sigma, mindlnp=1., sky=True, minFlux=None, variance=True,
                fitstats=True, shared_params=False)
            
            print('Forced photometry took', Time()-t0)
            
            # print 'Fit params:'
            # tractor.printThawedParams()

            # Record results
            T.set('tractor_%s_counts' % band, np.array([src.getBrightness().getBand(band) for src in cat]))
            T.set('tractor_%s_counts_invvar' % band, IV)
            T.cell = np.zeros(len(T), int) + celli
            if fs is not None:
                # Per-source stats
                for k in ['prochi2', 'pronpix', 'profracflux', 'proflux', 'npix']:
                    T.set(k, getattr(fs, k))
                # Per-image stats
                for k in imstatkeys:
                    X = getattr(fs, k)
                    imstats.get(k)[celli] = X[0]
            results.append(T)

            # Make plots for the first N cells
            if celli >= 10:
                continue
    
            mod = tractor.getModelImage(0)
            ima = dict(interpolation='nearest', origin='lower',
                       vmin=med + -2. * sigma, vmax=med + 5. * sigma)
            plt.clf()
            plt.imshow(img, **ima)
            plt.title('Data')
            ps.savefig()
            
            plt.clf()
            plt.imshow(mod, **ima)
            plt.title('Model')
            ps.savefig()
            
            noise = np.random.normal(scale=sigma, size=img.shape)
            plt.clf()
            plt.imshow(mod + noise, **ima)
            plt.title('Model + noise')
            ps.savefig()
            
            chi = (img - mod) * tim.getInvError()
            plt.clf()
            plt.imshow(chi, interpolation='nearest', origin='lower', vmin=-5, vmax=5)
            plt.title('Chi')
            ps.savefig()
    

    # Merge results from the cells
    TT = merge_tables(results)
    # Cut to just the sources within the cells
    TT.cut(TT.inbounds)
    TT.delete_column('inbounds')
    # Sort them back into original order
    TT.cut(np.argsort(TT.row))
    #TT.delete_column('row')
    TT.writeto(opt.outfn)
    print('Wrote results to', opt.outfn)
    
    if opt.statsfn:
        imstats.writeto(opt.statsfn)
        print('Wrote image statistics to', opt.statsfn)

    plot_results(opt.outfn, ps)

Example 48

Project: url-abuse
Source File: __init__.py
View license
def create_app(configfile=None):
    app = Flask(__name__)
    handler = RotatingFileHandler('urlabuse.log', maxBytes=10000, backupCount=5)
    handler.setFormatter(Formatter('%(asctime)s %(message)s'))
    app.wsgi_app = ReverseProxied(app.wsgi_app)
    app.logger.addHandler(handler)
    app.logger.setLevel(logging.INFO)
    Bootstrap(app)
    q = Queue(connection=conn)

    # Mail Config
    app.config['MAIL_SERVER'] = 'localhost'
    app.config['MAIL_PORT'] = 25
    mail = Mail(app)

    app.config['SECRET_KEY'] = 'devkey'
    app.config['BOOTSTRAP_SERVE_LOCAL'] = True
    app.config['configfile'] = config_path

    parser = configparser.SafeConfigParser()
    parser.read(app.config['configfile'])

    replacelist = make_dict(parser, 'replacelist')
    auth_users = prepare_auth()
    ignorelist = [i.strip()
                  for i in parser.get('abuse', 'ignore').split('\n')
                  if len(i.strip()) > 0]
    autosend_threshold = 5

    def _get_user_ip(request):
        ip = request.headers.get('X-Forwarded-For')
        if ip is None:
            ip = request.remote_addr
        return ip

    @app.route('/', methods=['GET', 'POST'])
    def index():
        form = URLForm()
        return render_template('index.html', form=form)

    @app.route('/urlreport', methods=['GET'])
    def url_report():
        return render_template('url-report.html')

    @app.errorhandler(404)
    def page_not_found(e):
        ip = request.headers.get('X-Forwarded-For')
        if ip is None:
            ip = request.remote_addr
        if request.path != '/_result/':
            app.logger.info('404 of {} on {}'.format(ip, request.path))
        return render_template('404.html'), 404

    def authenticate():
        """Sends a 401 response that enables basic auth"""
        return Response('Could not verify your access level for that URL.\n'
                        'You have to login with proper credentials', 401,
                        {'WWW-Authenticate': 'Basic realm="Login Required"'})

    def check_auth(username, password):
        """This function is called to check if a username /
        password combination is valid.
        """
        if auth_users is None:
            return False
        else:
            db_pass = auth_users.get(username)
            return db_pass == password

    @app.route('/login', methods=['GET', 'POST'])
    def login():
        auth = request.authorization
        if not auth or not check_auth(auth.username, auth.password):
            return authenticate()
        return redirect(url_for('index'))

    @app.route("/_result/<job_key>", methods=['GET'])
    def check_valid(job_key):
        if job_key is None:
            return json.dumps(None), 200
        job = Job.fetch(job_key, connection=conn)
        if job.is_finished:
            return json.dumps(job.result), 200
        else:
            return json.dumps("Nay!"), 202

    @app.route('/start', methods=['POST'])
    def run_query():
        data = json.loads(request.data)
        url = data["url"]
        ip = _get_user_ip(request)
        app.logger.info('{} {}'.format(ip, url))
        if get_submissions(url) >= autosend_threshold:
            send(url, '', True)
        is_valid = q.enqueue_call(func=is_valid_url, args=(url,), result_ttl=500)
        return is_valid.get_id()

    @app.route('/urls', methods=['POST'])
    def urls():
        data = json.loads(request.data)
        url = data["url"]
        u = q.enqueue_call(func=url_list, args=(url,), result_ttl=500)
        return u.get_id()

    @app.route('/resolve', methods=['POST'])
    def resolve():
        data = json.loads(request.data)
        url = data["url"]
        u = q.enqueue_call(func=dns_resolve, args=(url,), result_ttl=500)
        return u.get_id()

    @app.route('/phishtank', methods=['POST'])
    def phishtank():
        data = json.loads(request.data)
        if not os.path.exists('phishtank.key'):
            return None
        url = parser.get("PHISHTANK", "url")
        key = open('phishtank.key', 'r').readline().strip()
        query = data["query"]
        u = q.enqueue_call(func=phish_query, args=(url, key, query,), result_ttl=500)
        return u.get_id()

    @app.route('/virustotal_report', methods=['POST'])
    def vt():
        data = json.loads(request.data)
        if not os.path.exists('virustotal.key'):
            return None
        url = parser.get("VIRUSTOTAL", "url_report")
        url_up = parser.get("VIRUSTOTAL", "url_upload")
        key = open('virustotal.key', 'r').readline().strip()
        query = data["query"]
        u = q.enqueue_call(func=vt_query_url, args=(url, url_up, key, query,), result_ttl=500)
        return u.get_id()

    @app.route('/googlesafebrowsing', methods=['POST'])
    def gsb():
        data = json.loads(request.data)
        if not os.path.exists('googlesafebrowsing.key'):
            return None
        url = parser.get("GOOGLESAFEBROWSING", "url")
        key = open('googlesafebrowsing.key', 'r').readline().strip()
        url = url.format(key)
        query = data["query"]
        u = q.enqueue_call(func=gsb_query, args=(url, query,), result_ttl=500)
        return u.get_id()

    @app.route('/urlquery', methods=['POST'])
    def urlquery():
        data = json.loads(request.data)
        if not os.path.exists('urlquery.key'):
            return None
        url = parser.get("URLQUERY", "url")
        key = open('urlquery.key', 'r').readline().strip()
        query = data["query"]
        u = q.enqueue_call(func=urlquery_query, args=(url, key, query,), result_ttl=500)
        return u.get_id()

    @app.route('/ticket', methods=['POST'])
    def ticket():
        if not request.authorization:
            return ''
        data = json.loads(request.data)
        server = parser.get("SPHINX", "server")
        port = int(parser.get("SPHINX", "port"))
        url = parser.get("ITS", "url")
        query = data["query"]
        u = q.enqueue_call(func=sphinxsearch, args=(server, port, url, query,),
                           result_ttl=500)
        return u.get_id()

    @app.route('/whois', methods=['POST'])
    def whoismail():
        if not request.authorization:
            return ''
        server = parser.get("WHOIS", "server")
        port = parser.getint("WHOIS", "port")
        data = json.loads(request.data)
        query = data["query"]
        u = q.enqueue_call(func=whois, args=(server, port, query, ignorelist, replacelist),
                           result_ttl=500)
        return u.get_id()

    @app.route('/eupi', methods=['POST'])
    def eu():
        data = json.loads(request.data)
        if not os.path.exists('eupi.key'):
            return None
        url = parser.get("EUPI", "url")
        key = open('eupi.key', 'r').readline().strip()
        query = data["query"]
        u = q.enqueue_call(func=eupi, args=(url, key, query,), result_ttl=500)
        return u.get_id()

    @app.route('/pdnscircl', methods=['POST'])
    def dnscircl():
        url = parser.get("PDNS_CIRCL", "url")
        user, password = open('pdnscircl.key', 'r').readlines()
        data = json.loads(request.data)
        query = data["query"]
        u = q.enqueue_call(func=pdnscircl, args=(url, user.strip(), password.strip(),
                                                 query,), result_ttl=500)
        return u.get_id()

    @app.route('/bgpranking', methods=['POST'])
    def bgpr():
        data = json.loads(request.data)
        query = data["query"]
        u = q.enqueue_call(func=bgpranking, args=(query,), result_ttl=500)
        return u.get_id()

    @app.route('/psslcircl', methods=['POST'])
    def sslcircl():
        url = parser.get("PSSL_CIRCL", "url")
        user, password = open('psslcircl.key', 'r').readlines()
        data = json.loads(request.data)
        query = data["query"]
        u = q.enqueue_call(func=psslcircl, args=(url, user.strip(), password.strip(),
                                                 query,), result_ttl=500)
        return u.get_id()

    @app.route('/get_cache', methods=['POST'])
    def get_cache():
        data = json.loads(request.data)
        url = data["query"]
        data = cached(url)
        dumped = json.dumps(data, sort_keys=True, indent=4, separators=(',', ': '))
        return dumped

    def digest(data):
        to_return = ''
        all_mails = set()
        for entry in data:
            for url, info in list(entry.items()):
                to_return += '\n{}\n'.format(url)
                if info.get('whois'):
                    all_mails.update(info.get('whois'))
                    to_return += '\tContacts: {}\n'.format(', '.join(info.get('whois')))
                if info.get('vt') and len(info.get('vt')) == 4:
                    vtstuff = info.get('vt')
                    to_return += '\t{} out of {} positive detections in VT - {}\n'.format(
                        vtstuff[2], vtstuff[3], vtstuff[1])
                if info.get('gsb'):
                    to_return += '\tKnown as malicious on Google Safe Browsing: {}\n'.format(info.get('gsb'))
                if info.get('phishtank'):
                    to_return += '\tKnown as malicious on PhishTank\n'
                if info.get('dns'):
                    ipv4, ipv6 = info.get('dns')
                    if ipv4 is not None:
                        for ip in ipv4:
                            to_return += '\t' + ip + '\n'
                            data = info[ip]
                            if data.get('bgp'):
                                to_return += '\t\t(PTR: {}) is announced by {} ({}).\n'.format(*(data.get('bgp')[:3]))
                            if data.get('whois'):
                                all_mails.update(data.get('whois'))
                                to_return += '\t\tContacts: {}\n'.format(', '.join(data.get('whois')))
                    if ipv6 is not None:
                        for ip in ipv6:
                            to_return += '\t' + ip + '\n'
                            data = info[ip]
                            if data.get('whois'):
                                all_mails.update(data.get('whois'))
                                to_return += '\t\tContacts: {}\n'.format(', '.join(data.get('whois')))
            to_return += '\tAll contacts: {}\n'.format(', '.join(all_mails))
        return to_return

    def send(url, ip='', autosend=False):
        if not get_mail_sent(url):
            set_mail_sent(url)
            data = cached(url)
            if not autosend:
                subject = 'URL Abuse report from ' + ip
            else:
                subject = 'URL Abuse report sent automatically'
            msg = Message(subject, sender='[email protected]', recipients=["[email protected]"])
            msg.body = digest(data)
            msg.body += '\n\n'
            msg.body += json.dumps(data, sort_keys=True, indent=4, separators=(',', ': '))
            mail.send(msg)

    @app.route('/submit', methods=['POST'])
    def send_mail():
        data = json.loads(request.data)
        url = data["url"]
        if not get_mail_sent(url):
            ip = _get_user_ip(request)
            send(url, ip)
        return redirect(url_for('index'))

    return app

Example 49

Project: Python-mode-klen
Source File: runmod.py
View license
def __rope_start_everything():
    import os
    import sys
    import socket
    import pickle
    import marshal
    import inspect
    import types
    import threading

    class _MessageSender(object):

        def send_data(self, data):
            pass

    class _SocketSender(_MessageSender):

        def __init__(self, port):
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', port))
            self.my_file = s.makefile('wb')

        def send_data(self, data):
            if not self.my_file.closed:
                pickle.dump(data, self.my_file)

        def close(self):
            self.my_file.close()

    class _FileSender(_MessageSender):

        def __init__(self, file_name):
            self.my_file = open(file_name, 'wb')

        def send_data(self, data):
            if not self.my_file.closed:
                marshal.dump(data, self.my_file)

        def close(self):
            self.my_file.close()


    def _cached(func):
        cache = {}
        def newfunc(self, arg):
            if arg in cache:
                return cache[arg]
            result = func(self, arg)
            cache[arg] = result
            return result
        return newfunc

    class _FunctionCallDataSender(object):

        def __init__(self, send_info, project_root):
            self.project_root = project_root
            if send_info.isdigit():
                self.sender = _SocketSender(int(send_info))
            else:
                self.sender = _FileSender(send_info)

            def global_trace(frame, event, arg):
                # HACK: Ignoring out->in calls
                # This might lose some information
                if self._is_an_interesting_call(frame):
                    return self.on_function_call
            sys.settrace(global_trace)
            threading.settrace(global_trace)

        def on_function_call(self, frame, event, arg):
            if event != 'return':
                return
            args = []
            returned = ('unknown',)
            code = frame.f_code
            for argname in code.co_varnames[:code.co_argcount]:
                try:
                    args.append(self._object_to_persisted_form(frame.f_locals[argname]))
                except (TypeError, AttributeError):
                    args.append(('unknown',))
            try:
                returned = self._object_to_persisted_form(arg)
            except (TypeError, AttributeError):
                pass
            try:
                data = (self._object_to_persisted_form(frame.f_code),
                        tuple(args), returned)
                self.sender.send_data(data)
            except (TypeError):
                pass
            return self.on_function_call

        def _is_an_interesting_call(self, frame):
            #if frame.f_code.co_name in ['?', '<module>']:
            #    return False
            #return not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)

            if not self._is_code_inside_project(frame.f_code) and \
               (not frame.f_back or not self._is_code_inside_project(frame.f_back.f_code)):
                return False
            return True

        def _is_code_inside_project(self, code):
            source = self._path(code.co_filename)
            return source is not None and os.path.exists(source) and \
                   _realpath(source).startswith(self.project_root)

        @_cached
        def _get_persisted_code(self, object_):
            source = self._path(object_.co_filename)
            if not os.path.exists(source):
                raise TypeError('no source')
            return ('defined', _realpath(source), str(object_.co_firstlineno))

        @_cached
        def _get_persisted_class(self, object_):
            try:
                return ('defined', _realpath(inspect.getsourcefile(object_)),
                        object_.__name__)
            except (TypeError, AttributeError):
                return ('unknown',)

        def _get_persisted_builtin(self, object_):
            if isinstance(object_, str):
                return ('builtin', 'str')
            if isinstance(object_, list):
                holding = None
                if len(object_) > 0:
                    holding = object_[0]
                return ('builtin', 'list', self._object_to_persisted_form(holding))
            if isinstance(object_, dict):
                keys = None
                values = None
                if len(object_) > 0:
                    keys = list(object_.keys())[0]
                    values = object_[keys]
                    if values == object_ and len(object_) > 1:
                        keys = list(object_.keys())[1]
                        values = object_[keys]
                return ('builtin', 'dict',
                        self._object_to_persisted_form(keys),
                        self._object_to_persisted_form(values))
            if isinstance(object_, tuple):
                objects = []
                if len(object_) < 3:
                    for holding in object_:
                        objects.append(self._object_to_persisted_form(holding))
                else:
                    objects.append(self._object_to_persisted_form(object_[0]))
                return tuple(['builtin', 'tuple'] + objects)
            if isinstance(object_, set):
                holding = None
                if len(object_) > 0:
                    for o in object_:
                        holding = o
                        break
                return ('builtin', 'set', self._object_to_persisted_form(holding))
            return ('unknown',)

        def _object_to_persisted_form(self, object_):
            if object_ is None:
                return ('none',)
            if isinstance(object_, types.CodeType):
                return self._get_persisted_code(object_)
            if isinstance(object_, types.FunctionType):
                return self._get_persisted_code(object_.__code__)
            if isinstance(object_, types.MethodType):
                return self._get_persisted_code(object_.__func__.__code__)
            if isinstance(object_, types.ModuleType):
                return self._get_persisted_module(object_)
            if isinstance(object_, (str, list, dict, tuple, set)):
                return self._get_persisted_builtin(object_)
            if isinstance(object_, type):
                return self._get_persisted_class(object_)
            return ('instance', self._get_persisted_class(type(object_)))

        @_cached
        def _get_persisted_module(self, object_):
            path = self._path(object_.__file__)
            if path and os.path.exists(path):
                return ('defined', _realpath(path))
            return ('unknown',)

        def _path(self, path):
            if path.endswith('.pyc'):
                path = path[:-1]
            if path.endswith('.py'):
                return path

        def close(self):
            self.sender.close()
            sys.settrace(None)

    def _realpath(path):
        return os.path.realpath(os.path.abspath(os.path.expanduser(path)))

    send_info = sys.argv[1]
    project_root = sys.argv[2]
    file_to_run = sys.argv[3]
    run_globals = globals()
    run_globals.update({'__name__': '__main__',
                        'builtins': __builtins__,
                        '__file__': file_to_run})
    if send_info != '-':
        data_sender = _FunctionCallDataSender(send_info, project_root)
    del sys.argv[1:4]
    with open(file_to_run) as file:
        exec(compile(file.read(), file_to_run, 'exec'), run_globals)
    if send_info != '-':
        data_sender.close()

Example 50

View license
@error.context_aware
def run(test, params, env):
    """
    KVM migration with destination problems.
    Contains group of test for testing qemu behavior if some
    problems happens on destination side.

    Tests are described right in test classes comments down in code.

    Test needs params: nettype = bridge.

    :param test: kvm test object.
    :param params: Dictionary with test parameters.
    :param env: Dictionary with the test environment.
    """
    login_timeout = int(params.get("login_timeout", 360))
    mig_timeout = float(params.get("mig_timeout", "3600"))
    mig_protocol = params.get("migration_protocol", "tcp")

    test_rand = None
    mount_path = None
    while mount_path is None or os.path.exists(mount_path):
        test_rand = utils.generate_random_string(3)
        mount_path = ("%s/ni_mount_%s" %
                      (data_dir.get_data_dir(), test_rand))

    mig_dst = os.path.join(mount_path, "mig_dst")

    migration_exec_cmd_src = params.get("migration_exec_cmd_src",
                                        "gzip -c > %s")
    migration_exec_cmd_src = (migration_exec_cmd_src % (mig_dst))

    class MiniSubtest(object):

        def __new__(cls, *args, **kargs):
            self = super(MiniSubtest, cls).__new__(cls)
            ret = None
            exc_info = None
            if args is None:
                args = []
            try:
                try:
                    ret = self.test(*args, **kargs)
                except Exception:
                    exc_info = sys.exc_info()
            finally:
                if hasattr(self, "clean"):
                    try:
                        self.clean()
                    except Exception:
                        if exc_info is None:
                            raise
                    if exc_info:
                        raise exc_info[0], exc_info[1], exc_info[2]
            return ret

    def control_service(session, service, init_service, action, timeout=60):
        """
        Start service on guest.

        :param vm: Virtual machine for vm.
        :param service: service to stop.
        :param action: action with service (start|stop|restart)
        :param init_service: name of service for old service control.
        """
        status = utils_misc.get_guest_service_status(session, service,
                                                     service_former=init_service)
        if action == "start" and status == "active":
            logging.debug("%s already started, no need start it again.",
                          service)
            return
        if action == "stop" and status == "inactive":
            logging.debug("%s already stopped, no need stop it again.",
                          service)
            return
        try:
            session.cmd("systemctl --version", timeout=timeout)
            session.cmd("systemctl %s %s.service" % (action, service),
                        timeout=timeout)
        except:
            session.cmd("service %s %s" % (init_service, action),
                        timeout=timeout)

    def set_nfs_server(vm, share_cfg):
        """
        Start nfs server on guest.

        :param vm: Virtual machine for vm.
        """
        session = vm.wait_for_login(timeout=login_timeout)
        cmd = "echo '%s' > /etc/exports" % (share_cfg)
        control_service(session, "nfs-server", "nfs", "stop")
        session.cmd(cmd)
        control_service(session, "nfs-server", "nfs", "start")
        session.cmd("iptables -F")
        session.close()

    def umount(mount_path):
        """
        Umount nfs server mount_path

        :param mount_path: path where nfs dir will be placed.
        """
        utils.run("umount -f %s" % (mount_path))

    def create_file_disk(dst_path, size):
        """
        Create file with size and create there ext3 filesystem.

        :param dst_path: Path to file.
        :param size: Size of file in MB
        """
        utils.run("dd if=/dev/zero of=%s bs=1M count=%s" % (dst_path, size))
        utils.run("mkfs.ext3 -F %s" % (dst_path))

    def mount(disk_path, mount_path, options=None):
        """
        Mount Disk to path

        :param disk_path: Path to disk
        :param mount_path: Path where disk will be mounted.
        :param options: String with options for mount
        """
        if options is None:
            options = ""
        else:
            options = "%s" % options

        utils.run("mount %s %s %s" % (options, disk_path, mount_path))

    def find_disk_vm(vm, disk_serial):
        """
        Find disk on vm which ends with disk_serial

        :param vm: VM where to find a disk.
        :param disk_serial: sufix of disk id.

        :return: string Disk path
        """
        session = vm.wait_for_login(timeout=login_timeout)

        disk_path = os.path.join("/", "dev", "disk", "by-id")
        disks = session.cmd("ls %s" % disk_path).split("\n")
        session.close()
        disk = filter(lambda x: x.endswith(disk_serial), disks)
        if not disk:
            return None
        return os.path.join(disk_path, disk[0])

    def prepare_disk(vm, disk_path, mount_path):
        """
        Create Ext3 on disk a send there data from main disk.

        :param vm: VM where to find a disk.
        :param disk_path: Path to disk in guest system.
        """
        session = vm.wait_for_login(timeout=login_timeout)
        session.cmd("mkfs.ext3 -F %s" % (disk_path))
        session.cmd("mount %s %s" % (disk_path, mount_path))
        session.close()

    def disk_load(vm, src_path, dst_path, copy_timeout=None, dsize=None):
        """
        Start disk load. Cyclic copy from src_path to dst_path.

        :param vm: VM where to find a disk.
        :param src_path: Source of data
        :param dst_path: Path to destination
        :param copy_timeout: Timeout for copy
        :param dsize: Size of data block which is periodical copied.
        """
        if dsize is None:
            dsize = 100
        session = vm.wait_for_login(timeout=login_timeout)
        cmd = ("nohup /bin/bash -c 'while true; do dd if=%s of=%s bs=1M "
               "count=%s; done;' 2> /dev/null &" % (src_path, dst_path, dsize))
        pid = re.search(r"\[.+\] (.+)",
                        session.cmd_output(cmd, timeout=copy_timeout))
        return pid.group(1)

    class IscsiServer_tgt(object):

        """
        Class for set and start Iscsi server.
        """

        def __init__(self):
            self.server_name = "autotest_guest_" + test_rand
            self.user = "user1"
            self.passwd = "pass"
            self.config = """
<target %s:dev01>
    backing-store %s
    incominguser %s %s
</target>
"""

        def set_iscsi_server(self, vm_ds, disk_path, disk_size):
            """
            Set iscsi server with some variant.

            @oaram vm_ds: VM where should be iscsi server started.
            :param disk_path: path where should be disk placed.
            :param disk_size: size of new disk.
            """
            session = vm_ds.wait_for_login(timeout=login_timeout)

            session.cmd("dd if=/dev/zero of=%s bs=1M count=%s" % (disk_path,
                                                                  disk_size))
            status, output = session.cmd_status_output("setenforce 0")
            if status not in [0, 127]:
                logging.warn("Function setenforce fails.\n %s" % (output))

            config = self.config % (self.server_name, disk_path,
                                    self.user, self.passwd)
            cmd = "cat > /etc/tgt/conf.d/virt.conf << EOF" + config + "EOF"
            control_service(session, "tgtd", "tgtd", "stop")
            session.sendline(cmd)
            control_service(session, "tgtd", "tgtd", "start")
            session.cmd("iptables -F")
            session.close()

        def find_disk(self):
            disk_path = os.path.join("/", "dev", "disk", "by-path")
            disks = utils.run("ls %s" % disk_path).stdout.split("\n")
            disk = filter(lambda x: self.server_name in x, disks)
            if disk is []:
                return None
            return os.path.join(disk_path, disk[0].strip())

        def connect(self, vm_ds):
            """
            Connect to iscsi server on guest.

            :param vm_ds: Guest where is iscsi server running.

            :return: path where disk is connected.
            """
            ip_dst = vm_ds.get_address()
            utils.run("iscsiadm -m discovery -t st -p %s" % (ip_dst))

            server_ident = ('iscsiadm -m node --targetname "%s:dev01"'
                            ' --portal %s' % (self.server_name, ip_dst))
            utils.run("%s --op update --name node.session.auth.authmethod"
                      " --value CHAP" % (server_ident))
            utils.run("%s --op update --name node.session.auth.username"
                      " --value %s" % (server_ident, self.user))
            utils.run("%s --op update --name node.session.auth.password"
                      " --value %s" % (server_ident, self.passwd))
            utils.run("%s --login" % (server_ident))
            time.sleep(1.0)
            return self.find_disk()

        def disconnect(self):
            server_ident = ('iscsiadm -m node --targetname "%s:dev01"' %
                            (self.server_name))
            utils.run("%s --logout" % (server_ident))

    class IscsiServer(object):

        """
        Iscsi server implementation interface.
        """

        def __init__(self, iscsi_type, *args, **kargs):
            if iscsi_type == "tgt":
                self.ic = IscsiServer_tgt(*args, **kargs)
            else:
                raise NotImplementedError()

        def __getattr__(self, name):
            if self.ic:
                return self.ic.__getattribute__(name)
            raise AttributeError("Cannot find attribute %s in class" % name)

    class test_read_only_dest(MiniSubtest):

        """
        Migration to read-only destination by using a migration to file.

        1) Start guest with NFS server.
        2) Config NFS server share for read-only.
        3) Mount the read-only share to host.
        4) Start second guest and try to migrate to read-only dest.

        result) Migration should fail with error message about read-only dst.
        """

        def test(self):
            if params.get("nettype") != "bridge":
                raise error.TestNAError("Unable start test without params"
                                        " nettype=bridge.")

            vm_ds = env.get_vm("virt_test_vm2_data_server")
            vm_guest = env.get_vm("virt_test_vm1_guest")
            ro_timeout = int(params.get("read_only_timeout", "480"))
            exp_str = r".*Read-only file system.*"
            utils.run("mkdir -p %s" % (mount_path))

            vm_ds.verify_alive()
            vm_guest.create()
            vm_guest.verify_alive()

            set_nfs_server(vm_ds, "/mnt *(ro,async,no_root_squash)")

            mount_src = "%s:/mnt" % (vm_ds.get_address())
            mount(mount_src, mount_path,
                  "-o hard,timeo=14,rsize=8192,wsize=8192")
            vm_guest.migrate(mig_timeout, mig_protocol,
                             not_wait_for_migration=True,
                             migration_exec_cmd_src=migration_exec_cmd_src,
                             env=env)

            if not utils_misc.wait_for(lambda: process_output_check(
                                       vm_guest.process, exp_str),
                                       timeout=ro_timeout, first=2):
                raise error.TestFail("The Read-only file system warning not"
                                     " come in time limit.")

        def clean(self):
            if os.path.exists(mig_dst):
                os.remove(mig_dst)
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)

    class test_low_space_dest(MiniSubtest):

        """
        Migrate to destination with low space.

        1) Start guest.
        2) Create disk with low space.
        3) Try to migratie to the disk.

        result) Migration should fail with warning about No left space on dev.
        """

        def test(self):
            self.disk_path = None
            while self.disk_path is None or os.path.exists(self.disk_path):
                self.disk_path = ("%s/disk_%s" %
                                  (test.tmpdir, utils.generate_random_string(3)))

            disk_size = utils.convert_data_size(params.get("disk_size", "10M"),
                                                default_sufix='M')
            disk_size /= 1024 * 1024    # To MB.

            exp_str = r".*gzip: stdout: No space left on device.*"
            vm_guest = env.get_vm("virt_test_vm1_guest")
            utils.run("mkdir -p %s" % (mount_path))

            vm_guest.verify_alive()
            vm_guest.wait_for_login(timeout=login_timeout)

            create_file_disk(self.disk_path, disk_size)
            mount(self.disk_path, mount_path, "-o loop")

            vm_guest.migrate(mig_timeout, mig_protocol,
                             not_wait_for_migration=True,
                             migration_exec_cmd_src=migration_exec_cmd_src,
                             env=env)

            if not utils_misc.wait_for(lambda: process_output_check(
                                       vm_guest.process, exp_str),
                                       timeout=60, first=1):
                raise error.TestFail("The migration to destination with low "
                                     "storage space didn't fail as it should.")

        def clean(self):
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)
            if os.path.exists(self.disk_path):
                os.remove(self.disk_path)

    class test_extensive_io(MiniSubtest):

        """
        Migrate after extensive_io abstract class. This class only define
        basic funtionaly and define interface. For other tests.

        1) Start ds_guest which starts data server.
        2) Create disk for data stress in ds_guest.
        3) Share and prepare disk from ds_guest
        6) Mount the disk to mount_path
        7) Create disk for second guest in the mounted path.
        8) Start second guest with prepared disk.
        9) Start stress on the prepared disk on second guest.
        10) Wait few seconds.
        11) Restart iscsi server.
        12) Migrate second guest.

        result) Migration should be successful.
        """

        def test(self):
            self.copier_pid = None
            if params.get("nettype") != "bridge":
                raise error.TestNAError("Unable start test without params"
                                        " nettype=bridge.")

            self.disk_serial = params.get("drive_serial_image2_vm1",
                                          "nfs-disk-image2-vm1")
            self.disk_serial_src = params.get("drive_serial_image1_vm1",
                                              "root-image1-vm1")
            self.guest_mount_path = params.get("guest_disk_mount_path", "/mnt")
            self.copy_timeout = int(params.get("copy_timeout", "1024"))

            self.copy_block_size = params.get("copy_block_size", "100M")
            self.copy_block_size = utils.convert_data_size(
                self.copy_block_size,
                "M")
            self.disk_size = "%s" % (self.copy_block_size * 1.4)
            self.copy_block_size /= 1024 * 1024

            self.server_recover_timeout = (
                int(params.get("server_recover_timeout", "240")))

            utils.run("mkdir -p %s" % (mount_path))

            self.test_params()
            self.config()

            self.vm_guest_params = params.copy()
            self.vm_guest_params["images_base_dir_image2_vm1"] = mount_path
            self.vm_guest_params["image_name_image2_vm1"] = "ni_mount_%s/test" % (test_rand)
            self.vm_guest_params["image_size_image2_vm1"] = self.disk_size
            self.vm_guest_params = self.vm_guest_params.object_params("vm1")
            self.image2_vm_guest_params = (self.vm_guest_params.
                                           object_params("image2"))

            env_process.preprocess_image(test,
                                         self.image2_vm_guest_params,
                                         env)
            self.vm_guest.create(params=self.vm_guest_params)

            self.vm_guest.verify_alive()
            self.vm_guest.wait_for_login(timeout=login_timeout)
            self.workload()

            self.restart_server()

            self.vm_guest.migrate(mig_timeout, mig_protocol, env=env)

            try:
                self.vm_guest.verify_alive()
                self.vm_guest.wait_for_login(timeout=login_timeout)
            except aexpect.ExpectTimeoutError:
                raise error.TestFail("Migration should be successful.")

        def test_params(self):
            """
            Test specific params. Could be implemented in inherited class.
            """
            pass

        def config(self):
            """
            Test specific config.
            """
            raise NotImplementedError()

        def workload(self):
            disk_path = find_disk_vm(self.vm_guest, self.disk_serial)
            if disk_path is None:
                raise error.TestFail("It was impossible to find disk on VM")

            prepare_disk(self.vm_guest, disk_path, self.guest_mount_path)

            disk_path_src = find_disk_vm(self.vm_guest, self.disk_serial_src)
            dst_path = os.path.join(self.guest_mount_path, "test.data")
            self.copier_pid = disk_load(self.vm_guest, disk_path_src, dst_path,
                                        self.copy_timeout, self.copy_block_size)

        def restart_server(self):
            raise NotImplementedError()

        def clean_test(self):
            """
            Test specific cleanup.
            """
            pass

        def clean(self):
            if self.copier_pid:
                try:
                    if self.vm_guest.is_alive():
                        session = self.vm_guest.wait_for_login(timeout=login_timeout)
                        session.cmd("kill -9 %s" % (self.copier_pid))
                except:
                    logging.warn("It was impossible to stop copier. Something "
                                 "probably happened with GUEST or NFS server.")

            if params.get("kill_vm") == "yes":
                if self.vm_guest.is_alive():
                    self.vm_guest.destroy()
                    utils_misc.wait_for(lambda: self.vm_guest.is_dead(), 30,
                                        2, 2, "Waiting for dying of guest.")
                qemu_img = qemu_storage.QemuImg(self.image2_vm_guest_params,
                                                mount_path,
                                                None)
                qemu_img.check_image(self.image2_vm_guest_params,
                                     mount_path)

            self.clean_test()

    class test_extensive_io_nfs(test_extensive_io):

        """
        Migrate after extensive io.

        1) Start ds_guest which starts NFS server.
        2) Create disk for data stress in ds_guest.
        3) Share disk over NFS.
        4) Mount the disk to mount_path
        5) Create disk for second guest in the mounted path.
        6) Start second guest with prepared disk.
        7) Start stress on the prepared disk on second guest.
        8) Wait few seconds.
        9) Restart iscsi server.
        10) Migrate second guest.

        result) Migration should be successful.
        """

        def config(self):
            vm_ds = env.get_vm("virt_test_vm2_data_server")
            self.vm_guest = env.get_vm("vm1")
            self.image2_vm_guest_params = None
            self.copier_pid = None
            self.qemu_img = None

            vm_ds.verify_alive()
            self.control_session_ds = vm_ds.wait_for_login(timeout=login_timeout)

            set_nfs_server(vm_ds, "/mnt *(rw,async,no_root_squash)")

            mount_src = "%s:/mnt" % (vm_ds.get_address())
            mount(mount_src, mount_path,
                  "-o hard,timeo=14,rsize=8192,wsize=8192")

        def restart_server(self):
            time.sleep(10)  # Wait for wail until copy start working.
            control_service(self.control_session_ds, "nfs-server",
                            "nfs", "stop")  # Stop NFS server
            time.sleep(5)
            control_service(self.control_session_ds, "nfs-server",
                            "nfs", "start")  # Start NFS server

            """
            Touch waits until all previous requests are invalidated
            (NFS grace period). Without grace period qemu start takes
            to long and timers for machine creation dies.
            """
            qemu_img = qemu_storage.QemuImg(self.image2_vm_guest_params,
                                            mount_path,
                                            None)
            utils.run("touch %s" % (qemu_img.image_filename),
                      self.server_recover_timeout)

        def clean_test(self):
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)

    class test_extensive_io_iscsi(test_extensive_io):

        """
        Migrate after extensive io.

        1) Start ds_guest which starts iscsi server.
        2) Create disk for data stress in ds_guest.
        3) Share disk over iscsi.
        4) Join to disk on host.
        5) Prepare partition on the disk.
        6) Mount the disk to mount_path
        7) Create disk for second guest in the mounted path.
        8) Start second guest with prepared disk.
        9) Start stress on the prepared disk on second guest.
        10) Wait few seconds.
        11) Restart iscsi server.
        12) Migrate second guest.

        result) Migration should be successful.
        """

        def test_params(self):
            self.iscsi_variant = params.get("iscsi_variant", "tgt")
            self.ds_disk_path = os.path.join(self.guest_mount_path, "test.img")

        def config(self):
            vm_ds = env.get_vm("virt_test_vm2_data_server")
            self.vm_guest = env.get_vm("vm1")
            self.image2_vm_guest_params = None
            self.copier_pid = None
            self.qemu_img = None

            vm_ds.verify_alive()
            self.control_session_ds = vm_ds.wait_for_login(timeout=login_timeout)

            self.isci_server = IscsiServer("tgt")
            disk_path = os.path.join(self.guest_mount_path, "disk1")
            self.isci_server.set_iscsi_server(vm_ds, disk_path,
                                              (int(float(self.disk_size) * 1.1) / (1024 * 1024)))
            self.host_disk_path = self.isci_server.connect(vm_ds)

            utils.run("mkfs.ext3 -F %s" % (self.host_disk_path))
            mount(self.host_disk_path, mount_path)

        def restart_server(self):
            time.sleep(10)  # Wait for wail until copy start working.
            control_service(self.control_session_ds, "tgtd",
                            "tgtd", "stop", 240)  # Stop Iscsi server
            time.sleep(5)
            control_service(self.control_session_ds, "tgtd",
                            "tgtd", "start", 240)  # Start Iscsi server

            """
            Wait for iscsi server after restart and will be again
            accessible.
            """
            qemu_img = qemu_storage.QemuImg(self.image2_vm_guest_params,
                                            mount_path,
                                            None)
            utils.run("touch %s" % (qemu_img.image_filename),
                      self.server_recover_timeout)

        def clean_test(self):
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)
            if os.path.exists(self.host_disk_path):
                self.isci_server.disconnect()

    test_type = params.get("test_type")
    if (test_type in locals()):
        tests_group = locals()[test_type]
        tests_group()
    else:
        raise error.TestFail("Test group '%s' is not defined in"
                             " migration_with_dst_problem test" % test_type)