logging.info

Here are the examples of the python api logging.info taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: mongodb_consistent_backup
Source File: Backup.py
View license
    def run(self):
        # TODO would be nice to have this code  look like: (functions do the work) and its readable
        """
        self.log(version_message,INFO)
        self.lock()
        self.start_timer()
        if not self.is_sharded():
            self.exec_unsharded()
        else
            self.exec_sharded()
        self.stopTimer()
        self.archive()
        self.upload()
        self.notify()
        if self.db:
            self.db.close()
        self.log(backup_complete_message,INFO)
        """
        logging.info("Starting %s version %s (git commit hash: %s)" % (self.program_name, self.version, self.git_commit))

        # noinspection PyBroadException
        try:
            self._lock = Lock(self.lock_file)
        except Exception:
            logging.fatal("Could not acquire lock! Is another %s process running? Exiting" % self.program_name)
            sys.exit(1)

        if not self.is_sharded:
            logging.info("Running backup of %s:%s in replset mode" % (self.host, self.port))

            self.archiver_threads = 1

            # get shard secondary
            try:
                self.replset = Replset(
                    self.db,
                    self.user,
                    self.password,
                    self.authdb,
                    self.max_repl_lag_secs
                )
                secondary    = self.replset.find_secondary()
                replset_name = secondary['replSet']

                self.secondaries[replset_name] = secondary
            except Exception, e:
                self.exception("Problem getting shard secondaries! Error: %s" % e)

            try:
                self.mongodumper = Dumper(
                    self.secondaries,
                    self.backup_root_directory,
                    self.backup_binary,
                    self.dump_gzip,
                    self.user,
                    self.password,
                    self.authdb,
                    None,
                    self.verbose
                )
                self.mongodumper.run()
            except Exception, e:
                self.exception("Problem performing replset mongodump! Error: %s" % e)

        else:
            logging.info("Running backup of %s:%s in sharded mode" % (self.host, self.port))

            # connect to balancer and stop it
            try:
                self.sharding = Sharding(
                    self.db,
                    self.user,
                    self.password,
                    self.authdb,
                    self.balancer_wait_secs,
                    self.balancer_sleep
                )
                self.sharding.get_start_state()
            except Exception, e:
                self.exception("Problem connecting to the balancer! Error: %s" % e)

            # get shard secondaries
            try:
                self.replset_sharded = ReplsetSharded(
                    self.sharding,
                    self.db,
                    self.user,
                    self.password,
                    self.authdb,
                    self.max_repl_lag_secs
                )
                self.secondaries = self.replset_sharded.find_secondaries()
            except Exception, e:
                self.exception("Problem getting shard secondaries! Error: %s" % e)

            # Stop the balancer:    
            try:
                self.sharding.stop_balancer()
            except Exception, e:
                self.exception("Problem stopping the balancer! Error: %s" % e)

            # start the oplog tailer threads
            if self.no_oplog_tailer:
                logging.warning("Oplog tailing disabled! Skipping")
            else:
                try:
                    self.oplogtailer = OplogTailer(
                        self.secondaries,
                        self.backup_name,
                        self.backup_root_directory,
                        self.dump_gzip,
                        self.user,
                        self.password,
                        self.authdb
                    )
                    self.oplogtailer.run()
                except Exception, e:
                    self.exception("Failed to start oplog tailing threads! Error: %s" % e)

            # start the mongodumper threads
            try:
                self.mongodumper = Dumper(
                    self.secondaries, 
                    self.backup_root_directory,
                    self.backup_binary,
                    self.dump_gzip,
                    self.user,
                    self.password,
                    self.authdb,
                    self.sharding.get_config_server(),
                    self.verbose
                )
                self.mongodumper_summary = self.mongodumper.run()
            except Exception, e:
                self.exception("Problem performing mongodumps! Error: %s" % e)

            # stop the oplog tailing threads:
            if self.oplogtailer:
                self.oplog_summary = self.oplogtailer.stop()

            # set balancer back to original value
            try:
                self.sharding.restore_balancer_state()
            except Exception, e:
                self.exception("Problem restoring balancer lock! Error: %s" % e)

            # resolve/merge tailed oplog into mongodump oplog.bson to a consistent point for all shards
            if self.oplogtailer:
                self.oplog_resolver = OplogResolver(self.oplog_summary, self.mongodumper_summary, self.dump_gzip,
                                                    self.resolver_threads)
                self.oplog_resolver.run()

        # archive (and optionally compress) backup directories to archive files (threaded)
        if self.no_archiver:
            logging.warning("Archiving disabled! Skipping")
        else:
            try:
                self.archiver = Archiver(self.backup_root_directory, self.no_archiver_gzip, self.archiver_threads, self.verbose)
                self.archiver.run()
            except Exception, e:
                self.exception("Problem performing archiving! Error: %s" % e)

        self.end_time = time()
        self.backup_duration = self.end_time - self.start_time

        # AWS S3 secure multipart uploader (optional)
        if self.upload_s3_bucket_name and self.upload_s3_bucket_prefix and self.upload_s3_access_key and self.upload_s3_secret_key:
            try:
                self.uploader_s3 = UploadS3(
                    self.backup_root_directory,
                    self.backup_root_subdirectory,
                    self.upload_s3_bucket_name,
                    self.upload_s3_bucket_prefix,
                    self.upload_s3_access_key,
                    self.upload_s3_secret_key,
                    self.upload_s3_remove_uploaded,
                    self.upload_s3_url,
                    self.upload_s3_threads,
                    self.upload_s3_chunk_size_mb
                )
                self.uploader_s3.run()
            except Exception, e:
                self.exception("Problem performing AWS S3 multipart upload! Error: %s" % e)

        # send notifications of backup state
        if self.notify_nsca:
            try:
                self.notify_nsca.notify(self.notify_nsca.success, "%s: backup '%s' succeeded in %s secs" % (
                    self.program_name,
                    self.backup_name,
                    self.backup_duration
                ))
            except Exception, e:
                self.exception("Problem running NSCA notifier! Error: %s" % e)

        if self.db:
            self.db.close()

        self._lock.release()

        logging.info("Backup completed in %s sec" % self.backup_duration)

Example 2

Project: PHEnix
Source File: run_snp_pipeline.py
View license
def main(args):

    if args.get("version") is None:
        args["version"] = get_version()

    make_aux = False
    if args["workflow"] and args["input"]:
        logging.info("PIPELINE_START")
        workflow_config = pipeline(args["workflow"], args["input"])
        try:
            args["r1"] = workflow_config["r1"]
            args["r2"] = workflow_config["r2"]
            args["outdir"] = workflow_config["outdir"]
            args["config"] = workflow_config["config"]
            args["reference"] = workflow_config["reference"]
            args["sample_name"] = workflow_config["sample_name"]
            make_aux = True
        except KeyError:
            logging.critical("Could not find parameters in %s", args["input"])
            return 5

    logging.info("Initialising data matrix.")

    if args["outdir"] is None:
        sys.stdout.write("Please provide output directory.")
        return -1
    elif not os.path.exists(args["outdir"]):
        os.makedirs(args["outdir"])

    # If config is specified, then load data from that.
    if args["config"]:
        load_config(args)

    mapper = None
    if args["mapper"]:
        mapper = map_fac(mapper=args["mapper"], custom_options=args["mapper_options"])

    variant = None
    if args["variant"]:
        variant = variant_fac(variant=args["variant"], custom_options=args["variant_options"])

    if args["annotators"]:
        args["annotators"] = make_annotators(args["annotators"])

    if args["filters"]:
        try:
            if isinstance(args["filters"], str):
                args["filters"] = str_to_filters(args["filters"])
            elif isinstance(args["filters"], dict):
                args["filters"] = make_filters(args["filters"])
            else:
                logging.warn("Unknown filters specified: %s", args["filters"])
        except Exception:
            logging.error("Failed to recognise and create filters.")
            return 3

    if args["bam"] is not None:
        logging.info("Found BAM file: %s", args["bam"])
        bam_file = args["bam"]
    elif args["vcf"] is None and mapper is not None:
        logging.info("Mapping data file with %s.", args["mapper"])
        bam_file = os.path.join(args["outdir"], "%s.bam" % args["sample_name"])
        success = mapper.make_bam(ref=args["reference"],
                                  R1=args["r1"],
                                  R2=args["r2"],
                                  out_file=bam_file,
                                  sample_name=args["sample_name"],
                                  make_aux=make_aux)

        if not success:
            logging.warn("Could not map reads to the reference. Aborting.")
            return 1
    else:
        bam_file = None

    logging.info("Creating digitised variants with %s.", args["variant"])
    if args["vcf"]:
        vcf_file = args["vcf"]
    elif bam_file is not None:
        vcf_file = os.path.join(args["outdir"], "%s.vcf" % args["sample_name"])

        if variant and not variant.make_vcf(ref=args["reference"], bam=bam_file, vcf_file=vcf_file, make_aux=make_aux):
            logging.error("VCF was not created.")
            return 2

        # Remove BAM file if it was generated and not kept.
        if not args["bam"] and not args["keep_temp"]:
            logging.debug("Removing BAM file: %s", bam_file)
            os.unlink(bam_file)
            os.unlink("%s.bai" % bam_file)
    else:
        vcf_file = None

    annotators_metadata = []
    if args["annotators"] and vcf_file:
        logging.info("Annotating")
        for annotator in args["annotators"]:
            # TODO: This iterates over the VCF for each annotator. Not good.
            annotator.annotate(vcf_path=vcf_file)

            meta = annotator.get_meta()

            if meta:
                annotators_metadata.append(meta)

    if args["filters"] and vcf_file:
        logging.info("Applying filters: %s", [str(f) for f in args["filters"]])
        var_set = VariantSet(vcf_file, filters=args["filters"], reference=args["reference"])

        var_set.add_metadata(mapper.get_meta())
        var_set.add_metadata(variant.get_meta())

        if args.get("version") is not None:
            var_set.add_metadata(OrderedDict({"PHEnix-Version":(args["version"],)}))

        for annotator_md in annotators_metadata:
            var_set.add_metadata(annotator_md)

        final_vcf = os.path.join(args["outdir"], "%s.filtered.vcf" % args["sample_name"])
        var_set.filter_variants(out_vcf=final_vcf)

        if not args["vcf"] and not args["keep_temp"]:
            logging.debug("Removing VCF file: %s", vcf_file)
            os.unlink(vcf_file)

    if args["workflow"] and args["input"]:
        component_complete = os.path.join(args["outdir"], "ComponentComplete.txt")
        open(component_complete, 'a').close()

        logging.info("PIPELINE_END")

    return 0

Example 3

Project: autotest
Source File: server_job.py
View license
    def run(self, cleanup=False, install_before=False, install_after=False,
            collect_crashdumps=True, namespace={}, control=None,
            control_file_dir=None, only_collect_crashinfo=False):
        # for a normal job, make sure the uncollected logs file exists
        # for a crashinfo-only run it should already exist, bail out otherwise
        created_uncollected_logs = False
        if self.resultdir and not os.path.exists(self._uncollected_log_file):
            if only_collect_crashinfo:
                # if this is a crashinfo-only run, and there were no existing
                # uncollected logs, just bail out early
                logging.info("No existing uncollected logs, "
                             "skipping crashinfo collection")
                return
            else:
                log_file = open(self._uncollected_log_file, "w")
                pickle.dump([], log_file)
                log_file.close()
                created_uncollected_logs = True

        # use a copy so changes don't affect the original dictionary
        namespace = namespace.copy()
        machines = self.machines
        if control is None:
            if self.control is None:
                control = ''
            else:
                control = self._load_control_file(self.control)
        if control_file_dir is None:
            control_file_dir = self.resultdir

        self.aborted = False
        namespace['machines'] = machines
        namespace['args'] = self.args
        namespace['job'] = self
        namespace['ssh_user'] = self._ssh_user
        namespace['ssh_port'] = self._ssh_port
        namespace['ssh_pass'] = self._ssh_pass
        test_start_time = int(time.time())

        if self.resultdir:
            os.chdir(self.resultdir)
            # touch status.log so that the parser knows a job is running here
            open(self.get_status_log_path(), 'a').close()
            self.enable_external_logging()

        collect_crashinfo = True
        temp_control_file_dir = None
        try:
            try:
                if install_before and machines:
                    self._execute_code(INSTALL_CONTROL_FILE, namespace)

                if only_collect_crashinfo:
                    return

                # determine the dir to write the control files to
                cfd_specified = (control_file_dir and control_file_dir is not
                                 self._USE_TEMP_DIR)
                if cfd_specified:
                    temp_control_file_dir = None
                else:
                    temp_control_file_dir = tempfile.mkdtemp(
                        suffix='temp_control_file_dir')
                    control_file_dir = temp_control_file_dir
                server_control_file = os.path.join(control_file_dir,
                                                   self._control_filename)
                client_control_file = os.path.join(control_file_dir,
                                                   CLIENT_CONTROL_FILENAME)
                if self._client:
                    namespace['control'] = control
                    utils.open_write_close(client_control_file, control)
                    shutil.copyfile(CLIENT_WRAPPER_CONTROL_FILE,
                                    server_control_file)
                else:
                    utils.open_write_close(server_control_file, control)
                logging.info("Processing control file")
                self._execute_code(server_control_file, namespace)
                logging.info("Finished processing control file")

                # no error occurred, so we don't need to collect crashinfo
                collect_crashinfo = False
            except Exception, e:
                try:
                    logging.exception(
                        'Exception escaped control file, job aborting:')
                    self.record('INFO', None, None, str(e),
                                {'job_abort_reason': str(e)})
                except:
                    pass  # don't let logging exceptions here interfere
                raise
        finally:
            if temp_control_file_dir:
                # Clean up temp directory used for copies of the control files
                try:
                    shutil.rmtree(temp_control_file_dir)
                except Exception, e:
                    logging.warn('Could not remove temp directory %s: %s',
                                 temp_control_file_dir, e)

            if machines and (collect_crashdumps or collect_crashinfo):
                namespace['test_start_time'] = test_start_time
                if collect_crashinfo:
                    # includes crashdumps
                    self._execute_code(CRASHINFO_CONTROL_FILE, namespace)
                else:
                    self._execute_code(CRASHDUMPS_CONTROL_FILE, namespace)
            if self._uncollected_log_file and created_uncollected_logs:
                os.remove(self._uncollected_log_file)
            self.disable_external_logging()
            if cleanup and machines:
                self._execute_code(CLEANUP_CONTROL_FILE, namespace)
            if install_after and machines:
                self._execute_code(INSTALL_CONTROL_FILE, namespace)

Example 4

View license
def run(test, params, env):
    """
    Test virsh {at|de}tach-disk command.

    The command can attach new disk/detach disk.
    1.Prepare test environment,destroy or suspend a VM.
    2.Perform virsh attach/detach-disk operation.
    3.Recover test environment.
    4.Confirm the test result.
    """

    def check_vm_partition(vm, device, os_type, target_name, old_parts):
        """
        Check VM disk's partition.

        :param vm. VM guest.
        :param os_type. VM's operation system type.
        :param target_name. Device target type.
        :return: True if check successfully.
        """
        logging.info("Checking VM partittion...")
        if vm.is_dead():
            vm.start()
        try:
            attached = False
            if os_type == "linux":
                session = vm.wait_for_login()
                new_parts = libvirt.get_parts_list(session)
                added_parts = list(set(new_parts).difference(set(old_parts)))
                logging.debug("Added parts: %s" % added_parts)
                for i in range(len(added_parts)):
                    if device == "disk":
                        if target_name.startswith("vd"):
                            if added_parts[i].startswith("vd"):
                                attached = True
                        elif target_name.startswith("hd") or target_name.startswith("sd"):
                            if added_parts[i].startswith("sd"):
                                attached = True
                    elif device == "cdrom":
                        if added_parts[i].startswith("sr"):
                            attached = True
                session.close()
            return attached
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    def acpiphp_module_modprobe(vm, os_type):
        """
        Add acpiphp module if VM's os type is rhle5.*

        :param vm. VM guest.
        :param os_type. VM's operation system type.
        :return: True if operate successfully.
        """
        if vm.is_dead():
            vm.start()
        try:
            if os_type == "linux":
                session = vm.wait_for_login()
                s_rpm, _ = session.cmd_status_output(
                    "rpm --version")
                # If status is different from 0, this
                # guest OS doesn't support the rpm package
                # manager
                if s_rpm:
                    session.close()
                    return True
                _, o_vd = session.cmd_status_output(
                    "rpm -qa | grep redhat-release")
                if o_vd.find("5Server") != -1:
                    s_mod, o_mod = session.cmd_status_output(
                        "modprobe acpiphp")
                    del o_mod
                    if s_mod != 0:
                        session.close()
                        return False
                session.close()
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    vm_ref = params.get("at_dt_disk_vm_ref", "name")
    at_options = params.get("at_dt_disk_at_options", "")
    dt_options = params.get("at_dt_disk_dt_options", "")
    pre_vm_state = params.get("at_dt_disk_pre_vm_state", "running")
    status_error = "yes" == params.get("status_error", 'no')
    no_attach = params.get("at_dt_disk_no_attach", 'no')
    os_type = params.get("os_type", "linux")

    # Get test command.
    test_cmd = params.get("at_dt_disk_test_cmd", "attach-disk")

    # Disk specific attributes.
    device = params.get("at_dt_disk_device", "disk")
    device_source_name = params.get("at_dt_disk_device_source", "attach.img")
    device_source_format = params.get("at_dt_disk_device_source_format", "raw")
    device_target = params.get("at_dt_disk_device_target", "vdd")
    source_path = "yes" == params.get("at_dt_disk_device_source_path", "yes")
    create_img = "yes" == params.get("at_dt_disk_create_image", "yes")
    test_twice = "yes" == params.get("at_dt_disk_test_twice", "no")
    test_type = "yes" == params.get("at_dt_disk_check_type", "no")
    test_audit = "yes" == params.get("at_dt_disk_check_audit", "no")
    test_block_dev = "yes" == params.get("at_dt_disk_iscsi_device", "no")
    test_logcial_dev = "yes" == params.get("at_dt_disk_logical_device", "no")
    restart_libvirtd = "yes" == params.get("at_dt_disk_restart_libvirtd", "no")
    vg_name = params.get("at_dt_disk_vg", "vg_test_0")
    lv_name = params.get("at_dt_disk_lv", "lv_test_0")
    serial = params.get("at_dt_disk_serial", "")
    address = params.get("at_dt_disk_address", "")
    address2 = params.get("at_dt_disk_address2", "")
    cache_options = params.get("cache_options", "")
    time_sleep = params.get("time_sleep", 3)
    if serial:
        at_options += (" --serial %s" % serial)
    if address2:
        at_options_twice = at_options + (" --address %s" % address2)
    if address:
        at_options += (" --address %s" % address)
    if cache_options:
        if cache_options.count("directsync"):
            if not libvirt_version.version_compare(1, 0, 0):
                raise error.TestNAError("'directsync' cache option doesn't support in"
                                        " current libvirt version.")
        at_options += (" --cache %s" % cache_options)

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)

    # Start vm and get all partions in vm.
    if vm.is_dead():
        vm.start()
    session = vm.wait_for_login()
    old_parts = libvirt.get_parts_list(session)
    session.close()
    vm.destroy(gracefully=False)

    # Back up xml file.
    backup_xml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    # Create virtual device file.
    device_source_path = os.path.join(test.tmpdir, device_source_name)
    if test_block_dev:
        device_source = libvirt.setup_or_cleanup_iscsi(True)
        if not device_source:
            # We should skip this case
            raise error.TestNAError("Can not get iscsi device name in host")
        if test_logcial_dev:
            lv_utils.vg_create(vg_name, device_source)
            device_source = libvirt.create_local_disk("lvm",
                                                      size="10M",
                                                      vgname=vg_name,
                                                      lvname=lv_name)
            logging.debug("New created volume: %s", lv_name)
    else:
        if source_path and create_img:
            device_source = libvirt.create_local_disk(
                "file", path=device_source_path,
                size="1G", disk_format=device_source_format)
        else:
            device_source = device_source_name

    # if we are testing audit, we need to start audit servcie first.
    if test_audit:
        auditd_service = Factory.create_service("auditd")
        if not auditd_service.status():
            auditd_service.start()
        logging.info("Auditd service status: %s" % auditd_service.status())

    # If we are testing cdrom device, we need to detach hdc in VM first.
    if device == "cdrom":
        if vm.is_alive():
            vm.destroy(gracefully=False)
        s_detach = virsh.detach_disk(vm_name, device_target, "--config")
        if not s_detach:
            logging.error("Detach hdc failed before test.")

    # If we are testing detach-disk, we need to attach certain device first.
    if test_cmd == "detach-disk" and no_attach != "yes":
        s_attach = virsh.attach_disk(vm_name, device_source, device_target,
                                     "--driver qemu --config").exit_status
        if s_attach != 0:
            logging.error("Attaching device failed before testing detach-disk")

        if test_twice:
            device_target2 = params.get("at_dt_disk_device_target2",
                                        device_target)
            device_source = libvirt.create_local_disk(
                "file", path=device_source_path,
                size="1", disk_format=device_source_format)
            s_attach = virsh.attach_disk(vm_name, device_source, device_target2,
                                         "--driver qemu --config").exit_status
            if s_attach != 0:
                logging.error("Attaching device failed before testing "
                              "detach-disk test_twice")

    vm.start()
    vm.wait_for_login()

    # Add acpiphp module before testing if VM's os type is rhle5.*
    if not acpiphp_module_modprobe(vm, os_type):
        raise error.TestError("Add acpiphp module failed before test.")

    # Turn VM into certain state.
    if pre_vm_state == "paused":
        logging.info("Suspending %s..." % vm_name)
        if vm.is_alive():
            vm.pause()
    elif pre_vm_state == "shut off":
        logging.info("Shuting down %s..." % vm_name)
        if vm.is_alive():
            vm.destroy(gracefully=False)

    # Get disk count before test.
    disk_count_before_cmd = vm_xml.VMXML.get_disk_count(vm_name)

    # Test.
    domid = vm.get_id()
    domuuid = vm.get_uuid()

    # Confirm how to reference a VM.
    if vm_ref == "name":
        vm_ref = vm_name
    elif vm_ref.find("invalid") != -1:
        vm_ref = params.get(vm_ref)
    elif vm_ref == "id":
        vm_ref = domid
    elif vm_ref == "hex_id":
        vm_ref = hex(int(domid))
    elif vm_ref == "uuid":
        vm_ref = domuuid
    else:
        vm_ref = ""

    if test_cmd == "attach-disk":
        status = virsh.attach_disk(vm_ref, device_source, device_target,
                                   at_options, debug=True).exit_status
    elif test_cmd == "detach-disk":
        status = virsh.detach_disk(vm_ref, device_target, dt_options,
                                   debug=True).exit_status

    if restart_libvirtd:
        libvirtd_serv = utils_libvirtd.Libvirtd()
        libvirtd_serv.restart()

    if test_twice:
        device_target2 = params.get("at_dt_disk_device_target2", device_target)
        device_source = libvirt.create_local_disk(
            "file", path=device_source_path,
            size="1G", disk_format=device_source_format)
        if test_cmd == "attach-disk":
            if address2:
                at_options = at_options_twice
            status = virsh.attach_disk(vm_ref, device_source,
                                       device_target2, at_options,
                                       debug=True).exit_status
        elif test_cmd == "detach-disk":
            status = virsh.detach_disk(vm_ref, device_target2, dt_options,
                                       debug=True).exit_status

    # Resume guest after command. On newer libvirt this is fixed as it has
    # been a bug. The change in xml file is done after the guest is resumed.
    if pre_vm_state == "paused":
        vm.resume()

    # Check audit log
    check_audit_after_cmd = True
    if test_audit:
        grep_audit = ('grep "%s" /var/log/audit/audit.log'
                      % test_cmd.split("-")[0])
        cmd = (grep_audit + ' | ' + 'grep "%s" | tail -n1 | grep "res=success"'
               % device_source)
        if utils.run(cmd).exit_status:
            logging.error("Audit check failed")
            check_audit_after_cmd = False

    # Need wait a while for xml to sync
    time.sleep(float(time_sleep))
    # Check disk count after command.
    check_count_after_cmd = True
    disk_count_after_cmd = vm_xml.VMXML.get_disk_count(vm_name)
    if test_cmd == "attach-disk":
        if disk_count_after_cmd == disk_count_before_cmd:
            check_count_after_cmd = False
    elif test_cmd == "detach-disk":
        if disk_count_after_cmd < disk_count_before_cmd:
            check_count_after_cmd = False

    # Recover VM state.
    if pre_vm_state == "shut off":
        vm.start()

    # Check in VM after command.
    check_vm_after_cmd = True
    check_vm_after_cmd = check_vm_partition(vm, device, os_type,
                                            device_target, old_parts)

    # Check disk type after attach.
    check_disk_type = True
    if test_type:
        if test_block_dev:
            check_disk_type = vm_xml.VMXML.check_disk_type(vm_name,
                                                           device_source,
                                                           "block")
        else:
            check_disk_type = vm_xml.VMXML.check_disk_type(vm_name,
                                                           device_source,
                                                           "file")
    # Check disk serial after attach.
    check_disk_serial = True
    if serial:
        disk_serial = vm_xml.VMXML.get_disk_serial(vm_name, device_target)
        if serial != disk_serial:
            check_disk_serial = False

    # Check disk address after attach.
    check_disk_address = True
    if address:
        disk_address = vm_xml.VMXML.get_disk_address(vm_name, device_target)
        if address != disk_address:
            check_disk_address = False

    # Check multifunction address after attach.
    check_disk_address2 = True
    if address2:
        disk_address2 = vm_xml.VMXML.get_disk_address(vm_name, device_target2)
        if address2 != disk_address2:
            check_disk_address2 = False

    # Check disk cache option after attach.
    check_cache_after_cmd = True
    if cache_options:
        disk_cache = vm_xml.VMXML.get_disk_attr(vm_name, device_target,
                                                "driver", "cache")
        if cache_options == "default":
            if disk_cache is not None:
                check_cache_after_cmd = False
        elif disk_cache != cache_options:
            check_cache_after_cmd = False

    # Eject cdrom test
    eject_cdrom = "yes" == params.get("at_dt_disk_eject_cdrom", "no")
    save_vm = "yes" == params.get("at_dt_disk_save_vm", "no")
    save_file = os.path.join(test.tmpdir, "vm.save")
    try:
        if eject_cdrom:
            eject_params = {'type_name': "file", 'device_type': "cdrom",
                            'target_dev': "hdc", 'target_bus': "ide"}
            eject_xml = libvirt.create_disk_xml(eject_params)
            logging.debug("Eject CDROM by XML: %s", open(eject_xml).read())
            # Run command tiwce to make sure cdrom tray open first #BZ892289
            # Open tray
            virsh.attach_device(domainarg=vm_name, filearg=eject_xml, debug=True)
            # Add time sleep between two attach commands.
            if time_sleep:
                time.sleep(float(time_sleep))
            # Eject cdrom
            result = virsh.attach_device(domainarg=vm_name, filearg=eject_xml,
                                         debug=True)
            if result.exit_status != 0:
                raise error.TestFail("Eject CDROM failed")
            if vm_xml.VMXML.check_disk_exist(vm_name, device_source):
                raise error.TestFail("Find %s after do eject" % device_source)
        # Save and restore VM
        if save_vm:
            result = virsh.save(vm_name, save_file, debug=True)
            libvirt.check_exit_status(result)
            result = virsh.restore(save_file, debug=True)
            libvirt.check_exit_status(result)
            if vm_xml.VMXML.check_disk_exist(vm_name, device_source):
                raise error.TestFail("Find %s after do restore" % device_source)

        # Destroy VM.
        vm.destroy(gracefully=False)

        # Check disk count after VM shutdown (with --config).
        check_count_after_shutdown = True
        inactive_vmxml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
        disk_count_after_shutdown = len(inactive_vmxml.get_disk_all())
        if test_cmd == "attach-disk":
            if disk_count_after_shutdown == disk_count_before_cmd:
                check_count_after_shutdown = False
        elif test_cmd == "detach-disk":
            if disk_count_after_shutdown < disk_count_before_cmd:
                check_count_after_shutdown = False

    finally:
        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        backup_xml.sync()
        if os.path.exists(save_file):
            os.remove(save_file)
        if test_block_dev:
            if test_logcial_dev:
                libvirt.delete_local_disk("lvm", vgname=vg_name, lvname=lv_name)
                lv_utils.vg_remove(vg_name)
                utils.system("pvremove %s" % device_source, ignore_status=True)
            libvirt.setup_or_cleanup_iscsi(False)
        else:
            libvirt.delete_local_disk("file", device_source)

    # Check results.
    if status_error:
        if not status:
            raise error.TestFail("virsh %s exit with unexpected value."
                                 % test_cmd)
    else:
        if status:
            raise error.TestFail("virsh %s failed." % test_cmd)
        if test_cmd == "attach-disk":
            if at_options.count("config"):
                if not check_count_after_shutdown:
                    raise error.TestFail("Cannot see config attached device "
                                         "in xml file after VM shutdown.")
                if not check_disk_serial:
                    raise error.TestFail("Serial set failed after attach")
                if not check_disk_address:
                    raise error.TestFail("Address set failed after attach")
                if not check_disk_address2:
                    raise error.TestFail("Address(multifunction) set failed"
                                         " after attach")
            else:
                if not check_count_after_cmd:
                    raise error.TestFail("Cannot see device in xml file"
                                         " after attach.")
                if not check_vm_after_cmd:
                    raise error.TestFail("Cannot see device in VM after"
                                         " attach.")
                if not check_disk_type:
                    raise error.TestFail("Check disk type failed after"
                                         " attach.")
                if not check_audit_after_cmd:
                    raise error.TestFail("Audit hotplug failure after attach")
                if not check_cache_after_cmd:
                    raise error.TestFail("Check cache failure after attach")
                if at_options.count("persistent"):
                    if not check_count_after_shutdown:
                        raise error.TestFail("Cannot see device attached "
                                             "with persistent after "
                                             "VM shutdown.")
                else:
                    if check_count_after_shutdown:
                        raise error.TestFail("See non-config attached device "
                                             "in xml file after VM shutdown.")
        elif test_cmd == "detach-disk":
            if dt_options.count("config"):
                if check_count_after_shutdown:
                    raise error.TestFail("See config detached device in "
                                         "xml file after VM shutdown.")
            else:
                if check_count_after_cmd:
                    raise error.TestFail("See device in xml file "
                                         "after detach.")
                if check_vm_after_cmd:
                    raise error.TestFail("See device in VM after detach.")
                if not check_audit_after_cmd:
                    raise error.TestFail("Audit hotunplug failure "
                                         "after detach")

                if dt_options.count("persistent"):
                    if check_count_after_shutdown:
                        raise error.TestFail("See device deattached "
                                             "with persistent after "
                                             "VM shutdown.")
                else:
                    if not check_count_after_shutdown:
                        raise error.TestFail("See non-config detached "
                                             "device in xml file after "
                                             "VM shutdown.")

        else:
            raise error.TestError("Unknown command %s." % test_cmd)

Example 5

View license
def run(test, params, env):
    """
    Test virsh {at|de}tach-disk command for lxc.

    The command can attach new disk/detach disk.
    1.Prepare test environment,destroy or suspend a VM.
    2.Perform virsh attach/detach-disk operation.
    3.Recover test environment.
    4.Confirm the test result.
    """

    vm_ref = params.get("at_dt_disk_vm_ref", "name")
    at_options = params.get("at_dt_disk_at_options", "")
    dt_options = params.get("at_dt_disk_dt_options", "")
    pre_vm_state = params.get("at_dt_disk_pre_vm_state", "running")
    status_error = "yes" == params.get("status_error", 'no')
    no_attach = params.get("at_dt_disk_no_attach", 'no')

    # Get test command.
    test_cmd = params.get("at_dt_disk_test_cmd", "attach-disk")

    # Disk specific attributes.
    device_source = params.get("at_dt_disk_device_source", "/dev/sdc1")
    device_target = params.get("at_dt_disk_device_target", "vdd")
    test_twice = "yes" == params.get("at_dt_disk_test_twice", "no")
    test_audit = "yes" == params.get("at_dt_disk_check_audit", "no")
    serial = params.get("at_dt_disk_serial", "")
    address = params.get("at_dt_disk_address", "")
    address2 = params.get("at_dt_disk_address2", "")
    if serial:
        at_options += (" --serial %s" % serial)
    if address2:
        at_options_twice = at_options + (" --address %s" % address2)
    if address:
        at_options += (" --address %s" % address)

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    if vm.is_alive():
        vm.destroy(gracefully=False)
    # Back up xml file.
    backup_xml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    # Create virtual device file if user doesn't prepare a partition.
    test_block_dev = False
    if device_source.count("ENTER"):
        device_source = libvirt.setup_or_cleanup_iscsi(True)
        test_block_dev = True
        if not device_source:
            # We should skip this case
            raise error.TestNAError("Can not get iscsi device name in host")

    if vm.is_alive():
        vm.destroy(gracefully=False)

    # if we are testing audit, we need to start audit servcie first.
    if test_audit:
        auditd_service = Factory.create_service("auditd")
        if not auditd_service.status():
            auditd_service.start()
        logging.info("Auditd service status: %s" % auditd_service.status())

    # If we are testing detach-disk, we need to attach certain device first.
    if test_cmd == "detach-disk" and no_attach != "yes":
        s_attach = virsh.attach_disk(vm_name, device_source, device_target,
                                     "--config").exit_status
        if s_attach != 0:
            logging.error("Attaching device failed before testing detach-disk")

        if test_twice:
            device_target2 = params.get("at_dt_disk_device_target2",
                                        device_target)
            s_attach = virsh.attach_disk(vm_name, device_source,
                                         device_target2,
                                         "--config").exit_status
            if s_attach != 0:
                logging.error("Attaching device failed before testing "
                              "detach-disk test_twice")

    vm.start()

    # Turn VM into certain state.
    if pre_vm_state == "paused":
        logging.info("Suspending %s..." % vm_name)
        if vm.is_alive():
            vm.pause()
    elif pre_vm_state == "shut off":
        logging.info("Shuting down %s..." % vm_name)
        if vm.is_alive():
            vm.destroy(gracefully=False)

    # Get disk count before test.
    disk_count_before_cmd = vm_xml.VMXML.get_disk_count(vm_name)

    # Test.
    domid = vm.get_id()
    domuuid = vm.get_uuid()

    # Confirm how to reference a VM.
    if vm_ref == "name":
        vm_ref = vm_name
    elif vm_ref.find("invalid") != -1:
        vm_ref = params.get(vm_ref)
    elif vm_ref == "id":
        vm_ref = domid
    elif vm_ref == "hex_id":
        vm_ref = hex(int(domid))
    elif vm_ref == "uuid":
        vm_ref = domuuid
    else:
        vm_ref = ""

    if test_cmd == "attach-disk":
        status = virsh.attach_disk(vm_ref, device_source, device_target,
                                   at_options, debug=True).exit_status
    elif test_cmd == "detach-disk":
        status = virsh.detach_disk(vm_ref, device_target, dt_options,
                                   debug=True).exit_status
    if test_twice:
        device_target2 = params.get("at_dt_disk_device_target2", device_target)
        if test_cmd == "attach-disk":
            if address2:
                at_options = at_options_twice
            status = virsh.attach_disk(vm_ref, device_source,
                                       device_target2, at_options,
                                       debug=True).exit_status
        elif test_cmd == "detach-disk":
            status = virsh.detach_disk(vm_ref, device_target2, dt_options,
                                       debug=True).exit_status

    # Resume guest after command. On newer libvirt this is fixed as it has
    # been a bug. The change in xml file is done after the guest is resumed.
    if pre_vm_state == "paused":
        vm.resume()

    # Check audit log
    check_audit_after_cmd = True
    if test_audit:
        grep_audit = ('grep "%s" /var/log/audit/audit.log'
                      % test_cmd.split("-")[0])
        cmd = (grep_audit + ' | ' + 'grep "%s" | tail -n1 | grep "res=success"'
               % device_source)
        if utils.run(cmd).exit_status:
            logging.error("Audit check failed")
            check_audit_after_cmd = False

    # Check disk count after command.
    check_count_after_cmd = True
    disk_count_after_cmd = vm_xml.VMXML.get_disk_count(vm_name)
    if test_cmd == "attach-disk":
        if disk_count_after_cmd == disk_count_before_cmd:
            check_count_after_cmd = False
    elif test_cmd == "detach-disk":
        if disk_count_after_cmd < disk_count_before_cmd:
            check_count_after_cmd = False

    # Recover VM state.
    if pre_vm_state == "shut off":
        vm.start()

    # Check disk type after attach.
    check_disk_type = True
    try:
        check_disk_type = vm_xml.VMXML.check_disk_type(vm_name,
                                                       device_source,
                                                       "block")
    except xcepts.LibvirtXMLError:
        # No disk found
        check_disk_type = False

    # Check disk serial after attach.
    check_disk_serial = True
    if serial:
        disk_serial = vm_xml.VMXML.get_disk_serial(vm_name, device_target)
        if serial != disk_serial:
            check_disk_serial = False

    # Check disk address after attach.
    check_disk_address = True
    if address:
        disk_address = vm_xml.VMXML.get_disk_address(vm_name, device_target)
        if utils_test.canonicalize_disk_address(address) !=\
           utils_test.canonicalize_disk_address(disk_address):
            check_disk_address = False

    # Check multifunction address after attach.
    check_disk_address2 = True
    if address2:
        disk_address2 = vm_xml.VMXML.get_disk_address(vm_name, device_target2)
        if utils_test.canonicalize_disk_address(address2) !=\
           utils_test.canonicalize_disk_address(disk_address2):
            check_disk_address2 = False

    # Destroy VM.
    vm.destroy(gracefully=False)

    # Check disk count after VM shutdown (with --config).
    check_count_after_shutdown = True
    disk_count_after_shutdown = vm_xml.VMXML.get_disk_count(vm_name)
    if test_cmd == "attach-disk":
        if disk_count_after_shutdown == disk_count_before_cmd:
            check_count_after_shutdown = False
    elif test_cmd == "detach-disk":
        if disk_count_after_shutdown < disk_count_before_cmd:
            check_count_after_shutdown = False

    # Recover VM.
    if vm.is_alive():
        vm.destroy(gracefully=False)
    backup_xml.sync()
    if test_block_dev:
        libvirt.setup_or_cleanup_iscsi(False)

    # Check results.
    if status_error:
        if not status:
            raise error.TestFail("virsh %s exit with unexpected value."
                                 % test_cmd)
    else:
        if status:
            raise error.TestFail("virsh %s failed." % test_cmd)
        if test_cmd == "attach-disk":
            if at_options.count("config"):
                if not check_count_after_shutdown:
                    raise error.TestFail("Cannot see config attached device "
                                         "in xml file after VM shutdown.")
                if not check_disk_serial:
                    raise error.TestFail("Serial set failed after attach")
                if not check_disk_address:
                    raise error.TestFail("Address set failed after attach")
                if not check_disk_address2:
                    raise error.TestFail("Address(multifunction) set failed"
                                         " after attach")
            else:
                if not check_count_after_cmd:
                    raise error.TestFail("Cannot see device in xml file"
                                         " after attach.")
                if not check_disk_type:
                    raise error.TestFail("Check disk type failed after"
                                         " attach.")
                if not check_audit_after_cmd:
                    raise error.TestFail("Audit hotplug failure after attach")
                if at_options.count("persistent"):
                    if not check_count_after_shutdown:
                        raise error.TestFail("Cannot see device attached "
                                             "with persistent after "
                                             "VM shutdown.")
                else:
                    if check_count_after_shutdown:
                        raise error.TestFail("See non-config attached device "
                                             "in xml file after VM shutdown.")
        elif test_cmd == "detach-disk":
            if dt_options.count("config"):
                if check_count_after_shutdown:
                    raise error.TestFail("See config detached device in "
                                         "xml file after VM shutdown.")
            else:
                if check_count_after_cmd:
                    raise error.TestFail("See device in xml file "
                                         "after detach.")
                if not check_audit_after_cmd:
                    raise error.TestFail("Audit hotunplug failure "
                                         "after detach")

                if dt_options.count("persistent"):
                    if check_count_after_shutdown:
                        raise error.TestFail("See device deattached "
                                             "with persistent after "
                                             "VM shutdown.")
                else:
                    if not check_count_after_shutdown:
                        raise error.TestFail("See non-config detached "
                                             "device in xml file after "
                                             "VM shutdown.")

        else:
            raise error.TestError("Unknown command %s." % test_cmd)

Example 6

View license
def run(test, params, env):
    """
    Test command: virsh domif-setlink and domif-getlink.

    The command   set and get link state of a virtual interface
    1. Prepare test environment.
    2. Perform virsh domif-setlink and domif-getlink operation.
    3. Recover test environment.
    4. Confirm the test result.
    """

    def domif_setlink(vm, device, operation, options):
        """
        Set the domain link state

        :param vm : domain name
        :param device : domain virtual interface
        :param opration : domain virtual interface state
        :param options : some options like --config

        """

        return virsh.domif_setlink(vm, device, operation, options, debug=True)

    def domif_getlink(vm, device, options):
        """
        Get the domain link state

        :param vm : domain name
        :param device : domain virtual interface
        :param options : some options like --config

        """

        return virsh.domif_getlink(vm, device, options,
                                   ignore_status=True, debug=True)

    def guest_cmd_check(cmd, session, pattern):
        """
        Check cmd output with pattern in session
        """
        try:
            cmd_status, output = session.cmd_status_output(cmd, timeout=10)
            logging.info("exit: %s, output: %s",
                         cmd_status, output)
            return re.search(pattern, output)
        except (aexpect.ShellTimeoutError, aexpect.ShellStatusError), e:
            logging.debug(e)
            return re.search(pattern, str(e.__str__))

    def guest_if_state(if_name, session):
        """
        Get the domain link state from the guest
        """
        # Get link state by ethtool
        cmd = "ethtool %s" % if_name
        pattern = "Link detected: ([a-zA-Z]+)"
        ret = guest_cmd_check(cmd, session, pattern)
        if ret:
            return ret.group(1) == "yes"
        else:
            return False

    def check_update_device(vm, if_name, session):
        """
        Change link state by upadte-device command, Check the results
        """
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm.name)

        # Get interface xml object
        iface = vmxml.get_devices(device_type="interface")[0]
        if iface.address:
            del iface.address

        # Change link state to up
        iface.link_state = "up"
        iface.xmltreefile.write()
        ret = virsh.update_device(vm.name, iface.xml,
                                  ignore_status=True, debug=True)
        if ret.exit_status:
            logging.error("Failed to update device to up state")
            return False
        if not guest_if_state(if_name, session):
            logging.error("Guest link should be up now")
            return False

        # Change link state to down
        iface.link_state = "down"
        iface.xmltreefile.write()
        ret = virsh.update_device(vm.name, iface.xml,
                                  ignore_status=True, debug=True)
        if ret.exit_status:
            logging.error("Failed to update device to down state")
            return False
        if guest_if_state(if_name, session):
            logging.error("Guest link should be down now")
            return False

        # Passed all test
        return True

    vm_name = []
    # vm_name list:first element for original name in config
    vm_name.append(params.get("main_vm", "avocado-vt-vm1"))
    vm = env.get_vm(vm_name[0])
    options = params.get("if_options", "--config")
    start_vm = params.get("start_vm", "no")
    domain = params.get("domain", "name")
    if_device = params.get("if_device", "net")
    if_name = params.get("if_name", "vnet0")
    if_operation = params.get("if_operation", "up")
    status_error = params.get("status_error", "no")
    mac_address = vm.get_virsh_mac_address(0)
    check_link_state = "yes" == params.get("check_link_state", "no")
    check_link_by_update_device = "yes" == params.get(
        "excute_update_device", "no")
    device = "vnet0"
    username = params.get("username")
    password = params.get("password")

    # Back up xml file.
    vm_xml_file = os.path.join(test.tmpdir, "vm.xml")
    virsh.dumpxml(vm_name[0], extra="--inactive", to_file=vm_xml_file)

    # Vm status
    if start_vm == "yes" and vm.is_dead():
        vm.start()

    elif start_vm == "no" and vm.is_alive():
        vm.destroy()

    # vm_name list: second element for 'domain' in virsh command
    if domain == "ID":
        # Get ID for the running domain
        vm_name.append(vm.get_id())
    elif domain == "UUID":
        # Get UUID for the domain
        vm_name.append(vm.get_uuid())
    elif domain == "no_match_UUID":
        # Generate a random UUID
        vm_name.append(uuid.uuid1())
    elif domain == "no_match_name":
        # Generate a random string as domain name
        vm_name.append(utils_misc.generate_random_string(6))
    elif domain == " ":
        # Set domain name empty
        vm_name.append("''")
    else:
        # Set domain name
        vm_name.append(vm_name[0])

    try:
        # Test device net or mac address
        if if_device == "net" and vm.is_alive():
            device = if_name
            # Get all vm's interface device
            device = vm_xml.VMXML.get_net_dev(vm_name[0])[0]

        elif if_device == "mac":
            device = mac_address

        # Test no exist device
        if if_device == "no_exist_net":
            device = "vnet-1"
        elif if_device == "no_exist_mac":
            # Generate random mac address for negative test
            device = utils_net.VirtIface.complete_mac_address("01:02")
        elif if_device == " ":
            device = "''"

        # Setlink opertation
        result = domif_setlink(vm_name[1], device, if_operation, options)
        status = result.exit_status
        logging.info("Setlink done")

        # Getlink opertation
        get_result = domif_getlink(vm_name[1], device, options)
        getlink_output = get_result.stdout.strip()

        # Check the getlink command output
        if status_error == "no":
            if not re.search(if_operation, getlink_output):
                raise error.TestFail("Getlink result should "
                                     "equal with setlink operation")

        logging.info("Getlink done")
        # If --config is given should restart the vm then test link status
        if options == "--config" and vm.is_alive():
            vm.destroy()
            vm.start()
            logging.info("Restart VM")

        elif start_vm == "no":
            vm.start()

        error_msg = None
        if status_error == "no":
            # Serial login the vm to check link status
            # Start vm check the link statue
            session = vm.wait_for_serial_login(username=username,
                                               password=password)
            guest_if_name = utils_net.get_linux_ifname(session, mac_address)

            # Check link state in guest
            if check_link_state:
                if (if_operation == "up" and
                        not guest_if_state(guest_if_name, session)):
                    error_msg = "Link state should be up in guest"
                if (if_operation == "down" and
                        guest_if_state(guest_if_name, session)):
                    error_msg = "Link state should be down in guest"

            # Test of setting link state by update_device command
            if check_link_by_update_device:
                if not check_update_device(vm, guest_if_name, session):
                    error_msg = "Check update_device failed"

            # Set the link up make host connect with vm
            domif_setlink(vm_name[0], device, "up", "")
            if not utils_misc.wait_for(
                    lambda: guest_if_state(guest_if_name, session), 5):
                error_msg = "Link state isn't up in guest"

            # Ignore status of this one
            cmd = 'ifdown %s' % guest_if_name
            pattern = "Device '%s' successfully disconnected" % guest_if_name
            guest_cmd_check(cmd, session, pattern)

            cmd = 'ifup %s' % guest_if_name
            pattern = "Determining IP information for %s" % guest_if_name
            pattern += "|Connection successfully activated"
            if not guest_cmd_check(cmd, session, pattern):
                error_msg = ("Could not bring up interface %s inside guest"
                             % guest_if_name)
        else:  # negative test
            # stop guest, so state is always consistent on next start
            vm.destroy()

        if error_msg:
            raise error.TestFail(error_msg)

        # Check status_error
        if status_error == "yes":
            if status:
                logging.info("Expected error (negative testing). Output: %s",
                             result.stderr.strip())

            else:
                raise error.TestFail("Unexpected return code %d "
                                     "(negative testing)" % status)
        elif status_error != "no":
            raise error.TestError("Invalid value for status_error '%s' "
                                  "(must be 'yes' or 'no')" % status_error)
    finally:
        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        virsh.undefine(vm_name[0])
        virsh.define(vm_xml_file)
        os.remove(vm_xml_file)

Example 7

Project: tp-libvirt
Source File: virsh_undefine.py
View license
def run(test, params, env):
    """
    Test virsh undefine command.

    Undefine an inactive domain, or convert persistent to transient.
    1.Prepare test environment.
    2.Backup the VM's information to a xml file.
    3.When the libvirtd == "off", stop the libvirtd service.
    4.Perform virsh undefine operation.
    5.Recover test environment.(libvirts service,VM)
    6.Confirm the test result.
    """

    vm_ref = params.get("undefine_vm_ref", "vm_name")
    extra = params.get("undefine_extra", "")
    option = params.get("undefine_option", "")
    libvirtd_state = params.get("libvirtd", "on")
    status_error = ("yes" == params.get("status_error", "no"))
    undefine_twice = ("yes" == params.get("undefine_twice", 'no'))
    local_ip = params.get("local_ip", "LOCAL.EXAMPLE.COM")
    local_pwd = params.get("local_pwd", "password")
    remote_ip = params.get("remote_ip", "REMOTE.EXAMPLE.COM")
    remote_user = params.get("remote_user", "user")
    remote_pwd = params.get("remote_pwd", "password")
    remote_prompt = params.get("remote_prompt", "#")
    pool_type = params.get("pool_type")
    pool_name = params.get("pool_name", "test")
    pool_target = params.get("pool_target")
    volume_size = params.get("volume_size", "1G")
    vol_name = params.get("vol_name", "test_vol")
    emulated_img = params.get("emulated_img", "emulated_img")
    emulated_size = "%sG" % (int(volume_size[:-1]) + 1)
    disk_target = params.get("disk_target", "vdb")
    wipe_data = "yes" == params.get("wipe_data", "no")
    if wipe_data:
        option += " --wipe-storage"

    vm_name = params.get("main_vm", "avocado-vt-vm1")
    vm = env.get_vm(vm_name)
    vm_id = vm.get_id()
    vm_uuid = vm.get_uuid()

    # polkit acl related params
    uri = params.get("virsh_uri")
    unprivileged_user = params.get('unprivileged_user')
    if unprivileged_user:
        if unprivileged_user.count('EXAMPLE'):
            unprivileged_user = 'testacl'

    if not libvirt_version.version_compare(1, 1, 1):
        if params.get('setup_libvirt_polkit') == 'yes':
            raise error.TestNAError("API acl test not supported in current"
                                    " libvirt version.")

    # Back up xml file.Xen host has no guest xml file to define a guset.
    backup_xml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    # Confirm how to reference a VM.
    if vm_ref == "vm_name":
        vm_ref = vm_name
    elif vm_ref == "id":
        vm_ref = vm_id
    elif vm_ref == "hex_vm_id":
        vm_ref = hex(int(vm_id))
    elif vm_ref == "uuid":
        vm_ref = vm_uuid
    elif vm_ref.find("invalid") != -1:
        vm_ref = params.get(vm_ref)

    volume = None
    pvtest = None
    status3 = None

    elems = backup_xml.xmltreefile.findall('/devices/disk/source')
    existing_images = [elem.get('file') for elem in elems]

    # Backup images since remove-all-storage could remove existing libvirt
    # managed guest images
    if existing_images and option.count("remove-all-storage"):
        for img in existing_images:
            backup_img = img + '.bak'
            logging.info('Backup %s to %s', img, backup_img)
            shutil.copyfile(img, backup_img)

    try:
        save_file = "/var/lib/libvirt/qemu/save/%s.save" % vm_name
        if option.count("managedsave") and vm.is_alive():
            virsh.managedsave(vm_name)

        if not vm.is_lxc():
            snp_list = virsh.snapshot_list(vm_name)
            if option.count("snapshot"):
                snp_file_list = []
                if not len(snp_list):
                    virsh.snapshot_create(vm_name)
                    logging.debug("Create a snapshot for test!")
                else:
                    # Backup snapshots for domain
                    for snp_item in snp_list:
                        tmp_file = os.path.join(test.tmpdir, snp_item + ".xml")
                        virsh.snapshot_dumpxml(vm_name, snp_item, to_file=tmp_file)
                        snp_file_list.append(tmp_file)
            else:
                if len(snp_list):
                    raise error.TestNAError("This domain has snapshot(s), "
                                            "cannot be undefined!")
        if option.count("remove-all-storage"):
            pvtest = utlv.PoolVolumeTest(test, params)
            pvtest.pre_pool(pool_name, pool_type, pool_target, emulated_img,
                            emulated_size=emulated_size)
            new_pool = libvirt_storage.PoolVolume(pool_name)
            if not new_pool.create_volume(vol_name, volume_size):
                raise error.TestFail("Creation of volume %s failed." % vol_name)
            volumes = new_pool.list_volumes()
            volume = volumes[vol_name]
            virsh.attach_disk(vm_name, volume, disk_target, "--config")

        # Turn libvirtd into certain state.
        if libvirtd_state == "off":
            utils_libvirtd.libvirtd_stop()

        # Test virsh undefine command.
        output = ""
        if vm_ref != "remote":
            vm_ref = "%s %s" % (vm_ref, extra)
            cmdresult = virsh.undefine(vm_ref, option,
                                       unprivileged_user=unprivileged_user,
                                       uri=uri,
                                       ignore_status=True, debug=True)
            status = cmdresult.exit_status
            output = cmdresult.stdout.strip()
            if status:
                logging.debug("Error status, command output: %s",
                              cmdresult.stderr.strip())
            if undefine_twice:
                status2 = virsh.undefine(vm_ref,
                                         ignore_status=True).exit_status
        else:
            if remote_ip.count("EXAMPLE.COM") or local_ip.count("EXAMPLE.COM"):
                raise error.TestNAError("remote_ip and/or local_ip parameters"
                                        " not changed from default values")
            try:
                local_user = params.get("username", "root")
                uri = libvirt_vm.complete_uri(local_ip)
                # setup ssh auto login from remote machine to test machine
                # for the command to execute remotely
                ssh_key.setup_remote_ssh_key(remote_ip, remote_user,
                                             remote_pwd, hostname2=local_ip,
                                             user2=local_user,
                                             password2=local_pwd)
                session = remote.remote_login("ssh", remote_ip, "22",
                                              remote_user, remote_pwd,
                                              remote_prompt)
                cmd_undefine = "virsh -c %s undefine %s" % (uri, vm_name)
                status, output = session.cmd_status_output(cmd_undefine)
                logging.info("Undefine output: %s", output)
            except (process.CmdError, remote.LoginError, aexpect.ShellError), de:
                logging.error("Detail: %s", de)
                status = 1

        # Recover libvirtd state.
        if libvirtd_state == "off":
            utils_libvirtd.libvirtd_start()

        # Shutdown VM.
        if virsh.domain_exists(vm.name):
            try:
                if vm.is_alive():
                    vm.destroy(gracefully=False)
            except process.CmdError, detail:
                logging.error("Detail: %s", detail)

        # After vm.destroy, virsh.domain_exists returns True due to
        # timing issue and tests fails.
        time.sleep(2)
        # Check if VM exists.
        vm_exist = virsh.domain_exists(vm_name)

        # Check if xml file exists.
        xml_exist = False
        if vm.is_qemu() and os.path.exists("/etc/libvirt/qemu/%s.xml" % vm_name):
            xml_exist = True
        if vm.is_lxc() and os.path.exists("/etc/libvirt/lxc/%s.xml" % vm_name):
            xml_exist = True
        if vm.is_xen() and os.path.exists("/etc/xen/%s" % vm_name):
            xml_exist = True

        # Check if save file exists if use --managed-save
        save_exist = os.path.exists(save_file)

        # Check if save file exists if use --managed-save
        volume_exist = volume and os.path.exists(volume)

        # Test define with acl control and recover domain.
        if params.get('setup_libvirt_polkit') == 'yes':
            if virsh.domain_exists(vm.name):
                virsh.undefine(vm_ref, ignore_status=True)
            cmd = "chmod 666 %s" % backup_xml.xml
            process.run(cmd, ignore_status=False, shell=True)
            s_define = virsh.define(backup_xml.xml,
                                    unprivileged_user=unprivileged_user,
                                    uri=uri, ignore_status=True, debug=True)
            status3 = s_define.exit_status

    finally:
        # Recover main VM.
        try:
            backup_xml.sync()
        except LibvirtXMLError:
            # sync() tries to undefines and define the xml to sync
            # but virsh_undefine test would have undefined already
            # may lead to error out
            backup_xml.define()

        # Recover existing guest images
        if existing_images and option.count("remove-all-storage"):
            for img in existing_images:
                backup_img = img + '.bak'
                logging.info('Recover image %s to %s', backup_img, img)
                shutil.move(backup_img, img)

        # Clean up pool
        if pvtest:
            pvtest.cleanup_pool(pool_name, pool_type,
                                pool_target, emulated_img)
        # Recover VM snapshots.
        if option.count("snapshot") and (not vm.is_lxc()):
            logging.debug("Recover snapshots for domain!")
            for file_item in snp_file_list:
                virsh.snapshot_create(vm_name, file_item)

    # Check results.
    if status_error:
        if not status:
            raise error.TestFail("virsh undefine return unexpected result.")
        if params.get('setup_libvirt_polkit') == 'yes':
            if status3 == 0:
                raise error.TestFail("virsh define with false acl permission" +
                                     " should failed.")
    else:
        if status:
            raise error.TestFail("virsh undefine failed.")
        if undefine_twice:
            if not status2:
                raise error.TestFail("Undefine the same VM twice succeeded.")
        if vm_exist:
            raise error.TestFail("VM still exists after undefine.")
        if xml_exist:
            raise error.TestFail("Xml file still exists after undefine.")
        if option.count("managedsave") and save_exist:
            raise error.TestFail("Save file still exists after undefine.")
        if option.count("remove-all-storage") and volume_exist:
            raise error.TestFail("Volume file '%s' still exists after"
                                 " undefine." % volume)
        if wipe_data and option.count("remove-all-storage"):
            if not output.count("Wiping volume '%s'" % disk_target):
                raise error.TestFail("Command didn't wipe volume storage!")
        if params.get('setup_libvirt_polkit') == 'yes':
            if status3:
                raise error.TestFail("virsh define with right acl permission" +
                                     " should succeeded")

Example 8

View license
def run(test, params, env):
    """
    Test disk encryption option.

    1.Prepare test environment,destroy or suspend a VM.
    2.Prepare pool, volume.
    3.Edit disks xml and start the domain.
    4.Perform test operation.
    5.Recover test environment.
    6.Confirm the test result.
    """

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    virsh_dargs = {'debug': True, 'ignore_status': True}

    def create_pool(p_name, p_type, p_target):
        """
        Define and start a pool.

        :param p_name. Pool name.
        :param p_type. Pool type.
        :param p_target. Pool target path.
        """
        p_xml = pool_xml.PoolXML(pool_type=p_type)
        p_xml.name = p_name
        p_xml.target_path = p_target

        if not os.path.exists(p_target):
            os.mkdir(p_target)
        p_xml.xmltreefile.write()
        ret = virsh.pool_define(p_xml.xml, **virsh_dargs)
        libvirt.check_exit_status(ret)
        ret = virsh.pool_build(p_name, **virsh_dargs)
        libvirt.check_exit_status(ret)
        ret = virsh.pool_start(p_name, **virsh_dargs)
        libvirt.check_exit_status(ret)

    def create_vol(p_name, p_format, vol_params):
        """
        Create volume.

        :param p_name. Pool name.
        :param vol_params. Volume parameters dict.
        :return: True if create successfully.
        """
        volxml = vol_xml.VolXML()
        v_xml = volxml.new_vol(**vol_params)
        v_xml.encryption = volxml.new_encryption(
            **{"format": p_format})
        v_xml.xmltreefile.write()
        ret = virsh.vol_create(p_name, v_xml.xml, **virsh_dargs)
        libvirt.check_exit_status(ret)

    def check_in_vm(vm, target, old_parts):
        """
        Check mount/read/write disk in VM.
        :param vm. VM guest.
        :param target. Disk dev in VM.
        :return: True if check successfully.
        """
        try:
            session = vm.wait_for_login()
            rpm_stat = session.cmd_status("rpm -q parted || "
                                          "yum install -y parted", 300)
            if rpm_stat != 0:
                raise error.TestFail("Failed to query/install parted, make sure"
                                     " that you have usable repo in guest")

            new_parts = libvirt.get_parts_list(session)
            added_parts = list(set(new_parts).difference(set(old_parts)))
            logging.info("Added parts:%s", added_parts)
            if len(added_parts) != 1:
                logging.error("The number of new partitions is invalid in VM")
                return False

            added_part = None
            if target.startswith("vd"):
                if added_parts[0].startswith("vd"):
                    added_part = added_parts[0]
            elif target.startswith("hd"):
                if added_parts[0].startswith("sd"):
                    added_part = added_parts[0]

            if not added_part:
                logging.error("Cann't see added partition in VM")
                return False

            libvirt.mk_part("/dev/%s" % added_part, size="10M", session=session)
            # Run partprobe to make the change take effect.
            process.run("partprobe", ignore_status=True, shell=True)
            libvirt.mkfs("/dev/%s1" % added_part, "ext3", session=session)

            cmd = ("mount /dev/%s1 /mnt && echo '123' > /mnt/testfile"
                   " && cat /mnt/testfile && umount /mnt" % added_part)
            s, o = session.cmd_status_output(cmd)
            logging.info("Check disk operation in VM:\n%s", o)
            session.close()
            if s != 0:
                return False
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    # Disk specific attributes.
    device = params.get("virt_disk_device", "disk")
    device_target = params.get("virt_disk_device_target", "vdd")
    device_type = params.get("virt_disk_device_type", "file")
    device_bus = params.get("virt_disk_device_bus", "virtio")

    # Pool/Volume options.
    pool_name = params.get("pool_name")
    pool_type = params.get("pool_type")
    pool_target = params.get("pool_target")
    volume_name = params.get("vol_name")
    volume_alloc = params.get("vol_alloc")
    volume_cap_unit = params.get("vol_cap_unit")
    volume_cap = params.get("vol_cap")
    volume_target_path = params.get("target_path")
    volume_target_format = params.get("target_format")
    volume_target_encypt = params.get("target_encypt", "")
    volume_target_label = params.get("target_label")

    status_error = "yes" == params.get("status_error")

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)

    # Start vm and get all partions in vm.
    if vm.is_dead():
        vm.start()
    session = vm.wait_for_login()
    old_parts = libvirt.get_parts_list(session)
    session.close()
    vm.destroy(gracefully=False)

    # Back up xml file.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    try:
        # Prepare the disk.
        sec_uuid = []
        create_pool(pool_name, pool_type, pool_target)
        vol_params = {"name": volume_name, "capacity": int(volume_cap),
                      "allocation": int(volume_alloc), "format":
                      volume_target_format, "path": volume_target_path,
                      "label": volume_target_label,
                      "capacity_unit": volume_cap_unit}
        create_vol(pool_name, volume_target_encypt, vol_params)

        # Add disk xml.
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)

        disk_xml = Disk(type_name=device_type)
        disk_xml.device = device
        if device_type == "file":
            dev_attrs = "file"
        elif device_type == "dir":
            dev_attrs = "dir"
        else:
            dev_attrs = "dev"
        disk_xml.source = disk_xml.new_disk_source(
            **{"attrs": {dev_attrs: volume_target_path}})
        disk_xml.driver = {"name": "qemu", "type": volume_target_format,
                           "cache": "none"}
        disk_xml.target = {"dev": device_target, "bus": device_bus}

        v_xml = vol_xml.VolXML.new_from_vol_dumpxml(volume_name, pool_name)
        sec_uuid.append(v_xml.encryption.secret["uuid"])
        if not status_error:
            logging.debug("vol info -- format: %s, type: %s, uuid: %s",
                          v_xml.encryption.format,
                          v_xml.encryption.secret["type"],
                          v_xml.encryption.secret["uuid"])
            disk_xml.encryption = disk_xml.new_encryption(
                **{"encryption": v_xml.encryption.format, "secret": {
                    "type": v_xml.encryption.secret["type"],
                    "uuid": v_xml.encryption.secret["uuid"]}})

        # Sync VM xml.
        vmxml.add_device(disk_xml)
        vmxml.sync()

        try:
            # Start the VM and check status.
            vm.start()
            if status_error:
                raise error.TestFail("VM started unexpectedly.")

            if not check_in_vm(vm, device_target, old_parts):
                raise error.TestFail("Check encryption disk in VM failed")
        except virt_vm.VMStartError, e:
            if status_error:
                logging.debug("VM failed to start as expected."
                              "Error: %s" % str(e))
                pass
            else:
                raise error.TestFail("VM failed to start."
                                     "Error: %s" % str(e))

    finally:
        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        logging.info("Restoring vm...")
        vmxml_backup.sync()

        # Clean up pool, vol
        for i in sec_uuid:
            virsh.secret_undefine(i, **virsh_dargs)
            virsh.vol_delete(volume_name, pool_name, **virsh_dargs)
        if virsh.pool_state_dict().has_key(pool_name):
            virsh.pool_destroy(pool_name, **virsh_dargs)
            virsh.pool_undefine(pool_name, **virsh_dargs)

Example 9

View license
def run(test, params, env):
    """
    Test multiple disks attachment.

    1.Prepare test environment,destroy or suspend a VM.
    2.Perform 'qemu-img create' operation.
    3.Edit disks xml and start the domain.
    4.Perform test operation.
    5.Recover test environment.
    6.Confirm the test result.
    """
    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    virsh_dargs = {'debug': True, 'ignore_status': True}

    def check_disk_order(targets_name):
        """
        Check VM disk's order on pci bus.

        :param targets_name. Disks target list.
        :return: True if check successfully.
        """
        logging.info("Checking VM disks order...")
        xml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        disk_list = xml.devices.by_device_tag("disk")
        slot_dict = {}
        # Get the disks pci slot.
        for disk in disk_list:
            if 'virtio' == disk.target['bus']:
                slot_dict[disk.target['dev']] = int(
                    disk.address.attrs['slot'], base=16)
        # Disk's order on pci bus should keep the same with disk target name.
        s_dev = sorted(slot_dict.keys())
        s_slot = sorted(slot_dict.values())
        for i in range(len(s_dev)):
            if s_dev[i] in targets_name and slot_dict[s_dev[i]] != s_slot[i]:
                return False
        return True

    def setup_nfs_disk(disk_name, disk_type, disk_format="raw"):
        """
        Setup nfs disk.
        """
        mount_src = os.path.join(test.tmpdir, "nfs-export")
        if not os.path.exists(mount_src):
            os.mkdir(mount_src)
        mount_dir = os.path.join(test.tmpdir, "nfs-mount")

        if disk_type in ["file", "floppy", "iso"]:
            disk_path = "%s/%s" % (mount_src, disk_name)
            device_source = libvirt.create_local_disk(disk_type, disk_path, "2",
                                                      disk_format=disk_format)
            #Format the disk.
            if disk_type in ["file", "floppy"]:
                cmd = ("mkfs.ext3 -F %s && setsebool virt_use_nfs true"
                       % device_source)
                if utils.run(cmd, ignore_status=True).exit_status:
                    raise error.TestNAError("Format disk failed")

        nfs_params = {"nfs_mount_dir": mount_dir, "nfs_mount_options": "ro",
                      "nfs_mount_src": mount_src, "setup_local_nfs": "yes",
                      "export_options": "rw,no_root_squash"}

        nfs_obj = nfs.Nfs(nfs_params)
        nfs_obj.setup()
        if not nfs_obj.mount():
            return None

        disk = {"disk_dev": nfs_obj, "format": "nfs", "source":
                "%s/%s" % (mount_dir, os.path.split(device_source)[-1])}

        return disk

    def prepare_disk(path, disk_format):
        """
        Prepare the disk for a given disk format.
        """
        disk = {}
        # Check if we test with a non-existed disk.
        if os.path.split(path)[-1].startswith("notexist."):
            disk.update({"format": disk_format,
                         "source": path})

        elif disk_format == "scsi":
            scsi_option = params.get("virt_disk_device_scsi_option", "")
            disk_source = libvirt.create_scsi_disk(scsi_option)
            if disk_source:
                disk.update({"format": "scsi",
                             "source": disk_source})
            else:
                raise error.TestNAError("Get scsi disk failed")

        elif disk_format in ["iso", "floppy"]:
            disk_path = libvirt.create_local_disk(disk_format, path)
            disk.update({"format": disk_format,
                         "source": disk_path})
        elif disk_format == "nfs":
            nfs_disk_type = params.get("nfs_disk_type", None)
            disk.update(setup_nfs_disk(os.path.split(path)[-1], nfs_disk_type))

        elif disk_format == "iscsi":
            # Create iscsi device if needed.
            image_size = params.get("image_size", "2G")
            device_source = libvirt.setup_or_cleanup_iscsi(
                is_setup=True, is_login=True, image_size=image_size)
            logging.debug("iscsi dev name: %s", device_source)

            # Format the disk and make file system.
            libvirt.mk_part(device_source)
            # Run partprobe to make the change take effect.
            utils.run("partprobe", ignore_status=True)
            libvirt.mkfs("%s1" % device_source, "ext3")
            device_source += "1"
            disk.update({"format": disk_format,
                         "source": device_source})
        elif disk_format in ["raw", "qcow2"]:
            disk_size = params.get("virt_disk_device_size", "1")
            device_source = libvirt.create_local_disk(
                "file", path, disk_size, disk_format=disk_format)
            disk.update({"format": disk_format,
                         "source": device_source})

        return disk

    def check_disk_format(targets_name, targets_format):
        """
        Check VM disk's type.

        :param targets_name. Device name list.
        :param targets_format. Device format list.
        :return: True if check successfully.
        """
        logging.info("Checking VM disks type... ")
        for tn, tf in zip(targets_name, targets_format):
            disk_format = vm_xml.VMXML.get_disk_attr(vm_name, tn,
                                                     "driver", "type")
            if disk_format not in [None, tf]:
                return False
        return True

    def check_vm_partitions(devices, targets_name, exists=True):
        """
        Check VM disk's partition.

        :return: True if check successfully.
        """
        logging.info("Checking VM partittion...")
        try:
            session = vm.wait_for_login()
            for i in range(len(devices)):
                if devices[i] == "cdrom":
                    s, o = session.cmd_status_output(
                        "ls /dev/sr0 && mount /dev/sr0 /mnt &&"
                        " ls /mnt && umount /mnt")
                    logging.info("cdrom devices in VM:\n%s", o)
                elif devices[i] == "floppy":
                    s, o = session.cmd_status_output(
                        "modprobe floppy && ls /dev/fd0")
                    logging.info("floppy devices in VM:\n%s", o)
                else:
                    if targets_name[i] == "hda":
                        target = "sda"
                    else:
                        target = targets_name[i]
                    s, o = session.cmd_status_output(
                        "grep %s /proc/partitions" % target)
                    logging.info("Disk devices in VM:\n%s", o)
                if s != 0:
                    if exists:
                        session.close()
                        return False
                else:
                    if not exists:
                        session.close()
                        return False
            session.close()
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    def check_vm_block_size(targets_name, log_size, phy_size):
        """
        Check VM disk's blocksize.

        :param logical_size. Device logical block size.
        :param physical_size. Device physical block size.
        :return: True if check successfully.
        """
        logging.info("Checking VM block size...")
        try:
            session = vm.wait_for_login()
            for target in targets_name:
                cmd = "cat /sys/block/%s/queue/" % target
                s, o = session.cmd_status_output("%slogical_block_size"
                                                 % cmd)
                logging.debug("logical block size in VM:\n%s", o)
                if s != 0 or o.strip() != log_size:
                    session.close()
                    return False
                s, o = session.cmd_status_output("%sphysical_block_size"
                                                 % cmd)
                logging.debug("physical block size in VM:\n%s", o)
                if s != 0 or o.strip() != phy_size:
                    session.close()
                    return False
            session.close()
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    def check_readonly(targets_name):
        """
        Check disk readonly option.
        """
        logging.info("Checking disk readonly option...")
        try:
            session = vm.wait_for_login()
            for target in targets_name:
                if target == "hdc":
                    mount_cmd = "mount /dev/cdrom /mnt"
                elif target == "fda":
                    mount_cmd = "modprobe floppy && mount /dev/fd0 /mnt"
                else:
                    mount_cmd = "mount /dev/%s /mnt" % target
                cmd = ("(%s && ls /mnt || exit 1) && (echo "
                       "'test' > /mnt/test || umount /mnt)" % mount_cmd)
                s, o = session.cmd_status_output(cmd)
                logging.debug("cmd exit: %s, output: %s", s, o)
                if s:
                    session.close()
                    return False
            session.close()
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    def check_bootorder_snapshot(disk_name):
        """
        Check VM disk's bootorder option with snapshot.

        :param disk_name. The target disk to be checked.
        """
        logging.info("Checking diskorder option with snapshot...")
        snapshot1 = "s1"
        snapshot2 = "s2"
        snapshot2_file = os.path.join(test.tmpdir, "s2")
        ret = virsh.snapshot_create(vm_name, "", **virsh_dargs)
        libvirt.check_exit_status(ret)

        ret = virsh.snapshot_create_as(vm_name, "%s --disk-only" % snapshot1,
                                       **virsh_dargs)
        libvirt.check_exit_status(ret)

        ret = virsh.snapshot_dumpxml(vm_name, snapshot1)
        libvirt.check_exit_status(ret)

        cmd = "echo \"%s\" | grep %s.%s" % (ret.stdout, disk_name, snapshot1)
        if utils.run(cmd, ignore_status=True).exit_status:
            raise error.TestError("Check snapshot disk failed")

        ret = virsh.snapshot_create_as(vm_name,
                                       "%s --memspec file=%s,snapshot=external"
                                       % (snapshot2, snapshot2_file),
                                       **virsh_dargs)
        libvirt.check_exit_status(ret)

        ret = virsh.dumpxml(vm_name)
        libvirt.check_exit_status(ret)

        cmd = ("echo \"%s\" | grep -A 16 %s.%s | grep \"boot order='%s'\""
               % (ret.stdout, disk_name, snapshot2, bootorder))
        if utils.run(cmd, ignore_status=True).exit_status:
            raise error.TestError("Check snapshot disk with bootorder failed")

        snap_lists = virsh.snapshot_list(vm_name)
        if snapshot1 not in snap_lists or snapshot2 not in snap_lists:
            raise error.TestError("Check snapshot list failed")

        # Check virsh save command after snapshot.
        save_file = "/tmp/%s.save" % vm_name
        ret = virsh.save(vm_name, save_file, **virsh_dargs)
        libvirt.check_exit_status(ret)

        # Check virsh restore command after snapshot.
        ret = virsh.restore(save_file, **virsh_dargs)
        libvirt.check_exit_status(ret)

        #Passed all test.
        os.remove(save_file)

    def check_boot_console(bootorders):
        """
        Get console output and check bootorder.
        """
        # Get console output.
        vm.serial_console.read_until_output_matches(
            ["Hard Disk"], utils_misc.strip_console_codes)
        output = vm.serial_console.get_stripped_output()
        logging.debug("serial output: %s", output)
        lines = re.findall(r"^Booting from (.+)...", output, re.M)
        logging.debug("lines: %s", lines)
        if len(lines) != len(bootorders):
            return False
        for i in range(len(bootorders)):
            if lines[i] != bootorders[i]:
                return False

        return True

    def check_disk_save_restore(save_file, device_targets,
                                startup_policy):
        """
        Check domain save and restore operation.
        """
        # Save the domain.
        ret = virsh.save(vm_name, save_file,
                         **virsh_dargs)
        libvirt.check_exit_status(ret)

        # Restore the domain.
        restore_error = False
        # Check disk startup policy option
        if "optional" in startup_policy:
            os.remove(disks[0]["source"])
            restore_error = True
        ret = virsh.restore(save_file, **virsh_dargs)
        libvirt.check_exit_status(ret, restore_error)
        if restore_error:
            return

        # Connect to the domain and check disk.
        try:
            session = vm.wait_for_login()
            cmd = ("ls /dev/%s && mkfs.ext3 -F /dev/%s && mount /dev/%s"
                   " /mnt && ls /mnt && touch /mnt/test && umount /mnt"
                   % (device_targets[0], device_targets[0], device_targets[0]))
            s, o = session.cmd_status_output(cmd)
            if s:
                session.close()
                raise error.TestError("Failed to read/write disk in VM:"
                                      " %s" % o)
            session.close()
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            raise error.TestError(str(e))

    def check_dom_iothread():
        """
        Check iothread by qemu-monitor-command.
        """
        ret = virsh.qemu_monitor_command(vm_name,
                                         '{"execute": "query-iothreads"}',
                                         "--pretty")
        libvirt.check_exit_status(ret)
        logging.debug("Domain iothreads: %s", ret.stdout)
        iothreads_ret = json.loads(ret.stdout)
        if len(iothreads_ret['return']) != int(dom_iothreads):
            raise error.TestFail("Failed to check domain iothreads")

    status_error = "yes" == params.get("status_error", "no")
    define_error = "yes" == params.get("define_error", "no")
    dom_iothreads = params.get("dom_iothreads")

    # Disk specific attributes.
    devices = params.get("virt_disk_device", "disk").split()
    device_source_names = params.get("virt_disk_device_source").split()
    device_targets = params.get("virt_disk_device_target", "vda").split()
    device_formats = params.get("virt_disk_device_format", "raw").split()
    device_types = params.get("virt_disk_device_type", "file").split()
    device_bus = params.get("virt_disk_device_bus", "virtio").split()
    driver_options = params.get("driver_option", "").split()
    device_bootorder = params.get("virt_disk_boot_order", "").split()
    device_readonly = params.get("virt_disk_option_readonly",
                                 "no").split()
    device_attach_error = params.get("disks_attach_error", "").split()
    device_attach_option = params.get("disks_attach_option", "").split(';')
    device_address = params.get("virt_disk_addr_options", "").split()
    startup_policy = params.get("virt_disk_device_startuppolicy", "").split()
    bootorder = params.get("disk_bootorder", "")
    bootdisk_target = params.get("virt_disk_bootdisk_target", "vda")
    bootdisk_bus = params.get("virt_disk_bootdisk_bus", "virtio")
    bootdisk_driver = params.get("virt_disk_bootdisk_driver", "")
    serial = params.get("virt_disk_serial", "")
    wwn = params.get("virt_disk_wwn", "")
    vendor = params.get("virt_disk_vendor", "")
    product = params.get("virt_disk_product", "")
    add_disk_driver = params.get("add_disk_driver")
    iface_driver = params.get("iface_driver_option", "")
    bootdisk_snapshot = params.get("bootdisk_snapshot", "")
    snapshot_option = params.get("snapshot_option", "")
    snapshot_error = "yes" == params.get("snapshot_error", "no")
    add_usb_device = "yes" == params.get("add_usb_device", "no")
    input_usb_address = params.get("input_usb_address", "")
    hub_usb_address = params.get("hub_usb_address", "")
    hotplug = "yes" == params.get(
        "virt_disk_device_hotplug", "no")
    device_at_dt_disk = "yes" == params.get("virt_disk_at_dt_disk", "no")
    device_with_source = "yes" == params.get(
        "virt_disk_with_source", "yes")
    virtio_scsi_controller = "yes" == params.get(
        "virtio_scsi_controller", "no")
    virtio_scsi_controller_driver = params.get(
        "virtio_scsi_controller_driver", "")
    source_path = "yes" == params.get(
        "virt_disk_device_source_path", "yes")
    check_patitions = "yes" == params.get(
        "virt_disk_check_partitions", "yes")
    check_patitions_hotunplug = "yes" == params.get(
        "virt_disk_check_partitions_hotunplug", "yes")
    test_slots_order = "yes" == params.get(
        "virt_disk_device_test_order", "no")
    test_disks_format = "yes" == params.get(
        "virt_disk_device_test_format", "no")
    test_block_size = "yes" == params.get(
        "virt_disk_device_test_block_size", "no")
    test_file_img_on_disk = "yes" == params.get(
        "test_file_image_on_disk", "no")
    test_with_boot_disk = "yes" == params.get(
        "virt_disk_with_boot_disk", "no")
    test_disk_option_cmd = "yes" == params.get(
        "test_disk_option_cmd", "no")
    test_disk_type_dir = "yes" == params.get(
        "virt_disk_test_type_dir", "no")
    test_disk_bootorder = "yes" == params.get(
        "virt_disk_test_bootorder", "no")
    test_disk_bootorder_snapshot = "yes" == params.get(
        "virt_disk_test_bootorder_snapshot", "no")
    test_boot_console = "yes" == params.get(
        "virt_disk_device_boot_console", "no")
    test_disk_readonly = "yes" == params.get(
        "virt_disk_device_test_readonly", "no")
    test_disk_snapshot = "yes" == params.get(
        "virt_disk_test_snapshot", "no")
    test_disk_save_restore = "yes" == params.get(
        "virt_disk_test_save_restore", "no")
    test_bus_device_option = "yes" == params.get(
        "test_bus_option_cmd", "no")
    snapshot_before_start = "yes" == params.get(
        "snapshot_before_start", "no")

    if dom_iothreads:
        if not libvirt_version.version_compare(1, 2, 8):
            raise error.TestNAError("iothreads not supported for"
                                    " this libvirt version")

    if test_block_size:
        logical_block_size = params.get("logical_block_size")
        physical_block_size = params.get("physical_block_size")

    if any([test_boot_console, add_disk_driver]):
        if vm.is_dead():
            vm.start()
        session = vm.wait_for_login()
        if test_boot_console:
            # Setting console to kernel parameters
            vm.set_kernel_console("ttyS0", "115200")
        if add_disk_driver:
            # Ignore errors here
            session.cmd("dracut --force --add-drivers '%s'"
                        % add_disk_driver)
        session.close()
        vm.shutdown()

    # Destroy VM.
    if vm.is_alive():
        vm.destroy(gracefully=False)

    # Back up xml file.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    # Get device path.
    device_source_path = ""
    if source_path:
        device_source_path = test.virtdir

    # Prepare test environment.
    qemu_config = LibvirtQemuConfig()
    if test_disks_format:
        qemu_config.allow_disk_format_probing = True
        utils_libvirtd.libvirtd_restart()

    # Create virtual device file.
    disks = []
    try:
        for i in range(len(device_source_names)):
            if test_disk_type_dir:
                # If we testing disk type dir option,
                # it needn't to create disk image
                disks.append({"format": "dir",
                              "source": device_source_names[i]})
            else:
                path = "%s/%s.%s" % (device_source_path,
                                     device_source_names[i], device_formats[i])
                disk = prepare_disk(path, device_formats[i])
                if disk:
                    disks.append(disk)

    except Exception, e:
        logging.error(repr(e))
        for img in disks:
            if img.has_key("disk_dev"):
                if img["format"] == "nfs":
                    img["disk_dev"].cleanup()
            else:
                if img["format"] == "iscsi":
                    libvirt.setup_or_cleanup_iscsi(is_setup=False)
                if img["format"] not in ["dir", "scsi"]:
                    os.remove(img["source"])
        raise error.TestNAError("Creating disk failed")

    # Build disks xml.
    disks_xml = []
    # Additional disk images.
    disks_img = []
    vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
    try:
        for i in range(len(disks)):
            disk_xml = Disk(type_name=device_types[i])
            # If we are testing image file on iscsi disk,
            # mount the disk and then create the image.
            if test_file_img_on_disk:
                mount_path = "/tmp/diskimg"
                if utils.run("mkdir -p %s && mount %s %s"
                             % (mount_path, disks[i]["source"],
                                mount_path), ignore_status=True).exit_status:
                    raise error.TestNAError("Prepare disk failed")
                disk_path = "%s/%s.qcow2" % (mount_path, device_source_names[i])
                disk_source = libvirt.create_local_disk("file", disk_path, "1",
                                                        disk_format="qcow2")
                disks_img.append({"format": "qcow2",
                                  "source": disk_source, "path": mount_path})
            else:
                disk_source = disks[i]["source"]

            disk_xml.device = devices[i]

            if device_with_source:
                if device_types[i] == "file":
                    dev_attrs = "file"
                elif device_types[i] == "dir":
                    dev_attrs = "dir"
                else:
                    dev_attrs = "dev"
                source_dict = {dev_attrs: disk_source}
                if len(startup_policy) > i:
                    source_dict.update({"startupPolicy": startup_policy[i]})
                disk_xml.source = disk_xml.new_disk_source(
                    **{"attrs": source_dict})

            if len(device_bootorder) > i:
                disk_xml.boot = device_bootorder[i]

            if test_block_size:
                disk_xml.blockio = {"logical_block_size": logical_block_size,
                                    "physical_block_size": physical_block_size}

            if wwn != "":
                disk_xml.wwn = wwn
            if serial != "":
                disk_xml.serial = serial
            if vendor != "":
                disk_xml.vendor = vendor
            if product != "":
                disk_xml.product = product

            disk_xml.target = {"dev": device_targets[i], "bus": device_bus[i]}
            if len(device_readonly) > i:
                disk_xml.readonly = "yes" == device_readonly[i]

            # Add driver options from parameters
            driver_dict = {"name": "qemu"}
            if len(driver_options) > i:
                for driver_option in driver_options[i].split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
            disk_xml.driver = driver_dict

            # Add disk address from parameters.
            if len(device_address) > i:
                addr_dict = {}
                for addr_option in device_address[i].split(','):
                    if addr_option != "":
                        d = addr_option.split('=')
                        addr_dict.update({d[0].strip(): d[1].strip()})
                disk_xml.address = disk_xml.new_disk_address(
                    **{"attrs": addr_dict})

            logging.debug("disk xml: %s", disk_xml)
            if hotplug:
                disks_xml.append(disk_xml)
            else:
                vmxml.add_device(disk_xml)

        # If we want to test with bootdisk.
        # just edit the bootdisk xml.
        if test_with_boot_disk:
            xml_devices = vmxml.devices
            disk_index = xml_devices.index(xml_devices.by_device_tag("disk")[0])
            disk = xml_devices[disk_index]
            if bootorder != "":
                disk.boot = bootorder
                osxml = vm_xml.VMOSXML()
                osxml.type = vmxml.os.type
                osxml.arch = vmxml.os.arch
                osxml.machine = vmxml.os.machine
                if test_boot_console:
                    osxml.loader = "/usr/share/seabios/bios.bin"
                    osxml.bios_useserial = "yes"
                    osxml.bios_reboot_timeout = "-1"

                del vmxml.os
                vmxml.os = osxml
            driver_dict = {"name": disk.driver["name"],
                           "type": disk.driver["type"]}
            if bootdisk_driver != "":
                for driver_option in bootdisk_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
            disk.driver = driver_dict

            if iface_driver != "":
                driver_dict = {"name": "vhost"}
                for driver_option in iface_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
                iface_list = xml_devices.by_device_tag("interface")[0]
                iface_index = xml_devices.index(iface_list)
                iface = xml_devices[iface_index]
                iface.driver = iface.new_driver(**{"driver_attr": driver_dict})
                iface.model = "virtio"
                del iface.address

            if bootdisk_snapshot != "":
                disk.snapshot = bootdisk_snapshot

            disk.target = {"dev": bootdisk_target, "bus": bootdisk_bus}
            device_source = disk.source.attrs["file"]

            del disk.address
            vmxml.devices = xml_devices
            vmxml.define()

        # Add virtio_scsi controller.
        if virtio_scsi_controller:
            scsi_controller = Controller("controller")
            scsi_controller.type = "scsi"
            scsi_controller.index = "0"
            ctl_model = params.get("virtio_scsi_controller_model")
            if ctl_model:
                scsi_controller.model = ctl_model
            if virtio_scsi_controller_driver != "":
                driver_dict = {}
                for driver_option in virtio_scsi_controller_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
                scsi_controller.driver = driver_dict
            vmxml.del_controller("scsi")
            vmxml.add_device(scsi_controller)

        # Test usb devices.
        usb_devices = {}
        if add_usb_device:
            # Delete all usb devices first.
            controllers = vmxml.get_devices(device_type="controller")
            for ctrl in controllers:
                if ctrl.type == "usb":
                    vmxml.del_device(ctrl)

            inputs = vmxml.get_devices(device_type="input")
            for input in inputs:
                if input.type_name == "tablet":
                    vmxml.del_device(input)

            # Add new usb controllers.
            usb_controller1 = Controller("controller")
            usb_controller1.type = "usb"
            usb_controller1.index = "0"
            usb_controller1.model = "piix3-uhci"
            vmxml.add_device(usb_controller1)
            usb_controller2 = Controller("controller")
            usb_controller2.type = "usb"
            usb_controller2.index = "1"
            usb_controller2.model = "ich9-ehci1"
            vmxml.add_device(usb_controller2)

            input_obj = Input("tablet")
            input_obj.input_bus = "usb"
            addr_dict = {}
            if input_usb_address != "":
                for addr_option in input_usb_address.split(','):
                    if addr_option != "":
                        d = addr_option.split('=')
                        addr_dict.update({d[0].strip(): d[1].strip()})
            if addr_dict:
                input_obj.address = input_obj.new_input_address(
                    **{"attrs": addr_dict})
            vmxml.add_device(input_obj)
            usb_devices.update({"input": addr_dict})

            hub_obj = Hub("usb")
            addr_dict = {}
            if hub_usb_address != "":
                for addr_option in hub_usb_address.split(','):
                    if addr_option != "":
                        d = addr_option.split('=')
                        addr_dict.update({d[0].strip(): d[1].strip()})
            if addr_dict:
                hub_obj.address = hub_obj.new_hub_address(
                    **{"attrs": addr_dict})
            vmxml.add_device(hub_obj)
            usb_devices.update({"hub": addr_dict})

        if dom_iothreads:
            # Delete cputune/iothreadids section, it may have conflict
            # with domain iothreads
            del vmxml.cputune
            del vmxml.iothreadids
            vmxml.iothreads = int(dom_iothreads)

        # After compose the disk xml, redefine the VM xml.
        vmxml.sync()

        # Test snapshot before vm start.
        if test_disk_snapshot:
            if snapshot_before_start:
                ret = virsh.snapshot_create_as(vm_name, "s1 %s"
                                               % snapshot_option)
                libvirt.check_exit_status(ret, snapshot_error)

        # Start the VM.
        vm.start()
        if status_error:
            raise error.TestFail("VM started unexpectedly")

        # Hotplug the disks.
        if device_at_dt_disk:
            for i in range(len(disks)):
                attach_option = ""
                if len(device_attach_option) > i:
                    attach_option = device_attach_option[i]
                ret = virsh.attach_disk(vm_name, disks[i]["source"],
                                        device_targets[i],
                                        attach_option)
                libvirt.check_exit_status(ret)

        elif hotplug:
            for i in range(len(disks_xml)):
                disks_xml[i].xmltreefile.write()
                attach_option = ""
                if len(device_attach_option) > i:
                    attach_option = device_attach_option[i]
                ret = virsh.attach_device(vm_name, disks_xml[i].xml,
                                          flagstr=attach_option)
                attach_error = False
                if len(device_attach_error) > i:
                    attach_error = "yes" == device_attach_error[i]
                libvirt.check_exit_status(ret, attach_error)

    except virt_vm.VMStartError as details:
        if not status_error:
            raise error.TestFail('VM failed to start:\n%s' % details)
    except xcepts.LibvirtXMLError:
        if not define_error:
            raise error.TestFail("Failed to define VM")
    else:
        # VM is started, perform the tests.
        if test_slots_order:
            if not check_disk_order(device_targets):
                raise error.TestFail("Disks slots order error in domain xml")

        if test_disks_format:
            if not check_disk_format(device_targets, device_formats):
                raise error.TestFail("Disks type error in VM xml")

        if test_boot_console:
            # Check if disks bootorder is as expected.
            expected_order = params.get("expected_order").split(',')
            if not check_boot_console(expected_order):
                raise error.TestFail("Test VM bootorder failed")

        if test_block_size:
            # Check disk block size in VM.
            if not check_vm_block_size(device_targets,
                                       logical_block_size, physical_block_size):
                raise error.TestFail("Test disk block size in VM failed")

        if test_disk_option_cmd:
            # Check if disk options take affect in qemu commmand line.
            cmd = ("ps -ef | grep %s | grep -v grep " % vm_name)
            if test_with_boot_disk:
                d_target = bootdisk_target
            else:
                d_target = device_targets[0]
                if device_with_source:
                    cmd += (" | grep %s" %
                            (device_source_names[0].replace(',', ',,')))
            io = vm_xml.VMXML.get_disk_attr(vm_name, d_target, "driver", "io")
            if io:
                cmd += " | grep aio=%s" % io
            ioeventfd = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                   "driver", "ioeventfd")
            if ioeventfd:
                cmd += " | grep ioeventfd=%s" % ioeventfd
            event_idx = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                   "driver", "event_idx")
            if event_idx:
                cmd += " | grep event_idx=%s" % event_idx

            discard = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                 "driver", "discard")
            if discard:
                cmd += " | grep discard=%s" % discard
            copy_on_read = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                      "driver", "copy_on_read")
            if copy_on_read:
                cmd += " | grep copy-on-read=%s" % copy_on_read

            iothread = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                  "driver", "iothread")
            if iothread:
                cmd += " | grep iothread=iothread%s" % iothread

            if serial != "":
                cmd += " | grep serial=%s" % serial
            if wwn != "":
                cmd += " | grep -E \"wwn=(0x)?%s\"" % wwn
            if vendor != "":
                cmd += " | grep vendor=%s" % vendor
            if product != "":
                cmd += " | grep \"product=%s\"" % product

            num_queues = ""
            ioeventfd = ""
            if virtio_scsi_controller_driver != "":
                for driver_option in virtio_scsi_controller_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        if d[0].strip() == "queues":
                            num_queues = d[1].strip()
                        elif d[0].strip() == "ioeventfd":
                            ioeventfd = d[1].strip()
            if num_queues != "":
                cmd += " | grep num_queues=%s" % num_queues
            if ioeventfd:
                cmd += " | grep ioeventfd=%s" % ioeventfd

            iface_event_idx = ""
            if iface_driver != "":
                for driver_option in iface_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        if d[0].strip() == "event_idx":
                            iface_event_idx = d[1].strip()
            if iface_event_idx != "":
                cmd += " | grep virtio-net-pci,event_idx=%s" % iface_event_idx

            if utils.run(cmd, ignore_status=True).exit_status:
                raise error.TestFail("Check disk driver option failed")

        if test_disk_snapshot:
            ret = virsh.snapshot_create_as(vm_name, "s1 %s" % snapshot_option)
            libvirt.check_exit_status(ret, snapshot_error)

        # Check the disk bootorder.
        if test_disk_bootorder:
            for i in range(len(device_targets)):
                if len(device_attach_error) > i:
                    if device_attach_error[i] == "yes":
                        continue
                if device_bootorder[i] != vm_xml.VMXML.get_disk_attr(
                        vm_name, device_targets[i], "boot", "order"):
                    raise error.TestFail("Check bootorder failed")

        # Check disk bootorder with snapshot.
        if test_disk_bootorder_snapshot:
            disk_name = os.path.splitext(device_source)[0]
            check_bootorder_snapshot(disk_name)

        # Check disk readonly option.
        if test_disk_readonly:
            if not check_readonly(device_targets):
                raise error.TestFail("Checking disk readonly option failed")

        # Check disk bus device option in qemu command line.
        if test_bus_device_option:
            cmd = ("ps -ef | grep %s | grep -v grep " % vm_name)
            dev_bus = int(vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                     "address", "bus"), 16)
            if device_bus[0] == "virtio":
                pci_slot = int(vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                          "address", "slot"), 16)
                if devices[0] == "lun":
                    device_option = "scsi=on"
                else:
                    device_option = "scsi=off"
                cmd += (" | grep virtio-blk-pci,%s,bus=pci.%x,addr=0x%x"
                        % (device_option, dev_bus, pci_slot))
            if device_bus[0] in ["ide", "sata", "scsi"]:
                dev_unit = int(vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                          "address", "unit"), 16)
                dev_id = vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                    "alias", "name")
            if device_bus[0] == "ide":
                check_cmd = "/usr/libexec/qemu-kvm -device ? 2>&1 |grep -E 'ide-cd|ide-hd'"
                if utils.run(check_cmd, ignore_status=True).exit_status:
                    raise error.TestNAError("ide-cd/ide-hd not supported by this qemu-kvm")

                if devices[0] == "cdrom":
                    device_option = "ide-cd"
                else:
                    device_option = "ide-hd"
                cmd += (" | grep %s,bus=ide.%d,unit=%d,drive=drive-%s,id=%s"
                        % (device_option, dev_bus, dev_unit, dev_id, dev_id))
            if device_bus[0] == "sata":
                cmd += (" | grep 'device ahci,.*,bus=pci.%s'" % dev_bus)
            if device_bus[0] == "scsi":
                if devices[0] == "lun":
                    device_option = "scsi-block"
                elif devices[0] == "cdrom":
                    device_option = "scsi-cd"
                else:
                    device_option = "scsi-hd"
                cmd += (" | grep %s,bus=scsi%d.%d,.*drive=drive-%s,id=%s"
                        % (device_option, dev_bus, dev_unit, dev_id, dev_id))
            if device_bus[0] == "usb":
                dev_port = vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                      "address", "port")
                dev_id = vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                    "alias", "name")
                if devices[0] == "disk":
                    cmd += (" | grep usb-storage,bus=usb%s.0,port=%s,"
                            "drive=drive-%s,id=%s"
                            % (dev_bus, dev_port, dev_id, dev_id))
                if usb_devices.has_key("input"):
                    cmd += (" | grep usb-tablet,id=input[0-9],bus=usb.%s,"
                            "port=%s" % (usb_devices["input"]["bus"],
                                         usb_devices["input"]["port"]))
                if usb_devices.has_key("hub"):
                    cmd += (" | grep usb-hub,id=hub0,bus=usb.%s,"
                            "port=%s" % (usb_devices["hub"]["bus"],
                                         usb_devices["hub"]["port"]))

            if utils.run(cmd, ignore_status=True).exit_status:
                raise error.TestFail("Cann't see disk option"
                                     " in command line")

        if dom_iothreads:
            check_dom_iothread()

        # Check in VM after command.
        if check_patitions:
            if not check_vm_partitions(devices,
                                       device_targets):
                raise error.TestFail("Cann't see device in VM")

        # Check disk save and restore.
        if test_disk_save_restore:
            save_file = "/tmp/%s.save" % vm_name
            check_disk_save_restore(save_file, device_targets,
                                    startup_policy)
            if os.path.exists(save_file):
                os.remove(save_file)

        # If we testing hotplug, detach the disk at last.
        if device_at_dt_disk:
            for i in range(len(disks)):
                dt_options = ""
                if devices[i] == "cdrom":
                    dt_options = "--config"
                ret = virsh.detach_disk(vm_name, device_targets[i],
                                        dt_options, **virsh_dargs)
                libvirt.check_exit_status(ret)
            # Check disks in VM after hotunplug.
            if check_patitions_hotunplug:
                if not check_vm_partitions(devices,
                                           device_targets, False):
                    raise error.TestFail("See device in VM after hotunplug")

        elif hotplug:
            for i in range(len(disks_xml)):
                if len(device_attach_error) > i:
                    if device_attach_error[i] == "yes":
                        continue
                ret = virsh.detach_device(vm_name, disks_xml[i].xml,
                                          flagstr=attach_option, **virsh_dargs)
                os.remove(disks_xml[i].xml)
                libvirt.check_exit_status(ret)

            # Check disks in VM after hotunplug.
            if check_patitions_hotunplug:
                if not check_vm_partitions(devices,
                                           device_targets, False):
                    raise error.TestFail("See device in VM after hotunplug")

    finally:
        # Delete snapshots.
        libvirt.clean_up_snapshots(vm_name, domxml=vmxml_backup)

        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        vmxml_backup.sync("--snapshots-metadata")

        # Restore qemu_config file.
        qemu_config.restore()
        utils_libvirtd.libvirtd_restart()

        for img in disks_img:
            os.remove(img["source"])
            if os.path.exists(img["path"]):
                utils.run("umount %s && rmdir %s"
                          % (img["path"], img["path"]), ignore_status=True)

        for img in disks:
            if img.has_key("disk_dev"):
                if img["format"] == "nfs":
                    img["disk_dev"].cleanup()

                del img["disk_dev"]
            else:
                if img["format"] == "scsi":
                    libvirt.delete_scsi_disk()
                elif img["format"] == "iscsi":
                    libvirt.setup_or_cleanup_iscsi(is_setup=False)
                elif img["format"] not in ["dir"]:
                    if os.path.exists(img["source"]):
                        os.remove(img["source"])

Example 10

Project: tp-qemu
Source File: qmp_command.py
View license
@error.context_aware
def run(test, params, env):
    """
    Test qmp event notification, this case will:
    1) Start VM with qmp enable.
    2) Connect to qmp port then run qmp_capabilities command.
    3) Initiate the qmp command defined in config (qmp_cmd)
    4) Verify that qmp command works as designed.

    :param test: QEMU test object
    :param params: Dictionary with the test parameters
    :param env: Dictionary with test environmen.
    """
    def check_result(qmp_o, output=None, exception_list=""):
        """
        Check test result with difference way accoriding to
        result_check.
        result_check = equal, will compare cmd_return_value with qmp
                       command output.
        result_check = contain, will try to find cmd_return_value in qmp
                       command output.
        result_check = m_equal_q, will compare key value in monitor command
                       output and qmp command output.
        result_check = m_in_q, will try to find monitor command output's key
                       value in qmp command output.
        result_check = m_format_q, will try to match the output's format with
                       check pattern.

        :param qmp_o: output from pre_cmd, qmp_cmd or post_cmd.
        :param o: output from pre_cmd, qmp_cmd or post_cmd or an execpt
        :param exception_list: element no need check.
        result set in config file.
        """
        if result_check == "equal":
            value = output
            if value != str(qmp_o):
                raise error.TestFail("QMP command return value does not match "
                                     "the expect result. Expect result: '%s'\n"
                                     "Actual result: '%s'" % (value, qmp_o))
        elif result_check == "contain":
            values = output.split(';')
            for value in values:
                if value in exception_list:
                    continue
                if value.strip() not in str(qmp_o):
                    raise error.TestFail("QMP command output does not contain "
                                         "expect result. Expect result: '%s'\n"
                                         "Actual result: '%s'"
                                         % (value, qmp_o))
        elif result_check == "not_contain":
            values = output.split(';')
            for value in values:
                if value in exception_list:
                    continue
                if value in str(qmp_o):
                    raise error.TestFail("QMP command output contains unexpect"
                                         " result. Unexpect result: '%s'\n"
                                         "Actual result: '%s'"
                                         % (value, qmp_o))
        elif result_check == "m_equal_q":
            msg = "QMP command ouput is not equal to in human monitor command."
            msg += "\nQMP command output: '%s'" % qmp_o
            msg += "\nHuman command output: '%s'" % output
            res = output.splitlines(True)
            if type(qmp_o) != type(res):
                len_o = 1
            else:
                len_o = len(qmp_o)
            if len(res) != len_o:
                if res[0].startswith(' '):
                    raise error.TestFail("Human command starts with ' ', "
                                         "there is probably some garbage in "
                                         "the output.\n" + msg)
                res_tmp = []
                #(qemu)info block in RHEL7 divided into 3 lines
                for line in res:
                    if not line.startswith(' '):
                        res_tmp.append(line)
                    else:
                        res_tmp[-1] += line
                res = res_tmp
                if len(res) != len_o:
                    raise error.TestFail(msg)
            re_str = r'([^ \t\n\r\f\v=]*)=([^ \t\n\r\f\v=]*)'
            for i in range(len(res)):
                if qmp_cmd == "query-version":
                    version = qmp_o['qemu']
                    version = "%s.%s.%s" % (version['major'], version['minor'],
                                            version['micro'])
                    package = qmp_o['package']
                    re_str = r"([0-9]+\.[0-9]+\.[0-9]+)\s*(\(\S*\))?"
                    hmp_version, hmp_package = re.findall(re_str, res[i])[0]
                    if not hmp_package:
                        hmp_package = package
                    hmp_package = hmp_package.strip()
                    package = package.strip()
                    hmp_version = hmp_version.strip()
                    if version != hmp_version or package != hmp_package:
                        raise error.TestFail(msg)
                else:
                    matches = re.findall(re_str, res[i])
                    for key, val in matches:
                        if key in exception_list:
                            continue
                        if '0x' in val:
                            val = long(val, 16)
                            if val != qmp_o[i][key]:
                                msg += "\nValue in human monitor: '%s'" % val
                                msg += "\nValue in qmp: '%s'" % qmp_o[i][key]
                                raise error.TestFail(msg)
                        elif qmp_cmd == "query-block":
                            cmp_str = "u'%s': u'%s'" % (key, val)
                            cmp_s = "u'%s': %s" % (key, val)
                            if '0' == val:
                                cmp_str_b = "u'%s': False" % key
                            elif '1' == val:
                                cmp_str_b = "u'%s': True" % key
                            else:
                                cmp_str_b = cmp_str
                            if (cmp_str not in str(qmp_o[i]) and
                                    cmp_str_b not in str(qmp_o[i]) and
                                    cmp_s not in str(qmp_o[i])):
                                msg += ("\nCan not find '%s', '%s' or '%s' in "
                                        " QMP command output."
                                        % (cmp_s, cmp_str_b, cmp_str))
                                raise error.TestFail(msg)
                        elif qmp_cmd == "query-balloon":
                            if (int(val) * 1024 * 1024 != qmp_o[key] and
                                    val not in str(qmp_o[key])):
                                msg += ("\n'%s' is not in QMP command output"
                                        % val)
                                raise error.TestFail(msg)
                        else:
                            if (val not in str(qmp_o[i][key]) and
                                    str(bool(int(val))) not in str(qmp_o[i][key])):
                                msg += ("\n'%s' is not in QMP command output"
                                        % val)
                                raise error.TestFail(msg)
        elif result_check == "m_in_q":
            res = output.splitlines(True)
            msg = "Key value from human monitor command is not in"
            msg += "QMP command output.\nQMP command output: '%s'" % qmp_o
            msg += "\nHuman monitor command output '%s'" % output
            for i in range(len(res)):
                params = res[i].rstrip().split()
                for param in params:
                    if param.rstrip() in exception_list:
                        continue
                    try:
                        str_o = str(qmp_o.values())
                    except AttributeError:
                        str_o = qmp_o
                    if param.rstrip() not in str(str_o):
                        msg += "\nKey value is '%s'" % param.rstrip()
                        raise error.TestFail(msg)
        elif result_check == "m_format_q":
            match_flag = True
            for i in qmp_o:
                if output is None:
                    raise error.TestError("QMP output pattern is missing")
                if re.match(output.strip(), str(i)) is None:
                    match_flag = False
            if not match_flag:
                msg = "Output does not match the pattern: '%s'" % output
                raise error.TestFail(msg)

    def qmp_cpu_check(output):
        """ qmp_cpu test check """
        last_cpu = int(params['smp']) - 1
        for out in output:
            cpu = out.get('CPU')
            if cpu is None:
                raise error.TestFail("'CPU' index is missing in QMP output "
                                     "'%s'" % out)
            else:
                current = out.get('current')
                if current is None:
                    raise error.TestFail("'current' key is missing in QMP "
                                         "output '%s'" % out)
                elif cpu < last_cpu:
                    if current is False:
                        pass
                    else:
                        raise error.TestFail("Attribute 'current' should be "
                                             "'False', but is '%s' instead.\n"
                                             "'%s'" % (current, out))
                elif cpu == last_cpu:
                    if current is True:
                        pass
                    else:
                        raise error.TestFail("Attribute 'current' should be "
                                             "'True', but is '%s' instead.\n"
                                             "'%s'" % (current, out))
                elif cpu <= last_cpu:
                    continue
                else:
                    raise error.TestFail("Incorrect CPU index '%s' (corrupted "
                                         "or higher than no_cpus).\n%s"
                                         % (cpu, out))

    qemu_binary = utils_misc.get_qemu_binary(params)
    if not utils_misc.qemu_has_option("qmp", qemu_binary):
        raise error.TestNAError("Host qemu does not support qmp.")

    vm = env.get_vm(params["main_vm"])
    vm.verify_alive()

    session = vm.wait_for_login(timeout=int(params.get("login_timeout", 360)))

    module = params.get("modprobe_module")
    if module:
        error.context("modprobe the module %s" % module, logging.info)
        session.cmd("modprobe %s" % module)

    qmp_ports = vm.get_monitors_by_type('qmp')
    if qmp_ports:
        qmp_port = qmp_ports[0]
    else:
        raise error.TestError("Incorrect configuration, no QMP monitor found.")
    hmp_ports = vm.get_monitors_by_type('human')
    if hmp_ports:
        hmp_port = hmp_ports[0]
    else:
        raise error.TestError("Incorrect configuration, no QMP monitor found.")
    callback = {"host_cmd": utils.system_output,
                "guest_cmd": session.get_command_output,
                "monitor_cmd": hmp_port.send_args_cmd,
                "qmp_cmd": qmp_port.send_args_cmd}

    def send_cmd(cmd):
        """ Helper to execute command on ssh/host/monitor """
        if cmd_type in callback.keys():
            return callback[cmd_type](cmd)
        else:
            raise error.TestError("cmd_type is not supported")

    pre_cmd = params.get("pre_cmd")
    qmp_cmd = params.get("qmp_cmd")
    cmd_type = params.get("event_cmd_type")
    post_cmd = params.get("post_cmd")
    result_check = params.get("cmd_result_check")
    cmd_return_value = params.get("cmd_return_value")
    exception_list = params.get("exception_list", "")

    # HOOKs
    if result_check == 'qmp_cpu':
        pre_cmd = "cpu index=%d" % (int(params['smp']) - 1)

    # Pre command
    if pre_cmd is not None:
        error.context("Run prepare command '%s'." % pre_cmd, logging.info)
        pre_o = send_cmd(pre_cmd)
        logging.debug("Pre-command: '%s'\n Output: '%s'", pre_cmd, pre_o)
    try:
        # Testing command
        error.context("Run qmp command '%s'." % qmp_cmd, logging.info)
        output = qmp_port.send_args_cmd(qmp_cmd)
        logging.debug("QMP command: '%s' \n Output: '%s'", qmp_cmd, output)
    except qemu_monitor.QMPCmdError, err:
        if params.get("negative_test") == 'yes':
            logging.debug("Negative QMP command: '%s'\n output:'%s'", qmp_cmd,
                          err)
            if params.get("negative_check_pattern"):
                check_pattern = params.get("negative_check_pattern")
                if check_pattern not in str(err):
                    raise error.TestFail("'%s' not in exception '%s'"
                                         % (check_pattern, err))
        else:
            raise error.TestFail(err)
    except qemu_monitor.MonitorProtocolError, err:
        raise error.TestFail(err)
    except Exception, err:
        raise error.TestFail(err)

    # Post command
    if post_cmd is not None:
        error.context("Run post command '%s'." % post_cmd, logging.info)
        post_o = send_cmd(post_cmd)
        logging.debug("Post-command: '%s'\n Output: '%s'", post_cmd, post_o)

    if result_check is not None:
        txt = "Verify that qmp command '%s' works as designed." % qmp_cmd
        error.context(txt, logging.info)
        if result_check == 'qmp_cpu':
            qmp_cpu_check(output)
        elif result_check == "equal" or result_check == "contain":
            check_result(output, cmd_return_value, exception_list)
        elif result_check == "m_format_q":
            check_result(output, cmd_return_value, exception_list)
        elif 'post' in result_check:
            result_check = result_check.split('_', 1)[1]
            check_result(post_o, cmd_return_value, exception_list)
        else:
            check_result(output, post_o, exception_list)
    session.close()

Example 11

Project: tp-qemu
Source File: rv_connect.py
View license
def launch_rv(client_vm, guest_vm, params):
    """
    Launches rv_binary with args based on spice configuration
    inside client_session on background.
    remote-viewer will try to connect from vm1 from vm2

    :param client_vm - vm object
    :param guest_vm - vm object
    :param params
    """
    rv_binary = params.get("rv_binary", "remote-viewer")
    rv_ld_library_path = params.get("rv_ld_library_path")
    display = params.get("display")

    proxy = params.get("spice_proxy", None)
    if proxy:
        try:
            socket.inet_aton(params.get("proxy_ip", None))
        except socket.error:
            raise error.TestNAError('Parameter proxy_ip not changed from default values')

    host_ip = utils_net.get_host_ip_address(params)
    host_port = None
    if guest_vm.get_spice_var("listening_addr") == "ipv6":
        host_ip = ("[" + utils_misc.convert_ipv4_to_ipv6(host_ip) +
                   "]")
    host_tls_port = None

    disable_audio = params.get("disable_audio", "no")
    full_screen = params.get("full_screen")

    check_spice_info = params.get("spice_info")
    ssltype = params.get("ssltype")
    test_type = params.get("test_type")

    # cmd var keeps final remote-viewer command line
    # to be executed on client
    cmd = rv_binary
    if client_vm.params.get("os_type") != "windows":
        cmd = cmd + " --display=:0.0"

    # If qemu_ticket is set, set the password
    #  of the VM using the qemu-monitor
    ticket = None
    ticket_send = params.get("spice_password_send")
    qemu_ticket = params.get("qemu_password")
    if qemu_ticket:
        guest_vm.monitor.cmd("set_password spice %s" % qemu_ticket)
        logging.info("Sending to qemu monitor: set_password spice %s"
                     % qemu_ticket)

    gencerts = params.get("gencerts")
    certdb = params.get("certdb")
    smartcard = params.get("smartcard")
    host_subj = None
    cacert = None

    rv_parameters_from = params.get("rv_parameters_from", "cmd")
    if rv_parameters_from == 'file':
        cmd += " ~/rv_file.vv"

    client_session = client_vm.wait_for_login(
        timeout=int(params.get("login_timeout", 360)))

    if display == "spice":

        ticket = guest_vm.get_spice_var("spice_password")

        if guest_vm.get_spice_var("spice_ssl") == "yes":

            # client needs cacert file
            cacert = "%s/%s" % (guest_vm.get_spice_var("spice_x509_prefix"),
                                guest_vm.get_spice_var("spice_x509_cacert_file"))
            client_session.cmd("rm -rf %s && mkdir -p %s" % (
                               guest_vm.get_spice_var("spice_x509_prefix"),
                               guest_vm.get_spice_var("spice_x509_prefix")))
            remote.copy_files_to(client_vm.get_address(), 'scp',
                                 params.get("username"),
                                 params.get("password"),
                                 params.get("shell_port"),
                                 cacert, cacert)

            host_tls_port = guest_vm.get_spice_var("spice_tls_port")
            host_port = guest_vm.get_spice_var("spice_port")

            # cacert subj is in format for create certificate(with '/' delimiter)
            # remote-viewer needs ',' delimiter. And also is needed to remove
            # first character (it's '/')
            host_subj = guest_vm.get_spice_var("spice_x509_server_subj")
            host_subj = host_subj.replace('/', ',')[1:]
            if ssltype == "invalid_explicit_hs":
                host_subj = "Invalid Explicit HS"
            else:
                host_subj += host_ip

            # If it's invalid implicit, a remote-viewer connection
            # will be attempted with the hostname, since ssl certs were
            # generated with the ip address
            hostname = socket.gethostname()
            if ssltype == "invalid_implicit_hs":
                spice_url = " spice://%s?tls-port=%s\&port=%s" % (hostname,
                                                                  host_tls_port,
                                                                  host_port)
            else:
                spice_url = " spice://%s?tls-port=%s\&port=%s" % (host_ip,
                                                                  host_tls_port,
                                                                  host_port)

            if rv_parameters_from == "menu":
                line = spice_url
            elif rv_parameters_from == "file":
                pass
            else:
                cmd += spice_url

            if not rv_parameters_from == "file":
                cmd += " --spice-ca-file=%s" % cacert

            if (params.get("spice_client_host_subject") == "yes" and not
                    rv_parameters_from == "file"):
                cmd += " --spice-host-subject=\"%s\"" % host_subj

        else:
            host_port = guest_vm.get_spice_var("spice_port")
            if rv_parameters_from == "menu":
                # line to be sent through monitor once r-v is started
                # without spice url
                line = "spice://%s?port=%s" % (host_ip, host_port)
            elif rv_parameters_from == "file":
                pass
            else:
                cmd += " spice://%s?port=%s" % (host_ip, host_port)

    elif display == "vnc":
        raise NotImplementedError("remote-viewer vnc")

    else:
        raise Exception("Unsupported display value")

    # Check to see if the test is using the full screen option.
    if full_screen == "yes" and not rv_parameters_from == "file":
        logging.info("Remote Viewer Set to use Full Screen")
        cmd += " --full-screen"

    if disable_audio == "yes":
        logging.info("Remote Viewer Set to disable audio")
        cmd += " --spice-disable-audio"

    # Check to see if the test is using a smartcard.
    if smartcard == "yes":
        logging.info("remote viewer Set to use a smartcard")
        if not rv_parameters_from == file:
            cmd += " --spice-smartcard"

        if certdb is not None:
            logging.debug("Remote Viewer set to use the following certificate"
                          " database: " + certdb)
            cmd += " --spice-smartcard-db " + certdb

        if gencerts is not None:
            logging.debug("Remote Viewer set to use the following certs: " +
                          gencerts)
            cmd += " --spice-smartcard-certificates " + gencerts

    if client_vm.params.get("os_type") == "linux":
        cmd = "nohup " + cmd + " &> /dev/null &"  # Launch it on background
        if rv_ld_library_path:
            cmd = "export LD_LIBRARY_PATH=" + rv_ld_library_path + ";" + cmd

    if rv_parameters_from == "file":
        print "Generating file"
        utils_spice.gen_rv_file(params, guest_vm, host_subj, cacert)
        print "Uploading file to client"
        client_vm.copy_files_to("rv_file.vv", "~/rv_file.vv")

    # Launching the actual set of commands
    try:
        if rv_ld_library_path:
            print_rv_version(client_session, "LD_LIBRARY_PATH=/usr/local/lib " + rv_binary)
        else:
            print_rv_version(client_session, rv_binary)

    except (ShellStatusError, ShellProcessTerminatedError):
        # Sometimes It fails with Status error, ingore it and continue.
        # It's not that important to have printed versions in the log.
        logging.debug("Ignoring a Status Exception that occurs from calling "
                      "print versions of remote-viewer or spice-gtk")

    logging.info("Launching %s on the client (virtual)", cmd)

    if proxy:
        if "http" in proxy:
            split = proxy.split('//')[1].split(':')
        else:
            split = proxy.split(':')
        host_ip = split[0]
        if len(split) > 1:
            host_port = split[1]
        else:
            host_port = "3128"
        if rv_parameters_from != "file":
            client_session.cmd("export SPICE_PROXY=%s" % proxy)

    if not params.get("rv_verify") == "only":
        try:
            client_session.cmd(cmd)
        except ShellStatusError:
            logging.debug("Ignoring a status exception, will check connection"
                          "of remote-viewer later")

        # Send command line through monitor since url was not provided
        if rv_parameters_from == "menu":
            utils_spice.wait_timeout(1)
            str_input(client_vm, line)

        # client waits for user entry (authentication) if spice_password is set
        # use qemu monitor password if set, else, if set, try normal password.
        if qemu_ticket:
            # Wait for remote-viewer to launch
            utils_spice.wait_timeout(5)
            str_input(client_vm, qemu_ticket)
        elif ticket:
            if ticket_send:
                ticket = ticket_send

            utils_spice.wait_timeout(5)  # Wait for remote-viewer to launch
            str_input(client_vm, ticket)

        utils_spice.wait_timeout(5)  # Wait for conncetion to establish

    is_rv_connected = True

    try:
        utils_spice.verify_established(client_vm, host_ip,
                                       host_port, rv_binary,
                                       host_tls_port,
                                       params.get("spice_secure_channels",
                                                  None))
    except utils_spice.RVConnectError:
        if test_type == "negative":
            logging.info("remote-viewer connection failed as expected")
            if ssltype in ("invalid_implicit_hs", "invalid_explicit_hs"):
                # Check the qemu process output to verify what is expected
                qemulog = guest_vm.process.get_output()
                if "SSL_accept failed" in qemulog:
                    return
                else:
                    raise error.TestFail("SSL_accept failed not shown in qemu" +
                                         "process as expected.")
            is_rv_connected = False
        else:
            raise error.TestFail("remote-viewer connection failed")

    if test_type == "negative" and is_rv_connected:
        raise error.TestFail("remote-viewer connection was established when" +
                             " it was supposed to be unsuccessful")

    # Get spice info
    output = guest_vm.monitor.cmd("info spice")
    logging.debug("INFO SPICE")
    logging.debug(output)

    # Check to see if ipv6 address is reported back from qemu monitor
    if (check_spice_info == "ipv6"):
        logging.info("Test to check if ipv6 address is reported"
                     " back from the qemu monitor")
        # Remove brackets from ipv6 host ip
        if (host_ip[1:len(host_ip) - 1] in output):
            logging.info("Reported ipv6 address found in output from"
                         " 'info spice'")
        else:
            raise error.TestFail("ipv6 address not found from qemu monitor"
                                 " command: 'info spice'")
    else:
        logging.info("Not checking the value of 'info spice'"
                     " from the qemu monitor")

    # prevent from kill remote-viewer after test finish
    if client_vm.params.get("os_type") == "linux":
        cmd = "disown -ar"
    client_session.cmd_output(cmd)

Example 12

Project: virt-test
Source File: virt.py
View license
    def run_once(self, params):
        # Convert params to a Params object
        params = utils_params.Params(params)

        # If a dependency test prior to this test has failed, let's fail
        # it right away as TestNA.
        if params.get("dependency_failed") == 'yes':
            raise error.TestNAError("Test dependency failed")

        # Report virt test version
        logging.info(version.get_pretty_version_info())
        # Report the parameters we've received and write them as keyvals
        logging.debug("Test parameters:")
        keys = params.keys()
        keys.sort()
        for key in keys:
            logging.debug("    %s = %s", key, params[key])
            self.write_test_keyval({key: params[key]})

        # Set the log file dir for the logging mechanism used by kvm_subprocess
        # (this must be done before unpickling env)
        utils_misc.set_log_file_dir(self.debugdir)

        # Open the environment file
        custom_env_path = params.get("custom_env_path", "")
        if custom_env_path:
            env_path = custom_env_path
        else:
            env_path = params.get("vm_type")
        env_filename = os.path.join(self.bindir, "backends", env_path,
                                    params.get("env", "env"))
        env = utils_env.Env(env_filename, self.env_version)
        other_subtests_dirs = params.get("other_tests_dirs", "")

        test_passed = False
        t_type = None

        try:
            try:
                try:
                    subtest_dirs = []
                    bin_dir = self.bindir

                    for d in other_subtests_dirs.split():
                        # Replace split char.
                        d = os.path.join(*d.split("/"))
                        subtestdir = os.path.join(bin_dir, d, "tests")
                        if not os.path.isdir(subtestdir):
                            raise error.TestError("Directory %s not"
                                                  " exist." % (subtestdir))
                        subtest_dirs += data_dir.SubdirList(subtestdir,
                                                            bootstrap.test_filter)

                    # Verify if we have the correspondent source file for it
                    for generic_subdir in asset.get_test_provider_subdirs('generic'):
                        subtest_dirs += data_dir.SubdirList(generic_subdir,
                                                            bootstrap.test_filter)

                    for specific_subdir in asset.get_test_provider_subdirs(params.get("vm_type")):
                        subtest_dirs += data_dir.SubdirList(specific_subdir,
                                                            bootstrap.test_filter)

                    subtest_dir = None

                    # Get the test routine corresponding to the specified
                    # test type
                    logging.debug("Searching for test modules that match "
                                  "'type = %s' and 'provider = %s' "
                                  "on this cartesian dict",
                                  params.get("type"), params.get("provider", None))

                    t_types = params.get("type").split()
                    provider = params.get("provider", None)
                    if provider is not None:
                        subtest_dirs = [d for d in subtest_dirs if provider in d]
                    # Make sure we can load provider_lib in tests
                    for s in subtest_dirs:
                        if os.path.dirname(s) not in sys.path:
                            sys.path.insert(0, os.path.dirname(s))

                    test_modules = {}
                    for t_type in t_types:
                        for d in subtest_dirs:
                            module_path = os.path.join(d, "%s.py" % t_type)
                            if os.path.isfile(module_path):
                                subtest_dir = d
                                break
                        if subtest_dir is None:
                            msg = ("Could not find test file %s.py on tests"
                                   "dirs %s" % (t_type, subtest_dirs))
                            raise error.TestError(msg)
                        # Load the test module
                        f, p, d = imp.find_module(t_type, [subtest_dir])
                        test_modules[t_type] = imp.load_module(t_type, f, p, d)
                        f.close()

                    # Preprocess
                    try:
                        params = env_process.preprocess(self, params, env)
                    finally:
                        env.save()

                    # Run the test function
                    for t_type in t_types:
                        test_module = test_modules[t_type]
                        run_func = utils_misc.get_test_entrypoint_func(
                            t_type, test_module)
                        try:
                            run_func(self, params, env)
                            self.verify_background_errors()
                        finally:
                            env.save()
                    test_passed = True
                    error_message = funcatexit.run_exitfuncs(env, t_type)
                    if error_message:
                        raise error.TestWarn("funcatexit failed with: %s"
                                             % error_message)

                except Exception, e:
                    if t_type is not None:
                        error_message = funcatexit.run_exitfuncs(env, t_type)
                        if error_message:
                            logging.error(error_message)
                    logging.error("Test failed: %s: %s",
                                  e.__class__.__name__, e)
                    try:
                        env_process.postprocess_on_error(
                            self, params, env)
                    finally:
                        env.save()
                    raise

            finally:
                # Postprocess
                try:
                    try:
                        env_process.postprocess(self, params, env)
                    except Exception, e:
                        if test_passed:
                            raise
                        logging.error("Exception raised during "
                                      "postprocessing: %s", e)
                finally:
                    env.save()

        except Exception, e:
            if params.get("abort_on_error") != "yes":
                raise
            # Abort on error
            logging.info("Aborting job (%s)", e)
            if params.get("vm_type") == "qemu":
                for vm in env.get_all_vms():
                    if vm.is_dead():
                        continue
                    logging.info("VM '%s' is alive.", vm.name)
                    for m in vm.monitors:
                        logging.info(
                            "'%s' has a %s monitor unix socket at: %s",
                            vm.name, m.protocol, m.filename)
                    logging.info(
                        "The command line used to start '%s' was:\n%s",
                        vm.name, vm.make_qemu_command())
                raise error.JobError("Abort requested (%s)" % e)

Example 13

Project: virt-test
Source File: standalone_test.py
View license
    def run_once(self):
        params = self.params

        # If a dependency test prior to this test has failed, let's fail
        # it right away as TestNA.
        if params.get("dependency_failed") == 'yes':
            raise error.TestNAError("Test dependency failed")

        # Report virt test version
        logging.info(version.get_pretty_version_info())
        # Report the parameters we've received and write them as keyvals
        logging.info("Starting test %s", self.tag)
        logging.debug("Test parameters:")
        keys = params.keys()
        keys.sort()
        for key in keys:
            logging.debug("    %s = %s", key, params[key])

        # Warn of this special condition in related location in output & logs
        if os.getuid() == 0 and params.get('nettype', 'user') == 'user':
            logging.warning("")
            logging.warning("Testing with nettype='user' while running "
                            "as root may produce unexpected results!!!")
            logging.warning("")

        # Open the environment file
        env_filename = os.path.join(
            data_dir.get_backend_dir(params.get("vm_type")),
            params.get("env", "env"))
        env = utils_env.Env(env_filename, self.env_version)

        test_passed = False
        t_types = None
        t_type = None

        try:
            try:
                try:
                    subtest_dirs = []

                    other_subtests_dirs = params.get("other_tests_dirs", "")
                    for d in other_subtests_dirs.split():
                        d = os.path.join(*d.split("/"))
                        subtestdir = os.path.join(self.bindir, d, "tests")
                        if not os.path.isdir(subtestdir):
                            raise error.TestError("Directory %s does not "
                                                  "exist" % (subtestdir))
                        subtest_dirs += data_dir.SubdirList(subtestdir,
                                                            bootstrap.test_filter)

                    provider = params.get("provider", None)

                    if provider is None:
                        # Verify if we have the correspondent source file for it
                        for generic_subdir in asset.get_test_provider_subdirs('generic'):
                            subtest_dirs += data_dir.SubdirList(generic_subdir,
                                                                bootstrap.test_filter)

                        for specific_subdir in asset.get_test_provider_subdirs(params.get("vm_type")):
                            subtest_dirs += data_dir.SubdirList(specific_subdir,
                                                                bootstrap.test_filter)
                    else:
                        provider_info = asset.get_test_provider_info(provider)
                        for key in provider_info['backends']:
                            subtest_dirs += data_dir.SubdirList(
                                provider_info['backends'][key]['path'],
                                bootstrap.test_filter)

                    subtest_dir = None

                    # Get the test routine corresponding to the specified
                    # test type
                    logging.debug("Searching for test modules that match "
                                  "'type = %s' and 'provider = %s' "
                                  "on this cartesian dict",
                                  params.get("type"), params.get("provider", None))

                    t_types = params.get("type").split()
                    # Make sure we can load provider_lib in tests
                    for s in subtest_dirs:
                        if os.path.dirname(s) not in sys.path:
                            sys.path.insert(0, os.path.dirname(s))

                    test_modules = {}
                    for t_type in t_types:
                        for d in subtest_dirs:
                            module_path = os.path.join(d, "%s.py" % t_type)
                            if os.path.isfile(module_path):
                                logging.debug("Found subtest module %s",
                                              module_path)
                                subtest_dir = d
                                break
                        if subtest_dir is None:
                            msg = ("Could not find test file %s.py on test"
                                   "dirs %s" % (t_type, subtest_dirs))
                            raise error.TestError(msg)
                        # Load the test module
                        f, p, d = imp.find_module(t_type, [subtest_dir])
                        test_modules[t_type] = imp.load_module(t_type, f, p, d)
                        f.close()

                    # Preprocess
                    try:
                        params = env_process.preprocess(self, params, env)
                    finally:
                        env.save()

                    # Run the test function
                    for t_type in t_types:
                        test_module = test_modules[t_type]
                        run_func = utils_misc.get_test_entrypoint_func(
                            t_type, test_module)
                        try:
                            run_func(self, params, env)
                            self.verify_background_errors()
                        finally:
                            env.save()
                    test_passed = True
                    error_message = funcatexit.run_exitfuncs(env, t_type)
                    if error_message:
                        raise error.TestWarn("funcatexit failed with: %s"
                                             % error_message)

                except Exception, e:
                    if (t_type is not None):
                        error_message = funcatexit.run_exitfuncs(env, t_type)
                        if error_message:
                            logging.error(error_message)
                    try:
                        env_process.postprocess_on_error(self, params, env)
                    finally:
                        env.save()
                    raise

            finally:
                # Postprocess
                try:
                    try:
                        env_process.postprocess(self, params, env)
                    except Exception, e:
                        if test_passed:
                            raise
                        logging.error("Exception raised during "
                                      "postprocessing: %s", e)
                finally:
                    env.save()

        except Exception, e:
            if params.get("abort_on_error") != "yes":
                raise
            # Abort on error
            logging.info("Aborting job (%s)", e)
            if params.get("vm_type") == "qemu":
                for vm in env.get_all_vms():
                    if vm.is_dead():
                        continue
                    logging.info("VM '%s' is alive.", vm.name)
                    for m in vm.monitors:
                        logging.info("It has a %s monitor unix socket at: %s",
                                     m.protocol, m.filename)
                    logging.info("The command line used to start it was:\n%s",
                                 vm.make_qemu_command())
                raise error.JobError("Abort requested (%s)" % e)

        return test_passed

Example 14

Project: virt-test
Source File: unattended_install.py
View license
@error.context_aware
def run(test, params, env):
    """
    Unattended install test:
    1) Starts a VM with an appropriated setup to start an unattended OS install.
    2) Wait until the install reports to the install watcher its end.

    :param test: QEMU test object.
    :param params: Dictionary with the test parameters.
    :param env: Dictionary with test environment.
    """
    @error.context_aware
    def copy_images():
        error.base_context("Copy image from NFS after installation failure")
        image_copy_on_error = params.get("image_copy_on_error", "no")
        if image_copy_on_error == "yes":
            logging.info("Running image_copy to copy pristine image from NFS.")
            try:
                error.context("Quit qemu-kvm before copying guest image")
                vm.monitor.quit()
            except Exception, e:
                logging.warn(e)
            from virttest import utils_test
            error.context("Copy image from NFS Server")
            utils_test.run_image_copy(test, params, env)

    src = params.get('images_good')
    base_dir = params.get("images_base_dir", data_dir.get_data_dir())
    dst = storage.get_image_filename(params, base_dir)
    if params.get("storage_type") == "iscsi":
        dd_cmd = "dd if=/dev/zero of=%s bs=1M count=1" % dst
        txt = "iscsi used, need destroy data in %s" % dst
        txt += " by command: %s" % dd_cmd
        logging.info(txt)
        utils.system(dd_cmd)
    image_name = os.path.basename(dst)
    mount_point = params.get("dst_dir")
    if mount_point and src:
        funcatexit.register(env, params.get("type"), copy_file_from_nfs, src,
                            dst, mount_point, image_name)

    vm = env.get_vm(params["main_vm"])
    local_dir = params.get("local_dir")
    if local_dir:
        local_dir = utils_misc.get_path(test.bindir, local_dir)
    else:
        local_dir = test.bindir
    if params.get("copy_to_local"):
        for param in params.get("copy_to_local").split():
            l_value = params.get(param)
            if l_value:
                need_copy = True
                nfs_link = utils_misc.get_path(test.bindir, l_value)
                i_name = os.path.basename(l_value)
                local_link = os.path.join(local_dir, i_name)
                if os.path.isfile(local_link):
                    file_hash = utils.hash_file(local_link, "md5")
                    expected_hash = utils.hash_file(nfs_link, "md5")
                    if file_hash == expected_hash:
                        need_copy = False
                if need_copy:
                    msg = "Copy %s to %s in local host." % (i_name, local_link)
                    error.context(msg, logging.info)
                    utils.get_file(nfs_link, local_link)
                    params[param] = local_link

    unattended_install_config = UnattendedInstallConfig(test, params, vm)
    unattended_install_config.setup()

    # params passed explicitly, because they may have been updated by
    # unattended install config code, such as when params['url'] == auto
    vm.create(params=params)

    post_finish_str = params.get("post_finish_str",
                                 "Post set up finished")
    install_timeout = int(params.get("install_timeout", 4800))

    migrate_background = params.get("migrate_background") == "yes"
    if migrate_background:
        mig_timeout = float(params.get("mig_timeout", "3600"))
        mig_protocol = params.get("migration_protocol", "tcp")

    logging.info("Waiting for installation to finish. Timeout set to %d s "
                 "(%d min)", install_timeout, install_timeout / 60)
    error.context("waiting for installation to finish")

    start_time = time.time()

    try:
        serial_name = vm.serial_ports[0]
    except IndexError:
        raise virt_vm.VMConfigMissingError(vm.name, "serial")

    log_file = utils_misc.get_path(test.debugdir,
                                   "serial-%s-%s.log" % (serial_name,
                                                         vm.name))
    logging.debug("Monitoring serial console log for completion message: %s",
                  log_file)
    serial_log_msg = ""
    serial_read_fails = 0

    # As the the install process start. We may need collect informations from
    # the image. So use the test case instead this simple function in the
    # following code.
    if mount_point and src:
        funcatexit.unregister(env, params.get("type"), copy_file_from_nfs,
                              src, dst, mount_point, image_name)

    send_key_timeout = int(params.get("send_key_timeout", 60))
    while (time.time() - start_time) < install_timeout:
        try:
            vm.verify_alive()
            if (params.get("send_key_at_install") and
                    (time.time() - start_time) < send_key_timeout):
                vm.send_key(params.get("send_key_at_install"))
        # Due to a race condition, sometimes we might get a MonitorError
        # before the VM gracefully shuts down, so let's capture MonitorErrors.
        except (virt_vm.VMDeadError, qemu_monitor.MonitorError), e:
            if params.get("wait_no_ack", "no") == "yes":
                break
            else:
                # Print out the original exception before copying images.
                logging.error(e)
                copy_images()
                raise e

        try:
            test.verify_background_errors()
        except Exception, e:
            copy_images()
            raise e

        # To ignore the try:except:finally problem in old version of python
        try:
            serial_log_msg = open(log_file, 'r').read()
        except IOError:
            # Only make noise after several failed reads
            serial_read_fails += 1
            if serial_read_fails > 10:
                logging.warn("Can not read from serial log file after %d tries",
                             serial_read_fails)

        if (params.get("wait_no_ack", "no") == "no" and
                (post_finish_str in serial_log_msg)):
            break

        # Due to libvirt automatically start guest after import
        # we only need to wait for successful login.
        if params.get("medium") == "import":
            try:
                vm.login()
                break
            except (remote.LoginError, Exception), e:
                pass

        if migrate_background:
            vm.migrate(timeout=mig_timeout, protocol=mig_protocol)
        else:
            time.sleep(1)
    else:
        logging.warn("Timeout elapsed while waiting for install to finish ")
        copy_images()
        raise error.TestFail("Timeout elapsed while waiting for install to "
                             "finish")

    logging.debug('cleaning up threads and mounts that may be active')
    global _url_auto_content_server_thread
    global _url_auto_content_server_thread_event
    if _url_auto_content_server_thread is not None:
        _url_auto_content_server_thread_event.set()
        _url_auto_content_server_thread.join(3)
        _url_auto_content_server_thread = None
        utils_disk.cleanup(unattended_install_config.cdrom_cd1_mount)

    global _unattended_server_thread
    global _unattended_server_thread_event
    if _unattended_server_thread is not None:
        _unattended_server_thread_event.set()
        _unattended_server_thread.join(3)
        _unattended_server_thread = None

    global _syslog_server_thread
    global _syslog_server_thread_event
    if _syslog_server_thread is not None:
        _syslog_server_thread_event.set()
        _syslog_server_thread.join(3)
        _syslog_server_thread = None

    time_elapsed = time.time() - start_time
    logging.info("Guest reported successful installation after %d s (%d min)",
                 time_elapsed, time_elapsed / 60)

    if params.get("shutdown_cleanly", "yes") == "yes":
        shutdown_cleanly_timeout = int(params.get("shutdown_cleanly_timeout",
                                                  120))
        logging.info("Wait for guest to shutdown cleanly")
        if params.get("medium", "cdrom") == "import":
            vm.shutdown()
        try:
            if utils_misc.wait_for(vm.is_dead, shutdown_cleanly_timeout, 1, 1):
                logging.info("Guest managed to shutdown cleanly")
        except qemu_monitor.MonitorError, e:
            logging.warning("Guest apparently shut down, but got a "
                            "monitor error: %s", e)

Example 15

Project: avocado-vt
Source File: test.py
View license
    def _runTest(self):
        params = self.params

        # If a dependency test prior to this test has failed, let's fail
        # it right away as TestNA.
        if params.get("dependency_failed") == 'yes':
            raise error.TestNAError("Test dependency failed")

        # Report virt test version
        logging.info(version.get_pretty_version_info())
        # Report the parameters we've received and write them as keyvals
        logging.debug("Test parameters:")
        keys = params.keys()
        keys.sort()
        for key in keys:
            logging.debug("    %s = %s", key, params[key])

        # Warn of this special condition in related location in output & logs
        if os.getuid() == 0 and params.get('nettype', 'user') == 'user':
            logging.warning("")
            logging.warning("Testing with nettype='user' while running "
                            "as root may produce unexpected results!!!")
            logging.warning("")

        # Find the test
        subtest_dirs = []
        test_filter = bootstrap.test_filter

        other_subtests_dirs = params.get("other_tests_dirs", "")
        for d in other_subtests_dirs.split():
            d = os.path.join(*d.split("/"))
            subtestdir = os.path.join(self.bindir, d, "tests")
            if not os.path.isdir(subtestdir):
                raise error.TestError("Directory %s does not "
                                      "exist" % subtestdir)
            subtest_dirs += data_dir.SubdirList(subtestdir,
                                                test_filter)

        provider = params.get("provider", None)

        if provider is None:
            # Verify if we have the correspondent source file for
            # it
            generic_subdirs = asset.get_test_provider_subdirs(
                'generic')
            for generic_subdir in generic_subdirs:
                subtest_dirs += data_dir.SubdirList(generic_subdir,
                                                    test_filter)
            specific_subdirs = asset.get_test_provider_subdirs(
                params.get("vm_type"))
            for specific_subdir in specific_subdirs:
                subtest_dirs += data_dir.SubdirList(
                    specific_subdir, bootstrap.test_filter)
        else:
            provider_info = asset.get_test_provider_info(provider)
            for key in provider_info['backends']:
                subtest_dirs += data_dir.SubdirList(
                    provider_info['backends'][key]['path'],
                    bootstrap.test_filter)

        subtest_dir = None

        # Get the test routine corresponding to the specified
        # test type
        logging.debug("Searching for test modules that match "
                      "'type = %s' and 'provider = %s' "
                      "on this cartesian dict",
                      params.get("type"),
                      params.get("provider", None))

        t_types = params.get("type").split()
        # Make sure we can load provider_lib in tests
        for s in subtest_dirs:
            if os.path.dirname(s) not in sys.path:
                sys.path.insert(0, os.path.dirname(s))

        test_modules = {}
        for t_type in t_types:
            for d in subtest_dirs:
                module_path = os.path.join(d, "%s.py" % t_type)
                if os.path.isfile(module_path):
                    logging.debug("Found subtest module %s",
                                  module_path)
                    subtest_dir = d
                    break
            if subtest_dir is None:
                msg = ("Could not find test file %s.py on test"
                       "dirs %s" % (t_type, subtest_dirs))
                raise error.TestError(msg)
            # Load the test module
            f, p, d = imp.find_module(t_type, [subtest_dir])
            test_modules[t_type] = imp.load_module(t_type, f, p, d)
            f.close()

        # TODO: the environment file is deprecated code, and should be removed
        # in future versions. Right now, it's being created on an Avocado temp
        # dir that is only persisted during the runtime of one job, which is
        # different from the original idea of the environment file (which was
        # persist information accross virt-test/avocado-vt job runs)
        env_filename = os.path.join(data_dir.get_tmp_dir(),
                                    params.get("env", "env"))
        env = utils_env.Env(env_filename, self.env_version)
        self.runner_queue.put({"func_at_exit": cleanup_env,
                               "args": (env_filename, self.env_version),
                               "once": True})

        test_passed = False
        t_type = None

        try:
            try:
                try:
                    # Preprocess
                    try:
                        params = env_process.preprocess(self, params, env)
                    finally:
                        self.__safe_env_save(env)

                    # Run the test function
                    for t_type in t_types:
                        test_module = test_modules[t_type]
                        run_func = utils_misc.get_test_entrypoint_func(
                            t_type, test_module)
                        try:
                            run_func(self, params, env)
                            self.verify_background_errors()
                        finally:
                            self.__safe_env_save(env)
                    test_passed = True
                    error_message = funcatexit.run_exitfuncs(env, t_type)
                    if error_message:
                        raise error.TestWarn("funcatexit failed with: %s" %
                                             error_message)

                except Exception:
                    if t_type is not None:
                        error_message = funcatexit.run_exitfuncs(env, t_type)
                        if error_message:
                            logging.error(error_message)
                    try:
                        env_process.postprocess_on_error(self, params, env)
                    finally:
                        self.__safe_env_save(env)
                    raise

            finally:
                # Postprocess
                try:
                    try:
                        params['test_passed'] = str(test_passed)
                        env_process.postprocess(self, params, env)
                    except Exception, e:
                        if test_passed:
                            raise
                        logging.error("Exception raised during "
                                      "postprocessing: %s", e)
                finally:
                    if self.__safe_env_save(env):
                        env.destroy()   # Force-clean as it can't be stored

        except Exception, e:
            if params.get("abort_on_error") != "yes":
                raise
            # Abort on error
            logging.info("Aborting job (%s)", e)
            if params.get("vm_type") == "qemu":
                for vm in env.get_all_vms():
                    if vm.is_dead():
                        continue
                    logging.info("VM '%s' is alive.", vm.name)
                    for m in vm.monitors:
                        logging.info("It has a %s monitor unix socket at: %s",
                                     m.protocol, m.filename)
                    logging.info("The command line used to start it was:\n%s",
                                 vm.make_create_command())
                raise error.JobError("Abort requested (%s)" % e)

        return test_passed

Example 16

Project: avocado-vt
Source File: virt.py
View license
    def run_once(self, params):
        # Convert params to a Params object
        params = utils_params.Params(params)

        # If a dependency test prior to this test has failed, let's fail
        # it right away as TestNA.
        if params.get("dependency_failed") == 'yes':
            raise error.TestNAError("Test dependency failed")

        # Report virt test version
        logging.info(version.get_pretty_version_info())
        # Report the parameters we've received and write them as keyvals
        logging.debug("Test parameters:")
        keys = params.keys()
        keys.sort()
        for key in keys:
            logging.debug("    %s = %s", key, params[key])
            self.write_test_keyval({key: params[key]})

        # Set the log file dir for the logging mechanism used by kvm_subprocess
        # (this must be done before unpickling env)
        utils_misc.set_log_file_dir(self.debugdir)

        # Open the environment file
        custom_env_path = params.get("custom_env_path", "")
        if custom_env_path:
            env_path = custom_env_path
        else:
            env_path = params.get("vm_type")
        env_filename = os.path.join(self.bindir, "backends", env_path,
                                    params.get("env", "env"))
        env = utils_env.Env(env_filename, self.env_version)
        other_subtests_dirs = params.get("other_tests_dirs", "")

        test_passed = False
        t_type = None

        try:
            try:
                try:
                    subtest_dirs = []
                    bin_dir = self.bindir

                    for d in other_subtests_dirs.split():
                        # Replace split char.
                        d = os.path.join(*d.split("/"))
                        subtestdir = os.path.join(bin_dir, d, "tests")
                        if not os.path.isdir(subtestdir):
                            raise error.TestError("Directory %s not"
                                                  " exist." % (subtestdir))
                        subtest_dirs += data_dir.SubdirList(subtestdir,
                                                            bootstrap.test_filter)

                    # Verify if we have the correspondent source file for it
                    for generic_subdir in asset.get_test_provider_subdirs('generic'):
                        subtest_dirs += data_dir.SubdirList(generic_subdir,
                                                            bootstrap.test_filter)

                    for specific_subdir in asset.get_test_provider_subdirs(params.get("vm_type")):
                        subtest_dirs += data_dir.SubdirList(specific_subdir,
                                                            bootstrap.test_filter)

                    subtest_dir = None

                    # Get the test routine corresponding to the specified
                    # test type
                    logging.debug("Searching for test modules that match "
                                  "'type = %s' and 'provider = %s' "
                                  "on this cartesian dict",
                                  params.get("type"), params.get("provider", None))

                    t_types = params.get("type").split()
                    provider = params.get("provider", None)
                    if provider is not None:
                        subtest_dirs = [
                            d for d in subtest_dirs if provider in d]
                    # Make sure we can load provider_lib in tests
                    for s in subtest_dirs:
                        if os.path.dirname(s) not in sys.path:
                            sys.path.insert(0, os.path.dirname(s))

                    test_modules = {}
                    for t_type in t_types:
                        for d in subtest_dirs:
                            module_path = os.path.join(d, "%s.py" % t_type)
                            if os.path.isfile(module_path):
                                subtest_dir = d
                                break
                        if subtest_dir is None:
                            msg = ("Could not find test file %s.py on tests"
                                   "dirs %s" % (t_type, subtest_dirs))
                            raise error.TestError(msg)
                        # Load the test module
                        f, p, d = imp.find_module(t_type, [subtest_dir])
                        test_modules[t_type] = imp.load_module(t_type, f, p, d)
                        f.close()

                    # Preprocess
                    try:
                        params = env_process.preprocess(self, params, env)
                    finally:
                        env.save()

                    # Run the test function
                    for t_type in t_types:
                        test_module = test_modules[t_type]
                        run_func = utils_misc.get_test_entrypoint_func(
                            t_type, test_module)
                        try:
                            run_func(self, params, env)
                            self.verify_background_errors()
                        finally:
                            env.save()
                    test_passed = True
                    error_message = funcatexit.run_exitfuncs(env, t_type)
                    if error_message:
                        raise error.TestWarn("funcatexit failed with: %s"
                                             % error_message)

                except Exception, e:
                    if t_type is not None:
                        error_message = funcatexit.run_exitfuncs(env, t_type)
                        if error_message:
                            logging.error(error_message)
                    logging.error("Test failed: %s: %s",
                                  e.__class__.__name__, e)
                    try:
                        env_process.postprocess_on_error(
                            self, params, env)
                    finally:
                        env.save()
                    raise

            finally:
                # Postprocess
                try:
                    try:
                        env_process.postprocess(self, params, env)
                    except Exception, e:
                        if test_passed:
                            raise
                        logging.error("Exception raised during "
                                      "postprocessing: %s", e)
                finally:
                    env.save()

        except Exception, e:
            if params.get("abort_on_error") != "yes":
                raise
            # Abort on error
            logging.info("Aborting job (%s)", e)
            if params.get("vm_type") == "qemu":
                for vm in env.get_all_vms():
                    if vm.is_dead():
                        continue
                    logging.info("VM '%s' is alive.", vm.name)
                    for m in vm.monitors:
                        logging.info(
                            "'%s' has a %s monitor unix socket at: %s",
                            vm.name, m.protocol, m.filename)
                    logging.info(
                        "The command line used to start '%s' was:\n%s",
                        vm.name, vm.make_create_command())
                raise error.JobError("Abort requested (%s)" % e)

Example 17

View license
    @commands.command_add('resetnetwork')
    def resetnetwork_cmd(self, data):
        """
        Configure network & hostname, restarting network.
        """

        os_mod = self.detect_os()
        if not os_mod:
            raise SystemError("Couldn't figure out my OS")

        xs_handle = pyxenstore.Handle()

        try:
            hostname = xs_handle.read(XENSTORE_HOSTNAME_PATH)
            logging.info('hostname: %r (from xenstore)' % hostname)
        except pyxenstore.NotFoundError:
            try:
                hostname = os_mod.network.get_hostname()
            except:
                hostname = _get_hostname()

            if not hostname:
                hostname = DEFAULT_HOSTNAME
            logging.info('hostname: %r (default)' % hostname)

        interfaces = []

        try:
            entries = xs_handle.entries(XENSTORE_INTERFACE_PATH)
        except pyxenstore.NotFoundError:
            entries = []

        for entry in entries:
            data = xs_handle.read(XENSTORE_INTERFACE_PATH + '/' + entry)
            data = anyjson.deserialize(data)
            interfaces.append(data)
            logging.info('interface %s: %r' % (entry, data))

        del xs_handle

        # Normalize interfaces data. It can come in a couple of different
        # (similar) formats, none of which are convenient.
        by_macaddr = dict([(mac, (up, name))
                           for name, up, mac in agentlib.get_interfaces()])

        config = {}

        for interface in interfaces:
            ifconfig = {}

            mac = interface.get('mac')
            if not mac:
                raise RuntimeError('No MAC found in config')

            # by_macaddr is keyed using lower case hexadecimal
            mac = mac.lower()

            ifconfig['mac'] = mac

            # 'label' used to be the method to determine which interface
            # this configuration applies to, but 'mac' is safer to use.
            # 'label' is now only used for printing a comment in the
            # generated configuration to easier differentiate interfaces.
            if mac not in by_macaddr:
                raise RuntimeError('Unknown interface MAC %s' % mac)

            ifconfig['label'] = interface.get('label')

            up, ifname = by_macaddr[mac]

            # Record if the interface is up already
            ifconfig['up'] = up

            # List of IPv4 and IPv6 addresses
            ip4s = interface.get('ips', [])
            ip6s = interface.get('ip6s', [])
            if not ip4s and not ip6s:
                raise RuntimeError('No IPs found for interface')

            # Gateway (especially IPv6) can be tied to an interface
            gateway4 = interface.get('gateway')
            gateway6 = interface.get('gateway6')

            # Filter out any IPs that aren't enabled
            for ip in ip4s + ip6s:
                try:
                    ip['enabled'] = int(ip.get('enabled', 0))
                except ValueError:
                    raise RuntimeError("Invalid value %r for 'enabled' key" %
                                       ip.get('enabled'))

            ip4s = filter(lambda i: i['enabled'], ip4s)
            ip6s = filter(lambda i: i['enabled'], ip6s)

            # Validate and normalize IPv4 and IPv6 addresses
            for ip in ip4s:
                if 'ip' not in ip:
                    raise RuntimeError("Missing 'ip' key for IPv4 address")
                if 'netmask' not in ip:
                    raise RuntimeError("Missing 'netmask' key for IPv4 address")

                # Rename 'ip' to 'address' to be more specific
                ip['address'] = ip.pop('ip')
                ip['prefixlen'] = NETMASK_TO_PREFIXLEN[ip['netmask']]

            for ip in ip6s:
                if 'ip' not in ip and 'address' not in ip:
                    raise RuntimeError("Missing 'ip' or 'address' key for IPv6 address")
                if 'netmask' not in ip:
                    raise RuntimeError("Missing 'netmask' key for IPv6 address")

                if 'gateway' in ip:
                    # FIXME: Should we fail if gateway6 is already set?
                    gateway6 = ip.pop('gateway')

                # FIXME: Should we fail if both 'ip' and 'address' are
                # specified but differ?

                # Rename 'ip' to 'address' to be more specific
                if 'address' not in ip:
                    ip['address'] = ip.pop('ip')

                # Rename 'netmask' to 'prefixlen' to be more accurate
                ip['prefixlen'] = ip.pop('netmask')

            ifconfig['ip4s'] = ip4s
            ifconfig['ip6s'] = ip6s

            ifconfig['gateway4'] = gateway4
            ifconfig['gateway6'] = gateway6

            # Routes are optional
            routes = interface.get('routes', [])

            # Validate and normalize routes
            for route in routes:
                if 'route' not in route:
                    raise RuntimeError("Missing 'route' key for route")
                if 'netmask' not in route:
                    raise RuntimeError("Missing 'netmask' key for route")
                if 'gateway' not in route:
                    raise RuntimeError("Missing 'gateway' key for route")

                # Rename 'route' to 'network' to be more specific
                route['network'] = route.pop('route')
                route['prefixlen'] = NETMASK_TO_PREFIXLEN[route['netmask']]

            ifconfig['routes'] = routes

            ifconfig['dns'] = interface.get('dns', [])

            config[ifname] = ifconfig

        # TODO: Should we fail if there isn't at least one gateway specified?
        #if not gateway4 and not gateway6:
        #    raise RuntimeError('No gateway found for public interface')

        return os_mod.network.configure_network(hostname, config)

Example 18

Project: RMG-Py
Source File: pdep.py
View license
    def update(self, reactionModel, pdepSettings):
        """
        Regenerate the :math:`k(T,P)` values for this partial network if the
        network is marked as invalid.
        """
        from rmgpy.kinetics import Arrhenius, KineticsData, MultiArrhenius
        from rmgpy.pdep.collision import SingleExponentialDown
        from rmgpy.pdep.reaction import fitInterpolationModel
        
        # Get the parameters for the pressure dependence calculation
        job = pdepSettings
        job.network = self
        outputDirectory = pdepSettings.outputFile
        
        Tmin = job.Tmin.value_si
        Tmax = job.Tmax.value_si
        Pmin = job.Pmin.value_si
        Pmax = job.Pmax.value_si
        Tlist = job.Tlist.value_si
        Plist = job.Plist.value_si
        maximumGrainSize = job.maximumGrainSize.value_si if job.maximumGrainSize is not None else 0.0
        minimumGrainCount = job.minimumGrainCount
        method = job.method
        interpolationModel = job.interpolationModel
        activeJRotor = job.activeJRotor
        activeKRotor = job.activeKRotor
        rmgmode = job.rmgmode
        
        # Figure out which configurations are isomers, reactant channels, and product channels
        self.updateConfigurations(reactionModel)

        # Make sure we have high-P kinetics for all path reactions
        for rxn in self.pathReactions:
            if rxn.kinetics is None and rxn.reverse.kinetics is None:
                raise PressureDependenceError('Path reaction {0} with no high-pressure-limit kinetics encountered in PDepNetwork #{1:d}.'.format(rxn, self.index))
            elif rxn.kinetics is not None and rxn.kinetics.isPressureDependent():
                raise PressureDependenceError('Pressure-dependent kinetics encountered for path reaction {0} in PDepNetwork #{1:d}.'.format(rxn, self.index))
        
        # Do nothing if the network is already valid
        if self.valid: return
        # Do nothing if there are no explored wells
        if len(self.explored) == 0 and len(self.source) > 1: return
        # Log the network being updated
        logging.info("Updating {0:s}".format(self))

        # Generate states data for unimolecular isomers and reactants if necessary
        for isomer in self.isomers:
            spec = isomer.species[0]
            if not spec.hasStatMech(): spec.generateStatMech()
        for reactants in self.reactants:
            for spec in reactants.species:
                if not spec.hasStatMech(): spec.generateStatMech()
        # Also generate states data for any path reaction reactants, so we can
        # always apply the ILT method in the direction the kinetics are known
        for reaction in self.pathReactions:
            for spec in reaction.reactants:
                if not spec.hasStatMech(): spec.generateStatMech()
        # While we don't need the frequencies for product channels, we do need
        # the E0, so create a conformer object with the E0 for the product
        # channel species if necessary
        for products in self.products:
            for spec in products.species:
                if spec.conformer is None:
                    spec.conformer = Conformer(E0=spec.getThermoData().E0)
        
        # Determine transition state energies on potential energy surface
        # In the absence of any better information, we simply set it to
        # be the reactant ground-state energy + the activation energy
        # Note that we need Arrhenius kinetics in order to do this
        for rxn in self.pathReactions:
            if rxn.kinetics is None:
                raise Exception('Path reaction "{0}" in PDepNetwork #{1:d} has no kinetics!'.format(rxn, self.index))
            elif isinstance(rxn.kinetics, KineticsData):
                if len(rxn.reactants) == 1:
                    kunits = 's^-1'
                elif len(rxn.reactants) == 2:
                    kunits = 'm^3/(mol*s)'
                elif len(rxn.reactants) == 3:
                    kunits = 'm^6/(mol^2*s)'
                else:
                    kunits = ''
                rxn.kinetics = Arrhenius().fitToData(Tlist=rxn.kinetics.Tdata.value_si, klist=rxn.kinetics.kdata.value_si, kunits=kunits)
            elif isinstance(rxn.kinetics, MultiArrhenius):
                logging.info('Converting multiple kinetics to a single Arrhenius expression for reaction {rxn}'.format(rxn=rxn))
                rxn.kinetics = rxn.kinetics.toArrhenius(Tmin=Tmin, Tmax=Tmax)
            elif not isinstance(rxn.kinetics, Arrhenius):
                raise Exception('Path reaction "{0}" in PDepNetwork #{1:d} has invalid kinetics type "{2!s}".'.format(rxn, self.index, rxn.kinetics.__class__))
            rxn.fixBarrierHeight(forcePositive=True)
            E0 = sum([spec.conformer.E0.value_si for spec in rxn.reactants]) + rxn.kinetics.Ea.value_si
            rxn.transitionState = rmgpy.species.TransitionState(
                conformer = Conformer(E0=(E0*0.001,"kJ/mol")),
            )

        # Set collision model
        bathGas = [spec for spec in reactionModel.core.species if not spec.reactive]
        self.bathGas = {}
        for spec in bathGas:
            # is this really the only/best way to weight them? And what is alpha0?
            self.bathGas[spec] = 1.0 / len(bathGas)
            spec.collisionModel = SingleExponentialDown(alpha0=(4.86,'kcal/mol'))

        # Save input file
        if not self.label: self.label = str(self.index)
        job.saveInputFile(os.path.join(outputDirectory, 'pdep', 'network{0:d}_{1:d}.py'.format(self.index, len(self.isomers))))
        
        self.printSummary(level=logging.INFO)

        # Calculate the rate coefficients
        self.initialize(Tmin, Tmax, Pmin, Pmax, maximumGrainSize, minimumGrainCount, activeJRotor, activeKRotor, rmgmode)
        K = self.calculateRateCoefficients(Tlist, Plist, method)

        # Generate PDepReaction objects
        configurations = []
        configurations.extend([isom.species[:] for isom in self.isomers])
        configurations.extend([reactant.species[:] for reactant in self.reactants])
        configurations.extend([product.species[:] for product in self.products])
        j = configurations.index(self.source)

        for i in range(K.shape[2]):
            if i != j:
                # Find the path reaction
                netReaction = None
                for r in self.netReactions:
                    if r.hasTemplate(configurations[j], configurations[i]):
                        netReaction = r
                # If net reaction does not already exist, make a new one
                if netReaction is None:
                    netReaction = PDepReaction(
                        reactants=configurations[j],
                        products=configurations[i],
                        network=self,
                        kinetics=None
                    )
                    netReaction = reactionModel.makeNewPDepReaction(netReaction)
                    self.netReactions.append(netReaction)

                    # Place the net reaction in the core or edge if necessary
                    # Note that leak reactions are not placed in the edge
                    if all([s in reactionModel.core.species for s in netReaction.reactants]) and all([s in reactionModel.core.species for s in netReaction.products]):
                        reactionModel.addReactionToCore(netReaction)
                    else:
                        reactionModel.addReactionToEdge(netReaction)

                # Set/update the net reaction kinetics using interpolation model
                Tdata = job.Tlist.value_si
                Pdata = job.Plist.value_si
                kdata = K[:,:,i,j].copy()
                order = len(netReaction.reactants)
                kdata *= 1e6 ** (order-1)
                kunits = {1: 's^-1', 2: 'cm^3/(mol*s)', 3: 'cm^6/(mol^2*s)'}[order]
                netReaction.kinetics = job.fitInterpolationModel(Tlist, Plist, kdata, kunits)

                # Check: For each net reaction that has a path reaction, make
                # sure the k(T,P) values for the net reaction do not exceed
                # the k(T) values of the path reaction
                # Only check the k(T,P) value at the highest P and lowest T,
                # as this is the one most likely to be in the high-pressure 
                # limit
                t = 0; p = len(Plist) - 1
                for pathReaction in self.pathReactions:
                    if pathReaction.isIsomerization():
                        # Don't check isomerization reactions, since their
                        # k(T,P) values potentially contain both direct and
                        # well-skipping contributions, and therefore could be
                        # significantly larger than the direct k(T) value
                        # (This can also happen for association/dissocation
                        # reactions, but the effect is generally not too large)
                        continue
                    if pathReaction.reactants == netReaction.reactants and pathReaction.products == netReaction.products:
                        kinf = pathReaction.kinetics.getRateCoefficient(Tlist[t])
                        if K[t,p,i,j] > 2 * kinf: # To allow for a small discretization error
                            logging.warning('k(T,P) for net reaction {0} exceeds high-P k(T) by {1:g} at {2:g} K, {3:g} bar'.format(netReaction, K[t,p,i,j] / kinf, Tlist[t], Plist[p]/1e5))
                            logging.info('    k(T,P) = {0:9.2e}    k(T) = {1:9.2e}'.format(K[t,p,i,j], kinf))
                        break
                    elif pathReaction.products == netReaction.reactants and pathReaction.reactants == netReaction.products:
                        kinf = pathReaction.kinetics.getRateCoefficient(Tlist[t]) / pathReaction.getEquilibriumConstant(Tlist[t])
                        if K[t,p,i,j] > 2 * kinf: # To allow for a small discretization error
                            logging.warning('k(T,P) for net reaction {0} exceeds high-P k(T) by {1:g} at {2:g} K, {3:g} bar'.format(netReaction, K[t,p,i,j] / kinf, Tlist[t], Plist[p]/1e5))           
                            logging.info('    k(T,P) = {0:9.2e}    k(T) = {1:9.2e}'.format(K[t,p,i,j], kinf))
                        break
        
        # Delete intermediate arrays to conserve memory
        self.cleanup()
        
        # We're done processing this network, so mark it as valid
        self.valid = True

Example 19

Project: cmonkey2
Source File: stringdb.py
View license
def get_network_factory(organism_code, filename, weight, sep='\t',
                        normalized=False):
    """STRING network factory from preprocessed edge file
    (protein1, protein2, combined_score), scores are already
    normalized to 1000.
    This is the standard factory method used for Microbes.
    """
    def can_add_edge(node1, node2, thesaurus, cano_genes):
        """check whether we can add the edge
            deprecated on 2/18/15 by keep_node object in read_edges2
            In principal, this could be replaced by an object that
            stores the keep_node instance variable.
        """
        if cano_genes is not None:
            return (node1 in thesaurus and node2 in thesaurus
                    and thesaurus[node1] in cano_genes and thesaurus[node2] in cano_genes)
        else:
            return node1 in thesaurus and node2 in thesaurus

    def read_edges2(filename, organism, ratios):
        """just read a preprocessed file, much faster to debug"""
        logging.info("stringdb.read_edges2()")
        dfile = util.read_dfile(filename, sep)
        logging.info("Finished loading %s", filename)
        result = []
        max_score = 0.0
        thesaurus = organism.thesaurus()
        if ratios:
            gene_lut = {}
            for row_name in ratios.row_names:
                if row_name in thesaurus:
                    gene_lut[thesaurus[row_name]] = row_name
                gene_lut[row_name] = row_name #A node should always map to itself
            cano_genes = gene_lut.keys()
        else:
            gene_lut = None
            cano_genes = None

        num_ignored = 0
        keep_node = {}  # Big Speedup: Use to search thesaurus and cano_genes only once for each gene
        idx = 1  # Used to display progress
        total_nodes = 0
        nodes_not_in_thesaurus = 0
        nodes_not_in_cano_genes = 0

        for line in dfile.lines:
            #This can be slow, display progress every 5%
            frac = idx % (len(dfile.lines)/20)
            idx += 1
            if frac == 0:
                logging.info("Processing network %d%%", round(100*float(idx)/len(dfile.lines)))

            node1 = patches.patch_string_gene(organism_code, line[0])
            node2 = patches.patch_string_gene(organism_code, line[1])
            for node in (node1, node2):
                if not node in keep_node:
                    if cano_genes is not None:
                        keep_node[node] = node in thesaurus and thesaurus[node] in cano_genes
                    else:
                        keep_node[node] = node in thesaurus
                    if not keep_node[node]:
                        if not node in thesaurus:
                            nodes_not_in_thesaurus += 1
                        elif not thesaurus[node] in cano_genes:
                            nodes_not_in_cano_genes += 1

                    # Add this node to the lut if it is not already there.
                    if (not gene_lut is None) and (not node in gene_lut):
                        gene_lut[node] = node
                        if node in thesaurus:
                            gene_lut[thesaurus[node]] = node
                total_nodes += 1

            score = float(line[2])
            max_score = max(score, max_score)

            if keep_node[node1] and keep_node[node2]:
                #2/18/15 SD.  Translate nodes into names in ratio rows using gene_lut
                #   This will let the ratios matrix define how the genes are named
                if gene_lut is None:
                    new_edge = (node1, node2, score)
                else:
                    new_edge = (gene_lut[node1], gene_lut[node2], score)
                #logging.info("Adding edge %s - %s - %f", new_edge[0], new_edge[1], new_edge[2])
                result.append(new_edge)
            else:
                num_ignored += 1

        # Warnings
        if nodes_not_in_thesaurus > 0:
            logging.warn('%d (out of %d) nodes not found in synonyms', nodes_not_in_thesaurus, total_nodes)
        if nodes_not_in_cano_genes > 0:
            logging.warn('%d (out of %d) nodes not found in canonical gene names', nodes_not_in_cano_genes, total_nodes)

        if not normalized:
            result = normalize_edges_to_max_score(result, max_score)

        logging.info("stringdb.read_edges2(), %d edges read, %d edges ignored",
                     len(result), num_ignored)

        return result

    def make_network(organism, ratios=None, check_size=False):
        """make network"""
        return network.Network.create("STRING",
                                      read_edges2(filename, organism, ratios),
                                      weight,
                                      organism, ratios)

    return make_network

Example 20

Project: ssbench
Source File: reporter.py
View license
    def calculate_scenario_stats(self, nth_pctile=95, format_numbers=True):
        """Compute various statistics from worker job result dicts.

        :param nth_pctile: Use this percentile when calculating the stats
        :param format_numbers: Should various floating-point numbers be
        formatted as strings or left full-precision floats
        :returns: A stats python dict which looks something like:
            SERIES_STATS = {
                'min': 1.1,
                'max': 1.1,
                'avg': 1.1,
                'std_dev': 1.1,
                'median': 1.1,
            }
            {
                'agg_stats': {
                    'worker_count': 1,
                    'start': 1.1,
                    'stop': 1.1,
                    'req_count': 1,
                    'retries': 0,
                    'errors' : 0,
                    'avg_req_per_sec': 1.1, # req_count / (stop - start)?
                    'retry_rate': 0.0,
                    'first_byte_latency': SERIES_STATS,
                    'last_byte_latency': SERIES_STATS,
                },
                'worker_stats': {
                    1: {  # keys are worker_ids
                        'start': 1.1,
                        'stop': 1.1,
                        'req_count': 1,
                        'retries': 0,
                        'retry_rate': 0.0,
                        'errors': 0,
                        'avg_req_per_sec': 1.1, # req_count / (stop - start)?
                        'first_byte_latency': SERIES_STATS,
                        'last_byte_latency': SERIES_STATS,
                    },
                    # ...
                },
                'op_stats': {
                    CREATE_OBJECT: { # keys are CRUD constants: CREATE_OBJECT, READ_OBJECT, etc.
                        'req_count': 1, # num requests of this CRUD type
                        'avg_req_per_sec': 1.1, # total_requests / sum(last_byte_latencies)
                        'first_byte_latency': SERIES_STATS,
                        'last_byte_latency': SERIES_STATS,
                        'size_stats': {
                            'small': { # keys are size_str values
                                'req_count': 1, # num requests of this type and size
                                'retries': 0, # num of retries
                                'avg_req_per_sec': 1.1, # total_requests / sum(last_byte_latencies)
                                'errors': 0,
                                'retry_rate': 0.0,
                                'first_byte_latency': SERIES_STATS,
                                'last_byte_latency': SERIES_STATS,
                            },
                            # ...
                        },
                    },
                    # ...
                },
                'size_stats': {
                    'small': { # keys are size_str values
                        'req_count': 1, # num requests of this size (for all CRUD types)
                        'retries': 0, # num of retries
                        'acutual_request_count': 1, # num requests includes retries
                        'avg_req_per_sec': 1.1, # total_requests / sum(last_byte_latencies)
                        'errors': 0,
                        'retry_rate': 0.0,
                        'first_byte_latency': SERIES_STATS,
                        'last_byte_latency': SERIES_STATS,
                    },
                    # ...
                },
                'time_series': {
                    'start': 1, # epoch time of first data point
                    'data': [
                        1, # number of requests finishing during this second
                        # ...
                    ],
                },
            }
        """
        # Each result looks like:
        # {
        #   'worker_id': 1,
        #   'type': 'get_object',
        #   'size': 4900000,
        #   'size_str': 'large',
        #   'first_byte_latency': 0.9137639999389648,
        #   'last_byte_latency': 0.913769006729126,
        #   'retries': 1
        #   'completed_at': 1324372892.360802,
        # }
        # OR
        # {
        #   'worker_id': 1,
        #   'type': 'get_object',
        #   'size_str': 'large'
        #   'completed_at': 1324372892.360802,
        #   'retries': 1
        #   'exception': '...',
        # }
        logging.info('Calculating statistics...')
        agg_stats = dict(start=2 ** 32, stop=0, req_count=0)
        op_stats = {}
        for crud_type in [ssbench.CREATE_OBJECT, ssbench.READ_OBJECT,
                          ssbench.UPDATE_OBJECT, ssbench.DELETE_OBJECT]:
            op_stats[crud_type] = dict(
                req_count=0, avg_req_per_sec=0,
                size_stats=OrderedDict.fromkeys(
                    self.scenario.sizes_by_name.keys()))

        req_completion_seconds = {}
        start_time = 0
        completion_time_max = 0
        completion_time_min = 2 ** 32
        stats = dict(
            nth_pctile=nth_pctile,
            agg_stats=agg_stats,
            worker_stats={},
            op_stats=op_stats,
            size_stats=OrderedDict.fromkeys(
                self.scenario.sizes_by_name.keys()))
        for results in self.unpacker:
            skipped = 0
            for result in results:
                try:
                    res_completed_at = result['completed_at']
                    res_completion_time = int(res_completed_at)
                    res_worker_id = result['worker_id']
                    res_type = result['type']
                    res_size_str = result['size_str']
                except KeyError as err:
                    logging.info('Skipped result with missing keys (%r): %r',
                                 err, result)
                    skipped += 1
                    continue

                try:
                    res_exception = result['exception']
                except KeyError:
                    try:
                        res_last_byte_latency = result['last_byte_latency']
                    except KeyError:
                        logging.info('Skipped result with missing'
                                     ' last_byte_latency key: %r',
                                     result)
                        skipped += 1
                        continue
                    if res_completion_time < completion_time_min:
                        completion_time_min = res_completion_time
                        start_time = (
                            res_completion_time - res_last_byte_latency)
                    if res_completion_time > completion_time_max:
                        completion_time_max = res_completion_time
                    req_completion_seconds[res_completion_time] = \
                        1 + req_completion_seconds.get(res_completion_time, 0)
                    result['start'] = res_completed_at - res_last_byte_latency
                else:
                    # report log exceptions
                    logging.warn('calculate_scenario_stats: exception from '
                                 'worker %d: %s',
                                 res_worker_id, res_exception)
                    try:
                        res_traceback = result['traceback']
                    except KeyError:
                        logging.warn('traceback missing')
                    else:
                        logging.info(res_traceback)

                # Stats per-worker
                if res_worker_id not in stats['worker_stats']:
                    stats['worker_stats'][res_worker_id] = {}
                self._add_result_to(stats['worker_stats'][res_worker_id],
                                    result)

                # Stats per-file-size
                try:
                    val = stats['size_stats'][res_size_str]
                except KeyError:
                    stats['size_stats'][res_size_str] = {}
                else:
                    if not val:
                        stats['size_stats'][res_size_str] = {}
                self._add_result_to(stats['size_stats'][res_size_str],
                                    result)

                self._add_result_to(agg_stats, result)

                type_stats = op_stats[res_type]
                self._add_result_to(type_stats, result)

                # Stats per-operation-per-file-size
                try:
                    val = type_stats['size_stats'][res_size_str]
                except KeyError:
                    type_stats['size_stats'][res_size_str] = {}
                else:
                    if not val:
                        type_stats['size_stats'][res_size_str] = {}
                self._add_result_to(
                    type_stats['size_stats'][res_size_str], result)
            if skipped > 0:
                logging.warn("Total number of results skipped: %d", skipped)

        agg_stats['worker_count'] = len(stats['worker_stats'].keys())
        self._compute_req_per_sec(agg_stats)
        self._compute_retry_rate(agg_stats)
        self._compute_latency_stats(agg_stats, nth_pctile, format_numbers)

        jobs_per_worker = []
        for worker_stats in stats['worker_stats'].values():
            jobs_per_worker.append(worker_stats['req_count'])
            self._compute_req_per_sec(worker_stats)
            self._compute_retry_rate(worker_stats)
            self._compute_latency_stats(worker_stats, nth_pctile,
                                        format_numbers)
        stats['jobs_per_worker_stats'] = self._series_stats(jobs_per_worker,
                                                            nth_pctile,
                                                            format_numbers)
        logging.debug('Jobs per worker stats:\n' +
                      pformat(stats['jobs_per_worker_stats']))

        for op_stats_dict in op_stats.itervalues():
            if op_stats_dict['req_count']:
                self._compute_req_per_sec(op_stats_dict)
                self._compute_retry_rate(op_stats_dict)
                self._compute_latency_stats(op_stats_dict, nth_pctile,
                                            format_numbers)
                for size_str, size_stats in \
                        op_stats_dict['size_stats'].iteritems():
                    if size_stats:
                        self._compute_req_per_sec(size_stats)
                        self._compute_retry_rate(size_stats)
                        self._compute_latency_stats(size_stats, nth_pctile,
                                                    format_numbers)
                    else:
                        op_stats_dict['size_stats'].pop(size_str)
        for size_str, size_stats in stats['size_stats'].iteritems():
            if size_stats:
                self._compute_req_per_sec(size_stats)
                self._compute_retry_rate(size_stats)
                self._compute_latency_stats(size_stats, nth_pctile,
                                            format_numbers)
            else:
                stats['size_stats'].pop(size_str)
        time_series_data = [req_completion_seconds.get(t, 0)
                            for t in range(completion_time_min,
                                           completion_time_max + 1)]
        stats['time_series'] = dict(start=completion_time_min,
                                    start_time=start_time,
                                    stop=completion_time_max,
                                    data=time_series_data)

        return stats

Example 21

Project: astor
Source File: rtrip.py
View license
def convert(srctree, dsttree=dsttree, readonly=False, dumpall=False):
    """Walk the srctree, and convert/copy all python files
    into the dsttree

    """

    allow_ast_comparison()

    parse_file = code_to_ast.parse_file
    find_py_files = code_to_ast.find_py_files
    srctree = os.path.normpath(srctree)

    if not readonly:
        dsttree = os.path.normpath(dsttree)
        logging.info('')
        logging.info('Trashing ' + dsttree)
        shutil.rmtree(dsttree, True)

    unknown_src_nodes = set()
    unknown_dst_nodes = set()
    badfiles = set()
    broken = []
    # TODO: When issue #26 resolved, remove UnicodeDecodeError
    handled_exceptions = SyntaxError, UnicodeDecodeError

    oldpath = None

    allfiles = find_py_files(srctree, None if readonly else dsttree)
    for srcpath, fname in allfiles:
        # Create destination directory
        if not readonly and srcpath != oldpath:
            oldpath = srcpath
            if srcpath >= srctree:
                dstpath = srcpath.replace(srctree, dsttree, 1)
                if not dstpath.startswith(dsttree):
                    raise ValueError("%s not a subdirectory of %s" %
                                     (dstpath, dsttree))
            else:
                assert srctree.startswith(srcpath)
                dstpath = dsttree
            os.makedirs(dstpath)

        srcfname = os.path.join(srcpath, fname)
        logging.info('Converting %s' % srcfname)
        try:
            srcast = parse_file(srcfname)
        except handled_exceptions:
            badfiles.add(srcfname)
            continue

        dsttxt = to_source(srcast)

        if not readonly:
            dstfname = os.path.join(dstpath, fname)
            try:
                with open(dstfname, 'w') as f:
                    f.write(dsttxt)
            except UnicodeEncodeError:
                badfiles.add(dstfname)

        # As a sanity check, make sure that ASTs themselves
        # round-trip OK
        try:
            dstast = ast.parse(dsttxt) if readonly else parse_file(dstfname)
        except SyntaxError:
            dstast = []
        unknown_src_nodes.update(strip_tree(srcast))
        unknown_dst_nodes.update(strip_tree(dstast))
        if dumpall or srcast != dstast:
            srcdump = dump_tree(srcast)
            dstdump = dump_tree(dstast)
            bad = srcdump != dstdump
            logging.warning('    calculating dump -- %s' %
                            ('bad' if bad else 'OK'))
            if bad:
                broken.append(srcfname)
            if dumpall or bad:
                if not readonly:
                    try:
                        with open(dstfname[:-3] + '.srcdmp', 'w') as f:
                            f.write(srcdump)
                    except UnicodeEncodeError:
                        badfiles.add(dstfname[:-3] + '.srcdmp')
                    try:
                        with open(dstfname[:-3] + '.dstdmp', 'w') as f:
                            f.write(dstdump)
                    except UnicodeEncodeError:
                        badfiles.add(dstfname[:-3] + '.dstdmp')
                elif dumpall:
                    sys.stdout.write('\n\nAST:\n\n    ')
                    sys.stdout.write(srcdump.replace('\n', '\n    '))
                    sys.stdout.write('\n\nDecompile:\n\n    ')
                    sys.stdout.write(dsttxt.replace('\n', '\n    '))
                    sys.stdout.write('\n\nNew AST:\n\n    ')
                    sys.stdout.write('(same as old)' if dstdump == srcdump
                                     else dstdump.replace('\n', '\n    '))
                    sys.stdout.write('\n')

    if badfiles:
        logging.warning('\nFiles not processed due to syntax errors:')
        for fname in sorted(badfiles):
            logging.warning('    %s' % fname)
    if broken:
        logging.warning('\nFiles failed to round-trip to AST:')
        for srcfname in broken:
            logging.warning('    %s' % srcfname)

    ok_to_strip = 'col_offset _precedence _use_parens lineno _p_op _pp'
    ok_to_strip = set(ok_to_strip.split())
    bad_nodes = (unknown_dst_nodes | unknown_src_nodes) - ok_to_strip
    if bad_nodes:
        logging.error('\nERROR -- UNKNOWN NODES STRIPPED: %s' % bad_nodes)
    logging.info('\n')

Example 22

Project: radiotool
Source File: retarget.py
View license
def retarget(songs, duration, music_labels=None, out_labels=None,
             out_penalty=None, volume=None, volume_breakpoints=None,
             springs=None, constraints=None,
             min_beats=None, max_beats=None,
             fade_in_len=3.0, fade_out_len=5.0,
             **kwargs):
    """Retarget a song to a duration given input and output labels on
    the music.

    Suppose you like one section of a song, say, the guitar solo, and
    you want to create a three minute long version of the solo.
    Suppose the guitar solo occurs from the 150 second mark to the 200
    second mark in the original song.

    You can set the label the guitar solo with 'solo' and the rest of
    the song with 'other' by crafting the ``music_labels`` input
    function. And you can set the ``out_labels`` function to give you
    nothing but solo::

        def labels(t):
            if 150 < t < 200:
                return 'solo'
            return 'other'

        def target(t): return 'solo'

        song = Song("sweet-rock-song.wav")

        composition, info = retarget(song, 180,
            music_labels=labels, out_labels=target)

        composition.export(filename="super-long-solo")

    You can achieve much more complicated retargetings by adjusting
    the ``music_labels``, `out_labels` and ``out_penalty`` functions,
    but this should give you a basic sense of how to use the
    ``retarget`` function.

    :param song: Song to retarget
    :type song: :py:class:`radiotool.composer.Song`
    :param duration: Duration of retargeted song (in seconds)
    :type duration: float
    :param music_labels: A function that takes a time (in seconds) and
        returns the label (str) of the input music at that time
    :type music_labels: function
    :param out_labels: A function that takes a time (in seconds) and
        returns the desired label (str) of the output music at that
        time
    :type out_labels: function
    :param out_penalty: A function that takes a time (in seconds) and
        returns the penalty for not matching the correct output label
        at that time (default is 1.0)
    :type out_penalty: function
    :returns: Composition of retargeted song, and dictionary of
        information about the retargeting
    :rtype: (:py:class:`radiotool.composer.Composition`, dict)
    """

    # get song analysis
    if isinstance(songs, Track):
        songs = [songs]
    multi_songs = len(songs) > 1

    analyses = [s.analysis for s in songs]

    # generate labels for every beat in the input and output
    beat_lengths = [a[BEAT_DUR_KEY] for a in analyses]
    beats = [a["beats"] for a in analyses]

    beat_length = np.mean(beat_lengths)
    logging.info("Beat lengths of songs: {} (mean: {})".
                 format(beat_lengths, beat_length))

    if out_labels is not None:
        target = [out_labels(i) for i in np.arange(0, duration, beat_length)]
    else:
        target = ["" for i in np.arange(0, duration, beat_length)]

    if music_labels is not None:
        if not multi_songs:
            music_labels = [music_labels]
            music_labels = [item for sublist in music_labels
                            for item in sublist]
        if len(music_labels) != len(songs):
            raise ArgumentException("Did not specify {} sets of music labels".
                                    format(len(songs)))
        start = [[music_labels[i](j) for j in b] for i, b in enumerate(beats)]
    else:
        start = [["" for i in b] for b in beats]

    if out_penalty is not None:
        pen = np.array([out_penalty(i) for i in np.arange(
            0, duration, beat_length)])
    else:
        pen = np.array([1 for i in np.arange(0, duration, beat_length)])

    # we're using a valence/arousal constraint, so we need these
    in_vas = kwargs.pop('music_va', None)
    if in_vas is not None:
        if not multi_songs:
            in_vas = [in_vas]
            in_vas = [item for sublist in in_vas for item in sublist]
        if len(in_vas) != len(songs):
            raise ArgumentException("Did not specify {} sets of v/a labels".
                                    format(len(songs)))
        for i, in_va in enumerate(in_vas):
            if callable(in_va):
                in_va = np.array([in_va(j) for j in beats[i]])
            in_vas[i] = in_va

    target_va = kwargs.pop('out_va', None)
    if callable(target_va):
        target_va = np.array(
            [target_va(i) for i in np.arange(0, duration, beat_length)])

    # set constraints
    if constraints is None:
        min_pause_len = 20.
        max_pause_len = 35.
        min_pause_beats = int(np.ceil(min_pause_len / beat_length))
        max_pause_beats = int(np.floor(max_pause_len / beat_length))

        constraints = [(
            rt_constraints.PauseConstraint(
                min_pause_beats, max_pause_beats,
                to_penalty=1.4, between_penalty=.05, unit="beats"),
            rt_constraints.PauseEntryVAChangeConstraint(target_va, .005),
            rt_constraints.PauseExitVAChangeConstraint(target_va, .005),
            rt_constraints.TimbrePitchConstraint(
                context=0, timbre_weight=1.5, chroma_weight=1.5),
            rt_constraints.EnergyConstraint(penalty=0.5),
            rt_constraints.MinimumLoopConstraint(8),
            rt_constraints.ValenceArousalConstraint(
                in_va, target_va, pen * .125),
            rt_constraints.NoveltyVAConstraint(in_va, target_va, pen),
        ) for in_va in in_vas]
    else:
        max_pause_beats = 0
        if len(constraints) > 0:
            if isinstance(constraints[0], rt_constraints.Constraint):
                constraints = [constraints]

    pipelines = [rt_constraints.ConstraintPipeline(constraints=c_set)
                 for c_set in constraints]

    trans_costs = []
    penalties = []
    all_beat_names = []

    for i, song in enumerate(songs):
        (trans_cost, penalty, bn) = pipelines[i].apply(song, len(target))
        trans_costs.append(trans_cost)
        penalties.append(penalty)
        all_beat_names.append(bn)

    logging.info("Combining tables")
    total_music_beats = int(np.sum([len(b) for b in beats]))
    total_beats = total_music_beats + max_pause_beats

    # combine transition cost tables

    trans_cost = np.ones((total_beats, total_beats)) * np.inf
    sizes = [len(b) for b in beats]
    idx = 0
    for i, size in enumerate(sizes):
        trans_cost[idx:idx + size, idx:idx + size] =\
            trans_costs[i][:size, :size]
        idx += size

    trans_cost[:total_music_beats, total_music_beats:] =\
        np.vstack([tc[:len(beats[i]), len(beats[i]):]
                   for i, tc in enumerate(trans_costs)])

    trans_cost[total_music_beats:, :total_music_beats] =\
        np.hstack([tc[len(beats[i]):, :len(beats[i])]
                  for i, tc in enumerate(trans_costs)])

    trans_cost[total_music_beats:, total_music_beats:] =\
        trans_costs[0][len(beats[0]):, len(beats[0]):]

    # combine penalty tables
    penalty = np.empty((total_beats, penalties[0].shape[1]))

    penalty[:total_music_beats, :] =\
        np.vstack([p[:len(beats[i]), :] for i, p in enumerate(penalties)])

    penalty[total_music_beats:, :] = penalties[0][len(beats[0]):, :]

    logging.info("Building cost table")

    # compute the dynamic programming table (prev python method)
    # cost, prev_node = _build_table(analysis, duration, start, target, pen)

    # first_pause = 0
    # if max_pause_beats > 0:
    first_pause = total_music_beats

    if min_beats is None:
        min_beats = 0
    elif min_beats is 'default':
        min_beats = int(20. / beat_length)

    if max_beats is None:
        max_beats = -1
    elif max_beats is 'default':
        max_beats = int(90. / beat_length)
        max_beats = min(max_beats, penalty.shape[1])

    tc2 = np.nan_to_num(trans_cost)
    pen2 = np.nan_to_num(penalty)

    beat_names = []
    for i, bn in enumerate(all_beat_names):
        for b in bn:
            if not str(b).startswith('p'):
                beat_names.append((i, float(b)))
    beat_names.extend([('p', i) for i in xrange(max_pause_beats)])

    result_labels = []

    logging.info("Running optimization (full backtrace, memory efficient)")
    logging.info("\twith min_beats(%d) and max_beats(%d) and first_pause(%d)" %
                 (min_beats, max_beats, first_pause))

    song_starts = [0]
    for song in songs:
        song_starts.append(song_starts[-1] + len(song.analysis["beats"]))
    song_ends = np.array(song_starts[1:], dtype=np.int32)
    song_starts = np.array(song_starts[:-1], dtype=np.int32)

    t1 = time.clock()
    path_i, path_cost = build_table_full_backtrace(
        tc2, pen2, song_starts, song_ends,
        first_pause=first_pause, max_beats=max_beats, min_beats=min_beats)
    t2 = time.clock()
    logging.info("Built table (full backtrace) in {} seconds"
                 .format(t2 - t1))

    path = []
    if max_beats == -1:
        max_beats = min_beats + 1

    first_pause_full = max_beats * first_pause
    n_beats = first_pause
    for i in path_i:
        if i >= first_pause_full:
            path.append(('p', i - first_pause_full))
            result_labels.append(None)
            # path.append('p' + str(i - first_pause_full))
        else:
            path.append(beat_names[i % n_beats])
            song_i = path[-1][0]
            beat_name = path[-1][1]
            result_labels.append(
                start[song_i][np.where(np.array(beats[song_i]) ==
                              beat_name)[0][0]])
            # path.append(float(beat_names[i % n_beats]))

    # else:
    #     print("Running optimization (fast, full table)")
    #     # this won't work right now- needs to be updated
    #     # with the multi-song approach

    #     # fortran method
    #     t1 = time.clock()
    #     cost, prev_node = build_table(tc2, pen2)
    #     t2 = time.clock()
    #     print("Built table (fortran) in {} seconds".format(t2 - t1))
    #     res = cost[:, -1]
    #     best_idx = N.argmin(res)
    #     if N.isfinite(res[best_idx]):
    #         path, path_cost, path_i = _reconstruct_path(
    #             prev_node, cost, beat_names, best_idx, N.shape(cost)[1] - 1)
    #         # path_i = [beat_names.index(x) for x in path]
    #     else:
    #         # throw an exception here?
    #         return None

    #     path = []
    #     result_labels = []
    #     if max_pause_beats == 0:
    #         n_beats = total_music_beats
    #         first_pause = n_beats
    #     else:
    #         n_beats = first_pause
    #     for i in path_i:
    #         if i >= first_pause:
    #             path.append(('p', i - first_pause))
    #             result_labels.append(None)
    #         else:
    #             path.append(beat_names[i % n_beats])
    #             song_i = path[-1][0]
    #             beat_name = path[-1][1]
    #             result_labels.append(
    #                 start[song_i][N.where(N.array(beats[song_i]) ==
    #                               beat_name)[0][0]])

    # return a radiotool Composition
    logging.info("Generating audio")
    (comp, cf_locations, result_full_labels,
     cost_labels, contracted, result_volume) =\
        _generate_audio(
            songs, beats, path, path_cost, start,
            volume=volume,
            volume_breakpoints=volume_breakpoints,
            springs=springs,
            fade_in_len=fade_in_len, fade_out_len=fade_out_len)

    info = {
        "beat_length": beat_length,
        "contracted": contracted,
        "cost": np.sum(path_cost) / len(path),
        "path": path,
        "path_i": path_i,
        "target_labels": target,
        "result_labels": result_labels,
        "result_full_labels": result_full_labels,
        "result_volume": result_volume,
        "transitions": [Label("crossfade", loc) for loc in cf_locations],
        "path_cost": cost_labels
    }

    return comp, info

Example 23

Project: radiotool
Source File: retarget.py
View license
def _generate_audio(songs, beats, new_beats, new_beats_cost, music_labels,
                    volume=None, volume_breakpoints=None,
                    springs=None, fade_in_len=3.0, fade_out_len=5.0):
    # assuming same sample rate for all songs

    logging.info("Building volume")
    if volume is not None and volume_breakpoints is not None:
        raise Exception("volume and volume_breakpoints cannot both be defined")
    if volume_breakpoints is None:
        if volume is None:
            volume = 1.0
        volume_array = np.array([volume])

    if volume_breakpoints is not None:
        volume_array = volume_breakpoints.to_array(songs[0].samplerate)

    result_volume = np.zeros(volume_array.shape)

    min_channels = min([x.channels for x in songs])

    comp = Composition(channels=min_channels)

    # currently assuming no transitions between different songs

    beat_length = np.mean([song.analysis[BEAT_DUR_KEY]
                          for song in songs])

    audio_segments = []
    segment_song_indicies = [new_beats[0][0]]
    current_seg = [0, 0]
    if new_beats[0][0] == 'p':
        current_seg = 'p'

    for i, (song_i, b) in enumerate(new_beats):
        if segment_song_indicies[-1] != song_i:
            segment_song_indicies.append(song_i)

        if current_seg == 'p' and song_i != 'p':
            current_seg = [i, i]
        elif current_seg != 'p' and song_i == 'p':
            audio_segments.append(current_seg)
            current_seg = 'p'
        elif current_seg != 'p':
            current_seg[1] = i
    if current_seg != 'p':
        audio_segments.append(current_seg)

    segment_song_indicies = [x for x in segment_song_indicies if x != 'p']

    beats = [np.array(b) for b in beats]
    score_start = 0
    current_loc = 0.0
    last_segment_beat = 0

    comp.add_tracks(songs)

    all_cf_locations = []

    aseg_fade_ins = []

    logging.info("Building audio")
    for (aseg, song_i) in zip(audio_segments, segment_song_indicies):
        segments = []
        # TODO: is this +1 correct?
        starts = np.array([x[1] for x in new_beats[aseg[0]:aseg[1] + 1]])

        bis = [np.nonzero(beats[song_i] == b)[0][0] for b in starts]
        dists = np.zeros(len(starts))
        durs = np.zeros(len(starts))

        for i, beat in enumerate(starts):
            if i < len(bis) - 1:
                if bis[i] + 1 != bis[i + 1]:
                    dists[i + 1] = 1
            if bis[i] + 1 >= len(beats[song_i]):
                # use the average beat duration if we don't know
                # how long the beat is supposed to be
                logging.warning("USING AVG BEAT DURATION IN SYNTHESIS -\
                    POTENTIALLY NOT GOOD")
                durs[i] = songs[song_i].analysis[BEAT_DUR_KEY]
            else:
                durs[i] = beats[song_i][bis[i] + 1] - beats[song_i][bis[i]]

        # add pause duration to current location
        # current_loc +=\
            # (aseg[0] - last_segment_beat) *\
            #      song.analysis[BEAT_DUR_KEY]

        # catch up to the pause
        current_loc = max(
            aseg[0] * beat_length,
            current_loc)

        last_segment_beat = aseg[1] + 1

        cf_durations = []
        seg_start = starts[0]
        seg_start_loc = current_loc

        cf_locations = []

        segment_starts = [0]
        try:
            segment_starts.extend(np.where(dists == 1)[0])
        except:
            pass

        # print "segment starts", segment_starts

        for i, s_i in enumerate(segment_starts):
            if i == len(segment_starts) - 1:
                # last segment?
                seg_duration = np.sum(durs[s_i:])
            else:
                next_s_i = segment_starts[i + 1]
                seg_duration = np.sum(durs[s_i:next_s_i])

                cf_durations.append(durs[next_s_i])
                cf_locations.append(current_loc + seg_duration)

            seg_music_location = starts[s_i]

            seg = Segment(songs[song_i], current_loc,
                          seg_music_location, seg_duration)

            segments.append(seg)

            # update location for next segment
            current_loc += seg_duration

        # for i, start in enumerate(starts):
        #     dur = durs[i]
        #     current_loc += dur
        #     if i == 0 or dists[i - 1] == 0:
        #         pass
        #         # dur = durs[i]
        #         # current_loc += dur
        #     else:
        #         seg = Segment(song, seg_start_loc, seg_start,
        #                       current_loc - seg_start_loc)
        #         print "segment duration", current_loc - seg_start_loc
        #         segments.append(seg)

        #         # track = Track(wav_fn, t["name"])
        #         # comp.add_track(track)
        #         # dur = durs[i]
        #         cf_durations.append(dur)
        #         cf_locations.append(current_loc)

        #         seg_start_loc = current_loc
        #         seg_start = start

        #         # current_loc += dur

        # last_seg = Segment(song, seg_start_loc, seg_start,
        #     current_loc - seg_start_loc)
        # segments.append(last_seg)

        comp.add_segments(segments)

        if segments[-1].comp_location + segments[-1].duration >\
                len(volume_array):

            diff = len(volume_array) -\
                (segments[-1].comp_location + segments[-1].duration)
            new_volume_array =\
                np.ones(segments[-1].comp_location + segments[-1].duration) *\
                volume_array[-1]
            new_volume_array[:len(volume_array)] = volume_array
            volume_array = new_volume_array
            result_volume = np.zeros(new_volume_array.shape)

        for i, seg in enumerate(segments[:-1]):
            logging.info(cf_durations[i], seg.duration_in_seconds,
                         segments[i + 1].duration_in_seconds)
            rawseg = comp.cross_fade(seg, segments[i + 1], cf_durations[i])

            # decrease volume along crossfades
            volume_frames = volume_array[
                rawseg.comp_location:rawseg.comp_location + rawseg.duration]
            raw_vol = RawVolume(rawseg, volume_frames)
            comp.add_dynamic(raw_vol)

            result_volume[rawseg.comp_location:
                          rawseg.comp_location + rawseg.duration] =\
                volume_frames

        s0 = segments[0]
        sn = segments[-1]

        if fade_in_len is not None:
            fi_len = min(fade_in_len, s0.duration_in_seconds)
            fade_in_len_samps = fi_len * s0.track.samplerate
            fade_in = comp.fade_in(s0, fi_len, fade_type="linear")
            aseg_fade_ins.append(fade_in)
        else:
            fade_in = None

        if fade_out_len is not None:
            fo_len = min(5.0, sn.duration_in_seconds)
            fade_out_len_samps = fo_len * sn.track.samplerate
            fade_out = comp.fade_out(sn, fade_out_len, fade_type="exponential")
        else:
            fade_out = None

        prev_end = 0.0

        for seg in segments:
            volume_frames = volume_array[
                seg.comp_location:seg.comp_location + seg.duration]

            # this can happen on the final segment:
            if len(volume_frames) == 0:
                volume_frames = np.array([prev_end] * seg.duration)
            elif len(volume_frames) < seg.duration:
                delta = [volume_frames[-1]] *\
                    (seg.duration - len(volume_frames))
                volume_frames = np.r_[volume_frames, delta]
            raw_vol = RawVolume(seg, volume_frames)
            comp.add_dynamic(raw_vol)

            try:
                result_volume[seg.comp_location:
                              seg.comp_location + seg.duration] = volume_frames
            except ValueError:
                diff = (seg.comp_location + seg.duration) - len(result_volume)
                result_volume = np.r_[result_volume, np.zeros(diff)]
                result_volume[seg.comp_location:
                              seg.comp_location + seg.duration] = volume_frames

            if len(volume_frames) != 0:
                prev_end = volume_frames[-1]

            # vol = Volume.from_segment(seg, volume)
            # comp.add_dynamic(vol)

        if fade_in is not None:
            result_volume[s0.comp_location:
                          s0.comp_location + fade_in_len_samps] *=\
                fade_in.to_array(channels=1).flatten()
        if fade_out is not None:
            result_volume[sn.comp_location + sn.duration - fade_out_len_samps:
                          sn.comp_location + sn.duration] *=\
                fade_out.to_array(channels=1).flatten()

        all_cf_locations.extend(cf_locations)

    # result labels
    label_time = 0.0
    pause_len = beat_length
    # pause_len = song.analysis[BEAT_DUR_KEY]
    result_full_labels = []
    prev_label = -1
    for beat_i, (song_i, beat) in enumerate(new_beats):
        if song_i == 'p':
            current_label = None
            if current_label != prev_label:
                result_full_labels.append(Label("pause", label_time))
            prev_label = None

            # label_time += pause_len
            # catch up
            label_time = max(
                (beat_i + 1) * pause_len,
                label_time)
        else:
            beat_i = np.where(np.array(beats[song_i]) == beat)[0][0]
            next_i = beat_i + 1
            current_label = music_labels[song_i][beat_i]
            if current_label != prev_label:
                if current_label is None:
                    result_full_labels.append(Label("none", label_time))
                else:
                    result_full_labels.append(Label(current_label, label_time))
            prev_label = current_label

            if (next_i >= len(beats[song_i])):
                logging.warning("USING AVG BEAT DURATION - "
                                "POTENTIALLY NOT GOOD")
                label_time += songs[song_i].analysis[BEAT_DUR_KEY]
            else:
                label_time += beats[song_i][next_i] - beat

    # result costs
    cost_time = 0.0
    result_cost = []
    for i, (song_i, b) in enumerate(new_beats):
        result_cost.append(Label(new_beats_cost[i], cost_time))

        if song_i == 'p':
            # cost_time += pause_len
            # catch up
            cost_time = max(
                (i + 1) * pause_len,
                cost_time)
        else:
            beat_i = np.where(np.array(beats[song_i]) == b)[0][0]
            next_i = beat_i + 1

            if (next_i >= len(beats[song_i])):
                cost_time += songs[song_i].analysis[BEAT_DUR_KEY]
            else:
                cost_time += beats[song_i][next_i] - b

    logging.info("Contracting pause springs")
    contracted = []
    min_contraction = 0.5
    if springs is not None:
        offset = 0.0
        for spring in springs:
            contracted_time, contracted_dur = comp.contract(
                spring.time - offset, spring.duration,
                min_contraction=min_contraction)
            if contracted_dur > 0:
                logging.info("Contracted", contracted_time,
                             "at", contracted_dur)

                # move all the volume frames back
                c_time_samps = contracted_time * segments[0].track.samplerate
                c_dur_samps = contracted_dur * segments[0].track.samplerate
                result_volume = np.r_[
                    result_volume[:c_time_samps],
                    result_volume[c_time_samps + c_dur_samps:]]

                # can't move anything EARLIER than contracted_time

                new_cf = []
                for cf in all_cf_locations:
                    if cf > contracted_time:
                        new_cf.append(
                            max(cf - contracted_dur, contracted_time))
                    else:
                        new_cf.append(cf)
                all_cf_locations = new_cf

                # for lab in result_full_labels:
                #     if lab.time > contracted_time + contracted_dur:
                #         lab.time -= contracted_dur

                first_label = True
                for lab_i, lab in enumerate(result_full_labels):
                    # is this contracted in a pause that already started?
                    # if lab_i + 1 < len(result_full_labels):
                    #     next_lab = result_full_labels[lab_i + 1]
                    #     if lab.time < contracted_time <= next_lab.time:
                    #         first_label = False

                    # if lab.time > contracted_time:
                    #     # TODO: fix this hack
                    #     if lab.name == "pause" and first_label:
                    #         pass
                    #     else:
                    #         lab.time -= contracted_dur
                    #     first_label = False

                    try:
                        if lab.time == contracted_time and\
                            result_full_labels[lab_i + 1].time -\
                                contracted_dur == lab.time:

                            logging.warning("LABEL HAS ZERO LENGTH", lab)
                    except:
                        pass

                    if lab.time > contracted_time:
                        logging.info("\tcontracting label", lab)
                        lab.time = max(
                            lab.time - contracted_dur, contracted_time)
                        # lab.time -= contracted_dur
                        logging.info("\t\tto", lab)

                new_result_cost = []
                for cost_lab in result_cost:
                    if cost_lab.time <= contracted_time:
                        # cost is before contracted time
                        new_result_cost.append(cost_lab)
                    elif contracted_time < cost_lab.time <=\
                            contracted_time + contracted_dur:
                        # cost is during contracted time
                        # remove these labels
                        if cost_lab.name > 0:
                            logging.warning("DELETING nonzero cost label",
                                            cost_lab.name, cost_lab.time)
                    else:
                        # cost is after contracted time
                        cost_lab.time = max(
                            cost_lab.time - contracted_dur, contracted_time)
                        # cost_lab.time -= contracted_dur
                        new_result_cost.append(cost_lab)

                # new_result_cost = []
                # first_label = True
                # # TODO: also this hack. bleh.
                # for cost_lab in result_cost:
                #     if cost_lab.time < contracted_time:
                #         new_result_cost.append(cost_lab)
                #     elif cost_lab.time > contracted_time and\
                #             cost_lab.time <= contracted_time +\
                #                contracted_dur:
                #         if first_label:
                #             cost_lab.time = contracted_time
                #             new_result_cost.append(cost_lab)
                #         elif cost_lab.name > 0:
                #             print "DELETING nonzero cost label:",\
                #                 cost_lab.name, cost_lab.time
                #         first_label = False
                #     elif cost_lab.time > contracted_time + contracted_dur:
                #         cost_lab.time -= contracted_dur
                #         new_result_cost.append(cost_lab)
                #         first_label = False
                result_cost = new_result_cost

                contracted.append(
                    Spring(contracted_time + offset, contracted_dur))
                offset += contracted_dur

    for fade in aseg_fade_ins:
        for spring in contracted:
            if (spring.time - 1 <
                    fade.comp_location_in_seconds <
                    spring.time + spring.duration + 1):

                result_volume[
                    fade.comp_location:
                    fade.comp_location + fade.duration] /=\
                    fade.to_array(channels=1).flatten()

                fade.fade_type = "linear"
                fade.duration_in_seconds = 2.0
                result_volume[
                    fade.comp_location:
                    fade.comp_location + fade.duration] *=\
                    fade.to_array(channels=1).flatten()

                logging.info("Changing fade at {}".format(
                    fade.comp_location_in_seconds))

    # for seg in comp.segments:
    #     print seg.comp_location, seg.duration
    # print
    # for dyn in comp.dynamics:
    #     print dyn.comp_location, dyn.duration

    # add all the segments to the composition
    # comp.add_segments(segments)

    # all_segs = []

    # for i, seg in enumerate(segments[:-1]):
    #     rawseg = comp.cross_fade(seg, segments[i + 1], cf_durations[i])
    #     all_segs.extend([seg, rawseg])

    #     # decrease volume along crossfades
    #     rawseg.track.frames *= music_volume

    # all_segs.append(segments[-1])

    # add dynamic for music
    # vol = Volume(song, 0.0,
    #     (last_seg.comp_location + last_seg.duration) /
    #        float(song.samplerate),
    #     volume)
    # comp.add_dynamic(vol)

    # cf durs?
    # durs

    return (comp, all_cf_locations, result_full_labels,
            result_cost, contracted, result_volume)

Example 24

Project: reprozip
Source File: traceutils.py
View license
def combine_traces(traces, target):
    """Combines multiple trace databases into one.

    The runs from the original traces are appended ('run_id' field gets
    translated to avoid conflicts).

    :param traces: List of trace database filenames.
    :type traces: [Path]
    :param target: Directory where to write the new database and associated
        configuration file.
    :type target: Path
    """
    # We are probably overwriting on of the traces we're reading, so write to
    # a temporary file first then move it
    fd, output = Path.tempfile('.sqlite3', 'reprozip_combined_')
    if PY3:
        # On PY3, connect() only accepts unicode
        conn = sqlite3.connect(str(output))
    else:
        conn = sqlite3.connect(output.path)
    os.close(fd)
    conn.row_factory = sqlite3.Row

    # Create the schema
    create_schema(conn)

    # Temporary database with lookup tables
    conn.execute(
        '''
        ATTACH DATABASE '' AS maps;
        ''')
    conn.execute(
        '''
        CREATE TABLE maps.map_runs(
            old INTEGER NOT NULL,
            new INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT
            );
        ''')
    conn.execute(
        '''
        CREATE TABLE maps.map_processes(
            old INTEGER NOT NULL,
            new INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT
            );
        ''')

    # Do the merge
    for other in traces:
        logging.info("Attaching database %s", other)

        # Attach the other trace
        conn.execute(
            '''
            ATTACH DATABASE ? AS trace;
            ''',
            (str(other),))

        # Add runs to lookup table
        conn.execute(
            '''
            INSERT INTO maps.map_runs(old)
            SELECT DISTINCT run_id AS old
            FROM trace.processes
            ORDER BY run_id;
            ''')

        logging.info(
            "%d rows in maps.map_runs",
            list(conn.execute('SELECT COUNT(*) FROM maps.map_runs;'))[0][0])

        # Add processes to lookup table
        conn.execute(
            '''
            INSERT INTO maps.map_processes(old)
            SELECT id AS old
            FROM trace.processes
            ORDER BY id;
            ''')

        logging.info(
            "%d rows in maps.map_processes",
            list(conn.execute('SELECT COUNT(*) FROM maps.map_processes;'))
            [0][0])

        # processes
        logging.info("Insert processes...")
        conn.execute(
            '''
            INSERT INTO processes(id, run_id, parent,
                                       timestamp, is_thread, exitcode)
            SELECT p.new AS id, r.new AS run_id, parent,
                   timestamp, is_thread, exitcode
            FROM trace.processes t
            INNER JOIN maps.map_runs r ON t.run_id = r.old
            INNER JOIN maps.map_processes p ON t.id = p.old
            ORDER BY t.id;
            ''')

        # opened_files
        logging.info("Insert opened_files...")
        conn.execute(
            '''
            INSERT INTO opened_files(run_id, name, timestamp,
                                     mode, is_directory, process)
            SELECT r.new AS run_id, name, timestamp,
                   mode, is_directory, p.new AS process
            FROM trace.opened_files t
            INNER JOIN maps.map_runs r ON t.run_id = r.old
            INNER JOIN maps.map_processes p ON t.process = p.old
            ORDER BY t.id;
            ''')

        # executed_files
        logging.info("Insert executed_files...")
        conn.execute(
            '''
            INSERT INTO executed_files(name, run_id, timestamp, process,
                                       argv, envp, workingdir)
            SELECT name, r.new AS run_id, timestamp, p.new AS process,
                   argv, envp, workingdir
            FROM trace.executed_files t
            INNER JOIN maps.map_runs r ON t.run_id = r.old
            INNER JOIN maps.map_processes p ON t.process = p.old
            ORDER BY t.id;
            ''')

        # Flush maps
        conn.execute(
            '''
            DELETE FROM maps.map_runs;
            ''')
        conn.execute(
            '''
            DELETE FROM maps.map_processes;
            ''')

        # Detach
        conn.execute(
            '''
            DETACH DATABASE trace;
            ''')

    conn.execute(
        '''
        DETACH DATABASE maps;
        ''')

    conn.commit()
    conn.close()

    # Move database to final destination
    if not target.exists():
        target.mkdir()
    output.move(target / 'trace.sqlite3')

Example 25

Project: viewfinder
Source File: build_archive_op.py
View license
  @gen.coroutine
  def _BuildArchive(self):
    """Drive overall archive process as outlined in class header comment."""

    logging.info('building archive for user: %d' % self._user_id)

    # Prepare temporary destination folder (delete existing.  We'll always start from scratch).
    self._ResetArchiveDir()

    # Copy in base assets and javascript which will drive browser experience of content for users.
    proc = process.Subprocess(['cp',
                               '-R',
                               os.path.join(self._offboarding_assets_dir_path, 'web_code'),
                               self._content_dir_path])
    code = yield gen.Task(proc.set_exit_callback)
    if code != 0:
      logging.error('Error copying offboarding assets: %d' % code)
      raise IOError()

    # Top level iteration is over viewpoints.
    # For each viewpoint,
    #    iterate over activities and collect photos/episodes as needed.
    #    Build various 'tables' in json format:
    #        Activity, Comment, Episode, Photo, ...
    #
    viewpoints_dict = yield _QueryFollowedForArchive(self._client, self._user_id)
    viewpoint_ids = [viewpoint['viewpoint_id'] for viewpoint in viewpoints_dict['viewpoints']]
    followers_dict = yield _QueryViewpointsForArchive(self._client,
                                                           self._user_id,
                                                           viewpoint_ids,
                                                           get_followers=True)
    for viewpoint, followers in zip(viewpoints_dict['viewpoints'], followers_dict['viewpoints']):
      viewpoint['followers'] = followers
    # Query user info for all users referenced by any of the viewpoints.
    users_to_query = list({f['follower_id'] for vp in followers_dict['viewpoints'] for f in vp['followers']})
    users_dict = yield _QueryUsersForArchive(self._client, self._user_id, users_to_query)
    top_level_metadata_dict = dict(viewpoints_dict.items() + users_dict.items())

    # Write the top level metadata to the root of the archive.
    # TODO(mike): Consider moving this IO to thread pool to avoid blocking on main thread.
    with open(os.path.join(self._content_dir_path, 'viewpoints.jsn'), mode='wb') as f:
      # Need to set metadata as variable for JS code.
      f.write("viewfinder.jsonp_data =")
      json.dump(top_level_metadata_dict, f)

    # Now, process each viewpoint.
    for vp_dict in top_level_metadata_dict['viewpoints']:
      if Follower.REMOVED not in vp_dict['labels']:
        yield self._ProcessViewpoint(vp_dict)

    # Now, generate user specific view file: index.html.
    # This is the file that the user will open to launch the web client view of their data.
    recipient_user = yield gen.Task(User.Query, self._client, self._user_id, None)
    user_info = {'user_id' : recipient_user.user_id,
                 'name' : recipient_user.name,
                 'email' : recipient_user.email,
                 'phone' : recipient_user.phone,
                 'default_viewpoint_id' : recipient_user.private_vp_id
                 }
    view_local = ResourcesManager().Instance().GenerateTemplate('view_local.html',
                                                                user_info=user_info,
                                                                viewpoint_id=None)
    with open(os.path.join(self._content_dir_path, 'index.html'), mode='wb') as f:
      f.write(view_local)

    with open(os.path.join(self._content_dir_path, 'README.txt'), mode='wb') as f:
      f.write("This Viewfinder archive contains both a readable local HTML file " +
              "and backup folders including all photos included in those conversations.\n")

    # Exec zip command relative to the parent of content dir so that paths in zip are relative to that.
    proc = process.Subprocess(['zip',
                               '-r',
                               BuildArchiveOperation._ZIP_FILE_NAME,
                               BuildArchiveOperation._CONTENT_DIR_NAME],
                              cwd=self._temp_dir_path)
    code = yield gen.Task(proc.set_exit_callback)
    if code != 0:
      logging.error('Error creating offboarding zip file: %d' % code)
      raise IOError()

    # Key is: "{user_id}/{timestamp}_{random}/Viewfinder.zip"
    # timestamp is utc unix timestamp.
    s3_key = '%d/%d_%d/Viewfinder.zip' % (self._user_id,
                               calendar.timegm(datetime.datetime.utcnow().utctimetuple()),
                               int(random.random() * 1000000))

    if options.options.fileobjstore:
      # Next, upload this to S3 (really fileobjstore in this case).
      with open(self._zip_file_path, mode='rb') as f:
        s3_data = f.read()
      yield gen.Task(self._user_zips_obj_store.Put, s3_key, s3_data)
    else:
      # Running against AWS S3, so use awscli to upload zip file into S3.
      s3_path = 's3://' + ObjectStore.USER_ZIPS_BUCKET + '/' + s3_key

      # Use awscli to copy file into S3.
      proc = process.Subprocess(['aws', 's3', 'cp', self._zip_file_path, s3_path, '--region', 'us-east-1'],
                                stdout=process.Subprocess.STREAM,
                                stderr=process.Subprocess.STREAM,
                                env={'AWS_ACCESS_KEY_ID': GetSecret('aws_access_key_id'),
                                     'AWS_SECRET_ACCESS_KEY': GetSecret('aws_secret_access_key')})

      result, error, code = yield [
        gen.Task(proc.stdout.read_until_close),
        gen.Task(proc.stderr.read_until_close),
        gen.Task(proc.set_exit_callback)
      ]

      if code != 0:
        logging.error("%d = 'aws s3 cp %s %s': %s" % (code, self._zip_file_path, s3_path, error))
        if result and len(result) > 0:
          logging.info("aws result: %s" % result)
        raise IOError()

    # Generate signed URL to S3 for given user zip.  Only allow link to live for 3 days.
    s3_url = self._user_zips_obj_store.GenerateUrl(s3_key,
                                                   cache_control='private,max-age=%d' %
                                                                 self._S3_ZIP_FILE_ACCESS_EXPIRATION,
                                                   expires_in=3 * self._S3_ZIP_FILE_ACCESS_EXPIRATION)
    logging.info('user zip uploaded: %s' % s3_url)

    # Finally, send the user an email with the link to download the zip files just uploaded to s3.
    email_args = {'from': EmailManager.Instance().GetInfoAddress(),
                  'to': self._email,
                  'subject': 'Your Viewfinder archive download is ready'}

    fmt_args = {'archive_url': s3_url,
                'hello_name': recipient_user.given_name or recipient_user.name}
    email_args['text'] = ResourcesManager.Instance().GenerateTemplate('user_zip.email', is_html=False, **fmt_args)
    yield gen.Task(EmailManager.Instance().SendEmail, description='user archive zip', **email_args)

Example 26

Project: viewfinder
Source File: fetch_contacts_op.py
View license
  @gen.coroutine
  def _FetchFacebookContacts(self):
    """Do Facebook specific data gathering and checking.
    Queries Facebook graph API for friend list using the identity's access token.
    """
    @gen.coroutine
    def _DetermineFacebookRankings():
      """Uses The tags from friends and the authors of the
      photos are used to determine friend rank for facebook contacts. The
      basic algorithm is:

      sorted([sum(exp_decay(pc.time) * strength(pc)) for pc in photos])

      A 'pc' in is a photo connection. There are three types, ordered by
      the 'strength' they impart in the summation equation:
        - from: the poster of a photo (strength=1.0)
        - tag: another user tagged in the photo (strength=1.0)
        - like: a facebook user who 'liked' the photo (strength=0.25)
      Exponential decay uses _FACEBOOK_CONNECTION_HALF_LIFE for half life.

      The rankings are passed to the provided callback as a dictionary of
      identity ('FacebookGraph:<id>') => rank.
      """
      logging.info('determining facebook contact rankings for identity %r...' % self._identity)
      http_client = httpclient.AsyncHTTPClient()
      friends = dict()  # facebook id => connection strength
      likes = dict()
      now = util.GetCurrentTimestamp()

      def _ComputeScore(create_iso8601, conn_type):
        """Computes the strength of a photo connection based on the time
        that's passed and the connection type.
        """
        decay = 0.001  # default is 1/1000th
        if create_iso8601:
          dt = iso8601.parse_date(create_iso8601)
          create_time = calendar.timegm(dt.utctimetuple())
          decay = math.exp(-math.log(2) * (now - create_time) /
                            FetchContactsOperation._FACEBOOK_CONNECTION_HALF_LIFE)
        return decay * FetchContactsOperation._PHOTO_CONNECTION_STRENGTHS[conn_type]

      # Construct the URL that will kick things off.
      url = FetchContactsOperation._FACEBOOK_PHOTOS_URL + '?' + \
          urllib.urlencode({'access_token': self._identity.access_token,
                            'format': 'json', 'limit': FetchContactsOperation._MAX_FETCH_COUNT})
      while True:
        logging.info('querying next %d Facebook photos for user %d' %
                     (FetchContactsOperation._MAX_FETCH_COUNT, self._user_id))
        response = yield gen.Task(http_client.fetch, url, method='GET')
        response_dict = www_util.ParseJSONResponse(response)
        for p_dict in response_dict['data']:
          created_time = p_dict.get('created_time', None)
          if p_dict.get('from', None) and p_dict['from']['id']:
            from_id = p_dict['from']['id']
            friends[from_id] = friends.get(from_id, 0.0) + \
                _ComputeScore(created_time, 'from')

          if p_dict.get('tags', None):
            for tag in p_dict['tags']['data']:
              if tag.get('id', None) is not None:
                friends[tag['id']] = friends.get(tag['id'], 0.0) + \
                    _ComputeScore(tag.get('created_time', None), 'tag')

          if p_dict.get('likes', None):
            for like in p_dict['likes']['data']:
              if like.get('id', None) is not None:
                likes[like['id']] = likes.get(like['id'], 0.0) + \
                    _ComputeScore(created_time, 'like')

        if (len(response_dict['data']) == FetchContactsOperation._MAX_FETCH_COUNT and
            response_dict.has_key('paging') and response_dict['paging'].has_key('next')):
          url = response_dict['paging']['next']
        else:
          for fb_id in friends.keys():
            friends[fb_id] += likes.get(fb_id, 0.0)
          ranked_friends = sorted(friends.items(), key=itemgetter(1), reverse=True)
          logging.info('successfully ranked %d Facebook contacts for user %d' %
                       (len(ranked_friends), self._user_id))
          raise gen.Return(dict([('FacebookGraph:%s' % fb_id, rank) for rank, (fb_id, _) in \
                                izip(xrange(len(ranked_friends)), ranked_friends)]))

    logging.info('fetching Facebook contacts for identity %r...' % self._identity)
    http_client = httpclient.AsyncHTTPClient()
    # Track fetched contacts regardless of rank in order to dedup contacts retrieved from Facebook.
    rankless_ids = set()

    # First get the rankings and then fetch the contacts.
    rankings = yield _DetermineFacebookRankings()
    url = FetchContactsOperation._FACEBOOK_FRIENDS_URL + '?' + \
        urllib.urlencode({'fields': 'first_name,name,last_name',
                          'access_token': self._identity.access_token,
                          'format': 'json', 'limit': FetchContactsOperation._MAX_FETCH_COUNT})
    retries = 0
    while True:
      if retries >= FetchContactsOperation._MAX_FETCH_RETRIES:
        raise TooManyRetriesError('failed to fetch contacts %d times; aborting' % retries)
      logging.info('fetching next %d Facebook contacts for user %d' %
                   (FetchContactsOperation._MAX_FETCH_COUNT, self._user_id))
      response = yield gen.Task(http_client.fetch, url, method='GET')
      try:
        response_dict = www_util.ParseJSONResponse(response)
      except Exception as exc:
        logging.warning('failed to fetch Facebook contacts: %s' % exc)
        retries += 1
        continue

      for c_dict in response_dict['data']:
        if c_dict.has_key('id'):
          ident = 'FacebookGraph:%s' % c_dict['id']

          # Skip contact if name is not present, or is empty.
          name = c_dict.get('name', None)
          if name:
            names = {'name': name,
                     'given_name': c_dict.get('first_name', None),
                     'family_name': c_dict.get('last_name', None)}

            # Check to see if we've already processed an identical contact.
            rankless_id = Contact.CalculateContactEncodedDigest(identities_properties=[(ident, None)], **names)
            if rankless_id in rankless_ids:
              # Duplicate among fetched contacts. Skip it.
              continue
            else:
              rankless_ids.add(rankless_id)

            rank = rankings[ident] if ident in rankings else None
            fetched_contact = Contact.CreateFromKeywords(self._user_id,
                                                         [(ident, None)],
                                                         self._notify_timestamp,
                                                         Contact.FACEBOOK,
                                                         rank=rank,
                                                         **names)
            self._fetched_contacts[fetched_contact.contact_id] = fetched_contact

      # Prepare to fetch next batch.
      if (len(response_dict['data']) == FetchContactsOperation._MAX_FETCH_COUNT and
          response_dict.has_key('paging') and response_dict['paging'].has_key('next')):
        retries = 0
        url = response_dict['paging']['next']
      else:
        break

Example 27

View license
    def minion(self,  storage=None, *args, **xargs):        
        self.app = self.Verum.app(self.parent.PluginFolder, None)
        # set storage
        if storage is None:
            storage = self.parent.storage
        self.app.set_interface(storage)

        # Check until stopped
        while not self.shutdown:
            # Check to see if it's the same day, if it is, sleep for a while, otherwise run the import
            delta = datetime.utcnow() - self.today
            if delta.days <= 0:
                time.sleep(SLEEP_TIME)
            else:
                logging.info("Starting daily {0} enrichment.".format(NAME))

                # Get the file
                r = requests.get(FEED)

                # split it out
                feed = r.text.split("\n")

                # Create list of IPs for cymru enrichment
                ips = set()

                for row in feed:
                    # Parse date
                    l = row.find("Feed generated at:")
                    if l > -1:
                        dt = row[l+18:].strip()
                        dt = dateutil.parser.parse(dt).strftime("%Y-%m-%dT%H:%M:%SZ")
                        next
                    row = row.split(",")

                    # if it's a record, parse the record
                    if len(row) == 6:
                        try:
                            # split out sub values
                            # row[0] -> domain
                            row[1] = row[1].split("|")  # ip
                            row[2] = row[2].split("|")  # nameserver domain
                            row[3] = row[3].split("|")  # nameserver ip
                            row[4] = row[4][26:-22]  # malware
                            # row[5] -> source

                            # Validate data in row
                            ext = tldextract.extract(row[0])
                            if not ext.domain or not ext.suffix:
                                # domain is not legitimate
                                next
                            l = list()
                            for ip in row[1]:
                                try:
                                    _ = ipaddress.ip_address(unicode(ip))
                                    l.append(ip)
                                except:
                                    pass
                            row[1] = copy.deepcopy(l)
                            l = list()
                            for domain in row[2]:
                                ext = tldextract.extract(domain)
                                if ext.domain and ext.suffix:
                                    l.append(domain)
                            row[2] = copy.deepcopy(l)
                            l = list()
                            for ip in row[3]:
                                try:
                                    _ = ipaddress.ip_address(unicode(ip))
                                    l.append(ip)
                                except:
                                    pass
                            row[3] = copy.deepcopy(l)

                            # add the ips to the set of ips
                            ips = ips.union(set(row[1])).union(set(row[3]))

                            g = nx.MultiDiGraph()

                            # Add indicator to graph
                            ## (Must account for the different types of indicators)
                            target_uri = "class=attribute&key={0}&value={1}".format('domain', row[0]) 
                            g.add_node(target_uri, {
                                'class': 'attribute',
                                'key': 'domain',
                                "value": row[0],
                                "start_time": dt,
                                "uri": target_uri
                            })


                            # Threat node
                            threat_uri = "class=attribute&key={0}&value={1}".format("malware", row[4]) 
                            g.add_node(threat_uri, {
                                'class': 'attribute',
                                'key': "malware",
                                "value": row[4],
                                "start_time": dt,
                                "uri": threat_uri
                            })

                            # Threat Edge
                            edge_attr = {
                                "relationship": "describedBy",
                                "origin": row[5],
                                "start_time": dt
                            }
                            source_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_uri)
                            dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, threat_uri)
                            edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                            rel_chain = "relationship"
                            while rel_chain in edge_attr:
                                edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                rel_chain = edge_attr[rel_chain]
                            if "origin" in edge_attr:
                                edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                            edge_attr["uri"] = edge_uri
                            g.add_edge(target_uri, threat_uri, edge_uri, edge_attr)                        

                            # for each IP associated with the domain, connect it to the target
                            for ip in row[1]:
                                # Create IP node
                                target_ip_uri = "class=attribute&key={0}&value={1}".format("ip", ip) 
                                g.add_node(target_ip_uri, {
                                    'class': 'attribute',
                                    'key': "ip",
                                    "value": ip,
                                    "start_time": dt,
                                    "uri": target_ip_uri
                                })

                                # ip Edge
                                edge_attr = {
                                    "relationship": "describedBy",
                                    "origin": row[5],
                                    "start_time": dt,
                                }
                                source_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_uri)
                                dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_ip_uri)
                                edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                rel_chain = "relationship"
                                while rel_chain in edge_attr:
                                    edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                    rel_chain = edge_attr[rel_chain]
                                if "origin" in edge_attr:
                                    edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                edge_attr["uri"] = edge_uri
                                g.add_edge(target_uri, target_ip_uri, edge_uri, edge_attr)


                            for nameserver in row[2]:
                                # Create nameserver node
                                ns_uri = "class=attribute&key={0}&value={1}".format("domain", nameserver) 
                                g.add_node(ns_uri, {
                                    'class': 'attribute',
                                    'key': "domain",
                                    "value": nameserver,
                                    "start_time": dt,
                                    "uri": ns_uri
                                })

                                # nameserver Edge
                                edge_attr = {
                                    "relationship": "describedBy",
                                    "origin": row[5],
                                    "start_time": dt,
                                    'describedBy': 'nameserver'
                                }
                                source_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_uri)
                                dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, target_ip_uri)
                                edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                rel_chain = "relationship"
                                while rel_chain in edge_attr:
                                    edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                    rel_chain = edge_attr[rel_chain]
                                if "origin" in edge_attr:
                                    edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                edge_attr["uri"] = edge_uri
                                g.add_edge(target_uri, ns_uri, edge_uri, edge_attr)

                            # if the number of NS IPs is a multiple of the # of NS's, we'll aassume each NS gets some of the ips
                            if len(row[2]) and len(row[3]) % len(row[2]) == 0:
                                for i in range(len(row[2])):
                                    for j in range(len(row[3])/len(row[2])):
                                        # Create NS IP node
                                        ns_ip_uri = "class=attribute&key={0}&value={1}".format("ip", row[3][i*len(row[3])/len(row[2]) + j]) 
                                        g.add_node(ns_ip_uri, {
                                            'class': 'attribute',
                                            'key': "ip",
                                            "value": ip,
                                            "start_time": dt,
                                            "uri": ns_ip_uri
                                        })

                                        # create NS uri
                                        ns_uri = "class=attribute&key={0}&value={1}".format("domain", row[2][i]) 


                                        # link NS to IP
                                        edge_attr = {
                                            "relationship": "describedBy",
                                            "origin": row[5],
                                            "start_time": dt
                                        }
                                        source_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_uri)
                                        dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_ip_uri)
                                        edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                        rel_chain = "relationship"
                                        while rel_chain in edge_attr:
                                            edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                            rel_chain = edge_attr[rel_chain]
                                        if "origin" in edge_attr:
                                            edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                        edge_attr["uri"] = edge_uri
                                        g.add_edge(ns_uri, ns_ip_uri, edge_uri, edge_attr)

                            # otherwise we'll attach each IP to each NS
                            else:
                                for ip in row[3]:
                                    # Create NS IP node
                                    ns_ip_uri = "class=attribute&key={0}&value={1}".format("ip", ip) 
                                    g.add_node(ns_ip_uri, {
                                        'class': 'attribute',
                                        'key': "ip",
                                        "value": ip,
                                        "start_time": dt,
                                        "uri": ns_ip_uri
                                    })
                                    
                                    for ns in row[2]:
                                        # create NS uri
                                        ns_uri = "class=attribute&key={0}&value={1}".format("domain", ns)

                                         # link NS to IP
                                        edge_attr = {
                                            "relationship": "describedBy",
                                            "origin": row[5],
                                            "start_time": dt
                                        }
                                        source_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_uri)
                                        dest_hash = uuid.uuid3(uuid.NAMESPACE_URL, ns_ip_uri)
                                        edge_uri = "source={0}&destionation={1}".format(str(source_hash), str(dest_hash))
                                        rel_chain = "relationship"
                                        while rel_chain in edge_attr:
                                            edge_uri = edge_uri + "&{0}={1}".format(rel_chain,edge_attr[rel_chain])
                                            rel_chain = edge_attr[rel_chain]
                                        if "origin" in edge_attr:
                                            edge_uri += "&{0}={1}".format("origin", edge_attr["origin"])
                                        edge_attr["uri"] = edge_uri
                                        g.add_edge(ns_uri, ns_ip_uri, edge_uri, edge_attr)

                            # classify malicious and merge with current graph
                            g = self.Verum.merge_graphs(g, self.app.classify.run({'key': 'domain', 'value': row[0], 'classification': 'malice'}))

                            # enrich depending on type
                            for domain in [row[0]] + row[2]:
                                try:
                                    g = self.Verum.merge_graphs(g, self.app.run_enrichments(domain, "domain", names=['TLD Enrichment']))
                                    g = self.Verum.merge_graphs(g, self.app.run_enrichments(domain, "domain", names=['IP Whois Enrichment']))
                                except Exception as e:
                                    logging.info("Enrichment of {0} failed due to {1}.".format(domain, e))
                                    #print "Enrichment of {0} failed due to {1}.".format(domain, e)  # DEBUG
                                    #raise
                                    pass
                            for ip in row[1] + row[3]:
                                try:
                                    g = self.Verum.merge_graphs(g, self.app.run_enrichments(ip, "ip", names=[u'Maxmind ASN Enrichment']))
                                except Exception as e:
                                    logging.info("Enrichment of {0} failed due to {1}.".format(ip, e))
                                    pass

                            try:
                                self.app.store_graph(self.Verum.remove_non_ascii_from_graph(g))
                            except:
                                print g.nodes(data=True)  # DEBUG
                                print g.edges(data=True)  # DEBUG
                                raise

                            # Do cymru enrichment
                            if len(ips) >= 50:
                                # validate IPs
                                ips2 = set()
                                for ip in ips:
                                    try:
                                        _ = ipaddress.ip_address(unicode(ip))
                                        ips2.add(ip)
                                    except:
                                        pass
                                ips = ips2
                                del(ips2)
                                try:
                                    self.app.store_graph(self.app.run_enrichments(ips, 'ip', names=[u'Cymru Enrichment']))
                                    #print "Cymru enrichment complete."
                                except Exception as e:
                                    logging.info("Cymru enrichment of {0} IPs failed due to {1}.".format(len(ips), e))
                                    #print "Cymru enrichment of {0} IPs failed due to {1}.".format(len(ips), e)  # DEBUG
                                    pass
                                ips = set()

                        except Exception as e:
                            print row
                            print e
                            raise

                # Copy today's date to today
                self.today = datetime.utcnow()

                logging.info("Daily {0} enrichment complete.".format(NAME))
                print "Daily {0} enrichment complete.".format(NAME)  # DEBUG

Example 28

Project: starcheat
Source File: mainwindow.py
View license
    def __init__(self):
        # check for new starcheat version online in seperate thread
        update_result = [None]
        update_thread = Thread(target=update_check_worker, args=[update_result], daemon=True)
        update_thread.start()

        """Display the main starcheat window."""
        self.app = QApplication(sys.argv)
        self.window = StarcheatMainWindow(self)
        self.ui = qt_mainwindow.Ui_MainWindow()
        self.ui.setupUi(self.window)

        logging.info("Main window init")

        self.players = None
        self.filename = None

        self.item_browser = None
        # remember the last selected item browser category
        self.remember_browser = "<all>"
        self.options_dialog = None
        self.preview_armor = True
        self.preview_bg = "#ffffff"

        # connect action menu
        self.ui.actionSave.triggered.connect(self.save)
        self.ui.actionReload.triggered.connect(self.reload)
        self.ui.actionOpen.triggered.connect(self.open_file)
        self.ui.actionQuit.triggered.connect(self.app.closeAllWindows)
        self.ui.actionOptions.triggered.connect(self.new_options_dialog)
        self.ui.actionItemBrowser.triggered.connect(self.new_item_browser)
        self.ui.actionAbout.triggered.connect(self.new_about_dialog)
        self.ui.actionMods.triggered.connect(self.new_mods_dialog)
        self.ui.actionImageBrowser.triggered.connect(self.new_image_browser_dialog)

        self.ui.actionExportPlayerBinary.triggered.connect(self.export_save)
        self.ui.actionExportPlayerJSON.triggered.connect(self.export_json)
        self.ui.actionImportPlayerBinary.triggered.connect(self.import_save)
        self.ui.actionImportPlayerJSON.triggered.connect(self.import_json)

        # set up bag tables
        bags = ("wieldable", "head", "chest", "legs", "back", "main_bag",
                "action_bar", "object_bag", "tile_bag", "essentials", "mouse")
        for bag in bags:
            logging.debug("Setting up %s bag", bag)
            self.bag_setup(getattr(self.ui, bag), bag)

        self.preview_setup()

        # signals
        self.ui.blueprints_button.clicked.connect(self.new_blueprint_edit)
        self.ui.appearance_button.clicked.connect(self.new_appearance_dialog)
        self.ui.techs_button.clicked.connect(self.new_techs_dialog)
        self.ui.quests_button.clicked.connect(self.new_quests_dialog)
        self.ui.ship_button.clicked.connect(self.new_ship_dialog)

        self.ui.name.textChanged.connect(self.set_name)
        self.ui.male.clicked.connect(self.set_gender)
        self.ui.female.clicked.connect(self.set_gender)
        self.ui.description.textChanged.connect(self.set_description)
        self.ui.pixels.valueChanged.connect(self.set_pixels)

        self.ui.health.valueChanged.connect(lambda: self.set_stat_slider("health"))
        self.ui.energy.valueChanged.connect(lambda: self.set_stat_slider("energy"))
        self.ui.health_button.clicked.connect(lambda: self.max_stat("health"))
        self.ui.energy_button.clicked.connect(lambda: self.max_stat("energy"))

        self.ui.copy_uuid_button.clicked.connect(self.copy_uuid)

        self.window.setWindowModified(False)

        logging.debug("Showing main window")
        self.window.show()

        # launch first setup if we need to
        if not new_setup_dialog(self.window):
            logging.error("Config/index creation failed")
            return
        logging.info("Starbound folder: %s", Config().read("starbound_folder"))

        logging.info("Checking assets hash")
        if not check_index_valid(self.window):
            logging.error("Index creation failed")
            return

        logging.info("Loading assets database")
        self.assets = Assets(Config().read("assets_db"),
                             Config().read("starbound_folder"))
        self.items = self.assets.items()

        # populate species combobox
        for species in self.assets.species().get_species_list():
            self.ui.race.addItem(species)
        self.ui.race.currentTextChanged.connect(self.update_species)

        # populate game mode combobox
        for mode in sorted(self.assets.player().mode_types.values()):
            self.ui.game_mode.addItem(mode)
        self.ui.game_mode.currentTextChanged.connect(self.set_game_mode)

        # launch open file dialog
        self.player = None
        logging.debug("Open file dialog")
        open_player = self.open_file()
        # we *need* at least an initial save file
        if not open_player:
            logging.warning("No player file selected")
            return

        self.ui.name.setFocus()

        # block for update check result (should be ready now)
        update_thread.join()
        if update_result[0]:
            update_check_dialog(self.window, update_result[0])

        sys.exit(self.app.exec_())

Example 29

Project: gfw-api
Source File: discovery.py
View license
def createResource(http, baseUrl, model, requestBuilder,
                   developerKey, resourceDesc, futureDesc, schema):

  class Resource(object):
    """A class for interacting with a resource."""

    def __init__(self):
      self._http = http
      self._baseUrl = baseUrl
      self._model = model
      self._developerKey = developerKey
      self._requestBuilder = requestBuilder

  def createMethod(theclass, methodName, methodDesc, futureDesc):
    methodName = _fix_method_name(methodName)
    pathUrl = methodDesc['path']
    httpMethod = methodDesc['httpMethod']
    methodId = methodDesc['id']

    mediaPathUrl = None
    accept = []
    maxSize = 0
    if 'mediaUpload' in methodDesc:
      mediaUpload = methodDesc['mediaUpload']
      mediaPathUrl = mediaUpload['protocols']['simple']['path']
      mediaResumablePathUrl = mediaUpload['protocols']['resumable']['path']
      accept = mediaUpload['accept']
      maxSize = _media_size_to_long(mediaUpload.get('maxSize', ''))

    if 'parameters' not in methodDesc:
      methodDesc['parameters'] = {}
    for name in STACK_QUERY_PARAMETERS:
      methodDesc['parameters'][name] = {
          'type': 'string',
          'location': 'query'
          }

    if httpMethod in ['PUT', 'POST', 'PATCH']:
      methodDesc['parameters']['body'] = {
          'description': 'The request body.',
          'type': 'object',
          'required': True,
          }
      if 'mediaUpload' in methodDesc:
        methodDesc['parameters']['media_body'] = {
            'description': 'The filename of the media request body.',
            'type': 'string',
            'required': False,
            }
        methodDesc['parameters']['body']['required'] = False

    argmap = {} # Map from method parameter name to query parameter name
    required_params = [] # Required parameters
    repeated_params = [] # Repeated parameters
    pattern_params = {}  # Parameters that must match a regex
    query_params = [] # Parameters that will be used in the query string
    path_params = {} # Parameters that will be used in the base URL
    param_type = {} # The type of the parameter
    enum_params = {} # Allowable enumeration values for each parameter


    if 'parameters' in methodDesc:
      for arg, desc in methodDesc['parameters'].iteritems():
        param = key2param(arg)
        argmap[param] = arg

        if desc.get('pattern', ''):
          pattern_params[param] = desc['pattern']
        if desc.get('enum', ''):
          enum_params[param] = desc['enum']
        if desc.get('required', False):
          required_params.append(param)
        if desc.get('repeated', False):
          repeated_params.append(param)
        if desc.get('location') == 'query':
          query_params.append(param)
        if desc.get('location') == 'path':
          path_params[param] = param
        param_type[param] = desc.get('type', 'string')

    for match in URITEMPLATE.finditer(pathUrl):
      for namematch in VARNAME.finditer(match.group(0)):
        name = key2param(namematch.group(0))
        path_params[name] = name
        if name in query_params:
          query_params.remove(name)

    def method(self, **kwargs):
      for name in kwargs.iterkeys():
        if name not in argmap:
          raise TypeError('Got an unexpected keyword argument "%s"' % name)

      for name in required_params:
        if name not in kwargs:
          raise TypeError('Missing required parameter "%s"' % name)

      for name, regex in pattern_params.iteritems():
        if name in kwargs:
          if isinstance(kwargs[name], basestring):
            pvalues = [kwargs[name]]
          else:
            pvalues = kwargs[name]
          for pvalue in pvalues:
            if re.match(regex, pvalue) is None:
              raise TypeError(
                  'Parameter "%s" value "%s" does not match the pattern "%s"' %
                  (name, pvalue, regex))

      for name, enums in enum_params.iteritems():
        if name in kwargs:
          if kwargs[name] not in enums:
            raise TypeError(
                'Parameter "%s" value "%s" is not an allowed value in "%s"' %
                (name, kwargs[name], str(enums)))

      actual_query_params = {}
      actual_path_params = {}
      for key, value in kwargs.iteritems():
        to_type = param_type.get(key, 'string')
        # For repeated parameters we cast each member of the list.
        if key in repeated_params and type(value) == type([]):
          cast_value = [_cast(x, to_type) for x in value]
        else:
          cast_value = _cast(value, to_type)
        if key in query_params:
          actual_query_params[argmap[key]] = cast_value
        if key in path_params:
          actual_path_params[argmap[key]] = cast_value
      body_value = kwargs.get('body', None)
      media_filename = kwargs.get('media_body', None)

      if self._developerKey:
        actual_query_params['key'] = self._developerKey

      headers = {}
      headers, params, query, body = self._model.request(headers,
          actual_path_params, actual_query_params, body_value)

      expanded_url = uritemplate.expand(pathUrl, params)
      url = urlparse.urljoin(self._baseUrl, expanded_url + query)

      resumable = None
      multipart_boundary = ''

      if media_filename:
        # Convert a simple filename into a MediaUpload object.
        if isinstance(media_filename, basestring):
          (media_mime_type, encoding) = mimetypes.guess_type(media_filename)
          if media_mime_type is None:
            raise UnknownFileType(media_filename)
          if not mimeparse.best_match([media_mime_type], ','.join(accept)):
            raise UnacceptableMimeTypeError(media_mime_type)
          media_upload = MediaFileUpload(media_filename, media_mime_type)
        elif isinstance(media_filename, MediaUpload):
          media_upload = media_filename
        else:
          raise TypeError(
              'media_filename must be str or MediaUpload. Got %s' % type(media_upload))

        if media_upload.resumable():
          resumable = media_upload

        # Check the maxSize
        if maxSize > 0 and media_upload.size() > maxSize:
          raise MediaUploadSizeError("Media larger than: %s" % maxSize)

        # Use the media path uri for media uploads
        if media_upload.resumable():
          expanded_url = uritemplate.expand(mediaResumablePathUrl, params)
        else:
          expanded_url = uritemplate.expand(mediaPathUrl, params)
        url = urlparse.urljoin(self._baseUrl, expanded_url + query)

        if body is None:
          # This is a simple media upload
          headers['content-type'] = media_upload.mimetype()
          expanded_url = uritemplate.expand(mediaResumablePathUrl, params)
          if not media_upload.resumable():
            body = media_upload.getbytes(0, media_upload.size())
        else:
          # This is a multipart/related upload.
          msgRoot = MIMEMultipart('related')
          # msgRoot should not write out it's own headers
          setattr(msgRoot, '_write_headers', lambda self: None)

          # attach the body as one part
          msg = MIMENonMultipart(*headers['content-type'].split('/'))
          msg.set_payload(body)
          msgRoot.attach(msg)

          # attach the media as the second part
          msg = MIMENonMultipart(*media_upload.mimetype().split('/'))
          msg['Content-Transfer-Encoding'] = 'binary'

          if media_upload.resumable():
            # This is a multipart resumable upload, where a multipart payload
            # looks like this:
            #
            #  --===============1678050750164843052==
            #  Content-Type: application/json
            #  MIME-Version: 1.0
            #
            #  {'foo': 'bar'}
            #  --===============1678050750164843052==
            #  Content-Type: image/png
            #  MIME-Version: 1.0
            #  Content-Transfer-Encoding: binary
            #
            #  <BINARY STUFF>
            #  --===============1678050750164843052==--
            #
            # In the case of resumable multipart media uploads, the <BINARY
            # STUFF> is large and will be spread across multiple PUTs.  What we
            # do here is compose the multipart message with a random payload in
            # place of <BINARY STUFF> and then split the resulting content into
            # two pieces, text before <BINARY STUFF> and text after <BINARY
            # STUFF>. The text after <BINARY STUFF> is the multipart boundary.
            # In apiclient.http the HttpRequest will send the text before
            # <BINARY STUFF>, then send the actual binary media in chunks, and
            # then will send the multipart delimeter.

            payload = hex(random.getrandbits(300))
            msg.set_payload(payload)
            msgRoot.attach(msg)
            body = msgRoot.as_string()
            body, _ = body.split(payload)
            resumable = media_upload
          else:
            payload = media_upload.getbytes(0, media_upload.size())
            msg.set_payload(payload)
            msgRoot.attach(msg)
            body = msgRoot.as_string()

          multipart_boundary = msgRoot.get_boundary()
          headers['content-type'] = ('multipart/related; '
                                     'boundary="%s"') % multipart_boundary

      logging.info('URL being requested: %s' % url)
      return self._requestBuilder(self._http,
                                  self._model.response,
                                  url,
                                  method=httpMethod,
                                  body=body,
                                  headers=headers,
                                  methodId=methodId,
                                  resumable=resumable)

    docs = [methodDesc.get('description', DEFAULT_METHOD_DOC), '\n\n']
    if len(argmap) > 0:
      docs.append('Args:\n')
    for arg in argmap.iterkeys():
      if arg in STACK_QUERY_PARAMETERS:
        continue
      repeated = ''
      if arg in repeated_params:
        repeated = ' (repeated)'
      required = ''
      if arg in required_params:
        required = ' (required)'
      paramdesc = methodDesc['parameters'][argmap[arg]]
      paramdoc = paramdesc.get('description', 'A parameter')
      paramtype = paramdesc.get('type', 'string')
      docs.append('  %s: %s, %s%s%s\n' % (arg, paramtype, paramdoc, required,
                                          repeated))
      enum = paramdesc.get('enum', [])
      enumDesc = paramdesc.get('enumDescriptions', [])
      if enum and enumDesc:
        docs.append('    Allowed values\n')
        for (name, desc) in zip(enum, enumDesc):
          docs.append('      %s - %s\n' % (name, desc))

    setattr(method, '__doc__', ''.join(docs))
    setattr(theclass, methodName, method)

  def createNextMethodFromFuture(theclass, methodName, methodDesc, futureDesc):
    """ This is a legacy method, as only Buzz and Moderator use the future.json
    functionality for generating _next methods. It will be kept around as long
    as those API versions are around, but no new APIs should depend upon it.
    """
    methodName = _fix_method_name(methodName)
    methodId = methodDesc['id'] + '.next'

    def methodNext(self, previous):
      """Retrieve the next page of results.

      Takes a single argument, 'body', which is the results
      from the last call, and returns the next set of items
      in the collection.

      Returns:
        None if there are no more items in the collection.
      """
      if futureDesc['type'] != 'uri':
        raise UnknownLinkType(futureDesc['type'])

      try:
        p = previous
        for key in futureDesc['location']:
          p = p[key]
        url = p
      except (KeyError, TypeError):
        return None

      url = _add_query_parameter(url, 'key', self._developerKey)

      headers = {}
      headers, params, query, body = self._model.request(headers, {}, {}, None)

      logging.info('URL being requested: %s' % url)
      resp, content = self._http.request(url, method='GET', headers=headers)

      return self._requestBuilder(self._http,
                                  self._model.response,
                                  url,
                                  method='GET',
                                  headers=headers,
                                  methodId=methodId)

    setattr(theclass, methodName, methodNext)

  def createNextMethod(theclass, methodName, methodDesc, futureDesc):
    methodName = _fix_method_name(methodName)
    methodId = methodDesc['id'] + '.next'

    def methodNext(self, previous_request, previous_response):
      """Retrieves the next page of results.

      Args:
        previous_request: The request for the previous page.
        previous_response: The response from the request for the previous page.

      Returns:
        A request object that you can call 'execute()' on to request the next
        page. Returns None if there are no more items in the collection.
      """
      # Retrieve nextPageToken from previous_response
      # Use as pageToken in previous_request to create new request.

      if 'nextPageToken' not in previous_response:
        return None

      request = copy.copy(previous_request)

      pageToken = previous_response['nextPageToken']
      parsed = list(urlparse.urlparse(request.uri))
      q = parse_qsl(parsed[4])

      # Find and remove old 'pageToken' value from URI
      newq = [(key, value) for (key, value) in q if key != 'pageToken']
      newq.append(('pageToken', pageToken))
      parsed[4] = urllib.urlencode(newq)
      uri = urlparse.urlunparse(parsed)

      request.uri = uri

      logging.info('URL being requested: %s' % uri)

      return request

    setattr(theclass, methodName, methodNext)


  # Add basic methods to Resource
  if 'methods' in resourceDesc:
    for methodName, methodDesc in resourceDesc['methods'].iteritems():
      if futureDesc:
        future = futureDesc['methods'].get(methodName, {})
      else:
        future = None
      createMethod(Resource, methodName, methodDesc, future)

  # Add in nested resources
  if 'resources' in resourceDesc:

    def createResourceMethod(theclass, methodName, methodDesc, futureDesc):
      methodName = _fix_method_name(methodName)

      def methodResource(self):
        return createResource(self._http, self._baseUrl, self._model,
                              self._requestBuilder, self._developerKey,
                              methodDesc, futureDesc, schema)

      setattr(methodResource, '__doc__', 'A collection resource.')
      setattr(methodResource, '__is_resource__', True)
      setattr(theclass, methodName, methodResource)

    for methodName, methodDesc in resourceDesc['resources'].iteritems():
      if futureDesc and 'resources' in futureDesc:
        future = futureDesc['resources'].get(methodName, {})
      else:
        future = {}
      createResourceMethod(Resource, methodName, methodDesc, future)

  # Add <m>_next() methods to Resource
  if futureDesc and 'methods' in futureDesc:
    for methodName, methodDesc in futureDesc['methods'].iteritems():
      if 'next' in methodDesc and methodName in resourceDesc['methods']:
        createNextMethodFromFuture(Resource, methodName + '_next',
                         resourceDesc['methods'][methodName],
                         methodDesc['next'])
  # Add _next() methods
  # Look for response bodies in schema that contain nextPageToken, and methods
  # that take a pageToken parameter.
  if 'methods' in resourceDesc:
    for methodName, methodDesc in resourceDesc['methods'].iteritems():
      if 'response' in methodDesc:
        responseSchema = methodDesc['response']
        if '$ref' in responseSchema:
          responseSchema = schema[responseSchema['$ref']]
        hasNextPageToken = 'nextPageToken' in responseSchema.get('properties',
                                                                 {})
        hasPageToken = 'pageToken' in methodDesc.get('parameters', {})
        if hasNextPageToken and hasPageToken:
          createNextMethod(Resource, methodName + '_next',
                           resourceDesc['methods'][methodName],
                           methodName)

  return Resource()

Example 30

Project: dpxdt
Source File: api.py
View license
@app.route('/api/report_run', methods=['POST'])
@auth.build_api_access_required
@utils.retryable_transaction()
def report_run():
    """Reports data for a run for a release candidate."""
    build = g.build
    release, run = _get_or_create_run(build)

    db.session.refresh(run, lockmode='update')

    current_url = request.form.get('url', type=str)
    current_image = request.form.get('image', type=str)
    current_log = request.form.get('log', type=str)
    current_config = request.form.get('config', type=str)

    ref_url = request.form.get('ref_url', type=str)
    ref_image = request.form.get('ref_image', type=str)
    ref_log = request.form.get('ref_log', type=str)
    ref_config = request.form.get('ref_config', type=str)

    diff_failed = request.form.get('diff_failed', type=str)
    diff_image = request.form.get('diff_image', type=str)
    diff_log = request.form.get('diff_log', type=str)

    distortion = request.form.get('distortion', default=None, type=float)
    run_failed = request.form.get('run_failed', type=str)

    if current_url:
        run.url = current_url
    if current_image:
        run.image = current_image
    if current_log:
        run.log = current_log
    if current_config:
        run.config = current_config
    if current_image or current_log or current_config:
        logging.info('Saving run data: build_id=%r, release_name=%r, '
                     'release_number=%d, run_name=%r, url=%r, '
                     'image=%r, log=%r, config=%r, run_failed=%r',
                     build.id, release.name, release.number, run.name,
                     run.url, run.image, run.log, run.config, run_failed)

    if ref_url:
        run.ref_url = ref_url
    if ref_image:
        run.ref_image = ref_image
    if ref_log:
        run.ref_log = ref_log
    if ref_config:
        run.ref_config = ref_config
    if ref_image or ref_log or ref_config:
        logging.info('Saved reference data: build_id=%r, release_name=%r, '
                     'release_number=%d, run_name=%r, ref_url=%r, '
                     'ref_image=%r, ref_log=%r, ref_config=%r',
                     build.id, release.name, release.number, run.name,
                     run.ref_url, run.ref_image, run.ref_log, run.ref_config)

    if diff_image:
        run.diff_image = diff_image
    if diff_log:
        run.diff_log = diff_log
    if distortion:
        run.distortion = distortion

    if diff_image or diff_log:
        logging.info('Saved pdiff: build_id=%r, release_name=%r, '
                     'release_number=%d, run_name=%r, diff_image=%r, '
                     'diff_log=%r, diff_failed=%r, distortion=%r',
                     build.id, release.name, release.number, run.name,
                     run.diff_image, run.diff_log, diff_failed, distortion)

    if run.image and run.diff_image:
        run.status = models.Run.DIFF_FOUND
    elif run.image and run.ref_image and not run.diff_log:
        run.status = models.Run.NEEDS_DIFF
    elif run.image and run.ref_image and not diff_failed:
        run.status = models.Run.DIFF_NOT_FOUND
    elif run.image and not run.ref_config:
        run.status = models.Run.NO_DIFF_NEEDED
    elif run_failed or diff_failed:
        run.status = models.Run.FAILED
    else:
        # NOTE: Intentionally do not transition state here in the default case.
        # We allow multiple background workers to be writing to the same Run in
        # parallel updating its various properties.
        pass

    # TODO: Verify the build has access to both the current_image and
    # the reference_sha1sum so they can't make a diff from a black image
    # and still see private data in the diff image.

    if run.status == models.Run.NEEDS_DIFF:
        task_id = '%s:%s:%s' % (run.id, run.image, run.ref_image)
        logging.info('Enqueuing pdiff task=%r', task_id)

        work_queue.add(
            constants.PDIFF_QUEUE_NAME,
            payload=dict(
                build_id=build.id,
                release_name=release.name,
                release_number=release.number,
                run_name=run.name,
                run_sha1sum=run.image,
                reference_sha1sum=run.ref_image,
            ),
            build_id=build.id,
            release_id=release.id,
            run_id=run.id,
            source='report_run',
            task_id=task_id)

    # Flush the run so querying for Runs in _check_release_done_processing
    # will be find the new run too and we won't deadlock.
    db.session.add(run)
    db.session.flush()

    _check_release_done_processing(release)
    db.session.commit()

    signals.run_updated_via_api.send(
        app, build=build, release=release, run=run)

    logging.info('Updated run: build_id=%r, release_name=%r, '
                 'release_number=%d, run_name=%r, status=%r',
                 build.id, release.name, release.number, run.name, run.status)

    return flask.jsonify(success=True)

Example 31

Project: SERT
Source File: train.py
View license
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--data',
                        type=argparse_utils.existing_file_path, required=True)
    parser.add_argument('--meta',
                        type=argparse_utils.existing_file_path, required=True)

    parser.add_argument('--type', choices=MODELS, required=True)

    parser.add_argument('--iterations',
                        type=argparse_utils.positive_int, default=1)

    parser.add_argument('--batch_size',
                        type=argparse_utils.positive_int, default=1024)

    parser.add_argument('--word_representation_size',
                        type=argparse_utils.positive_int, default=300)
    parser.add_argument('--representation_initializer',
                        type=argparse_utils.existing_file_path, default=None)

    # Specific to VectorSpaceLanguageModel.
    parser.add_argument('--entity_representation_size',
                        type=argparse_utils.positive_int, default=None)
    parser.add_argument('--num_negative_samples',
                        type=argparse_utils.positive_int, default=None)
    parser.add_argument('--one_hot_classes',
                        action='store_true',
                        default=False)

    parser.add_argument('--regularization_lambda',
                        type=argparse_utils.ratio, default=0.01)

    parser.add_argument('--model_output', type=str, required=True)

    args = parser.parse_args()

    if args.entity_representation_size is None:
        args.entity_representation_size = args.word_representation_size

    args.type = MODELS[args.type]

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging_utils.log_module_info(theano, lasagne, np, scipy)

    # Load data.
    logging.info('Loading data from %s.', args.data)
    data_sets = np.load(args.data)

    if 'w_train' in data_sets and not args.ignore_weights:
        w_train = data_sets['w_train']
    else:
        logging.warning('No weights found in data set; '
                        'assuming uniform instance weighting.')

        w_train = np.ones(data_sets['x_train'].shape[0], dtype=np.float32)

    training_set = (data_sets['x_train'], data_sets['y_train'][()], w_train)
    validation_set = (data_sets['x_validate'], data_sets['y_validate'][()])

    logging.info('Training instances: %s (%s) %s (%s) %s (%s)',
                 training_set[0].shape, training_set[0].dtype,
                 training_set[1].shape, training_set[1].dtype,
                 training_set[2].shape, training_set[2].dtype)
    logging.info('Validation instances: %s (%s) %s (%s)',
                 validation_set[0].shape, validation_set[0].dtype,
                 validation_set[1].shape, validation_set[1].dtype)

    num_entities = training_set[1].shape[1]
    assert num_entities > 1

    if args.one_hot_classes:
        logging.info('Transforming y-values to one-hot values.')

        if not scipy.sparse.issparse(training_set[1]) or \
           not scipy.sparse.issparse(validation_set[1]):
            raise RuntimeError(
                'Argument --one_hot_classes expects sparse truth values.')

        y_train, (x_train, w_train) = sparse_to_one_hot_multiple(
            training_set[1], training_set[0], training_set[2])

        training_set = (x_train, y_train, w_train)

        y_validate, (x_validate,) = sparse_to_one_hot_multiple(
            validation_set[1], validation_set[0])

        validation_set = (x_validate, y_validate)

    logging.info('Loading meta-data from %s.', args.meta)
    with open(args.meta, 'rb') as f:
        # We do not load the remaining of the vocabulary.
        data_args, words, tokens = (pickle.load(f) for _ in range(3))

        vocabulary_size = len(words)

    representations = lasagne.init.GlorotUniform().sample(
        (vocabulary_size, args.word_representation_size))

    if args.representation_initializer:
        # This way of creating the dictionary ignores duplicate words in
        # the representation initializer.
        representation_lookup = dict(
            embedding_utils.load_binary_representations(
                args.representation_initializer, tokens))

        representation_init_count = 0

        for word, meta in words.items():
            if word.lower() in representation_lookup:
                representations[meta.id] = \
                    representation_lookup[word.lower()]

                representation_init_count += 1

        logging.info('Initialized representations from '
                     'pre-learned collection for %d words (%.2f%%).',
                     representation_init_count,
                     (representation_init_count /
                      float(len(words))) * 100.0)

    # Allow GC to clear memory.
    del words
    del tokens

    model_options = {
        'batch_size': args.batch_size,
        'window_size': data_args.window_size,
        'representations_init': representations,
        'regularization_lambda': args.regularization_lambda,
        'training_set': training_set,
        'validation_set': validation_set,
    }

    if args.type == models.LanguageModel:
        model_options.update(
            output_layer_size=num_entities)
    elif args.type == models.VectorSpaceLanguageModel:
        entity_representations = lasagne.init.GlorotUniform().sample(
            (num_entities, args.entity_representation_size))

        model_options.update(
            entity_representations_init=entity_representations,
            num_negative_samples=args.num_negative_samples)

    # Construct neural net.
    model = args.type(**model_options)

    train(model, args.iterations, args.model_output,
          abort_threshold=1e-5,
          early_stopping=False,
          additional_args=[args])

Example 32

Project: csvdedupe
Source File: csvlink.py
View license
    def main(self):

        data_1 = {}
        data_2 = {}
        # import the specified CSV file

        data_1 = csvhelpers.readData(self.input_1, self.field_names_1,
                                     prefix='input_1')
        data_2 = csvhelpers.readData(self.input_2, self.field_names_2,
                                     prefix='input_2')

        # sanity check for provided field names in CSV file
        for field in self.field_names_1:
            if field not in list(data_1.values())[0]:
                raise self.parser.error(
                    "Could not find field '" + field + "' in input")

        for field in self.field_names_2:
            if field not in list(data_2.values())[0]:
                raise self.parser.error(
                    "Could not find field '" + field + "' in input")

        if self.field_names_1 != self.field_names_2:
            for record_id, record in data_2.items():
                remapped_record = {}
                for new_field, old_field in zip(self.field_names_1,
                                                self.field_names_2):
                    remapped_record[new_field] = record[old_field]
                data_2[record_id] = remapped_record

        logging.info('imported %d rows from file 1', len(data_1))
        logging.info('imported %d rows from file 2', len(data_2))

        logging.info('using fields: %s' % [field['field']
                                           for field in self.field_definition])

        # If --skip_training has been selected, and we have a settings cache still
        # persisting from the last run, use it in this next run.
        # __Note:__ if you want to add more training data, don't use skip training
        if self.skip_training and os.path.exists(self.settings_file):

            # Load our deduper from the last training session cache.
            logging.info('reading from previous training cache %s'
                                                          % self.settings_file)
            with open(self.settings_file, 'rb') as f:
                deduper = dedupe.StaticRecordLink(f)
                

            fields = {variable.field for variable in deduper.data_model.primary_fields}
            (nonexact_1,
             nonexact_2,
             exact_pairs) = exact_matches(data_1, data_2, fields)
            
            
        else:
            # # Create a new deduper object and pass our data model to it.
            deduper = dedupe.RecordLink(self.field_definition)

            fields = {variable.field for variable in deduper.data_model.primary_fields}
            (nonexact_1,
             nonexact_2,
             exact_pairs) = exact_matches(data_1, data_2, fields)

            # Set up our data sample
            logging.info('taking a sample of %d possible pairs', self.sample_size)
            deduper.sample(nonexact_1, nonexact_2, self.sample_size)

            # Perform standard training procedures
            self.dedupe_training(deduper)

        # ## Blocking

        logging.info('blocking...')

        # ## Clustering

        # Find the threshold that will maximize a weighted average of our precision and recall. 
        # When we set the recall weight to 2, we are saying we care twice as much
        # about recall as we do precision.
        #
        # If we had more data, we would not pass in all the blocked data into
        # this function but a representative sample.

        logging.info('finding a good threshold with a recall_weight of %s' %
                     self.recall_weight)
        threshold = deduper.threshold(data_1, data_2,
                                      recall_weight=self.recall_weight)

        # `duplicateClusters` will return sets of record IDs that dedupe
        # believes are all referring to the same entity.

        logging.info('clustering...')
        clustered_dupes = deduper.match(data_1, data_2, threshold)

        clustered_dupes.extend(exact_pairs)

        logging.info('# duplicate sets %s' % len(clustered_dupes))

        write_function = csvhelpers.writeLinkedResults
        # write out our results

        if self.output_file:
            if sys.version < '3' :
                with open(self.output_file, 'wb') as output_file:
                    write_function(clustered_dupes, self.input_1, self.input_2,
                                   output_file, self.inner_join)
            else :
                with open(self.output_file, 'w') as output_file:
                    write_function(clustered_dupes, self.input_1, self.input_2,
                                   output_file, self.inner_join)
        else:
            write_function(clustered_dupes, self.input_1, self.input_2,
                           sys.stdout, self.inner_join)

Example 33

Project: dcos
Source File: test_3dt.py
View license
@retrying.retry(wait_fixed=2000, stop_max_delay=LATENCY * 1000)
def test_3dt_bundle_download_and_extract(cluster):
    """
    test bundle download and validate zip file
    """

    bundles = _get_bundle_list(cluster)
    assert bundles

    expected_common_files = ['dmesg-0.output.gz', 'opt/mesosphere/active.buildinfo.full.json.gz', '3dt-health.json']

    # these files are expected to be in archive for a master host
    expected_master_files = ['dcos-mesos-master.service.gz'] + expected_common_files

    # for agent host
    expected_agent_files = ['dcos-mesos-slave.service.gz'] + expected_common_files

    # for public agent host
    expected_public_agent_files = ['dcos-mesos-slave-public.service.gz'] + expected_common_files

    def _read_from_zip(z: zipfile.ZipFile, item: str, to_json=True):
        # raises KeyError if item is not in zipfile.
        item_content = z.read(item).decode()

        if to_json:
            # raises ValueError if cannot deserialize item_content.
            return json.loads(item_content)

        return item_content

    def _get_3dt_health(z: zipfile.ZipFile, item: str):
        # try to load 3dt health report and validate the report is for this host
        try:
            _health_report = _read_from_zip(z, item)
        except KeyError:
            # we did not find a key in archive, let's take a look at items in archive and try to read
            # diagnostics logs.

            # namelist() gets a list of all items in a zip archive.
            logging.info(z.namelist())

            # summaryErrorsReport.txt and summaryReport.txt are diagnostic job log files.
            for log in ('summaryErrorsReport.txt', 'summaryReport.txt'):
                try:
                    log_data = _read_from_zip(z, log, to_json=False)
                    logging.info("{}:\n{}".format(log, log_data))
                except KeyError:
                    logging.info("Could not read {}".format(log))
            raise

        except ValueError:
            logging.info("Could not deserialize 3dt-health")
            raise

        return _health_report

    with tempfile.TemporaryDirectory() as tmp_dir:
        download_base_url = '/system/health/v1/report/diagnostics/serve'
        for bundle in bundles:
            bundle_full_location = os.path.join(tmp_dir, bundle)
            with open(bundle_full_location, 'wb') as f:
                r = cluster.get(path=os.path.join(download_base_url, bundle), stream=True)
                for chunk in r.iter_content(1024):
                    f.write(chunk)

            # validate bundle zip file.
            assert zipfile.is_zipfile(bundle_full_location)
            z = zipfile.ZipFile(bundle_full_location)

            # get a list of all files in a zip archive.
            archived_items = z.namelist()

            # make sure all required log files for master node are in place.
            for master_ip in cluster.masters:
                master_folder = master_ip + '_master/'

                # try to load 3dt health report and validate the report is for this host
                health_report = _get_3dt_health(z, master_folder + '3dt-health.json')
                assert 'ip' in health_report
                assert health_report['ip'] == master_ip

                # make sure systemd unit output is correct and does not contain error message
                gzipped_unit_output = z.open(master_folder + 'dcos-mesos-master.service.gz')
                verify_unit_response(gzipped_unit_output)

                for expected_master_file in expected_master_files:
                    expected_file = master_folder + expected_master_file
                    assert expected_file in archived_items, 'expecting {} in {}'.format(expected_file, archived_items)

            # make sure all required log files for agent node are in place.
            for slave_ip in cluster.slaves:
                agent_folder = slave_ip + '_agent/'

                # try to load 3dt health report and validate the report is for this host
                health_report = _get_3dt_health(z, agent_folder + '3dt-health.json')
                assert 'ip' in health_report
                assert health_report['ip'] == slave_ip

                # make sure systemd unit output is correct and does not contain error message
                gzipped_unit_output = z.open(agent_folder + 'dcos-mesos-slave.service.gz')
                verify_unit_response(gzipped_unit_output)

                for expected_agent_file in expected_agent_files:
                    expected_file = agent_folder + expected_agent_file
                    assert expected_file in archived_items, 'expecting {} in {}'.format(expected_file, archived_items)

            # make sure all required log files for public agent node are in place.
            for public_slave_ip in cluster.public_slaves:
                agent_public_folder = public_slave_ip + '_agent_public/'

                # try to load 3dt health report and validate the report is for this host
                health_report = _get_3dt_health(z, agent_public_folder + '3dt-health.json')
                assert 'ip' in health_report
                assert health_report['ip'] == public_slave_ip

                # make sure systemd unit output is correct and does not contain error message
                gzipped_unit_output = z.open(agent_public_folder + 'dcos-mesos-slave-public.service.gz')
                verify_unit_response(gzipped_unit_output)

                for expected_public_agent_file in expected_public_agent_files:
                    expected_file = agent_public_folder + expected_public_agent_file
                    assert expected_file in archived_items, ('expecting {} in {}'.format(expected_file, archived_items))

Example 34

Project: django-extensions
Source File: reset_db.py
View license
    @signalcommand
    def handle(self, *args, **options):
        """
        Resets the database for this project.

        Note: Transaction wrappers are in reverse as a work around for
        autocommit, anybody know how to do this the right way?
        """

        if args:
            raise CommandError("reset_db takes no arguments")

        router = options.get('router')
        dbinfo = settings.DATABASES.get(router)
        if dbinfo is None:
            raise CommandError("Unknown database router %s" % router)

        engine = dbinfo.get('ENGINE').split('.')[-1]

        user = password = database_name = database_host = database_port = ''
        if engine == 'mysql':
            (user, password, database_name, database_host, database_port) = parse_mysql_cnf(dbinfo)

        user = options.get('user') or dbinfo.get('USER') or user
        password = options.get('password') or dbinfo.get('PASSWORD') or password
        owner = options.get('owner') or user

        database_name = options.get('dbname') or dbinfo.get('NAME') or database_name
        if database_name == '':
            raise CommandError("You need to specify DATABASE_NAME in your Django settings file.")

        database_host = dbinfo.get('HOST') or database_host
        database_port = dbinfo.get('PORT') or database_port

        verbosity = int(options.get('verbosity', 1))
        if options.get('interactive'):
            confirm = input("""
You have requested a database reset.
This will IRREVERSIBLY DESTROY
ALL data in the database "%s".
Are you sure you want to do this?

Type 'yes' to continue, or 'no' to cancel: """ % (database_name,))
        else:
            confirm = 'yes'

        if confirm != 'yes':
            print("Reset cancelled.")
            return

        if engine in ('sqlite3', 'spatialite'):
            import os
            try:
                logging.info("Unlinking %s database" % engine)
                os.unlink(database_name)
            except OSError:
                pass

        elif engine in ('mysql',):
            import MySQLdb as Database
            kwargs = {
                'user': user,
                'passwd': password,
            }
            if database_host.startswith('/'):
                kwargs['unix_socket'] = database_host
            else:
                kwargs['host'] = database_host

            if database_port:
                kwargs['port'] = int(database_port)

            connection = Database.connect(**kwargs)
            drop_query = 'DROP DATABASE IF EXISTS `%s`' % database_name
            utf8_support = options.get('no_utf8_support', False) and '' or 'CHARACTER SET utf8'
            create_query = 'CREATE DATABASE `%s` %s' % (database_name, utf8_support)
            logging.info('Executing... "' + drop_query + '"')
            connection.query(drop_query)
            logging.info('Executing... "' + create_query + '"')
            connection.query(create_query)

        elif engine in ('postgresql', 'postgresql_psycopg2', 'postgis'):
            if engine == 'postgresql' and django.VERSION < (1, 9):
                import psycopg as Database  # NOQA
            elif engine in ('postgresql', 'postgresql_psycopg2', 'postgis'):
                import psycopg2 as Database  # NOQA

            conn_params = {'database': 'template1'}
            if user:
                conn_params['user'] = user
            if password:
                conn_params['password'] = password
            if database_host:
                conn_params['host'] = database_host
            if database_port:
                conn_params['port'] = database_port

            connection = Database.connect(**conn_params)
            connection.set_isolation_level(0)  # autocommit false
            cursor = connection.cursor()

            if options.get('close_sessions'):
                close_sessions_query = """
                    SELECT pg_terminate_backend(pg_stat_activity.pid)
                    FROM pg_stat_activity
                    WHERE pg_stat_activity.datname = '%s';
                """ % database_name
                logging.info('Executing... "' + close_sessions_query.strip() + '"')
                try:
                    cursor.execute(close_sessions_query)
                except Database.ProgrammingError as e:
                    logging.exception("Error: %s" % str(e))

            drop_query = "DROP DATABASE \"%s\";" % database_name
            logging.info('Executing... "' + drop_query + '"')
            try:
                cursor.execute(drop_query)
            except Database.ProgrammingError as e:
                logging.exception("Error: %s" % str(e))

            create_query = "CREATE DATABASE \"%s\"" % database_name
            if owner:
                create_query += " WITH OWNER = \"%s\" " % owner
            create_query += " ENCODING = 'UTF8'"

            if engine == 'postgis' and django.VERSION < (1, 9):
                # For PostGIS 1.5, fetch template name if it exists
                from django.contrib.gis.db.backends.postgis.base import DatabaseWrapper
                postgis_template = DatabaseWrapper(dbinfo).template_postgis
                if postgis_template is not None:
                    create_query += ' TEMPLATE = %s' % postgis_template

            if settings.DEFAULT_TABLESPACE:
                create_query += ' TABLESPACE = %s;' % settings.DEFAULT_TABLESPACE
            else:
                create_query += ';'

            logging.info('Executing... "' + create_query + '"')
            cursor.execute(create_query)

        else:
            raise CommandError("Unknown database engine %s" % engine)

        if verbosity >= 2 or options.get('interactive'):
            print("Reset successful.")

Example 35

View license
    def get_results(self):
        """Execute API calls to the timeseries data and tweet data we need for analysis. Perform analysis
        as we go because we often need results for next steps."""
        ######################
        # (1) Get the timeline
        ######################
        logging.info("retrieving timeline counts")
        results_timeseries = Results( self.user
            , self.password
            , self.stream_url
            , self.options.paged
            , self.options.output_file_path
            , pt_filter=self.options.filter
            , max_results=int(self.options.max)
            , start=self.options.start
            , end=self.options.end
            , count_bucket=self.options.count_bucket
            , show_query=self.options.query
            , search_v2=self.options.search_v2
            )
        # sort by date
        res_timeseries = sorted(results_timeseries.get_time_series(), key = itemgetter(0))
        # calculate total time interval span
        time_min_date = min(res_timeseries, key = itemgetter(2))[2]
        time_max_date = max(res_timeseries, key = itemgetter(2))[2]
        time_min = float(calendar.timegm(time_min_date.timetuple()))
        time_max = float(calendar.timegm(time_max_date.timetuple()))
        time_span = time_max - time_min
        logging.debug("time_min = {}, time_max = {}, time_span = {}".format(time_min, time_max, time_span))
        # create a simple object to hold our data 
        ts = TimeSeries()
        ts.dates = []
        ts.x = []
        ts.counts = []
        # load and format data
        for i in res_timeseries:
            ts.dates.append(i[2])
            ts.counts.append(float(i[1]))
            # create a independent variable in interval [0.0,1.0]
            ts.x.append((calendar.timegm(datetime.datetime.strptime(i[0], DATE_FMT).timetuple()) - time_min)/time_span)
        logging.info("read {} time items from search API".format(len(ts.dates)))
        if len(ts.dates) < 35:
            logging.warn("peak detection with with fewer than ~35 points is unreliable!")
        logging.debug('dates: ' + ','.join(map(str, ts.dates[:10])) + "...")
        logging.debug('counts: ' + ','.join(map(str, ts.counts[:10])) + "...")
        logging.debug('indep var: ' + ','.join(map(str, ts.x[:10])) + "...")
        ######################
        # (1.1) Get a second timeline?
        ######################
        if self.options.second_filter is not None:
            logging.info("retrieving second timeline counts")
            results_timeseries = Results( self.user
                , self.password
                , self.stream_url
                , self.options.paged
                , self.options.output_file_path
                , pt_filter=self.options.second_filter
                , max_results=int(self.options.max)
                , start=self.options.start
                , end=self.options.end
                , count_bucket=self.options.count_bucket
                , show_query=self.options.query
                , search_v2=self.options.search_v2
                )
            # sort by date
            second_res_timeseries = sorted(results_timeseries.get_time_series(), key = itemgetter(0))
            if len(second_res_timeseries) != len(res_timeseries):
                logging.error("time series of different sizes not allowed")
            else:
                ts.second_counts = []
                # load and format data
                for i in second_res_timeseries:
                    ts.second_counts.append(float(i[1]))
                logging.info("read {} time items from search API".format(len(ts.second_counts)))
                logging.debug('second counts: ' + ','.join(map(str, ts.second_counts[:10])) + "...")
        ######################
        # (2) Detrend and remove prominent period
        ######################
        logging.info("detrending timeline counts")
        no_trend = signal.detrend(np.array(ts.counts))
        # determine period of data
        df = (ts.dates[1] - ts.dates[0]).total_seconds()
        if df == 86400:
            # day counts, average over week
            n_buckets = 7
            n_avgs = {i:[] for i in range(n_buckets)}
            for t,c in zip(ts.dates, no_trend):
                n_avgs[t.weekday()].append(c)
        elif df == 3600:
            # hour counts, average over day
            n_buckets = 24
            n_avgs = {i:[] for i in range(n_buckets)}
            for t,c in zip(ts.dates, no_trend):
                n_avgs[t.hour].append(c)
        elif df == 60:
            # minute counts; average over day
            n_buckets = 24*60
            n_avgs = {i:[] for i in range(n_buckets)}
            for t,c in zip(ts.dates, no_trend):
                n_avgs[t.minute].append(c)
        else:
            sys.stderr.write("Weird interval problem! Exiting.\n")
            logging.error("Weird interval problem! Exiting.\n")
            sys.exit()
        logging.info("averaging over periods of {} buckets".format(n_buckets))
        # remove upper outliers from averages 
        df_avg_all = {i:np.average(n_avgs[i]) for i in range(n_buckets)}
        logging.debug("bucket averages: {}".format(','.join(map(str, [df_avg_all[i] for i in df_avg_all]))))
        n_avgs_remove_outliers = {i: [j for j in n_avgs[i] 
            if  abs(j - df_avg_all[i])/df_avg_all[i] < (1. + OUTLIER_FRAC) ]
            for i in range(n_buckets)}
        df_avg = {i:np.average(n_avgs_remove_outliers[i]) for i in range(n_buckets)}
        logging.debug("bucket averages w/o outliers: {}".format(','.join(map(str, [df_avg[i] for i in df_avg]))))

        # flatten cycle
        ts.counts_no_cycle_trend = np.array([no_trend[i] - df_avg[ts.dates[i].hour] for i in range(len(ts.counts))])
        logging.debug('no trend: ' + ','.join(map(str, ts.counts_no_cycle_trend[:10])) + "...")

        ######################
        # (3) Moving average 
        ######################
        ts.moving = np.convolve(ts.counts, np.ones((N_MOVING,))/N_MOVING, mode='valid')
        logging.debug('moving ({}): '.format(N_MOVING) + ','.join(map(str, ts.moving[:10])) + "...")

        ######################
        # (4) Peak detection
        ######################
        peakind = signal.find_peaks_cwt(ts.counts_no_cycle_trend, np.arange(MIN_PEAK_WIDTH, MAX_PEAK_WIDTH), min_snr = MIN_SNR)
        n_peaks = min(MAX_N_PEAKS, len(peakind))
        logging.debug('peaks ({}): '.format(n_peaks) + ','.join(map(str, peakind)))
        logging.debug('peaks ({}): '.format(n_peaks) + ','.join(map(str, [ts.dates[i] for i in peakind])))
        
        # top peaks determined by peak volume, better way?
        # peak detector algorithm:
        #      * middle of peak (of unknown width)
        #      * finds peaks up to MAX_PEAK_WIDTH wide
        #
        #   algorithm for geting peak start, peak and end parameters:
        #      find max, find fwhm, 
        #      find start, step past peak, keep track of volume and peak height, 
        #      stop at end of period or when timeseries turns upward
    
        peaks = []
        for i in peakind:
            # find the first max in the possible window
            i_start = max(0, i - SEARCH_PEAK_WIDTH)
            i_finish = min(len(ts.counts) - 1, i + SEARCH_PEAK_WIDTH)
            p_max = max(ts.counts[i_start:i_finish])
            h_max = p_max/2.
            # i_max not center
            i_max = i_start + ts.counts[i_start:i_finish].index(p_max)
            i_start, i_finish = i_max, i_max
            # start at peak, and go back and forward to find start and end
            while i_start >= 1:
                if (ts.counts[i_start - 1] <= h_max or 
                        ts.counts[i_start - 1] >= ts.counts[i_start] or
                        i_start - 1 <= 0):
                    break
                i_start -= 1
            while i_finish < len(ts.counts) - 1:
                if (ts.counts[i_finish + 1] <= h_max or
                        ts.counts[i_finish + 1] >= ts.counts[i_finish] or
                        i_finish + 1 >= len(ts.counts)):
                    break
                i_finish += 1
            # i is center of peak so balance window
            delta_i = max(1, i - i_start)
            if i_finish - i > delta_i:
                delta_i = i_finish - i
            # final est of start and finish
            i_finish = min(len(ts.counts) - 1, i + delta_i)
            i_start = max(0, i - delta_i)
            p_volume = sum(ts.counts[i_start:i_finish])
            peaks.append([ i , p_volume , (i, i_start, i_max, i_finish
                                            , h_max  , p_max, p_volume
                                            , ts.dates[i_start], ts.dates[i_max], ts.dates[i_finish])])
        # top n_peaks by volume
        top_peaks = sorted(peaks, key = itemgetter(1))[-n_peaks:]
        # re-sort peaks by date
        ts.top_peaks = sorted(top_peaks, key = itemgetter(0))
        logging.debug('top peaks ({}): '.format(len(ts.top_peaks)) + ','.join(map(str, ts.top_peaks[:4])) + "...")
    
        ######################
        # (5) high/low frequency 
        ######################
        ts.cycle, ts.trend = sm.tsa.filters.hpfilter(np.array(ts.counts))
        logging.debug('cycle: ' + ','.join(map(str, ts.cycle[:10])) + "...")
        logging.debug('trend: ' + ','.join(map(str, ts.trend[:10])) + "...")
    
        ######################
        # (6) n-grams for top peaks
        ######################
        ts.topics = []
        if self.options.get_topics:
            logging.info("retrieving tweets for peak topics")
            for a in ts.top_peaks:
                # start at peak
                ds = datetime.datetime.strftime(a[2][8], DATE_FMT2)
                # estimate how long to get TWEET_SAMPLE tweets
                # a[1][5] is max tweets per period
                if a[2][5] > 0:
                    est_periods = float(TWEET_SAMPLE)/a[2][5]
                else:
                    logging.warn("peak with zero max tweets ({}), setting est_periods to 1".format(a))
                    est_periods = 1
                # df comes from above, in seconds
                # time resolution is hours
                est_time = max(int(est_periods * df), 60)
                logging.debug("est_periods={}, est_time={}".format(est_periods, est_time))
                #
                if a[2][8] + datetime.timedelta(seconds=est_time) < a[2][9]:
                    de = datetime.datetime.strftime(a[2][8] + datetime.timedelta(seconds=est_time), DATE_FMT2)
                elif a[2][8] < a[2][9]:
                    de = datetime.datetime.strftime(a[2][9], DATE_FMT2)
                else:
                    de = datetime.datetime.strftime(a[2][8] + datetime.timedelta(seconds=60), DATE_FMT2)
                logging.info("retreive data for peak index={} in date range [{},{}]".format(a[0], ds, de))
                res = Results(
                    self.user
                    , self.password
                    , self.stream_url
                    , self.options.paged
                    , self.options.output_file_path
                    , pt_filter=self.options.filter
                    , max_results=int(self.options.max)
                    , start=ds
                    , end=de
                    , count_bucket=None
                    , show_query=self.options.query
                    , search_v2=self.options.search_v2
                    , hard_max = TWEET_SAMPLE
                    )
                logging.info("retrieved {} records".format(len(res)))
                n_grams_counts = list(res.get_top_grams(n=self.token_list_size))
                ts.topics.append(n_grams_counts)
                logging.debug('n_grams for peak index={}: '.format(a[0]) + ','.join(
                    map(str, [i[4].encode("utf-8","ignore") for i in n_grams_counts][:10])) + "...")
        return ts

Example 36

Project: MyLife
Source File: dailymail.py
View license
	def send(self, is_intro_email=False, force_send=False, date=None):

		try:
			now = datetime.datetime.now()
			settings = Settings.get()

			if is_intro_email:
				current_time = now
				logging.info('Sending intro email to %s' % settings.email_address)
			else:
				current_time, id, name, offset = self.get_time_in_timezone(settings)

				if current_time.hour != settings.email_hour and not force_send:
					logging.info('Current time for %s is %s, not sending email now, will send at %02d:00' % (name, current_time, settings.email_hour))
					return


			if date and force_send:
				today = date #Allow overriding this stuff
			else:
				today = current_time.date()
			
			if self.check_if_intro_email_sent_today(today) and not force_send:
				logging.info('Already sent the intro email today, skipping this email for now')	
				return	


			#Check if we've already sent an email
			slug = Slug.query(Slug.date == today).get()

			if slug and not force_send:
				msg = 'Tried to send another email on %s, already sent %s' % (date, slug.slug)
				log_error('Tried to send email again', msg)
				raise Exception(msg)

			if not slug:
				slug_id = self.get_slug()

				slug = Slug(slug=slug_id, date=today)

				slug.put()

			subject = "It's %s, %s %s - How did your day go?" % (today.strftime('%A'), today.strftime("%b"), today.day)
			app_id = app_identity.get_application_id()

			sender = "MyLife <post+%[email protected]%s.appspotmail.com>" % (slug.slug, app_id)

			message = mail.EmailMessage(sender=sender, subject=subject)

			message.to = settings.email_address
			if not settings.email_address:
				log_error('Missing To in daily email', 'There is no configured email address in your settings. Please visit your settings page to configure it so we can send you your daily email.')
				return

			message.body = """
Just reply to this email with your entry.

OLD_POST

	""".replace('APP_ID', app_id)

			message.html = """
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" "http://www.=
w3.org/TR/html4/loose.dtd">
<html>
	<head>
		<title></title>
	</head>
	<body>
	Just reply to this email with your entry.
	<br>
	<br>
OLD_POST
	</body>
</html>
""".replace('APP_ID', app_id)

			if is_intro_email:
				intro_msg = "Welcome to MyLife. We've sent you this email immediately so you can try the system out. In the future we will email you once a day and include an old post in each email. You can configure when you get your email and which email address we should use on the settings page."
				message.html = message.html.replace('OLD_POST', intro_msg + '<br><br>')
				message.body = message.body.replace('OLD_POST', intro_msg + '\r\n\r\n')
			else:
				#Lets try to put in an old post...
				old_post, old_type = self.get_old_post(today)

				if old_post and settings.include_old_post_in_entry:
					old_post_text = 'Remember this? One %s ago you wrote:\r\n\r\n' % old_type
					old_post_text += old_post.text.rstrip() + '\r\n\r\n'

					message.body = re.sub(r'OLD_POST\r?\n', old_post_text, message.body)

					old_post_text = re.sub(r'\r?\n', '<br>', old_post_text)
					message.html = re.sub(r'OLD_POST\r?\n', old_post_text, message.html)
				else:
					message.body = re.sub('OLD_POST\r?\n', '', message.body)
					message.html = re.sub('OLD_POST\r?\n', '', message.html)

			message.send()
			
			if is_intro_email:
				logging.info('Sent intro email')
			else:
				if old_post:
					logging.info('Sent daily email to %s, using old post from %s' % (message.to, old_post.date))
				else:
					logging.info('Sent daily email to %s, could not find old post' % message.to)

			return 'Email sent'
		except:
			log_error('Failed to send daily email', traceback.format_exc(6))
			return 'Failed sending email: %s' % traceback.format_exc(6)

Example 37

Project: GrepBugs
Source File: grepbugs.py
View license
def local_scan(srcdir, repo='none', account='local_scan', project='none', default_branch='none', no_reports=False):
	"""
	Perform a scan of local files
	"""
	# new scan so new scan_id
	scan_id = str(uuid.uuid1())
	clocsql = '/tmp/gb.cloc.' + scan_id + '.sql'
	basedir = os.path.dirname(os.path.abspath(__file__)) + '/' + srcdir.rstrip('/')
	logging.info('Using grep binary ' + grepbin)
	logging.info('Starting local scan with scan id ' + scan_id)

	# get db connection
	if 'mysql' == gbconfig.get('database', 'database'):
		try:
			import MySQLdb
			mysqldb  = MySQLdb.connect(host=gbconfig.get('database', 'host'), user=gbconfig.get('database', 'dbuname'), passwd=gbconfig.get('database', 'dbpword'), db=gbconfig.get('database', 'dbname'))
			mysqlcur = mysqldb.cursor()
		except Exception as e:
			print 'Error connecting to MySQL! See log file for details.'
			logging.debug('Error connecting to MySQL: ' + str(e))
			sys.exit(1)

	try:
		db  = lite.connect(dbfile)
		cur = db.cursor()

	except lite.Error as e:
		print 'Error connecting to db file! See log file for details.'
		logging.debug('Error connecting to db file: ' + str(e))
		sys.exit(1)
	except Exception as e:
		print 'CRITICAL: Unhandled exception occured! Quiters gonna quit! See log file for details.'
		logging.critical('Unhandled exception: ' + str(e))
		sys.exit(1)

	if args.u == True:
		print 'Scanning with existing rules set'
		logging.info('Scanning with existing rules set')
	else:
		# get latest greps
		download_rules()

	# prep db for capturing scan results
	try:
		# clean database
		cur.execute("DROP TABLE IF EXISTS metadata;");
		cur.execute("DROP TABLE IF EXISTS t;");
		cur.execute("VACUUM");

		# update database with new project info
		if 'none' == project:
			project = srcdir

		# query database
		params     = [repo, account, project]
		if 'mysql' == gbconfig.get('database', 'database'):
			mysqlcur.execute("SELECT project_id FROM projects WHERE repo=%s AND account=%s AND project=%s LIMIT 1;", params)
			rows = mysqlcur.fetchall()
		else:
			cur.execute("SELECT project_id FROM projects WHERE repo=? AND account=? AND project=? LIMIT 1;", params)
			rows = cur.fetchall()

		# assume new project by default
		newproject = True

		for row in rows:
			# not so fast, not a new project
			newproject = False
			project_id = row[0]

		if True == newproject:
			project_id = str(uuid.uuid1())
			params     = [project_id, repo, account, project, default_branch]
			if 'mysql' == gbconfig.get('database', 'database'):
				mysqlcur.execute("INSERT INTO projects (project_id, repo, account, project, default_branch) VALUES (%s, %s, %s, %s, %s);", params)
			else:
				cur.execute("INSERT INTO projects (project_id, repo, account, project, default_branch) VALUES (?, ?, ?, ?, ?);", params)

		# update database with new scan info
		params  = [scan_id, project_id]
		if 'mysql' == gbconfig.get('database', 'database'):
			mysqlcur.execute("INSERT INTO scans (scan_id, project_id) VALUES (%s, %s);", params)
			mysqldb.commit()
		else:
			cur.execute("INSERT INTO scans (scan_id, project_id) VALUES (?, ?);", params)
			db.commit()

	except Exception as e:
		print 'CRITICAL: Unhandled exception occured! Quiters gonna quit! See log file for details.'
		logging.critical('Unhandled exception: ' + str(e))
		sys.exit(1)

	# execute cloc to get sql output
	try:
		print 'counting source files...'
		logging.info('Running cloc for sql output.')
		return_code = call(["cloc", "--skip-uniqueness", "--quiet", "--sql=" + clocsql, "--sql-project=" + srcdir, srcdir])
		if 0 != return_code:
			logging.debug('WARNING: cloc did not run normally. return code: ' + str(return_code))

		# run sql script generated by cloc to save output to database
		f = open(clocsql, 'r')
		cur.executescript(f.read())
		db.commit()
		f.close()
		os.remove(clocsql)

	except Exception as e:
		print 'Error executing cloc sql! Aborting scan! See log file for details.'
		logging.debug('Error executing cloc sql (scan aborted). It is possible there were no results from running cloc.: ' + str(e))
		return scan_id

	# query cloc results
	cur.execute("SELECT Language, count(File), SUM(nBlank), SUM(nComment), SUM(nCode) FROM t GROUP BY Language ORDER BY Language;")
	
	rows    = cur.fetchall()
	cloctxt =  '-------------------------------------------------------------------------------' + "\n"
	cloctxt += 'Language                     files          blank        comment           code' + "\n"
	cloctxt += '-------------------------------------------------------------------------------' + "\n"
	
	sum_files   = 0
	sum_blank   = 0
	sum_comment = 0
	sum_code    = 0

	for row in rows:
		cloctxt += '{0:20}  {1:>12}  {2:>13} {3:>14} {4:>14}'.format(str(row[0]), str(row[1]), str(row[2]), str(row[3]), str(row[4])) + "\n"
		sum_files   += row[1]
		sum_blank   += row[2]
		sum_comment += row[3]
		sum_code    += row[4]
	
	cloctxt += '-------------------------------------------------------------------------------' + "\n"
	cloctxt += '{0:20}  {1:>12}  {2:>13} {3:>14} {4:>14}'.format('Sum', str(sum_files), str(sum_blank), str(sum_comment), str(sum_code)) + "\n"
	cloctxt += '-------------------------------------------------------------------------------' + "\n"

	# execute clock again to get txt output
	try:
		params = [cloctxt, scan_id]
		if 'mysql' == gbconfig.get('database', 'database'):
			mysqlcur.execute("UPDATE scans SET date_time=NOW(), cloc_out=%s WHERE scan_id=%s;", params)
			mysqldb.commit()
		else:
			cur.execute("UPDATE scans SET cloc_out=? WHERE scan_id=?;", params)
			db.commit()

	except Exception as e:
		print 'Error saving cloc txt! Aborting scan! See log file for details.'
		logging.debug('Error saving cloc txt (scan aborted): ' + str(e))
		return scan_id

	# load json data
	try:
		logging.info('Reading grep rules from json file.')
		json_file = open(gbfile, "r")
		greps     = json.load(json_file)
		json_file.close()
	except Exception as e:
		print 'CRITICAL: Unhandled exception occured! Quiters gonna quit! See log file for details.'
		logging.critical('Unhandled exception: ' + str(e))
		sys.exit(1)

	# query database
	cur.execute("SELECT DISTINCT Language FROM t ORDER BY Language;")
	rows = cur.fetchall()

	# grep all the bugs and output to file
	print 'grepping for bugs...'
	logging.info('Start grepping for bugs.')

	# get cloc extensions and create extension array
	clocext  = ''
	proc     = subprocess.Popen([clocbin, "--show-ext"], stdout=subprocess.PIPE)
	ext      = proc.communicate()
	extarray = str(ext[0]).split("\n")
	
	# override some extensions
	extarray.append('inc -> PHP')
	
	# loop through languages identified by cloc
	for row in rows:
		count = 0
		# loop through all grep rules for each language identified by cloc
		for i in range(0, len(greps)):
				# if the language matches a language in the gb rules file then do stuff
				if row[0] == greps[i]['language']:

					# get all applicable extensions based on language
					extensions = []
					for ii in range(0, len(extarray)):
						lang = str(extarray[ii]).split("->")
						if len(lang) > 1:							
							if str(lang[1]).strip() == greps[i]['language']:
								extensions.append(str(lang[0]).strip())

					# search with regex, filter by extensions, and capture result
					result = ''
					filter = []

					# build filter by extension
					for e in extensions:
						filter.append('--include=*.' + e)

					try:
						proc   = subprocess.Popen([grepbin, "-n", "-r", "-P"] +  filter + [greps[i]['regex'], srcdir], stdout=subprocess.PIPE)
						result = proc.communicate()

						if len(result[0]):	
							# update database with new results info
							result_id = str(uuid.uuid1())
							params    = [result_id, scan_id, greps[i]['language'], greps[i]['id'], greps[i]['regex'], greps[i]['description']]
							if 'mysql' == gbconfig.get('database', 'database'):
								mysqlcur.execute("INSERT INTO results (result_id, scan_id, language, regex_id, regex_text, description) VALUES (%s, %s, %s, %s, %s, %s);", params)
								mysqldb.commit()
							else:
								cur.execute("INSERT INTO results (result_id, scan_id, language, regex_id, regex_text, description) VALUES (?, ?, ?, ?, ?, ?);", params)
								db.commit()

							perline = str(result[0]).split("\n")
							for r in range(0, len(perline) - 1):
								try:
									rr = str(perline[r]).replace(basedir, '').split(':', 1)
									# update database with new results_detail info
									result_detail_id = str(uuid.uuid1())
									code             = str(rr[1]).split(':', 1)
									params           = [result_detail_id, result_id, rr[0], code[0], str(code[1]).strip()]

									if 'mysql' == gbconfig.get('database', 'database'):
										mysqlcur.execute("INSERT INTO results_detail (result_detail_id, result_id, file, line, code) VALUES (%s, %s, %s, %s, %s);", params)
										mysqldb.commit()
									else:
										cur.execute("INSERT INTO results_detail (result_detail_id, result_id, file, line, code) VALUES (?, ?, ?, ?, ?);", params)
										db.commit()

								except lite.Error, e:
									print 'SQL error! See log file for details.'
									logging.debug('SQL error with params ' + str(params) + ' and error ' + str(e))
								except Exception as e:
									print 'Error parsing result! See log file for details.'
									logging.debug('Error parsing result: ' + str(e))
							
					except Exception as e:
						print 'Error calling grep! See log file for details'
						logging.debug('Error calling grep: ' + str(e))

	params = [project_id]
	if 'mysql' == gbconfig.get('database', 'database'):
		mysqlcur.execute("UPDATE projects SET last_scan=NOW() WHERE project_id=%s;", params)
		mysqldb.commit()
		mysqldb.close()
	else:
		cur.execute("UPDATE projects SET last_scan=datetime('now') WHERE project_id=?;", params)
		db.commit()
		db.close()

	if not no_reports:
		html_report(scan_id)

	return scan_id

Example 38

Project: TADbit
Source File: model_and_analyze.py
View license
def get_options():
    """
    parse option from call

    """
    parser = ArgumentParser(
        usage="%(prog)s [options] [--cfg CONFIG_PATH]",
        formatter_class=lambda prog: HelpFormatter(prog, width=95,
                                                   max_help_position=27))
    glopts = parser.add_argument_group('General arguments')
    taddet = parser.add_argument_group('TAD detection arguments')
    optimo = parser.add_argument_group('Optimization of IMP arguments')
    modelo = parser.add_argument_group('Modeling with optimal IMP arguments')
    descro = parser.add_argument_group('Descriptive, optional arguments')
    analyz = parser.add_argument_group('Output arguments')

    ## Define analysis actions:
    actions = {0  : "do nothing",
               1  : "column filtering",
               2  : "TAD borders",
               3  : "TAD alignment",
               4  : "optimization plot",
               5  : "correlation real/models",
               6  : "z-score plot",
               7  : "constraints",
               8  : "objective function",
               9  : "centroid",
               10 : "consistency",
               11 : "density",
               12 : "contact map",
               13 : "walking angle",
               14 : "persistence length",
               15 : "accessibility",
               16 : "interaction"}

    parser.add_argument('--usage', dest='usage', action="store_true",
                        default=False,
                        help='''show detailed usage documentation, with examples
                        and exit''')
    parser.add_argument('--cfg', dest='cfg', metavar="PATH", action='store',
                      default=None, type=str,
                      help='path to a configuration file with predefined ' +
                      'parameters')
    parser.add_argument('--analyze_only', dest='analyze_only',
                        action='store_true', default=False,
                        help=('load precomputed models in outdir, ' +
                              'skip optimization, modeling'))
    parser.add_argument('--optimize_only', dest='optimize_only', default=False,
                        action='store_true',
                        help='do the optimization of the region and exit')
    parser.add_argument('--tad_only', dest='tad_only', action="store_true",
                        default=False,
                        help='[%(default)s] exit after searching for TADs')
    parser.add_argument('--ncpus', dest='ncpus', metavar="INT", default=1,
                        type=int, help='[%(default)s] Number of CPUs to use')

    #########################################
    # GENERAL
    glopts.add_argument(
        '--root_path', dest='root_path', metavar="PATH",
        default='', type=str,
        help=('path to search for data files (just pass file name' +
              'in "data")'))
    glopts.add_argument('--data', dest='data', metavar="PATH", nargs='+',
                        type=str,
                        help='''path to file(s) with Hi-C data matrix. If many,
                        experiments will be summed up. I.e.: --data
                        replicate_1.txt replicate_2.txt''')
    glopts.add_argument('--xname', dest='xname', metavar="STR", nargs='+',
                        default=[], type=str,
                        help='''[file name] experiment name(s). Use same order
                        as data.''')
    glopts.add_argument('--norm', dest='norm', metavar="PATH", nargs='+',
                        type=str,
                        help='path to file(s) with normalizedHi-C data matrix.')
    glopts.add_argument('--nodiag', dest='nodiag', action='store_true',
                        help='''If the matrix does not contain self interacting
                        bins (only zeroes in the diagonal)''')
    glopts.add_argument('--filt', dest='filt', metavar='INT', default=90,
                        help='''Filter out column with more than a given
                        percentage of zeroes''')
    glopts.add_argument('--crm', dest='crm', metavar="NAME",
                        help='chromosome name')
    glopts.add_argument('--beg', dest='beg', metavar="INT", type=float,
                        default=None,
                        help='genomic coordinate from which to start modeling')
    glopts.add_argument('--end', dest='end', metavar="INT", type=float,
                        help='genomic coordinate where to end modeling')
    glopts.add_argument('--res', dest='res', metavar="INT", type=int,
                        help='resolution of the Hi-C experiment')
    glopts.add_argument('--outdir', dest='outdir', metavar="PATH",
                        default=None,
                        help='out directory for results')

    #########################################
    # TADs
    taddet.add_argument('--tad', dest='tad', action="store_true", default=False,
                        help='[%(default)s] search for TADs in experiments')
    taddet.add_argument('--centromere', dest='centromere', action="store_true",
                        default=False,
                        help='[%(default)s] search for centromeric region')
    taddet.add_argument('--group', dest='group', nargs='+', type=int,
                        default=0, metavar='INT',
                        help='''[all together] How to group Hi-C experiments for
                        the detection of TAD borders. I.e.: "--exp_group 2 2 1"
                        first 2 experiments used together, next 2 also, and last
                        alone (batch_mode option used)''')

    #########################################
    # MODELING
    modelo.add_argument('--nmodels_mod', dest='nmodels_mod', metavar="INT",
                        default='5000', type=int,
                        help=('[%(default)s] number of models to generate for' +
                              ' modeling'))
    modelo.add_argument('--nkeep_mod', dest='nkeep_mod', metavar="INT",
                        default='1000', type=int,
                        help=('[%(default)s] number of models to keep for ' +
                        'modeling'))

    #########################################
    # OPTIMIZATION
    optimo.add_argument('--maxdist', action='store', metavar="LIST",
                        default='400', dest='maxdist',
                        help='range of numbers for maxdist' +
                        ', i.e. 400:1000:100 -- or just a number')
    optimo.add_argument('--upfreq', dest='upfreq', metavar="LIST",
                        default='0',
                        help='range of numbers for upfreq' +
                        ', i.e. 0:1.2:0.3 --  or just a number')
    optimo.add_argument('--lowfreq', dest='lowfreq', metavar="LIST",
                        default='0',
                        help='range of numbers for lowfreq' +
                        ', i.e. -1.2:0:0.3 -- or just a number')
    optimo.add_argument('--scale', dest='scale', metavar="LIST",
                        default="0.01",
                        help='[%(default)s] range of numbers to be test as ' +
                        'optimal scale value, i.e. 0.005:0.01:0.001 -- Can ' +
                        'also pass only one number')
    optimo.add_argument('--dcutoff', dest='dcutoff', metavar="LIST",
                        default="2",
                        help='[%(default)s] range of numbers to be test as ' +
                        'optimal distance cutoff parameter (distance, in ' +
                        'number of beads, from which to consider 2 beads as ' +
                        'being close), i.e. 1:5:0.5 -- Can also pass only one' +
                        ' number')
    optimo.add_argument('--nmodels_opt', dest='nmodels_opt', metavar="INT",
                        default='500', type=int,
                        help='[%(default)s] number of models to generate for ' +
                        'optimization')
    optimo.add_argument('--nkeep_opt', dest='nkeep_opt', metavar="INT",
                        default='100', type=int,
                        help='[%(default)s] number of models to keep for ' +
                        'optimization')
    optimo.add_argument('--force_opt', dest='optimize_from_scratch',
                        action="store_true", default=False,
                        help='''[%(default)s] do not take into account previous
                        optimizations. Usefull for running in parallel in a
                        cluster for example.''')

    #########################################
    # DESCRIPTION
    descro.add_argument('--species', dest='species', metavar="STRING",
                        default='UNKNOWN',
                        help='species name, with no spaces, i.e.: homo_sapiens')
    descro.add_argument('--cell', dest='cell', metavar="STRING",
                        help='cell type name')
    descro.add_argument('--exp_type', dest='exp_type', metavar="STRING",
                        help='experiment type name (i.e.: Hi-C)')
    descro.add_argument('--assembly', dest='assembly', metavar="STRING",
                        default=None,
                        help='''NCBI ID of the original assembly
                        (i.e.: NCBI36 for human)''')
    descro.add_argument('--enzyme', dest='enzyme', metavar="STRING",
                        default=None,
                        help='''name of the enzyme used to digest
                        chromatin (i.e. HindIII)''')
    descro.add_argument('--identifier', dest='identifier', metavar="STRING",
                        default=None,
                        help='''NCBI identifier of the experiment''')
    descro.add_argument('--project', dest='project', metavar="STRING",
                        default=None,
                        help='''project name''')


    #########################################
    # OUTPUT
    analyz.add_argument('--analyze', dest='analyze', nargs='+',
                        choices=range(len(actions)), type=int,
                        default=range(2, len(actions)), metavar='INT',
                        help=('''[%s] list of numbers representing the
                        analysis to be done. Choose between:
                        %s''' % (' '.join([str(i) for i in range(
                                  2, len(actions))]),
                                 '\n'.join(['%s) %s' % (k, actions[k])
                                            for k in actions]))))
    analyz.add_argument('--not_write_cmm', dest='not_write_cmm',
                        default=False, action='store_true',
                        help='''[%(default)s] do not generate cmm files for each
                        model (Chimera input)''')
    analyz.add_argument('--not_write_xyz', dest='not_write_xyz',
                        default=False, action='store_true',
                        help='''[%(default)s] do not generate xyz files for each
                        model (3D coordinates)''')
    analyz.add_argument('--not_write_json', dest='not_write_json',
                        default=False, action='store_true',
                        help='''[%(default)s] do not generate json file.''')

    parser.add_argument_group(optimo)
    parser.add_argument_group(modelo)
    parser.add_argument_group(descro)
    parser.add_argument_group(analyz)
    opts = parser.parse_args()


    if opts.usage:
        print __doc__
        exit()

    log = '\tSummary of arguments:\n'
    # merger opts with CFG file and write summary
    args = reduce(lambda x, y: x + y, [i.strip('-').split('=')
                                       for i in sys.argv])
    new_opts = {}
    if opts.cfg:
        for line in open(opts.cfg):
            if not '=' in line:
                continue
            if line.startswith('#'):
                continue
            key, value = line.split('#')[0].strip().split('=')
            key = key.strip()
            value = value.strip()
            if value == 'True':
                value = True
            elif value == 'False':
                value = False
            elif key in ['data', 'norm', 'xname', 'group', 'analyze']:
                new_opts.setdefault(key, []).extend(value.split())
                continue
            new_opts[key] = value
    # bad key in configuration file
    opts.__dict__['description'] = {}
    for bad_k in set(new_opts.keys()) - set(opts.__dict__.keys()):
        sys.stderr.write('WARNING: parameter "%s" not recognized (used as description)\n' % (bad_k))
        try:
            opts.__dict__['description'][bad_k] = int(new_opts[bad_k])
        except ValueError:
            opts.__dict__['description'][bad_k] = new_opts[bad_k]
    for key in sorted(opts.__dict__.keys()):
        if key in args:
            log += '  * Command setting   %13s to %s\n' % (
                key, opts.__dict__[key])
        elif key in new_opts:
            opts.__dict__[key] = new_opts[key]
            log += '  - Config. setting   %13s to %s\n' % (
                key, new_opts[key])
        else:
            log += '  o Default setting   %13s to %s\n' % (
                key, opts.__dict__[key])

    # rename analysis actions
    for i, j in enumerate(opts.analyze):
        opts.analyze[i] = actions[int(j)]

    if not opts.data and not opts.norm:
        sys.stderr.write('MISSING data')
        exit(parser.print_help())
    if not opts.outdir:
        sys.stderr.write('MISSING outdir')
        exit(parser.print_help())
    if not opts.crm:
        sys.stderr.write('MISSING crm NAME')
        exit(parser.print_help())
    if not opts.res:
        sys.stderr.write('MISSING resolution')
        exit(parser.print_help())
    if not opts.analyze_only:
        if not opts.maxdist:
            sys.stderr.write('MISSING maxdist')
            exit(parser.print_help())
        if not opts.lowfreq:
            sys.stderr.write('MISSING lowfreq')
            exit(parser.print_help())
        if not opts.upfreq:
            sys.stderr.write('MISSING upfreq')
            exit(parser.print_help())

    if not opts.beg and not opts.tad_only:
        sys.stderr.write('WARNING: no begin coordinate given all')
    if not opts.end and not opts.tad_only:
        sys.stderr.write('WARNING: no begin coordinate given all')

    # groups for TAD detection
    if not opts.data:
        opts.data = [None] * len(opts.norm)
    else:
        opts.norm = [None] * len(opts.data)
    if not opts.group:
        opts.group = [len(opts.data)]
    else:
        opts.group = [int(i) for i in opts.group]

    if sum(opts.group) > len(opts.data):
        logging.info('ERROR: Number of experiments in groups larger than ' +
                     'the number of Hi-C data files given.')
        exit()

    # this options should stay as this now
    # opts.scale = '0.01'

    # switch to number
    opts.nmodels_mod = int(opts.nmodels_mod)
    opts.nkeep_mod   = int(opts.nkeep_mod  )
    opts.nmodels_opt = int(opts.nmodels_opt)
    opts.nkeep_opt   = int(opts.nkeep_opt  )
    opts.ncpus       = int(opts.ncpus      )
    opts.res         = int(opts.res        )

    # TODO: UNDER TEST
    opts.container   = None #['cylinder', 1000, 5000, 100]

    # do the division to bins
    if not opts.tad_only:
        try:
            opts.beg = int(float(opts.beg) / opts.res)
            opts.end = int(float(opts.end) / opts.res)
            if opts.end - opts.beg <= 2:
                raise Exception('"beg" and "end" parameter should be given in ' +
                                'genomic coordinates, not bin')
        except TypeError:
            pass

    # Create out-directory
    name = '{0}_{1}_{2}'.format(opts.crm, opts.beg, opts.end)
    if not os.path.exists(os.path.join(opts.outdir, name)):
        os.makedirs(os.path.join(opts.outdir, name))

    # write version log
    if not os.path.exists(os.path.join(opts.outdir,
                                       'TADbit_and_dependencies_versions.log')):
        vlog = os.path.join(opts.outdir, 'TADbit_and_dependencies_versions.log')
        vlog = open(vlog, 'w')
        vlog.write(get_dependencies_version())
        vlog.close()

    # write log
    if opts.optimize_only:
        log_format = '[OPTIMIZATION {}_{}_{}_{}_{}]   %(message)s'.format(
            opts.maxdist, opts.upfreq, opts.lowfreq, opts.scale, opts.dcutoff)
    elif opts.analyze_only:
        log_format = '[ANALYZE]   %(message)s'
    elif opts.tad_only:
        log_format = '[TAD]   %(message)s'
    else:
        log_format = '[DEFAULT]   %(message)s'
    try:
        logging.basicConfig(filename=os.path.join(opts.outdir, name, name + '.log'),
                            level=logging.INFO, format=log_format)
    except IOError:
        logging.basicConfig(filename=os.path.join(opts.outdir, name, name + '.log2'),
                            level=logging.INFO, format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(('\n' + log_format.replace('   %(message)s', '')
                  ).join(log.split('\n')))

    # update path to Hi-C data adding root directory
    if opts.root_path and opts.data[0]:
        for i in xrange(len(opts.data)):
            logging.info(os.path.join(opts.root_path, opts.data[i]))
            opts.data[i] = os.path.join(opts.root_path, opts.data[i])

    # update path to Hi-C norm adding root directory
    if opts.root_path and opts.norm[0]:
        for i in xrange(len(opts.norm)):
            logging.info(os.path.join(opts.root_path, opts.norm[i]))
            opts.norm[i] = os.path.join(opts.root_path, opts.norm[i])

    return opts

Example 39

Project: ganeti
Source File: bootstrap.py
View license
def MasterFailover(no_voting=False):
  """Failover the master node.

  This checks that we are not already the master, and will cause the
  current master to cease being master, and the non-master to become
  new master.

  Note: The call to MasterFailover from lib/client/gnt_cluster.py checks that
  a majority of nodes are healthy and responding before calling this. If this
  function is called from somewhere else, the caller should also verify that a
  majority of nodes are healthy.

  @type no_voting: boolean
  @param no_voting: force the operation without remote nodes agreement
                      (dangerous)

  @returns: the pair of an exit code and warnings to display
  """
  sstore = ssconf.SimpleStore()

  old_master, new_master = ssconf.GetMasterAndMyself(sstore)
  node_names = sstore.GetNodeList()
  mc_list = sstore.GetMasterCandidates()

  if old_master == new_master:
    raise errors.OpPrereqError("This commands must be run on the node"
                               " where you want the new master to be."
                               " %s is already the master" %
                               old_master, errors.ECODE_INVAL)

  if new_master not in mc_list:
    mc_no_master = [name for name in mc_list if name != old_master]
    raise errors.OpPrereqError("This node is not among the nodes marked"
                               " as master candidates. Only these nodes"
                               " can become masters. Current list of"
                               " master candidates is:\n"
                               "%s" % ("\n".join(mc_no_master)),
                               errors.ECODE_STATE)

  if not no_voting:
    vote_list = _GatherMasterVotes(node_names)
    if vote_list:
      voted_master = vote_list[0][0]
      if voted_master != old_master:
        raise errors.OpPrereqError("I have a wrong configuration, I believe"
                                   " the master is %s but the other nodes"
                                   " voted %s. Please resync the configuration"
                                   " of this node." %
                                   (old_master, voted_master),
                                   errors.ECODE_STATE)
  # end checks

  rcode = 0
  warnings = []

  logging.info("Setting master to %s, old master: %s", new_master, old_master)

  try:
    # Forcefully start WConfd so that we can access the configuration
    result = utils.RunCmd([pathutils.DAEMON_UTIL,
                           "start", constants.WCONFD, "--force-node",
                           "--no-voting", "--yes-do-it"])
    if result.failed:
      raise errors.OpPrereqError("Could not start the configuration daemon,"
                                 " command %s had exitcode %s and error %s" %
                                 (result.cmd, result.exit_code, result.output),
                                 errors.ECODE_NOENT)

    # instantiate a real config writer, as we now know we have the
    # configuration data
    livelock = utils.livelock.LiveLock("bootstrap_failover")
    cfg = config.GetConfig(None, livelock, accept_foreign=True)

    old_master_node = cfg.GetNodeInfoByName(old_master)
    if old_master_node is None:
      raise errors.OpPrereqError("Could not find old master node '%s' in"
                                 " cluster configuration." % old_master,
                                 errors.ECODE_NOENT)

    cluster_info = cfg.GetClusterInfo()
    new_master_node = cfg.GetNodeInfoByName(new_master)
    if new_master_node is None:
      raise errors.OpPrereqError("Could not find new master node '%s' in"
                                 " cluster configuration." % new_master,
                                 errors.ECODE_NOENT)

    cluster_info.master_node = new_master_node.uuid
    # this will also regenerate the ssconf files, since we updated the
    # cluster info
    cfg.Update(cluster_info, logging.error)

    # if cfg.Update worked, then it means the old master daemon won't be
    # able now to write its own config file (we rely on locking in both
    # backend.UploadFile() and ConfigWriter._Write(); hence the next
    # step is to kill the old master

    logging.info("Stopping the master daemon on node %s", old_master)

    runner = rpc.BootstrapRunner()
    master_params = cfg.GetMasterNetworkParameters()
    master_params.uuid = old_master_node.uuid
    ems = cfg.GetUseExternalMipScript()
    result = runner.call_node_deactivate_master_ip(old_master,
                                                   master_params, ems)

    msg = result.fail_msg
    if msg:
      warning = "Could not disable the master IP: %s" % (msg,)
      logging.warning("%s", warning)
      warnings.append(warning)

    result = runner.call_node_stop_master(old_master)
    msg = result.fail_msg
    if msg:
      warning = ("Could not disable the master role on the old master"
                 " %s, please disable manually: %s" % (old_master, msg))
      logging.error("%s", warning)
      warnings.append(warning)
  except errors.ConfigurationError, err:
    logging.error("Error while trying to set the new master: %s",
                  str(err))
    return 1, warnings
  finally:
    # stop WConfd again:
    result = utils.RunCmd([pathutils.DAEMON_UTIL, "stop", constants.WCONFD])
    if result.failed:
      warning = ("Could not stop the configuration daemon,"
                 " command %s had exitcode %s and error %s"
                 % (result.cmd, result.exit_code, result.output))
      logging.error("%s", warning)
      rcode = 1

  logging.info("Checking master IP non-reachability...")

  master_ip = sstore.GetMasterIP()
  total_timeout = 30

  # Here we have a phase where no master should be running
  def _check_ip(expected):
    if netutils.TcpPing(master_ip, constants.DEFAULT_NODED_PORT) != expected:
      raise utils.RetryAgain()

  try:
    utils.Retry(_check_ip, (1, 1.5, 5), total_timeout, args=[False])
  except utils.RetryTimeout:
    warning = ("The master IP is still reachable after %s seconds,"
               " continuing but activating the master IP on the current"
               " node will probably fail" % total_timeout)
    logging.warning("%s", warning)
    warnings.append(warning)
    rcode = 1

  if jstore.CheckDrainFlag():
    logging.info("Undraining job queue")
    jstore.SetDrainFlag(False)

  logging.info("Starting the master daemons on the new master")

  result = rpc.BootstrapRunner().call_node_start_master_daemons(new_master,
                                                                no_voting)
  msg = result.fail_msg
  if msg:
    logging.error("Could not start the master role on the new master"
                  " %s, please check: %s", new_master, msg)
    rcode = 1

  # Finally verify that the new master managed to set up the master IP
  # and warn if it didn't.
  try:
    utils.Retry(_check_ip, (1, 1.5, 5), total_timeout, args=[True])
  except utils.RetryTimeout:
    warning = ("The master IP did not come up within %s seconds; the"
               " cluster should still be working and reachable via %s,"
               " but not via the master IP address"
               % (total_timeout, new_master))
    logging.warning("%s", warning)
    warnings.append(warning)
    rcode = 1

  logging.info("Master failed over from %s to %s", old_master, new_master)
  return rcode, warnings

Example 40

Project: ldpush
Source File: junos.py
View license
  def _SetConfig(self, destination_file, data, canary, skip_show_compare=False,
                 skip_commit_check=False, get_rollback_patch=False):
    copied = False

    file_ptr = tempfile.NamedTemporaryFile()
    rollback_patch_ptr = tempfile.NamedTemporaryFile()
    rollback_patch = None
    # Setting the file name based upon if we are trying to copy a file or
    # we are trying to copy a config into the control plane.
    if destination_file in self.NON_FILE_DESTINATIONS:
      file_name = os.path.basename(file_ptr.name)
      if get_rollback_patch:
        rollback_patch = os.path.basename(rollback_patch_ptr.name)
    else:
      file_name = destination_file
      logging.info('Remote file path: %s', file_name)

    try:
      file_ptr.write(data)
      file_ptr.flush()
    except IOError:
      raise exceptions.SetConfigError('Could not open temporary file %r' %
                                      file_ptr.name)
    result = base_device.SetConfigResult()
    try:
      # Copy the file to the remote device.
      try:
        self._SendFileViaSftp(local_filename=file_ptr.name,
                              remote_filename=file_name)
        copied = True
      except (paramiko.SFTPError, IOError) as e:
        # _SendFileViaSftp puts the normalized destination path in e.args[1].
        msg = 'SFTP failed (filename %r to device %s(%s):%s): %s: %s' % (
            file_ptr.name, self.host, self.loopback_ipv4, e.args[1],
            e.__class__.__name__, e.args[0])
        raise exceptions.SetConfigError(msg)

      if not self._ChecksumsMatch(local_file_name=file_ptr.name,
                                  remote_file_name=file_name):
        raise exceptions.SetConfigError(
            'Local and remote file checksum mismatch.')

      if self.CONFIG_RUNNING == destination_file:
        operation = 'replace'
      elif self.CONFIG_STARTUP == destination_file:
        operation = 'override'
      elif self.CONFIG_PATCH == destination_file:
        operation = 'patch'
      else:
        result.transcript = 'SetConfig uploaded the file successfully.'
        return result
      if canary:
        logging.debug('Canary syntax checking configuration file %r.',
                      file_name)
        result = self._JunosLoad(operation, file_name, canary=True,
                                 skip_show_compare=skip_show_compare,
                                 skip_commit_check=skip_commit_check)
      else:
        logging.debug('Setting destination %r with configuration file %r.',
                      destination_file, file_name)
        result = self._JunosLoad(operation, file_name,
                                 skip_show_compare=skip_show_compare,
                                 skip_commit_check=skip_commit_check,
                                 rollback_patch=rollback_patch)

        if rollback_patch:
          try:
            self._GetFileViaSftp(local_filename=rollback_patch_ptr.name,
                                 remote_filename=rollback_patch)
            result.rollback_patch = rollback_patch_ptr.read()
          except (paramiko.SFTPError, IOError) as e:
            # _GetFileViaSftp puts the normalized source path in e.args[1].
            result.transcript += (
                'SFTP rollback patch retrieval failed '
                '(filename %r from device %s(%s):%s): %s: %s' % (
                    rollback_patch_ptr.name, self.host, self.loopback_ipv4,
                    e.args[1], e.__class__.__name__, e.args[0]))

      # Return the diagnostic results as the (optional) result.
      return result

    finally:
      local_delete_exception = None
      # Unlink the original temporary file.
      try:
        logging.info('Deleting the file on the local machine: %s',
                     file_ptr.name)
        file_ptr.close()
      except IOError:
        local_delete_exception = exceptions.SetConfigError(
            'Could not close temporary file.')

      local_rollback_patch_delete_exception = None
      # Unlink the rollback patch temporary file.
      try:
        logging.info('Deleting the file on the local machine: %s',
                     rollback_patch_ptr.name)
        rollback_patch_ptr.close()
      except IOError:
        local_rollback_patch_delete_exception = exceptions.SetConfigError(
            'Could not close temporary rollback patch file.')

      # If we copied the file to the router and we were pushing a configuration,
      # delete the temporary file off the router.
      if copied and destination_file in self.NON_FILE_DESTINATIONS:
        logging.info('Deleting file on the router: %s', file_name)
        self.Cmd('file delete ' + file_name)

      # Delete any rollback patch file too.
      if rollback_patch:
        logging.info('Deleting patch on the router: %s', rollback_patch)
        self.Cmd('file delete ' + rollback_patch)

      # If we got an exception on the local file delete, but did not get a
      # (more important) exception on the remote delete, raise the local delete
      # exception.
      #
      # pylint is confused by the re-raising
      # pylint: disable=raising-bad-type
      if local_delete_exception is not None:
        raise local_delete_exception
      if local_rollback_patch_delete_exception is not None:
        raise local_rollback_patch_delete_exception

Example 41

View license
  def Prepare(self, vm):
    """Prepares the DB and everything for the AWS-RDS provider.

    Args:
      vm: The VM to be used as the test client.

    """
    logging.info('Preparing MySQL Service benchmarks for RDS.')

    # TODO: Refactor the RDS DB instance creation and deletion logic out
    # to a new class called RDSDBInstance that Inherits from
    # perfkitbenchmarker.resource.BaseResource.
    # And do the same for GCP.

    # First is to create another subnet in the same VPC as the VM but in a
    # different zone. RDS requires two subnets in two different zones to create
    # a DB instance, EVEN IF you do not specify multi-AZ in your DB creation
    # request.

    # Get a list of zones and pick one that's different from the zone VM is in.
    new_subnet_zone = None
    get_zones_cmd = util.AWS_PREFIX + ['ec2', 'describe-availability-zones']
    stdout, _, _ = vm_util.IssueCommand(get_zones_cmd)
    response = json.loads(stdout)
    all_zones = response['AvailabilityZones']
    for zone in all_zones:
      if zone['ZoneName'] != vm.zone:
        new_subnet_zone = zone['ZoneName']
        break

    if new_subnet_zone is None:
      raise DBStatusQueryError('Cannot find a zone to create the required '
                               'second subnet for the DB instance.')

    # Now create a new subnet in the zone that's different from where the VM is
    logging.info('Creating a second subnet in zone %s', new_subnet_zone)
    new_subnet = aws_network.AwsSubnet(new_subnet_zone, vm.network.vpc.id,
                                       '10.0.1.0/24')
    new_subnet.Create()
    logging.info('Successfully created a new subnet, subnet id is: %s',
                 new_subnet.id)
    # Remember this so we can cleanup properly.
    vm.extra_subnet_for_db = new_subnet

    # Now we can create a new DB subnet group that has two subnets in it.
    db_subnet_group_name = 'pkb%s' % FLAGS.run_uri
    create_db_subnet_group_cmd = util.AWS_PREFIX + [
        'rds',
        'create-db-subnet-group',
        '--db-subnet-group-name', db_subnet_group_name,
        '--db-subnet-group-description', 'pkb_subnet_group_for_db',
        '--subnet-ids', vm.network.subnet.id, new_subnet.id]
    stdout, stderr, _ = vm_util.IssueCommand(create_db_subnet_group_cmd)
    logging.info('Created a DB subnet group, stdout is:\n%s\nstderr is:\n%s',
                 stdout, stderr)
    vm.db_subnet_group_name = db_subnet_group_name

    # open up tcp port 3306 in the VPC's security group, we need that to connect
    # to the DB.
    open_port_cmd = util.AWS_PREFIX + [
        'ec2',
        'authorize-security-group-ingress',
        '--group-id', vm.group_id,
        '--source-group', vm.group_id,
        '--protocol', 'tcp',
        '--port', MYSQL_PORT]
    stdout, stderr, _ = vm_util.IssueCommand(open_port_cmd)
    logging.info('Granted DB port ingress, stdout is:\n%s\nstderr is:\n%s',
                 stdout, stderr)

    # Finally, it's time to create the DB instance!
    vm.db_instance_id = 'pkb-DB-%s' % FLAGS.run_uri
    db_class = \
        RDS_CORE_TO_DB_CLASS_MAP['%s' % FLAGS.mysql_svc_db_instance_cores]
    vm.db_instance_master_user = MYSQL_ROOT_USER
    vm.db_instance_master_password = _GenerateRandomPassword()

    create_db_cmd = util.AWS_PREFIX + [
        'rds',
        'create-db-instance',
        '--db-instance-identifier', vm.db_instance_id,
        '--db-instance-class', db_class,
        '--engine', RDS_DB_ENGINE,
        '--engine-version', RDS_DB_ENGINE_VERSION,
        '--storage-type', RDS_DB_STORAGE_TYPE_GP2,
        '--allocated-storage', RDS_DB_STORAGE_GP2_SIZE,
        '--vpc-security-group-ids', vm.group_id,
        '--master-username', vm.db_instance_master_user,
        '--master-user-password', vm.db_instance_master_password,
        '--availability-zone', vm.zone,
        '--db-subnet-group-name', vm.db_subnet_group_name]

    status_query_cmd = util.AWS_PREFIX + [
        'rds',
        'describe-db-instances',
        '--db-instance-id', vm.db_instance_id]

    stdout, stderr, _ = vm_util.IssueCommand(create_db_cmd)
    logging.info('Request to create the DB has been issued, stdout:\n%s\n'
                 'stderr:%s\n', stdout, stderr)
    response = json.loads(stdout)

    db_creation_status = _RDSParseDBInstanceStatus(response)

    for status_query_count in xrange(1, DB_STATUS_QUERY_LIMIT + 1):
      if db_creation_status == 'available':
        break

      if db_creation_status not in RDS_DB_CREATION_PENDING_STATUS:
        raise DBStatusQueryError('Invalid status in DB creation response. '
                                 ' stdout is\n%s, stderr is\n%s' % (
                                     stdout, stderr))

      logging.info('Querying db creation status, current state is %s, query '
                   'count is %d', db_creation_status, status_query_count)
      time.sleep(DB_STATUS_QUERY_INTERVAL)

      stdout, stderr, _ = vm_util.IssueCommand(status_query_cmd)
      response = json.loads(stdout)
      db_creation_status = _RDSParseDBInstanceStatus(response)
    else:
      raise DBStatusQueryError('DB creation timed-out, we have '
                               'waited at least %s * %s seconds.' % (
                                   DB_STATUS_QUERY_INTERVAL,
                                   DB_STATUS_QUERY_LIMIT))

    # We are good now, db has been created. Now get the endpoint address.
    # On RDS, you always connect with a DNS name, if you do that from a EC2 VM,
    # that DNS name will be resolved to an internal IP address of the DB.
    if 'DBInstance' in response:
      vm.db_instance_address = response['DBInstance']['Endpoint']['Address']
    else:
      if 'DBInstances' in response:
        vm.db_instance_address = \
            response['DBInstances'][0]['Endpoint']['Address']

    logging.info('Successfully created an RDS DB instance. Address is %s',
                 vm.db_instance_address)
    logging.info('Complete output is:\n %s', response)

Example 42

View license
def _ProcessMultiStreamResults(start_times, latencies, sizes, operation,
                               all_sizes, results, metadata=None):
  """Read and process results from the api_multistream worker process.

  Results will be reported per-object size and combined for all
  objects.

  Args:
    start_times: a list of numpy arrays. Operation start times, as
      POSIX timestamps.
    latencies: a list of numpy arrays. Operation durations, in seconds.
    sizes: a list of numpy arrays. Object sizes used in each
      operation, in bytes.
    operation: 'upload' or 'download'. The operation the results are from.
    all_sizes: a sequence of integers. all object sizes in the
      distribution used, in bytes.
    results: a list to append Sample objects to.
    metadata: dict. Base sample metadata
  """

  num_streams = FLAGS.object_storage_streams_per_vm * FLAGS.num_vms

  assert len(start_times) == num_streams
  assert len(latencies) == num_streams
  assert len(sizes) == num_streams

  if metadata is None:
    metadata = {}
  metadata['num_streams'] = num_streams
  metadata['objects_per_stream'] = (
      FLAGS.object_storage_multistream_objects_per_stream)

  num_records = sum((len(start_time) for start_time in start_times))
  logging.info('Processing %s total operation records', num_records)

  stop_times = [start_time + latency
                for start_time, latency in zip(start_times, latencies)]

  last_start_time = max((start_time[0] for start_time in start_times))
  first_stop_time = min((stop_time[-1] for stop_time in stop_times))

  # Compute how well our synchronization worked
  first_start_time = min((start_time[0] for start_time in start_times))
  last_stop_time = max((stop_time[-1] for stop_time in stop_times))
  start_gap = last_start_time - first_start_time
  stop_gap = last_stop_time - first_stop_time
  if ((start_gap + stop_gap) / (last_stop_time - first_start_time) <
      MULTISTREAM_STREAM_GAP_THRESHOLD):
    logging.info(
        'First stream started %s seconds before last stream started', start_gap)
    logging.info(
        'Last stream ended %s seconds after first stream ended', stop_gap)
  else:
    logging.warning(
        'Difference between first and last stream start/end times was %s and '
        '%s, which is more than %s of the benchmark time %s.',
        start_gap, stop_gap, MULTISTREAM_STREAM_GAP_THRESHOLD,
        (last_stop_time - first_start_time))
    metadata['stream_gap_above_threshold'] = True

  # Find the indexes in each stream where all streams are active,
  # following Python's [inclusive, exclusive) index convention.
  active_start_indexes = []
  for start_time in start_times:
    for i in xrange(len(start_time)):
      if start_time[i] >= last_start_time:
        active_start_indexes.append(i)
        break
  active_stop_indexes = []
  for stop_time in stop_times:
    for i in xrange(len(stop_time) - 1, -1, -1):
      if stop_time[i] <= first_stop_time:
        active_stop_indexes.append(i + 1)
        break
  active_latencies = [
      latencies[i][active_start_indexes[i]:active_stop_indexes[i]]
      for i in xrange(num_streams)]
  active_sizes = [
      sizes[i][active_start_indexes[i]:active_stop_indexes[i]]
      for i in xrange(num_streams)]

  all_active_latencies = np.concatenate(active_latencies)
  all_active_sizes = np.concatenate(active_sizes)

  # Don't publish the full distribution in the metadata because doing
  # so might break regexp-based parsers that assume that all metadata
  # values are simple Python objects. However, do add an
  # 'object_size_B' metadata field even for the full results because
  # searching metadata is easier when all records with the same metric
  # name have the same set of metadata fields.
  distribution_metadata = metadata.copy()
  distribution_metadata['object_size_B'] = 'distribution'

  latency_prefix = 'Multi-stream %s latency' % operation
  logging.info('Processing %s multi-stream %s results for the full '
               'distribution.', len(all_active_latencies), operation)
  _AppendPercentilesToResults(
      results,
      all_active_latencies,
      latency_prefix,
      LATENCY_UNIT,
      distribution_metadata)

  # Publish by-size and full-distribution stats even if there's only
  # one size in the distribution, because it simplifies postprocessing
  # of results.
  for size in all_sizes:
    this_size_metadata = metadata.copy()
    this_size_metadata['object_size_B'] = size
    logging.info('Processing multi-stream %s results for object size %s',
                 operation, size)
    _AppendPercentilesToResults(
        results,
        all_active_latencies[all_active_sizes == size],
        latency_prefix,
        LATENCY_UNIT,
        this_size_metadata)

  # Throughput metrics
  total_active_times = [np.sum(latency) for latency in active_latencies]
  active_durations = [stop_times[i][active_stop_indexes[i] - 1] -
                      start_times[i][active_start_indexes[i]]
                      for i in xrange(num_streams)]
  total_active_sizes = [np.sum(size) for size in active_sizes]
  # 'net throughput (with gap)' is computed by taking the throughput
  # for each stream (total # of bytes transmitted / (stop_time -
  # start_time)) and then adding the per-stream throughputs. 'net
  # throughput' is the same, but replacing (stop_time - start_time)
  # with the sum of all of the operation latencies for that thread, so
  # we only divide by the time that stream was actually transmitting.
  results.append(sample.Sample(
      'Multi-stream ' + operation + ' net throughput',
      np.sum((size / active_time * 8
              for size, active_time
              in zip(total_active_sizes, total_active_times))),
      'bit / second', metadata=distribution_metadata))
  results.append(sample.Sample(
      'Multi-stream ' + operation + ' net throughput (with gap)',
      np.sum((size / duration * 8
              for size, duration in zip(total_active_sizes, active_durations))),
      'bit / second', metadata=distribution_metadata))
  results.append(sample.Sample(
      'Multi-stream ' + operation + ' net throughput (simplified)',
      sum([np.sum(size) for size in sizes]) /
      (last_stop_time - first_start_time) * 8,
      'bit / second', metadata=distribution_metadata))

  # QPS metrics
  results.append(sample.Sample(
      'Multi-stream ' + operation + ' QPS (any stream active)',
      num_records / (last_stop_time - first_start_time), 'operation / second',
      metadata=distribution_metadata))
  results.append(sample.Sample(
      'Multi-stream ' + operation + ' QPS (all streams active)',
      len(all_active_latencies) / (first_stop_time - last_start_time),
      'operation / second', metadata=distribution_metadata))

  # Statistics about benchmarking overhead
  gap_time = sum((active_duration - active_time
                  for active_duration, active_time
                  in zip(active_durations, total_active_times)))
  results.append(sample.Sample(
      'Multi-stream ' + operation + ' total gap time',
      gap_time, 'second', metadata=distribution_metadata))
  results.append(sample.Sample(
      'Multi-stream ' + operation + ' gap time proportion',
      gap_time / (first_stop_time - last_start_time) * 100.0,
      'percent', metadata=distribution_metadata))

Example 43

Project: immuno
Source File: analyze_cohort.py
View license
def generate_mutation_counts(
        mutation_files,
        hla_types,
        genes_expressed,
        max_peptide_length=31,
        skip_identifiers = {},
        output_file=None):
    """
    Returns dictionary that maps each patient ID to a tuple with six fields:
        - total number of mutated epitopes across all transcripts
        - number of mutated genes
        - number of mutated genes with MHC binding mutated epitope
        - number of mutated epitopes which are predicted to bind to an MHC
          allele
        - number of mutated genes with at least one immunogenic mutated
          epitope
        - number of mutated epitopes which are predicted to be immunogenic
          (MHC binder + non-self)
    """
    mutation_counts = OrderedDict()
    n = len(mutation_files)
    for i, (patient_id, vcf_df) in enumerate(mutation_files.iteritems()):
        if patient_id in skip_identifiers:
            logging.info("Skipping patient ID %s", patient_id)
            continue
        hla_allele_names = hla_types[patient_id]
        logging.info(
            "Processing %s (#%d/%d) with HLA alleles %s",
            patient_id, i + 1, n, hla_allele_names)

        if not args.quiet:
            print vcf_df

        try:
            transcripts_df, raw_genomic_mutation_df, variant_report = (
                expand_transcripts(
                    vcf_df,
                    patient_id,
                    max_peptide_length=max_peptide_length))
        except KeyboardInterrupt:
            raise
        except:
            logging.warning("Failed to apply mutations for %s", patient_id)
            raise

        # print each genetic mutation applied to each possible transcript
        # and either why it failed or what protein mutation resulted
        if not args.quiet:
            print_mutation_report(
                patient_id,
                variant_report,
                raw_genomic_mutation_df,
                transcripts_df)
            logging.info(
                "Calling MHC binding predictor for %s (#%d/%d)",
                patient_id, i + 1, n)

        def make_mhc_predictor():
            if args.netmhc_cons:
                return ConsensusBindingPredictor(hla_allele_names)
            else:
                return PanBindingPredictor(hla_allele_names)

        # If we want to read scored_epitopes from a CSV file, do that.
        if args.debug_scored_epitopes_csv:
            csv_file = args.debug_scored_epitopes_csv
            if isfile(csv_file):
                scored_epitopes = pd.read_csv(csv_file)
            else:
                mhc = make_mhc_predictor()
                scored_epitopes = mhc.predict(transcripts_df,
                        mutation_window_size=9)
                scored_epitopes.to_csv(csv_file)
        else:
            mhc = make_mhc_predictor()
            scored_epitopes = mhc.predict(transcripts_df,
                    mutation_window_size=9)

        if not args.quiet:
            print scored_epitopes

        imm = ImmunogenicityPredictor(
            alleles=hla_allele_names,
            binding_threshold=args.binding_threshold)
        scored_epitopes = imm.predict(scored_epitopes)
        scored_epitopes.to_csv("scored_epitopes.csv")
        scored_epitopes = pd.read_csv("scored_epitopes.csv")

        grouped = scored_epitopes.groupby(["Gene", "GeneMutationInfo"])
        n_coding_mutations = len(grouped)
        n_epitopes = 0
        n_ligand_mutations = 0
        n_ligands = 0
        n_immunogenic_mutations = 0
        n_immunogenic_epitopes = 0
        for (gene, mut), group in grouped:
            start_mask = group.EpitopeStart < group.MutationEnd
            stop_mask = group.EpitopeEnd > group.MutationStart
            mutated_subset = group[start_mask & stop_mask]
            # we might have duplicate epitopes from multiple transcripts, so
            # drop them
            n_curr_epitopes = len(mutated_subset.groupby(['Epitope']))
            n_epitopes += n_curr_epitopes
            below_threshold_mask = \
                mutated_subset.MHC_IC50 <= args.binding_threshold
            ligands = mutated_subset[below_threshold_mask]
            n_curr_ligands = len(ligands.groupby(['Epitope']))
            n_ligands += n_curr_ligands
            n_ligand_mutations += (n_curr_ligands) > 0
            thymic_deletion_mask = \
                np.array(ligands.ThymicDeletion).astype(bool)
            immunogenic_epitopes = ligands[~thymic_deletion_mask]
            curr_immunogenic_epitopes = immunogenic_epitopes.groupby(['Epitope']).first()
            n_immunogenic_epitopes += len(curr_immunogenic_epitopes)
            n_immunogenic_mutations += len(curr_immunogenic_epitopes) > 0
            logging.info(("%s %s: epitopes %s, ligands %d, imm %d"),
                         gene,
                         mut,
                         n_curr_epitopes,
                         n_curr_ligands,
                         len(curr_immunogenic_epitopes),
                        )
        result_tuple = (
            n_coding_mutations,
            n_epitopes,
            n_ligand_mutations,
            n_ligands,
            n_immunogenic_mutations,
            n_immunogenic_epitopes,
        )
        if output_file:
            data_string = ",".join(str(d) for d in result_tuple)
            output_file.write("%s,%s\n" % (patient_id, data_string))
            output_file.flush()
        mutation_counts[patient_id] = result_tuple
    return mutation_counts

Example 44

Project: immuno
Source File: load_file.py
View license
def expand_transcripts(
        vcf_df, patient_id, min_peptide_length=9, max_peptide_length=31):
    """
    Applies genomic variants to all possible transcripts.

    Parameters
    --------

    vcf_df : DataFrame
        Required to have basic variant columns (chr, pos, ref, alt)

    patient_id : str

    min_peptide_length : int

    max_peptide_length : int
    """

    assert len(vcf_df)  > 0, "No mutation entries for %s" % patient_id
    logging.info("Expanding transcripts from %d variants for %s", len(vcf_df), patient_id)
    vcf_df['chr'] = vcf_df.chr.map(normalize_chromosome_name)

    # annotate genomic mutations into all the possible
    # known transcripts they might be on
    transcripts_df = annotation.annotate_vcf_transcripts(vcf_df)

    assert len(transcripts_df) > 0, \
        "No annotated mutation entries for %s" % patient_id
    logging.info(
        "Annotated input %s has %d possible transcripts",
         patient_id,
         len(transcripts_df))

    new_rows = []

    group_cols = ['chr','pos', 'ref', 'alt', 'stable_id_transcript']

    seen_source_sequences = set([])

    # for each genetic variant in the source file,
    # we're going to print a string describing either the resulting
    # protein variant or whatever error prevented us from getting a result
    variant_report = OrderedDict()

    for (chromosome, pos, ref, alt, transcript_id), group in \
            transcripts_df.groupby(group_cols):
        mutation_description = "chr%s %s" % (
            chromosome,
            gene_mutation_description(pos, ref, alt),
        )
        key = (mutation_description, transcript_id)

        def skip(msg, *args):
            msg = msg % args
            logging.info(
                "Skipping %s on %s: %s" ,
                    mutation_description,
                    transcript_id,
                    msg)
            variant_report[key] = msg

        def error(msg, *args):
            msg = msg % args
            logging.warning(
                "Error in %s on %s: %s",
                    mutation_description,
                    transcript_id,
                    msg)
            variant_report[key] = msg

        def success(row):
            new_rows.append(row)
            msg = "SUCCESS: Gene = %s, Mutation = %s, SourceSequence = %s<%d>" \
                % (row['Gene'], row['PeptideMutationInfo'],
                    row['SourceSequence'], len(row['SourceSequence']))
            variant_report[key] = msg

        if chromosome.upper().startswith("M"):
            skip("Mitochondrial DNA is insane, don't even bother")
            continue
        elif ref == alt:
            skip("Not a variant, since ref %s matches alt %s", ref, alt)
            continue

        padding = max_peptide_length - 1
        if transcript_id:
            seq, start, stop, annot = \
                peptide_from_transcript_variant(
                    transcript_id, pos, ref, alt,
                    padding = padding)
        else:
            error("Skipping due to invalid transcript ID")
            continue

        if not seq:
            error(annot)
        else:
            starts_with = [s for s in seen_source_sequences if s.startswith(
                seq)]
            if any(starts_with):
                msg = "Already seen %d sequence(s) starting with %s<%d>" % (
                    len(starts_with), seq, len(seq))
                lengths = [("<%d>" % len(s)) for s in starts_with]
                msg += " (" + ', '.join(lengths) + ")"
                skip(msg)

                # Log the actual seen sequences
                for i, s in enumerate(starts_with):
                    logging.info("Sequence #%d (already seen): %s<%d>" % (i + 1,
                        s, len(s)))

                continue
            else:
                seen_source_sequences.add(seq)

            if '*' in seq:
                error(
                    "Found stop codon in peptide %s from transcript_id %s",
                    seq,
                    transcript_id)
            elif not is_valid_peptide(seq):
                error(
                    "Invalid peptide sequence for transcript_id %s: %s",
                    transcript_id,
                    seq)
            elif len(seq) < min_peptide_length:
                skip(
                    "Truncated peptide (len %d) too short for transcript %s",
                    len(seq),
                    transcript_id)
            else:
                row = deepcopy(group.irow(0))
                row['SourceSequence'] = seq
                row['MutationStart'] = start
                row['MutationEnd'] = stop
                gene_mutation_info = "chr%s %s" % (
                    chromosome,
                    gene_mutation_description(pos, ref, alt))
                row['GeneMutationInfo'] = gene_mutation_info
                row['PeptideMutationInfo'] = annot
                try:
                    gene = gene_names.transcript_id_to_gene_name(transcript_id)
                except:
                    gene = gene_names.transcript_id_to_gene_id(transcript_id)

                row['Gene'] = gene
                success(row)

    assert len(new_rows) > 0, "No mutations!"
    peptides = pd.DataFrame.from_records(new_rows)
    peptides['GeneInfo'] = peptides['info']
    peptides['TranscriptId'] = peptides['stable_id_transcript']

    transcripts_df = transcripts_df.merge(peptides)
    logging.info(
        "Generated %d peptides from %s",
        len(transcripts_df),
        patient_id
    )

    # drop verbose or uninteresting columns from VCF
    dumb_fields = (
        'description_gene',
        'filter',
        'qual',
        'id',
        'name',
        'info',
        'stable_id_transcript'
    )
    for dumb_field in dumb_fields:
        if dumb_field in transcripts_df.columns:
            transcripts_df = transcripts_df.drop(dumb_field, axis = 1)
    return transcripts_df, vcf_df, variant_report

Example 45

Project: pepdata
Source File: mhc.py
View license
@memoize
def load_dataframe(
        mhc_class=None,  # 1, 2, or None for neither
        hla=None,
        exclude_hla=None,
        human=True,
        peptide_length=None,
        assay_method=None,
        assay_group=None,
        only_standard_amino_acids=True,
        reduced_alphabet=None,  # 20 letter AA strings -> simpler alphabet
        warn_bad_lines=True,
        nrows=None):
    """
    Load IEDB MHC data without aggregating multiple entries for the same epitope

    Parameters
    ----------
    mhc_class : {None, 1, 2}
        Restrict to MHC Class I or Class II (or None for neither)

    hla : regex pattern, optional
        Restrict results to specific HLA type used in assay

    exclude_hla : regex pattern, optional
        Exclude certain HLA types

    human : bool
        Restrict to human samples (default True)

    peptide_length: int, optional
        Restrict epitopes to amino acid strings of given length

    assay_method : string, optional
        Limit to assay methods which contain the given string

    assay_group : string, optional
        Limit to assay groups which contain the given string

    only_standard_amino_acids : bool, optional
        Drop sequences which use non-standard amino acids, anything outside
        the core 20, such as X or U (default = True)

    reduced_alphabet : dictionary, optional
        Remap amino acid letters to some other alphabet

    warn_bad_lines : bool, optional
        The full MHC ligand dataset seems to contain several dozen lines with
        too many fields. This currently results in a lot of warning messages
        from Pandas, which you can turn off with this option (default = True)

    nrows : int, optional
        Don't load the full IEDB dataset but instead read only the first nrows
    """
    df = pd.read_csv(
            local_path(),
            header=[0, 1],
            skipinitialspace=True,
            nrows=nrows,
            low_memory=False,
            error_bad_lines=False,
            encoding="latin-1",
            warn_bad_lines=warn_bad_lines)

    # Sometimes the IEDB seems to put in an extra comma in the
    # header line, which creates an unnamed column of NaNs.
    # To deal with this, drop any columns which are all NaN
    df = df.dropna(axis=1, how="all")

    n = len(df)

    epitope_column_key = ("Epitope", "Description")
    mhc_allele_column_key = ("MHC", "Allele Name")

    epitopes = df[epitope_column_key] = df[epitope_column_key].str.upper()

    null_epitope_seq = epitopes.isnull()
    n_null = null_epitope_seq.sum()
    if n_null > 0:
        logging.info("Dropping %d null sequences", n_null)

    mask = ~null_epitope_seq

    if only_standard_amino_acids:
        # if have rare or unknown amino acids, drop the sequence
        bad_epitope_seq = \
            epitopes.str.contains(bad_amino_acids, na=False).astype("bool")
        n_bad = bad_epitope_seq.sum()
        if n_bad > 0:
            logging.info("Dropping %d bad sequences", n_bad)

        mask &= ~bad_epitope_seq

    if human:
        mask &= df[mhc_allele_column_key].str.startswith("HLA").astype("bool")

    if mhc_class == 1:
        mask &= df["MHC"]["MHC allele class"] == "I"
    elif mhc_class == 2:
        mask &= df["MHC"]["MHC allele class"] == "II"

    if hla:
        mask &= df[mhc_allele_column_key].str.contains(hla, na=False)

    if exclude_hla:
        mask &= ~(df[mhc_allele_column_key].str.contains(exclude_hla, na=False))

    if assay_group:
        mask &= df["Assay"]["Assay Group"].str.contains(assay_group)

    if assay_method:
        mask &= df["Assay"]["Method/Technique"].str.contains(assay_method)

    if peptide_length:
        assert peptide_length > 0
        mask &= df[epitope_column_key].str.len() == peptide_length

    df = df[mask].copy()

    logging.info("Returning %d / %d entries after filtering", len(df), n)

    if reduced_alphabet:
        epitopes = df[epitope_column_key]
        df["Epitope"]["Original Sequence"] = epitopes
        reduced_epitopes = epitopes.map(
            make_alphabet_transformer(reduced_alphabet))
        df[epitope_column_key] = reduced_epitopes
    return df

Example 46

Project: pepdata
Source File: tcell.py
View license
@memoize
def load_dataframe(
        mhc_class=None,  # 1, 2, or None for neither
        hla=None,
        exclude_hla=None,
        human=True,
        peptide_length=None,
        assay_method=None,
        assay_group=None,
        only_standard_amino_acids=True,
        reduced_alphabet=None,  # 20 letter AA strings -> simpler alphabet
        nrows=None):
    """
    Load IEDB T-cell data without aggregating multiple entries for same epitope

    Parameters
    ----------
    mhc_class: {None, 1, 2}
        Restrict to MHC Class I or Class II (or None for neither)

    hla: regex pattern, optional
        Restrict results to specific HLA type used in assay

    exclude_hla: regex pattern, optional
        Exclude certain HLA types

    human: bool
        Restrict to human samples (default True)

    peptide_length: int, optional
        Restrict epitopes to amino acid strings of given length

    assay_method string, optional
        Only collect results with assay methods containing the given string

    assay_group: string, optional
        Only collect results with assay groups containing the given string

    only_standard_amino_acids : bool, optional
        Drop sequences which use non-standard amino acids, anything outside
        the core 20, such as X or U (default = True)

    reduced_alphabet: dictionary, optional
        Remap amino acid letters to some other alphabet

    nrows: int, optional
        Don't load the full IEDB dataset but instead read only the first nrows
    """
    path = local_path()
    df = pd.read_csv(
            path,
            skipinitialspace=True,
            nrows=nrows,
            low_memory=False,
            error_bad_lines=False,
            encoding="latin-1")

    # Sometimes the IEDB seems to put in an extra comma in the
    # header line, which creates an unnamed column of NaNs.
    # To deal with this, drop any columns which are all NaN
    df = df.dropna(axis=1, how="all")

    n = len(df)

    epitopes = df["Epitope Linear Sequence"].str.upper()

    null_epitope_seq = epitopes.isnull()
    n_null = null_epitope_seq.sum()

    if n_null > 0:
        logging.info("Dropping %d null sequences", n_null)

    mask = ~null_epitope_seq

    if only_standard_amino_acids:
        # if have rare or unknown amino acids, drop the sequence
        bad_epitope_seq = \
            epitopes.str.contains(bad_amino_acids, na=False).astype("bool")
        n_bad = bad_epitope_seq.sum()
        if n_bad > 0:
            logging.info("Dropping %d bad sequences", n_bad)

        mask &= ~bad_epitope_seq

    if human:
        organism = df['Host Organism Name']
        mask &= organism.str.startswith('Homo sapiens', na=False).astype('bool')

    # Match known alleles such as "HLA-A*02:01",
    # broader groupings such as "HLA-A2"
    # and unknown alleles of the MHC-1 listed either as
    #  "HLA-Class I,allele undetermined"
    #  or
    #  "Class I,allele undetermined"
    mhc = df['MHC Allele Name']

    if mhc_class is not None:
        # since MHC classes can be specified as either strings ("I") or integers
        # standard them to be strings
        if mhc_class == 1:
            mhc_class = "I"
        elif mhc_class == 2:
            mhc_class = "II"
        if mhc_class not in {"I", "II"}:
            raise ValueError("Invalid MHC class: %s" % mhc_class)
        allele_dict = load_alleles_dict()
        mhc_class_mask = [False] * len(df)
        for i, allele_name in enumerate(mhc):
            allele_object = allele_dict.get(allele_name)
            if allele_object and allele_object.mhc_class == mhc_class:
                mhc_class_mask[i] = True
        mask &= np.array(mhc_class_mask)

    if hla:
        mask &= df["MHC Allele Name"].str.contains(hla, na=False)

    if exclude_hla:
        mask &= ~(df["MHC Allele Name"].str.contains(exclude_hla, na=False))

    if assay_group:
        mask &= df["Assay Group"].str.contains(assay_group)

    if assay_method:
        mask &= df["Method/Technique"].str.contains(assay_method)

    if peptide_length:
        assert peptide_length > 0
        mask &= df["Epitope Linear Sequence"].str.len() == peptide_length

    df = df[mask]

    logging.info("Returning %d / %d entries after filtering", len(df), n)

    if reduced_alphabet:
        epitopes = df["Epitope Linear Sequence"]
        df["Epitope Linear Sequence"] = \
            epitopes.map(make_alphabet_transformer(reduced_alphabet))
        df["Epitope Original Sequence"] = epitopes
    return df

Example 47

Project: stanford-ctc
Source File: sgd.py
View license
    def run(self,data_dict,alis,keys,sizes):
        """
        Runs stochastic gradient descent with nesterov acceleration.
        Model is objective.  
        """
        
        # momentum setup
        momIncrease = 10
        mom = 0.5

        # randomly select minibatch
       	random.shuffle(keys)

        for k in keys:
            self.it += 1

            if self.it > momIncrease:
                mom = self.momentum

            mb_data = data_dict[k]
            if mb_data.shape[1] > self.maxBatch:
                logging.info("SKIPPING utt exceeds batch length\
                        (Utterance length %d)." % mb_data.shape[1])
                continue

            mb_labels = np.array(alis[k],dtype=np.int32)

            if mb_data.shape[1] < mb_labels.shape[0]:
                logging.info("SKIPPING utt frames less than label length "
                       "(Utterance length %d, Num Labels %d)."
                       ""%(mb_data.shape[1],mb_labels.shape[0]))
                continue


            if self.optimizer == 'nesterov':
                # w = w+mom*velocity (evaluate gradient at future point)
                self.model.updateParams(mom,self.velocity)

            cost,grad,skip = self.model.costAndGrad(mb_data,mb_labels)

            if self.optimizer == 'nesterov':
                # undo update
                # w = w-mom*velocity
                self.model.updateParams(-mom,self.velocity)

            # Compute norm of all parameters as one vector
            gnorm = 0.0
            for dw,db in grad:
                gnorm += dw.euclid_norm()**2
                gnorm += db.euclid_norm()**2
            gnorm = np.sqrt(gnorm)

	    if skip:
		logging.info("SKIPPING: Key=%s, Cost=%f, SeqLen=%d, NumFrames=%d."%(k, cost,mb_labels.shape[0],mb_data.shape[1]))
		continue

            if np.isfinite(cost):
                # compute exponentially weighted cost
                if self.it > 1:
                    self.expcost.append(.01*cost + .99*self.expcost[-1])
                else:
                    self.expcost.append(cost)
                self.costt.append(cost)

                if self.model.reg > 0.0:
                    rc = self.model.regcost
                    if len(self.regcost) > 0:
                        self.regcost.append(0.01*rc + 0.99*self.regcost[-1])
                    else:
                        self.regcost.append(rc)

            # velocity = mom*velocity - alpha*grad
            if self.optimizer == 'nesterov':
                alph = self.alpha
                if gnorm > self.maxGNorm:
                    alph *= (self.maxGNorm/gnorm)
                for vs,gs in zip(self.velocity,grad):
                    vw,vb = vs 
                    dw,db = gs
                    vw.mult(mom)
                    vb.mult(mom)
                    vw.add_mult(dw,alpha=-alph)
                    vb.add_mult(db,alpha=-alph)
                update = self.velocity
                scale = 1.0

            elif self.optimizer == 'adagrad':
                delta = 1e-10
                for gts,gs in zip(self.gradt,grad):
                    dwt,dbt = gts 
                    dw,db = gs
                    gamma = 1. - 1./(1e-2*self.it+1)
                    cm.add_pow(dwt,dw,2,alpha=gamma,target=dwt)
                    cm.add_pow(dbt,db,2,alpha=gamma,target=dbt)
                    dwt.add(delta,target=dwt)
                    dbt.add(delta,target=dbt)
                    cm.mult_pow(dw,dwt,-0.5,target=dw)
                    cm.mult_pow(db,dbt,-0.5,target=db)
                    dwt.add(-delta,target=dwt)
                    dbt.add(-delta,target=dbt)
                update = grad
                scale = -self.alpha

	    # update params
	    self.model.updateParams(scale,update)

            if self.it%1 == 0:
                print ("Iter %d : Cost=%.4f, ExpCost=%.4f, GradNorm=%.4f, "
                       "SeqLen=%d, NumFrames=%d.")%(self.it,cost,
                       self.expcost[-1],gnorm,mb_labels.shape[0],
                       mb_data.shape[1])

Example 48

Project: pythonect
Source File: eval.py
View license
def eval(source, globals_={}, locals_={}):
    """Evaluate Pythonect code in the context of globals and locals.

    Args:
        source: A string representing a Pythonect code or a networkx.DiGraph() as
            returned by parse()
        globals: A dictionary.
        locals: Any mapping.

    Returns:
        The return value is the result of the evaluated code.

    Raises:
        SyntaxError: An error occurred parsing the code.
    """

    return_value = None

    # Meaningful program?

    if source != "pass":

        logging.info('Program is meaningful')

        return_value = []

        return_values = []

        globals_values = []

        locals_values = []

        tasks = []

        reduces = {}

        logging.debug('Evaluating %s with globals_ = %s and locals_ %s' % (source, globals_, locals_))

        if not isinstance(source, networkx.DiGraph):

            logging.info('Parsing program...')

            graph = parse(source)

        else:

            logging.info('Program is already parsed! Using source AS IS')

            graph = source

        root_nodes = sorted([node for node, degree in graph.in_degree().items() if degree == 0])

        if not root_nodes:

            cycles = networkx.simple_cycles(graph)

            if cycles:

                logging.info('Found cycles: %s in graph, using nodes() 1st node (i.e. %s) as root node' % (cycles, graph.nodes()[0]))

                root_nodes = [graph.nodes()[0]]

        logging.info('There are %d root node(s)' % len(root_nodes))

        logging.debug('Root node(s) are: %s' % root_nodes)

        # Extend Python's __builtin__ with Pythonect's `lang`

        start_globals_ = __extend_builtins(globals_)

        logging.debug('Initial globals_:\n%s' % pprint.pformat(start_globals_))

        # Default input

        start_globals_['_'] = start_globals_.get('_', locals_.get('_', None))

        logging.info('_ equal %s', start_globals_['_'])

        # Execute Pythonect program

        pool = __create_pool(globals_, locals_)

        # N-1

        for root_node in root_nodes[1:]:

            if globals_.get('__IN_EVAL__', None) is None and not _is_referencing_underscore(graph, root_node):

                # Reset '_'

                globals_['_'] = locals_['_'] = None

            if globals_.get('__IN_EVAL__', None) is None:

                globals_['__IN_EVAL__'] = True

            temp_globals_ = copy.copy(globals_)

            temp_locals_ = copy.copy(locals_)

            task_result = pool.apply_async(_run, args=(graph, root_node, temp_globals_, temp_locals_, {}, None, False))

            tasks.append((task_result, temp_locals_, temp_globals_))

        # 1

        if globals_.get('__IN_EVAL__', None) is None and not _is_referencing_underscore(graph, root_nodes[0]):

            # Reset '_'

            globals_['_'] = locals_['_'] = None

        if globals_.get('__IN_EVAL__', None) is None:

            globals_['__IN_EVAL__'] = True

        result = _run(graph, root_nodes[0], globals_, locals_, {}, None, False)

        # 1

        for expr_return_value in result:

            globals_values.append(globals_)

            locals_values.append(locals_)

            return_values.append([expr_return_value])

        # N-1

        for (task_result, task_locals_, task_globals_) in tasks:

            return_values.append(task_result.get())

            locals_values.append(task_locals_)

            globals_values.append(task_globals_)

        # Reduce + _PythonectResult Grouping

        for item in return_values:

            # Is there _PythonectResult in item list?

            for sub_item in item:

                if isinstance(sub_item, _PythonectResult):

                    # 1st Time?

                    if sub_item.values['node'] not in reduces:

                        reduces[sub_item.values['node']] = []

                        # Add Place holder to mark the position in the return value list

                        return_value.append(_PythonectLazyRunner(sub_item.values['node']))

                    reduces[sub_item.values['node']] = reduces[sub_item.values['node']] + [sub_item.values]

                else:

                    return_value.append(sub_item)

        # Any _PythonectLazyRunner's?

        if reduces:

            for return_item_idx in xrange(0, len(return_value)):

                if isinstance(return_value[return_item_idx], _PythonectLazyRunner):

                    # Swap list[X] with list[X.go(reduces)]

                    return_value[return_item_idx] = pool.apply_async(return_value[return_item_idx].go, args=(graph, reduces))

            return_value = __resolve_and_merge_results(return_value)

        # [...] ?

        if return_value:

            # Single return value? (e.g. [1])

            if len(return_value) == 1:

                return_value = return_value[0]

            # Update globals_ and locals_

#            globals_, locals_ = __merge_all_globals_and_locals(globals_, locals_, globals_values, {}, locals_values, {})

        # Set `return value` as `_`

        globals_['_'] = locals_['_'] = return_value

        if globals_.get('__IN_EVAL__', None) is not None:

            del globals_['__IN_EVAL__']

        pool.close()

        pool.join()

        pool.terminate()

    return return_value

Example 49

Project: nmn2
Source File: geo.py
View license
    def __init__(self, config, set_name, modules):
        if set_name == VAL:
            self.data = []
            return

        questions = []
        answers = []
        parse_lists = []
        worlds = []

        if config.quant:
            ANSWER_INDEX.index(YES)
            ANSWER_INDEX.index(NO)

        for i_env, environment in enumerate(ENVIRONMENTS):
            if i_env == config.fold and set_name == TRAIN:
                continue
            if i_env != config.fold and set_name == TEST:
                continue

            places = list()
            with open(LOCATION_FILE % environment) as loc_f:
                for line in loc_f:
                    parts = line.strip().split(";")
                    places.append(parts[0])

            cats = {place: np.zeros((len(CATS),)) for place in places}
            rels = {(pl1, pl2): np.zeros((len(RELS),)) for pl1 in places for pl2 in places}

            with open(WORLD_FILE % environment) as world_f:
                for line in world_f:
                    parts = line.strip().split(";")
                    if len(parts) < 2:
                        continue
                    name = parts[0][1:]
                    places_here = parts[1].split(",")
                    if name in CATS:
                        cat_id = CATS.index(name)
                        for place in places_here:
                            cats[place][cat_id] = 1
                    elif name in RELS:
                        rel_id = RELS.index(name)
                        for place_pair in places_here:
                            pl1, pl2 = place_pair.split("#")
                            rels[pl1, pl2][rel_id] = 1
                            rels[pl2, pl1][rel_id] = -1

            clean_places = [p.lower().replace(" ", "_") for p in places]
            place_index = {place: i for (i, place) in enumerate(places)}
            clean_place_index = {place: i for (i, place) in enumerate(clean_places)}
            
            cat_features = np.zeros((len(CATS), DATABASE_SIZE, 1))
            rel_features = np.zeros((len(RELS), DATABASE_SIZE, DATABASE_SIZE))

            for p1, i_p1 in place_index.items():
                cat_features[:, i_p1, 0] = cats[p1]
                for p2, i_p2 in place_index.items():
                    rel_features[:, i_p1, i_p2] = rels[p1, p2]

            world = World(environment, clean_place_index, cat_features, rel_features)

            for place in clean_places:
                ANSWER_INDEX.index(place)

            with open(DATA_FILE % environment) as data_f:
                for line in data_f:
                    line = line.strip()
                    if line == "" or line[0] == "#":
                        continue

                    parts = line.split(";")

                    question = parts[0]
                    if question[-1] != "?":
                        question += " ?"
                    question = question.lower()
                    questions.append(question)

                    answer = parts[1].lower().replace(" ", "_")
                    if config.quant and question.split()[0] in ("is", "are"):
                        answer = YES if answer else NO
                    answers.append(answer)

                    worlds.append(world)

            with open(PARSE_FILE % environment) as parse_f:
                for line in parse_f:
                    parse_strs = line.strip().split(";")
                    trees = [parse_tree(s) for s in parse_strs]
                    if not config.quant:
                        trees = [t for t in trees if t[0] != "exists"]
                    parse_lists.append(trees)

        assert len(questions) == len(parse_lists)

        data = []
        i_datum = 0
        for question, answer, parse_list, world in \
                zip(questions, answers, parse_lists, worlds):
            tokens = ["<s>"] + question.split() + ["</s>"]

            parse_list = parse_list[-config.k_best_parses:]

            indexed_question = [QUESTION_INDEX.index(w) for w in tokens]
            indexed_answer = \
                    tuple(ANSWER_INDEX[a] for a in answer.split(",") if a != "")
            assert all(a is not None for a in indexed_answer)
            layouts = [parse_to_layout(p, world, config, modules) for p in parse_list]

            data.append(GeoDatum(
                    i_datum, indexed_question, parse_list, layouts, indexed_answer, world))
            i_datum += 1

        self.data = data

        logging.info("%s:", set_name)
        logging.info("%s items", len(self.data))
        logging.info("%s words", len(QUESTION_INDEX))
        logging.info("%s functions", len(MODULE_INDEX))
        logging.info("%s answers", len(ANSWER_INDEX))

Example 50

Project: melosynth
Source File: melosynth.py
View license
def melosynth(inputfile, outputfile, fs, nHarmonics, square, useneg):
    '''
    Load pitch sequence from  a txt/csv file and synthesize it into a .wav

    :parameters:
    - inputfile : str
    Path to input file containing the pitch sequence.

    - outputfile: str
    Path to output wav file. If outputfile is None a file will be
    created with the same path/name as inputfile but ending with
    "_melosynth.wav"

    - fs : int
    Sampling frequency for the synthesized file.

    - nHarmonics : int
    Number of harmonics (including the fundamental) to use in the synthesis
    (default is 1). As the number is increased the wave will become more
    sawtooth-like.

    - square : bool
    When set to true, the waveform will converge to a square wave instead of
    a sawtooth as the number of harmonics is increased.

    - useneg : bool
    By default, negative frequency values (unvoiced frames) are synthesized as
    silence. If useneg is set to True, these frames will be synthesized using
    their absolute values (i.e. as voiced frames).
    '''

    # Preprocess input parameters
    fs = int(float(fs))
    nHarmonics = int(nHarmonics)
    if outputfile is None:
        outputfile = inputfile[:-4] + "_melosynth.wav"

    # Load pitch sequence
    logging.info('Loading data...')
    times, freqs = loadmel(inputfile)

    # Preprocess pitch sequence
    if useneg:
        freqs = np.abs(freqs)
    else:
        freqs[freqs < 0] = 0
    # Impute silence if start time > 0
    if times[0] > 0:
        estimated_hop = np.median(np.diff(times))
        prev_time = max(times[0] - estimated_hop, 0)
        times = np.insert(times, 0, prev_time)
        freqs = np.insert(freqs, 0, 0)


    logging.info('Generating wave...')
    signal = []

    translen = 0.010 # duration (in seconds) for fade in/out and freq interp
    phase = np.zeros(nHarmonics) # start phase for all harmonics
    f_prev = 0 # previous frequency
    t_prev = 0 # previous timestamp
    for t, f in zip(times, freqs):

        # Compute number of samples to synthesize
        nsamples = np.round((t - t_prev) * fs)

        if nsamples > 0:
            # calculate transition length (in samples)
            translen_sm = float(min(np.round(translen*fs), nsamples))

            # Generate frequency series
            freq_series = np.ones(nsamples) * f_prev

            # Interpolate between non-zero frequencies
            if f_prev > 0 and f > 0:
                freq_series += np.minimum(np.arange(nsamples)/translen_sm, 1) *\
                               (f - f_prev)
            elif f > 0:
                freq_series = np.ones(nsamples) * f

            # Repeat for each harmonic
            samples = np.zeros(nsamples)
            for h in range(nHarmonics):
                # Determine harmonic num (h+1 for sawtooth, 2h+1 for square)
                hnum = 2*h+1 if square else h+1
                # Compute the phase of each sample
                phasors = 2 * np.pi * (hnum) * freq_series / float(fs)
                phases = phase[h] + np.cumsum(phasors)
                # Compute sample values and add
                samples += np.sin(phases) / (hnum)
                # Update phase
                phase[h] = phases[-1]

            # Fade in/out and silence
            if f_prev == 0 and f > 0:
                samples *= np.minimum(np.arange(nsamples)/translen_sm, 1)
            if f_prev > 0 and f == 0:
                samples *= np.maximum(1 - (np.arange(nsamples)/translen_sm), 0)
            if f_prev == 0 and f == 0:
                samples *= 0

            # Append samples
            signal.extend(samples)

        t_prev = t
        f_prev = f

    # Normalize signal
    signal = np.asarray(signal)
    signal *= 0.8 / float(np.max(signal))

    logging.info('Saving wav file...')
    wavwrite(np.asarray(signal), outputfile, fs)